Skip to content

Commit

Permalink
[ROCm] Triton in XLA for ROCm - ir_emitter_triton related changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed Mar 25, 2024
1 parent f4b77bb commit d7a2b5b
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 105 deletions.
3 changes: 2 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ cc_library(
name = "ir_emitter_triton",
srcs = if_cuda_is_configured(["ir_emitter_triton.cc"]) + if_rocm_hipblaslt([
"ir_emitter_triton.cc",
]),
]) + if_cuda_is_configured(["ir_emitter_triton_cuda.cc"])
+ if_rocm_is_configured(["ir_emitter_triton_rocm.cc"]),
hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]),
deps = [
":hlo_traversal",
Expand Down
10 changes: 5 additions & 5 deletions xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "xla/service/gpu/ir_emitter_triton.h"
#else
#include "absl/status/status.h"
Expand Down Expand Up @@ -98,7 +98,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext());
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
VLOG(3) << fusion.ToString();
std::string suggested_kernel_name = std::string(fusion.name());
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -138,7 +138,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonSoftmaxFusionKind,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitSoftMax,
*ir_emitter_context.mlir_context()));
Expand All @@ -165,7 +165,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonGemmFusionKind,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitMatMul,
*ir_emitter_context.mlir_context()));
Expand Down Expand Up @@ -212,7 +212,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(

return result;
#else
return absl::UnimplementedError("Triton support requires CUDA");
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind()
return EmitterFusionKind::kCustomFusion;
}

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (fusion_backend_config_.kind() == kTritonGemmFusionKind ||
fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) {
return EmitterFusionKind::kTriton;
Expand Down
127 changes: 31 additions & 96 deletions xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ limitations under the License.
#include <utility>
#include <vector>

#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -50,6 +48,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Triple.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand Down Expand Up @@ -83,6 +82,7 @@ limitations under the License.
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "xla/autotuning.pb.h"
Expand Down Expand Up @@ -131,6 +131,7 @@ limitations under the License.
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"


namespace xla {
namespace gpu {

Expand All @@ -148,6 +149,15 @@ using ::mlir::Type;
using ::mlir::Value;
using mlir::ValueRange;


std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info);

absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info);

namespace {

// XLA -> Triton type conversions.
Expand Down Expand Up @@ -408,12 +418,16 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(),
llvm::Triple("nvptx64-unknown-unknown")),
/*pure=*/true);
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::CudaComputeCapability>
(device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
/*pure=*/true);
}
}
const bool is_integer =
Expand Down Expand Up @@ -735,81 +749,6 @@ absl::StatusOr<Value> EmitScope(
return values[instructions.back()];
}

// Create Triton pipeline.
//
// `out_cluster_info` must be kept alive at least until pm.run() is called.
// It should be read after that. We have to pass the cluster dims to
// LaunchDimensions. Triton currently uses this as an out-parameter to return
// the cluster dims determined based on `config.num_ctas` and a heuristic. There
// are some signs that show that this was intended to be used as an in-out
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::CudaComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
const int ccAsInt = cc.major * 10 + cc.minor;
const int threadsPerWarp = 32;

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mt::createCombineOpsPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mt::createReorderBroadcastPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(mlir::createSymbolDCEPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createCoalescePass());
pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeThreadLocalityPass());
pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mlir::createCSEPass());

if (cc.IsAtLeastAmpere()) {
pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));
}
if (!cc.IsAtLeastHopper()) {
pm.addPass(mt::gpu::createPrefetchPass());
}

pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
if (cc.IsAtLeastHopper()) {
pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt));
}
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt));
pm.addPass(mt::createConvertNVGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.

return absl::OkStatus();
}

// Extract additional attributes from an LLVM function that are not passed
// to the builder directly.
Expand Down Expand Up @@ -2172,6 +2111,7 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
module->getContext()->appendDialectRegistry(registry);

std::unique_ptr<llvm::Module> llvmModule =
Expand All @@ -2196,14 +2136,6 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
return llvmModule;
}

namespace {

std::string GetLibdevicePath(const HloModuleConfig& hlo_config) {
return nvptx::LibDevicePath(
hlo_config.debug_options().xla_gpu_cuda_data_dir());
}

} // namespace

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
Expand Down Expand Up @@ -2241,7 +2173,8 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
b.setInsertionPointToStart(&fn.front());

TF_RETURN_IF_ERROR(
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config()),
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config(),
device_info),
device_info, analysis, hlo_computation, fn, config));

b.create<mt::ReturnOp>(loc);
Expand All @@ -2259,7 +2192,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, absl::string_view fusion_kind,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context) {
Expand Down Expand Up @@ -2325,7 +2258,7 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
// TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context) {
Expand Down Expand Up @@ -2369,6 +2302,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
}
}


mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
if (!CreateTritonPipeline(pm, cc, config, /*out*/ cluster_info).ok()) {
return Internal("Failed to create Triton pipeline.");
Expand Down Expand Up @@ -2397,7 +2331,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
triton_module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
.getInt();
VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B";
if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
if (std::holds_alternative<se::CudaComputeCapability>(cc)
&& shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
return absl::ResourceExhaustedError(absl::StrFormat(
"Shared memory size limit exceeded: requested %d, available: %d",
shared_mem_bytes, device_info.shared_memory_per_block_optin()));
Expand All @@ -2406,7 +2341,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<llvm::Module> ll_triton_module,
TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module,
GetLibdevicePath(hlo_config)));
GetLibdevicePath(hlo_config, device_info)));
VLogModule(5, *ll_triton_module);
if (should_verify) {
VerifyModule(*ll_triton_module);
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/ir_emitter_triton.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ using TritonIrEmitter = std::function<Status(
absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, absl::string_view fusion_kind,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context);
Expand All @@ -99,7 +99,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
// Compiles a given Triton module to LLVM IR.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context);
Expand Down
Loading

0 comments on commit d7a2b5b

Please sign in to comment.