Skip to content

Commit

Permalink
PR #10510: [ROCm] Triton in XLA for ROCm - gemm rewriter triton relat…
Browse files Browse the repository at this point in the history
…ed changes.

Imported from GitHub PR openxla/xla#10510

First commit of the series for enabling Triton in XLA for ROCm .
Copybara import of the project:

--
832d7253db1f8972862252f034b216d1eed1da29 by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - gemm rewriter triton related changes.

Merging this change closes #10510

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10510 from ROCm:rocm_triton_backend 832d7253db1f8972862252f034b216d1eed1da29
PiperOrigin-RevId: 616862111
  • Loading branch information
zoranjovanovic-ns authored and tensorflower-gardener committed Mar 18, 2024
1 parent b9752df commit 12bd258
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
57 changes: 45 additions & 12 deletions third_party/xla/xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,14 @@ class GemmFusionVisitor : public DfsHloRewriteVisitor {
// If a GEMM requiring padding for cuBLAS is encountered here this
// happened because earlier ShouldTritonHandleGEMM() accepted it and padding
// was skipped. Accept it ignoring profitability checks.
if (!CublasRequiresPadding(*Cast<HloDotInstruction>(dot), gpu_version_) &&
!should_fuse) {
return absl::OkStatus();
// TODO(rocm): check ROCM padding requirements.
if (std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
if (!CublasRequiresPadding(
*Cast<HloDotInstruction>(dot),
std::get<se::CudaComputeCapability>(gpu_version_)) &&
!should_fuse) {
return OkStatus();
}
}

HloComputation* computation =
Expand Down Expand Up @@ -776,17 +781,35 @@ absl::StatusOr<bool> RunOnComputation(
return visitor.changed();
}

bool IsSupportedByTriton(
PrecisionConfig::Algorithm algorithm,
const se::CudaComputeCapability& cuda_compute_capability) {
bool IsSupportedByTriton(PrecisionConfig::Algorithm algorithm,
const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
return true;

if (cuda_compute_capability) {
return true;
}
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeastAmpere();
}
return false;
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return cuda_compute_capability.IsAtLeastAmpere();
if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeastAmpere();
}
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;

// TODO(b/326579472): Fix the support of this algorithm and maybe allow it
// here.
Expand All @@ -804,8 +827,12 @@ FusionDecision CanTritonHandleGEMM(
const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);

if (!cuda_compute_capability) return "Non CUDA device.";
if (!cuda_compute_capability && !rocm_compute_capability) {
return "Non CUDA or ROCM device.";
}

if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) {
if (!tsl::tensor_float_32_execution_enabled() ||
Expand All @@ -826,8 +853,14 @@ FusionDecision CanTritonHandleGEMM(
case F32:
return true;
case BF16:
return cuda_compute_capability->IsAtLeast(
stream_executor::CudaComputeCapability::AMPERE);
if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeast(
stream_executor::CudaComputeCapability::AMPERE);
}
if (rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
default:
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
std::move(std::get<DimOrdersAndReqs>(result_or_error));
int fusion_level =
hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level();
if (!std::get<se::CudaComputeCapability>(gpu_version)
// TODO(ROCm): Check fusion level for ROCm.
if (std::holds_alternative<se::CudaComputeCapability>(gpu_version) &&
!std::get<se::CudaComputeCapability>(gpu_version)
.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
fusion_level = std::min(fusion_level, 1);
}
Expand Down

0 comments on commit 12bd258

Please sign in to comment.