Skip to content

Commit

Permalink
[ROCm] Triton in XLA for ROCm - gemm rewriter triton related changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed Mar 15, 2024
1 parent 28cfdcd commit 832d725
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
45 changes: 38 additions & 7 deletions xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,14 @@ class GemmFusionVisitor : public DfsHloRewriteVisitor {
// If a GEMM requiring padding for cuBLAS is encountered here this
// happened because earlier ShouldTritonHandleGEMM() accepted it and padding
// was skipped. Accept it ignoring profitability checks.
if (!CublasRequiresPadding(*Cast<HloDotInstruction>(dot), gpu_version_) &&
!should_fuse) {
return absl::OkStatus();
// TODO(rocm): check ROCM padding requirements.
if(std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
if (!CublasRequiresPadding(
*Cast<HloDotInstruction>(dot),
std::get<se::CudaComputeCapability>(gpu_version_)) &&
!should_fuse) {
return OkStatus();
}
}

HloComputation* computation =
Expand Down Expand Up @@ -753,15 +758,31 @@ absl::StatusOr<bool> RunOnComputation(

bool IsSupportedByTriton(
PrecisionConfig::Algorithm algorithm,
const se::CudaComputeCapability& cuda_compute_capability) {
const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
if(rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return true;

case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
if(rocm_compute_capability) {
return false;
}
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return cuda_compute_capability.IsAtLeastAmpere();
if(rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
else if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeastAmpere();
}
return false;

// TODO(b/326579472): Fix the support of this algorithm and maybe allow it
// here.
Expand All @@ -779,8 +800,12 @@ FusionDecision CanTritonHandleGEMM(
const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);

if (!cuda_compute_capability) return "Non CUDA device.";
if (!cuda_compute_capability && !rocm_compute_capability) {
return "Non CUDA or ROCM device.";
}

if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) {
if (!tsl::tensor_float_32_execution_enabled() ||
Expand All @@ -801,8 +826,14 @@ FusionDecision CanTritonHandleGEMM(
case F32:
return true;
case BF16:
return cuda_compute_capability->IsAtLeast(
if(cuda_compute_capability) {
return cuda_compute_capability->IsAtLeast(
stream_executor::CudaComputeCapability::AMPERE);
}
else if(rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
default:
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
std::move(std::get<DimOrdersAndReqs>(result_or_error));
int fusion_level =
hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level();
if (!std::get<se::CudaComputeCapability>(gpu_version)
//TODO(ROCm) Check fusion level for ROCm.
if (std::holds_alternative<se::CudaComputeCapability>(gpu_version)
&& !std::get<se::CudaComputeCapability>(gpu_version)
.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
fusion_level = std::min(fusion_level, 1);
}
Expand Down

0 comments on commit 832d725

Please sign in to comment.