-
Notifications
You must be signed in to change notification settings - Fork 406
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[XLA:GPU] propagate the algorithm flag of dot op to cublasGemm custom…
… call. we have the algorithm flag of dot op. we handle it in triton emitter, now let's push it to cublas via gemm_rewriter. Otherwise the cublas call uses the default f32_f32_f32 algorithm and loses the competition with triton. As a result of this change it get clear that only Ampere ran bf16 version of cublas kernel. Hopper uses tf32 for that because it does not have the b16 version for this case. DotBF16ForBf16Bf16F32Tests was removed because the algorithm BF16_BF16_F32 expects F32 input and F32 output with the BF16 arithmetics inside cublas. PiperOrigin-RevId: 679595014
- Loading branch information
1 parent
3ccf821
commit 31b180c
Showing
9 changed files
with
261 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ | ||
#define XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
namespace xla::gpu { | ||
|
||
// In some cases we need to know what exact kernel was used. It happens when we | ||
// have no direct way to get this information from the HLO. For example, when we | ||
// have a fusion with a custom call to cuBLAS or another third party library. | ||
// This class allows to get the name of the kernel that was used. | ||
class KernelNameTracer { | ||
public: | ||
static std::unique_ptr<KernelNameTracer> Create(); | ||
|
||
virtual void start() = 0; | ||
virtual std::string stop() = 0; | ||
virtual ~KernelNameTracer() = default; | ||
}; | ||
|
||
} // namespace xla::gpu | ||
|
||
#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <memory> | ||
|
||
#include "xla/backends/profiler/gpu/cupti_collector.h" | ||
#include "xla/backends/profiler/gpu/cupti_tracer.h" | ||
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" | ||
#include "tsl/profiler/utils/time_utils.h" | ||
|
||
namespace xla::gpu { | ||
|
||
// This class allows to get the name of the kernel that was used. | ||
// It works only on CUDA. It uses CuptiTracer to get the kernel name. | ||
class KernelNameTracerCuda : public KernelNameTracer { | ||
public: | ||
KernelNameTracerCuda() | ||
: cupti_tracer_(profiler::CuptiTracer::GetCuptiTracerSingleton()) {} | ||
|
||
void start() override; | ||
|
||
// As of now it returns the name of the first kernel that was executed on | ||
// GPU:0. | ||
std::string stop() override; | ||
|
||
private: | ||
std::unique_ptr<profiler::CuptiTracer> cupti_tracer_; | ||
std::unique_ptr<profiler::CuptiTraceCollector> cupti_collector_; | ||
}; | ||
|
||
std::unique_ptr<KernelNameTracer> KernelNameTracer::Create() { | ||
return std::make_unique<KernelNameTracerCuda>(); | ||
} | ||
|
||
void KernelNameTracerCuda::start() { | ||
profiler::CuptiTracerCollectorOptions collector_options; | ||
collector_options.num_gpus = profiler::CuptiTracer::NumGpus(); | ||
auto start_gputime_ns = profiler::CuptiTracer::GetTimestamp(); | ||
auto start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); | ||
cupti_collector_ = profiler::CreateCuptiCollector( | ||
collector_options, start_walltime_ns, start_gputime_ns); | ||
profiler::CuptiTracerOptions options; | ||
options.activities_selected = {CUPTI_ACTIVITY_KIND_KERNEL}; | ||
cupti_tracer_->Enable(options, cupti_collector_.get()); | ||
} | ||
|
||
std::string KernelNameTracerCuda::stop() { | ||
cupti_tracer_->Disable(); | ||
uint64_t end_gpu_ns = cupti_collector_->GetTracingEndTimeNs(); | ||
auto space = std::make_unique<tensorflow::profiler::XSpace>(); | ||
cupti_collector_->Export(space.get(), end_gpu_ns); | ||
for (const auto& plane : space->planes()) { | ||
if (plane.name() == "/device:GPU:0") { | ||
return plane.event_metadata().at(1).name(); | ||
} | ||
} | ||
return ""; | ||
} | ||
|
||
} // namespace xla::gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" | ||
|
||
namespace xla::gpu { | ||
|
||
class KernelNameTracerNoop : public KernelNameTracer { | ||
public: | ||
void start() override {}; | ||
std::string stop() override { return "kernel_name_tracer_not_implemented"; }; | ||
}; | ||
|
||
std::unique_ptr<KernelNameTracer> KernelNameTracer::Create() { | ||
return std::make_unique<KernelNameTracerNoop>(); | ||
} | ||
|
||
} // namespace xla::gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.