Skip to content

Commit

Permalink
[XLA:GPU] Add support for TF32_TF32_F32_X3 algorithm
Browse files Browse the repository at this point in the history
Triton supports this algorithm directly.
cuBLAS does not support it and we do the fallback to f32_f32_f32.

The Triton version is slower than the cublas fallback and autotuner selects cublas.

PiperOrigin-RevId: 679638595
  • Loading branch information
loislo authored and Google-ML-Automation committed Sep 28, 2024
1 parent 0f30f33 commit aa5f2f6
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 38 deletions.
3 changes: 3 additions & 0 deletions xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ absl::StatusOr<se::blas::ComputationType> GetBlasComputationType(
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
case PrecisionConfig::ALG_DOT_F16_F16_F32:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
return se::blas::ComputationType::kF32;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
return se::blas::ComputationType::kTF32AsF32;
Expand Down Expand Up @@ -106,6 +107,7 @@ bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm) {
case PrecisionConfig::ALG_DOT_F64_F64_F64:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
return true;
Expand Down Expand Up @@ -188,6 +190,7 @@ bool IsSupportedDotAlgorithmOnGpu(
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == F32 && output_storage_type == F32;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == F32 && output_storage_type == F32;
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ absl::StatusOr<std::unique_ptr<HloModule>> CublasGemmAutotuneExtractor(
if (dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) {
PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) {
dot->mutable_precision_config()->set_algorithm(
PrecisionConfig::ALG_DOT_F32_F32_F32);
}
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class KernelNameTracerCuda : public KernelNameTracer {
std::string stop() override;

private:
std::unique_ptr<profiler::CuptiTracer> cupti_tracer_;
profiler::CuptiTracer* cupti_tracer_; // Not owned.
std::unique_ptr<profiler::CuptiTraceCollector> cupti_collector_;
};

Expand Down
39 changes: 23 additions & 16 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,27 @@ bool IsTf32Allowed(const HloDotInstruction* dot_instr) {
return algorithm_util::HasTf32InputType(algorithm);
}

mt::InputPrecision InferDotPrecision(const HloDotInstruction* dot_instr) {
auto algorithm = dot_instr->precision_config().algorithm();
if (algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) {
return mt::InputPrecision::TF32x3;
}
// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32.
bool is_unsupported_bitwidth =
HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) {
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
int in_width =
primitive_util::BitWidth(node->operand(0)->shape().element_type());
return in_width <= 8 && node->shape().element_type() == F32;
});

return IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth
? mt::InputPrecision::TF32
: mt::InputPrecision::IEEE;
}

bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr,
mlir::OpBuilder& builder, Value dot_input_lhs,
Value dot_input_rhs,
Expand Down Expand Up @@ -2545,17 +2566,6 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
const HloInstruction* root = dot_instr->parent()->root_instruction();
TF_RET_CHECK(!root->shape().IsTuple());

// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32.
bool is_unsupported_bitwidth =
HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) {
if (node->opcode() != HloOpcode::kConvert) {
return false;
}
int in_width =
primitive_util::BitWidth(node->operand(0)->shape().element_type());
return in_width <= 8 && node->shape().element_type() == F32;
});

// We'll be creating a lot of instructions from a single dot, use an
// implicit loc builder so we don't have to pass around the location all the
// time.
Expand Down Expand Up @@ -2708,10 +2718,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
// maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a
// lower precision than the output type. The change was introduced here:
// https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a
auto input_precision =
IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth
? mt::InputPrecision::TF32
: mt::InputPrecision::IEEE;
auto dot_precision = InferDotPrecision(dot_instr);

// Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32.
if (dot_instr->precision_config().algorithm() ==
Expand All @@ -2731,7 +2738,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
IsFp8Matmul(dot_instr) ? std::numeric_limits<int>::max() : 0;
accumulator_next =
b.create<mt::DotOp>(dot_input_lhs, dot_input_rhs, iter_args.back(),
/*inputPrecision=*/input_precision,
/*inputPrecision=*/dot_precision,
/*maxNumImpreciseAcc=*/max_num_imprecise_acc);
}
iter_args_next.push_back(accumulator_next);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest {
}
};

class TritonBF16BF16F32BlasTest : public TritonTest {
class BlasAlgorithmTest : public TritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
Expand All @@ -158,20 +158,16 @@ class TritonBF16BF16F32BlasTest : public TritonTest {
debug_options.set_xla_gpu_enable_triton_gemm(false);
return debug_options;
}

protected:
void SetUp() override {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
}
};

TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) {
TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32) {
// We check that the algorithm is propagated to the BLAS call.
// We also check that the kernel name matches the algorithm for Ampere.
// The algorithm for Hopper is not the one we expect because it uses TF32.

if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
constexpr std::string_view kHloText = R"(
HloModule t
Expand All @@ -186,6 +182,8 @@ TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) {
)";
const std::string pattern = R"(CHECK: "algorithm":"ALG_DOT_BF16_BF16_F32")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
ASSERT_TRUE(ok);

auto tracer = KernelNameTracer::Create();
tracer->start();
Expand Down Expand Up @@ -216,6 +214,57 @@ TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) {
}
}

TEST_F(BlasAlgorithmTest, Algorithm_TF32_TF32_F32_X3) {
// We check that the algorithm is propagated to the BLAS call.
// We also check that the kernel name matches the algorithm for Ampere.

constexpr std::string_view kHloText = R"(
HloModule t
ENTRY main {
lhs = f32[8512,256]{1,0} parameter(0)
rhs = f32[256,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_tf32_tf32_f32_x3,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
const std::string pattern =
R"(CHECK: "algorithm":"ALG_DOT_TF32_TF32_F32_X3")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
ASSERT_TRUE(ok);

auto tracer = KernelNameTracer::Create();
tracer->start();
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false));
auto kernel_name = tracer->stop();

if (kernel_name == "kernel_name_tracer_not_implemented") return;

auto cc = GetCudaComputeCapability();
using CudaComputeCapabilities =
stream_executor::CudaComputeCapability::CudaComputeCapabilities;
switch (cc.major) {
case CudaComputeCapabilities::BLACKWELL:
GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: "
<< kernel_name;
break;
case CudaComputeCapabilities::AMPERE:
// There is no support for TF32_TF32_F32_X3 on Ampere. We use F32_F32_F32.
EXPECT_THAT(kernel_name, ::testing::HasSubstr("ampere_sgemm_128x64_nn"));
break;
case CudaComputeCapabilities::HOPPER:
// There is no support for TF32_TF32_F32_X3 on Hopper. We use F32_F32_F32.
EXPECT_THAT(kernel_name, ::testing::HasSubstr("gemm_f32f32_f32f32_f32"));
break;
default:
GTEST_SKIP() << "Unsupported compute capability: " << cc.major
<< " has the kernel name: " << kernel_name;
}
}

TEST_F(TritonGemmTest, RejectDotInt4HLO) {
constexpr std::string_view kHloText = R"(
HloModule t
Expand Down Expand Up @@ -5235,7 +5284,7 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32
/*arel=*/1e-5}));
}

class TritonBF16BF16F32GemmTest : public TritonTest {
class TritonAlgorithmTest : public TritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
Expand All @@ -5246,23 +5295,39 @@ class TritonBF16BF16F32GemmTest : public TritonTest {
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
return debug_options;
}
};

protected:
void SetUp() override {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32_X3) {
const std::string kHloText = R"(
HloModule t
ENTRY main {
lhs = f32[8512,64]{1,0} parameter(0)
rhs = f32[64,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_tf32_tf32_f32_x3,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
}
};
)";
const std::string pattern =
R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
EXPECT_TRUE(ok);
}

TEST_F(TritonBF16BF16F32GemmTest, WorkWithF32InputAndAlgorithm_BF16_BF16_F32) {
TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32) {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
const std::string kHloText = R"(
HloModule t
ENTRY main {
lhs = f32[32,64]{1,0} parameter(0)
rhs = f32[64,16]{1,0} parameter(1)
ROOT dot = f32[32,16]{1,0} dot(lhs, rhs),
lhs = f32[8512,64]{1,0} parameter(0)
rhs = f32[64,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_bf16_bf16_f32,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/triton/triton_support_legacy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ bool IsDotAlgorithmSupportedByTriton(
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);
switch (algorithm) {
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
if (cuda_compute_capability) {
return true;
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ absl::StatusOr<Decision> CreateDotFusion(
if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 ||
algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ||
dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
dot.sparse_operands()) {
return Decision::Allow();
Expand Down

0 comments on commit aa5f2f6

Please sign in to comment.