Skip to content

Commit

Permalink
PR #10510: [ROCm] Triton in XLA for ROCm - gemm rewriter triton relat…
Browse files Browse the repository at this point in the history
…ed changes.

Imported from GitHub PR openxla/xla#10510

First commit of the series for enabling Triton in XLA for ROCm .
Copybara import of the project:

--
f680e98aa6908d02021bc0c8bb3d24edfc6b2a5d by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - gemm rewriter triton related changes.

Merging this change closes #10510

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10510 from ROCm:rocm_triton_backend f680e98aa6908d02021bc0c8bb3d24edfc6b2a5d
PiperOrigin-RevId: 615707894
  • Loading branch information
zoranjovanovic-ns authored and tensorflower-gardener committed Mar 14, 2024
1 parent c70e360 commit 3ad5991
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 15 deletions.
3 changes: 2 additions & 1 deletion third_party/xla/xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ cc_library(
"//xla/service:local_service",
"//xla/service:local_service_utils",
"//xla/service/gpu:executable_proto_cc",
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/stream_executor/platform",
"@com_google_absl//absl/status",
"@local_tsl//tsl/platform:casts",
Expand All @@ -253,11 +252,13 @@ cc_library(
"//xla/service/gpu:gpu_compiler",
]) + if_cuda([
"@local_config_cuda//cuda:cuda_headers",
"//xla/stream_executor/cuda:cuda_platform_id",
"//xla/stream_executor/cuda:cuda_activation_header",
"//xla/stream_executor/gpu:gpu_cudamallocasync_allocator",
"//xla/service/gpu:nvptx_compiler_impl",
]) + if_rocm([
"@local_config_rocm//rocm:rocm_headers",
"//xla/stream_executor/rocm:rocm_platform_id",
"//xla/service/gpu:amdgpu_compiler_impl",
]),
alwayslink = True,
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ limitations under the License.
#include "xla/service/hlo_proto_util.h"
#include "xla/service/local_service.h"
#include "xla/service/local_service_utils.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#endif

#if GOOGLE_CUDA
#include "xla/service/gpu/nvptx_compiler.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#elif TENSORFLOW_USE_ROCM
#include "xla/service/gpu/amdgpu_compiler.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#endif

namespace xla {
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/conv_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "third_party/gpus/cudnn/cudnn_version.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -75,6 +74,7 @@ limitations under the License.

#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
#include "third_party/gpus/cudnn/cudnn_version.h"
#if CUDNN_VERSION >= 90000
#include "third_party/gpus/cudnn/cudnn_ops.h"
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ limitations under the License.
#include <string>

#include <gtest/gtest.h>
#include "third_party/gpus/cudnn/cudnn_version.h"
#include "xla/error_spec.h"
#include "xla/stream_executor/device_description.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: keep
#include "third_party/gpus/cudnn/cudnn_version.h"
#endif

#include "xla/service/gpu/tests/gpu_codegen_test.h"
Expand Down
48 changes: 38 additions & 10 deletions third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,9 +704,14 @@ class GemmRewriterTritonVisitor : public DfsHloRewriteVisitor {
// If a GEMM requiring padding for cuBLAS is encountered here this
// happened because earlier ShouldTritonHandleGEMM() accepted it and padding
// was skipped. Accept it ignoring profitability checks.
if (!CublasRequiresPadding(*Cast<HloDotInstruction>(dot), gpu_version_) &&
!should_fuse) {
return absl::OkStatus();
// TODO(rocm): check ROCM padding requirements.
if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
if (!CublasRequiresPadding(
*Cast<HloDotInstruction>(dot),
std::get<se::CudaComputeCapability>(gpu_version_)) &&
!should_fuse) {
return OkStatus();
}
}

HloComputation* computation =
Expand Down Expand Up @@ -752,17 +757,31 @@ absl::StatusOr<bool> RunOnComputation(
return visitor.changed();
}

bool IsSupportedByTriton(
PrecisionConfig::Algorithm algorithm,
const se::CudaComputeCapability& cuda_compute_capability) {
bool IsSupportedByTriton(PrecisionConfig::Algorithm algorithm,
const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return true;

case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
if (rocm_compute_capability) {
return false;
}
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return cuda_compute_capability.IsAtLeastAmpere();
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
} else if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeastAmpere();
}
return false;

// TODO(b/326579472): Fix the support of this algorithm and maybe allow it
// here.
Expand All @@ -780,8 +799,12 @@ FusionDecision CanTritonHandleGEMM(
const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);

if (!cuda_compute_capability) return "Non CUDA device.";
if (!cuda_compute_capability && !rocm_compute_capability) {
return "Non CUDA or ROCM device.";
}

if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) {
if (!tsl::tensor_float_32_execution_enabled() ||
Expand All @@ -802,8 +825,13 @@ FusionDecision CanTritonHandleGEMM(
case F32:
return true;
case BF16:
return cuda_compute_capability->IsAtLeast(
stream_executor::CudaComputeCapability::AMPERE);
if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeast(
stream_executor::CudaComputeCapability::AMPERE);
} else if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
default:
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
std::move(std::get<DimOrdersAndReqs>(result_or_error));
int fusion_level =
hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level();
if (!std::get<se::CudaComputeCapability>(gpu_version)
// TODO(ROCm) Check fusion level for ROCm.
if (std::holds_alternative<se::CudaComputeCapability>(gpu_version) &&
!std::get<se::CudaComputeCapability>(gpu_version)
.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
fusion_level = std::min(fusion_level, 1);
}
Expand Down

0 comments on commit 3ad5991

Please sign in to comment.