Skip to content

Commit

Permalink
[XLA:GPU] propagate the algorithm flag of dot op to cublasGemm custom…
Browse files Browse the repository at this point in the history
… call.

we have the algorithm flag of dot op. we handle it in triton emitter, now let's push it to cublas via gemm_rewriter. Otherwise the cublas call uses the default f32_f32_f32 algorithm and loses the competition with triton.

As a result of this change it get clear that only Ampere ran bf16 version of cublas kernel. Hopper uses tf32 for that because it does not have the b16 version for this case.

DotBF16ForBf16Bf16F32Tests was removed because the algorithm BF16_BF16_F32 expects F32 input and F32 output with the BF16 arithmetics inside cublas.

PiperOrigin-RevId: 679595014
  • Loading branch information
loislo authored and Google-ML-Automation committed Sep 27, 2024
1 parent 3ccf821 commit 31b180c
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 24 deletions.
3 changes: 2 additions & 1 deletion xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ absl::StatusOr<se::blas::ComputationType> GetBlasComputationType(
switch (algorithm) {
case PrecisionConfig::ALG_DOT_F16_F16_F16:
return se::blas::ComputationType::kF16;
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
return se::blas::ComputationType::kBF16AsF32;
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32:
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
case PrecisionConfig::ALG_DOT_F16_F16_F32:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
return se::blas::ComputationType::kF32;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
Expand Down
9 changes: 0 additions & 9 deletions xla/service/gpu/dot_algorithm_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,6 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest,
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32),
Values(BF16), Values(BF16, F32),
Values(CC(8, 0)),
Values(SemanticVersion{6, 0, 0}),
Values(BackendRestriction::kNoRestriction),
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32),
Values(F32), Values(CC(8, 0)),
Expand Down
39 changes: 35 additions & 4 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
load("//xla:xla.bzl", "xla_cc_test")
Expand Down Expand Up @@ -206,6 +207,7 @@ xla_test(
"no_mac",
],
deps = [
":kernel_name_tracer",
":triton_fusion_emitter",
":triton_test_utils",
"//xla:autotuning_proto_cc",
Expand All @@ -223,21 +225,17 @@ xla_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:filecheck",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:path",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
Expand Down Expand Up @@ -286,6 +284,37 @@ xla_test(
],
)

cc_library(
name = "kernel_name_tracer_cuda",
testonly = True,
srcs = if_cuda(["kernel_name_tracer_cuda.cc"]),
hdrs = ["kernel_name_tracer.h"],
tags = ["manual"], # Need to exclude this from wildcard builds
deps = [
"//xla/backends/profiler/gpu:cupti_collector",
"//xla/backends/profiler/gpu:cupti_tracer",
"@tsl//tsl/profiler/utils:time_utils",
],
)

cc_library(
name = "kernel_name_tracer_noop",
testonly = True,
srcs = ["kernel_name_tracer_noop.cc"],
hdrs = ["kernel_name_tracer.h"],
tags = ["manual"], # Need to exclude this from wildcard builds
)

cc_library(
name = "kernel_name_tracer",
testonly = True,
hdrs = ["kernel_name_tracer.h"],
deps = if_cuda(
[":kernel_name_tracer_cuda"],
[":kernel_name_tracer_noop"],
),
)

cc_library(
name = "triton_test_utils",
testonly = True,
Expand Down Expand Up @@ -321,6 +350,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/utils:time_utils",
],
)

Expand Down Expand Up @@ -479,6 +509,7 @@ xla_test(
],
tags = ["no_mac"],
deps = [
":kernel_name_tracer",
":triton_fusion_emitter",
":triton_support",
":triton_test_utils",
Expand Down
39 changes: 39 additions & 0 deletions xla/service/gpu/fusions/triton/kernel_name_tracer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_
#define XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_

#include <memory>
#include <string>

namespace xla::gpu {

// In some cases we need to know what exact kernel was used. It happens when we
// have no direct way to get this information from the HLO. For example, when we
// have a fusion with a custom call to cuBLAS or another third party library.
// This class allows to get the name of the kernel that was used.
class KernelNameTracer {
public:
static std::unique_ptr<KernelNameTracer> Create();

virtual void start() = 0;
virtual std::string stop() = 0;
virtual ~KernelNameTracer() = default;
};

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_
72 changes: 72 additions & 0 deletions xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "xla/backends/profiler/gpu/cupti_collector.h"
#include "xla/backends/profiler/gpu/cupti_tracer.h"
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"
#include "tsl/profiler/utils/time_utils.h"

namespace xla::gpu {

// This class allows to get the name of the kernel that was used.
// It works only on CUDA. It uses CuptiTracer to get the kernel name.
class KernelNameTracerCuda : public KernelNameTracer {
public:
KernelNameTracerCuda()
: cupti_tracer_(profiler::CuptiTracer::GetCuptiTracerSingleton()) {}

void start() override;

// As of now it returns the name of the first kernel that was executed on
// GPU:0.
std::string stop() override;

private:
std::unique_ptr<profiler::CuptiTracer> cupti_tracer_;
std::unique_ptr<profiler::CuptiTraceCollector> cupti_collector_;
};

std::unique_ptr<KernelNameTracer> KernelNameTracer::Create() {
return std::make_unique<KernelNameTracerCuda>();
}

void KernelNameTracerCuda::start() {
profiler::CuptiTracerCollectorOptions collector_options;
collector_options.num_gpus = profiler::CuptiTracer::NumGpus();
auto start_gputime_ns = profiler::CuptiTracer::GetTimestamp();
auto start_walltime_ns = tsl::profiler::GetCurrentTimeNanos();
cupti_collector_ = profiler::CreateCuptiCollector(
collector_options, start_walltime_ns, start_gputime_ns);
profiler::CuptiTracerOptions options;
options.activities_selected = {CUPTI_ACTIVITY_KIND_KERNEL};
cupti_tracer_->Enable(options, cupti_collector_.get());
}

std::string KernelNameTracerCuda::stop() {
cupti_tracer_->Disable();
uint64_t end_gpu_ns = cupti_collector_->GetTracingEndTimeNs();
auto space = std::make_unique<tensorflow::profiler::XSpace>();
cupti_collector_->Export(space.get(), end_gpu_ns);
for (const auto& plane : space->planes()) {
if (plane.name() == "/device:GPU:0") {
return plane.event_metadata().at(1).name();
}
}
return "";
}

} // namespace xla::gpu
33 changes: 33 additions & 0 deletions xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <string>

#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"

namespace xla::gpu {

class KernelNameTracerNoop : public KernelNameTracer {
public:
void start() override {};
std::string stop() override { return "kernel_name_tracer_not_implemented"; };
};

std::unique_ptr<KernelNameTracer> KernelNameTracer::Create() {
return std::make_unique<KernelNameTracerNoop>();
}

} // namespace xla::gpu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <cstdlib>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
#include <variant>
Expand All @@ -37,6 +38,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h"
#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h"
#include "xla/service/gpu/fusions/triton/triton_test_utils.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
Expand All @@ -46,7 +48,6 @@ limitations under the License.
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/verified_hlo_module.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla.pb.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -147,6 +148,74 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest {
}
};

class TritonBF16BF16F32BlasTest : public TritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
// Do not autotune split-k by default, since this prevents deterministically
// matching the optimized HLO.
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
debug_options.set_xla_gpu_enable_triton_gemm(false);
return debug_options;
}

protected:
void SetUp() override {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
}
};

TEST_F(TritonBF16BF16F32BlasTest, PropagateAlgorithmToBlas) {
// We check that the algorithm is propagated to the BLAS call.
// We also check that the kernel name matches the algorithm for Ampere.
// The algorithm for Hopper is not the one we expect because it uses TF32.

constexpr std::string_view kHloText = R"(
HloModule t
ENTRY main {
lhs = f32[8512,256]{1,0} parameter(0)
rhs = f32[256,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_bf16_bf16_f32,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
const std::string pattern = R"(CHECK: "algorithm":"ALG_DOT_BF16_BF16_F32")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));

auto tracer = KernelNameTracer::Create();
tracer->start();
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false));
auto kernel_name = tracer->stop();

if (kernel_name == "kernel_name_tracer_not_implemented") return;

auto cc = GetCudaComputeCapability();
using CudaComputeCapabilities =
stream_executor::CudaComputeCapability::CudaComputeCapabilities;
switch (cc.major) {
case CudaComputeCapabilities::BLACKWELL:
GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: "
<< kernel_name;
break;
case CudaComputeCapabilities::AMPERE:
EXPECT_THAT(kernel_name, ::testing::HasSubstr("bf16gemm_"));
break;
case CudaComputeCapabilities::HOPPER:
// Hopper does not have bf16 kernels for ALG_DOT_BF16_BF16_F32 algorithm.
// As a result it uses TF32.
EXPECT_THAT(kernel_name, ::testing::HasSubstr("gemm_f32f32_tf32f32_f32"));
break;
default:
GTEST_SKIP() << "Unsupported compute capability: " << cc.major
<< " has the kernel name: " << kernel_name;
}
}

TEST_F(TritonGemmTest, RejectDotInt4HLO) {
constexpr std::string_view kHloText = R"(
HloModule t
Expand Down Expand Up @@ -200,6 +269,7 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) {
rhs_batch_dims={0}
}
)";

const std::string pattern =
R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/fusions/triton/triton_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) {
CHECK(false);
}

absl::Status CreateTritonIrAndFileCheck(
HloTestBase* test, absl::string_view hlo_text,
absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) {
absl::Status CreateTritonIrAndFileCheck(HloTestBase* test,
absl::string_view hlo_text,
absl::string_view triton_fusion_name,
absl::string_view filecheck_pattern) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> verified_module,
test->ParseAndReturnVerifiedModule(hlo_text));
auto* comp = verified_module->GetComputationWithName(triton_fusion_name);
Expand Down
Loading

0 comments on commit 31b180c

Please sign in to comment.