Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Triton in XLA for ROCm - gemm rewriter triton related changes. #10510

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,14 @@ class GemmFusionVisitor : public DfsHloRewriteVisitor {
// If a GEMM requiring padding for cuBLAS is encountered here this
// happened because earlier ShouldTritonHandleGEMM() accepted it and padding
// was skipped. Accept it ignoring profitability checks.
if (!CublasRequiresPadding(*Cast<HloDotInstruction>(dot), gpu_version_) &&
!should_fuse) {
return absl::OkStatus();
// TODO(rocm): check ROCM padding requirements.
if(std::holds_alternative<se::CudaComputeCapability>(gpu_version_)) {
if (!CublasRequiresPadding(
*Cast<HloDotInstruction>(dot),
std::get<se::CudaComputeCapability>(gpu_version_)) &&
!should_fuse) {
return OkStatus();
}
}

HloComputation* computation =
Expand Down Expand Up @@ -753,15 +758,31 @@ absl::StatusOr<bool> RunOnComputation(

bool IsSupportedByTriton(
PrecisionConfig::Algorithm algorithm,
const se::CudaComputeCapability& cuda_compute_capability) {
const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
if(rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return true;

case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
if(rocm_compute_capability) {
return false;
}
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return cuda_compute_capability.IsAtLeastAmpere();
if(rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
else if (cuda_compute_capability) {
return cuda_compute_capability->IsAtLeastAmpere();
}
return false;

// TODO(b/326579472): Fix the support of this algorithm and maybe allow it
// here.
Expand All @@ -779,8 +800,12 @@ FusionDecision CanTritonHandleGEMM(
const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version);
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version);

if (!cuda_compute_capability) return "Non CUDA device.";
if (!cuda_compute_capability && !rocm_compute_capability) {
return "Non CUDA or ROCM device.";
}

if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) {
if (!tsl::tensor_float_32_execution_enabled() ||
Expand All @@ -801,8 +826,14 @@ FusionDecision CanTritonHandleGEMM(
case F32:
return true;
case BF16:
return cuda_compute_capability->IsAtLeast(
if(cuda_compute_capability) {
return cuda_compute_capability->IsAtLeast(
stream_executor::CudaComputeCapability::AMPERE);
}
else if(rocm_compute_capability) {
return rocm_compute_capability->has_bf16_dtype_support();
}
return false;
default:
return false;
}
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/triton_tiling_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible(
std::move(std::get<DimOrdersAndReqs>(result_or_error));
int fusion_level =
hlo.GetModule()->config().debug_options().xla_gpu_triton_fusion_level();
if (!std::get<se::CudaComputeCapability>(gpu_version)
//TODO(ROCm) Check fusion level for ROCm.
if (std::holds_alternative<se::CudaComputeCapability>(gpu_version)
&& !std::get<se::CudaComputeCapability>(gpu_version)
.IsAtLeast(se::CudaComputeCapability::AMPERE)) {
fusion_level = std::min(fusion_level, 1);
}
Expand Down
Loading