diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index f6314e3dd3f48f..9ac6ea38409e1d 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -502,7 +502,8 @@ cc_library( name = "ir_emitter_triton", srcs = if_cuda_is_configured(["ir_emitter_triton.cc"]) + if_rocm_hipblaslt([ "ir_emitter_triton.cc", - ]), + ]) + if_cuda_is_configured(["ir_emitter_triton_cuda.cc"]) + + if_rocm_is_configured(["ir_emitter_triton_rocm.cc"]), hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]), deps = [ ":hlo_traversal", diff --git a/xla/service/gpu/fusions/triton.cc b/xla/service/gpu/fusions/triton.cc index 2fc6d15898da6c..ebbaccdb0bd742 100644 --- a/xla/service/gpu/fusions/triton.cc +++ b/xla/service/gpu/fusions/triton.cc @@ -42,7 +42,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/ir_emitter_triton.h" #else #include "absl/status/status.h" @@ -98,7 +98,7 @@ absl::StatusOr TritonFusion::Emit( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext()); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM VLOG(3) << fusion.ToString(); std::string suggested_kernel_name = std::string(fusion.name()); TF_ASSIGN_OR_RETURN( @@ -137,7 +137,7 @@ absl::StatusOr TritonFusion::Emit( TF_ASSIGN_OR_RETURN( triton_wrapper_result, TritonWrapper(analysis, impl_fn_name, hlo_computation, - ir_emitter_context.cuda_compute_capability(), + ir_emitter_context.gpu_compute_capability(), ir_emitter_context.gpu_device_info(), config, ir_emitter_context.llvm_module(), &EmitSoftMax, *ir_emitter_context.mlir_context())); @@ -164,7 +164,7 @@ absl::StatusOr TritonFusion::Emit( TF_ASSIGN_OR_RETURN( triton_wrapper_result, TritonWrapper(analysis, impl_fn_name, hlo_computation, - ir_emitter_context.cuda_compute_capability(), + ir_emitter_context.gpu_compute_capability(), ir_emitter_context.gpu_device_info(), config, ir_emitter_context.llvm_module(), &EmitMatMul, *ir_emitter_context.mlir_context())); @@ -212,7 +212,7 @@ absl::StatusOr TritonFusion::Emit( return result; #else - return absl::UnimplementedError("Triton support requires CUDA"); + return absl::UnimplementedError("Triton support requires CUDA or ROCm"); #endif } diff --git a/xla/service/gpu/hlo_fusion_analysis.cc b/xla/service/gpu/hlo_fusion_analysis.cc index 30742675d2e489..529ebf9008d5d9 100644 --- a/xla/service/gpu/hlo_fusion_analysis.cc +++ b/xla/service/gpu/hlo_fusion_analysis.cc @@ -202,7 +202,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() return EmitterFusionKind::kCustomFusion; } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (fusion_backend_config_.kind() == kTritonGemmFusionKind || fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) { return EmitterFusionKind::kTriton; diff --git a/xla/service/gpu/ir_emitter_triton.cc b/xla/service/gpu/ir_emitter_triton.cc index 427ac4dbd6b4fa..2d226d240bfda3 100644 --- a/xla/service/gpu/ir_emitter_triton.cc +++ b/xla/service/gpu/ir_emitter_triton.cc @@ -30,8 +30,6 @@ limitations under the License. #include #include -#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" -#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -54,6 +52,7 @@ limitations under the License. #include "llvm/TargetParser/Triple.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -90,6 +89,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "xla/autotuning.pb.h" @@ -144,6 +144,7 @@ limitations under the License. #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + namespace xla { namespace gpu { @@ -424,12 +425,16 @@ absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, mlir::getElementTypeOrSelf(inputs[0]).isF64()) { auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); if (dev_fn_id.ok()) { - return b.create( - inputs[0].getType(), inputs, "libdevice", libdevice_path, - ObtainDeviceFunctionName(dev_fn_id.value(), - hlo.shape().element_type(), - llvm::Triple("nvptx64-unknown-unknown")), - /*pure=*/true); + llvm::Triple triple("nvptx64-unknown-unknown"); + if (std::holds_alternative + (device_info.gpu_compute_capability())) { + triple.setTriple("amdgcn-unknown-unknown"); + } + return b.create( + inputs[0].getType(), inputs, "libdevice", libdevice_path, + ObtainDeviceFunctionName(dev_fn_id.value(), + hlo.shape().element_type(), triple), + /*pure=*/true); } } const bool is_integer = @@ -932,81 +937,6 @@ absl::StatusOr EmitScope( return values[instructions.back()]; } -// Create Triton pipeline. -// -// `out_cluster_info` must be kept alive at least until pm.run() is called. -// It should be read after that. We have to pass the cluster dims to -// LaunchDimensions. Triton currently uses this as an out-parameter to return -// the cluster dims determined based on `config.num_ctas` and a heuristic. There -// are some signs that show that this was intended to be used as an in-out -// parameter which would give a hint to Triton which cluster dims we prefer to -// use, but that's not the case currently. -absl::Status CreateTritonPipeline( - mlir::OpPassManager& pm, const se::CudaComputeCapability& cc, - const TritonGemmConfig& config, - mt::nvidia_gpu::ClusterInfo& out_cluster_info) { - const int ccAsInt = cc.major * 10 + cc.minor; - const int threadsPerWarp = 32; - - // Based on make_ttir() in - // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mlir::createInlinerPass()); - pm.addPass(mt::createRewriteTensorPointerPass()); - pm.addPass(mt::createCombineOpsPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mt::createReorderBroadcastPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addPass(mlir::createSymbolDCEPass()); - - // Based on make_ttgir() in - // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::createConvertTritonToTritonGPUPass( - config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); - pm.addPass(mt::gpu::createCoalescePass()); - pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); - pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); - pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); - pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); - pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); - pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); - pm.addPass(mlir::createCSEPass()); - - pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, - config.num_ctas, ccAsInt)); - - if (!cc.IsAtLeastHopper()) { - pm.addPass(mt::gpu::createPrefetchPass()); - } - - pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); - pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); - pm.addPass(mt::gpu::createReduceDataDuplicationPass()); - pm.addPass(mt::gpu::createReorderInstructionsPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - if (cc.IsAtLeastHopper()) { - pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); - } - pm.addPass(mlir::createCanonicalizerPass()); - - // Based on make_llir() in - // @triton//:third_party/nvidia/backend/compiler.py - pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); - pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertIndexToLLVMPass()); - pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); - pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); - pm.addPass(mt::createConvertNVGPUToLLVMPass()); - pm.addPass(mlir::createArithToLLVMConversionPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createSymbolDCEPass()); - // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. - - return absl::OkStatus(); -} - // Extract additional attributes from an LLVM function that are not passed // to the builder directly. SmallVector GetExtraAttrs(ml::LLVMFuncOp func) { @@ -2653,6 +2583,7 @@ absl::StatusOr> TranslateLLVMToLLVMIR( mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); mlir::registerNVVMDialectTranslation(registry); + mlir::registerROCDLDialectTranslation(registry); module->getContext()->appendDialectRegistry(registry); std::unique_ptr llvmModule = @@ -2677,14 +2608,6 @@ absl::StatusOr> TranslateLLVMToLLVMIR( return llvmModule; } -namespace { - -std::string GetLibdevicePath(const HloModuleConfig& hlo_config) { - return nvptx::LibDevicePath( - hlo_config.debug_options().xla_gpu_cuda_data_dir()); -} - -} // namespace absl::StatusOr> CreateTritonModule( const TritonFusionAnalysis& analysis, absl::string_view fn_name, @@ -2724,7 +2647,8 @@ absl::StatusOr> CreateTritonModule( b.setInsertionPointToStart(&fn.front()); TF_RETURN_IF_ERROR( - ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config()), + ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config(), + device_info), device_info, analysis, hlo_computation, fn, config)); b.create(loc); @@ -2746,13 +2670,16 @@ absl::StatusOr> CreateTritonModule( absl::StatusOr TritonWrapper( const TritonFusionAnalysis& analysis, absl::string_view fn_name, - const HloComputation* hlo_computation, const se::CudaComputeCapability& cc, + const HloComputation* hlo_computation, const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const TritonGemmConfig& config, llvm::Module* llvm_module, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context) { - if (!cc.IsAtLeastAmpere()) { - return absl::FailedPreconditionError( - "Triton support is only enabled for Ampere GPUs and up."); + if (std::holds_alternative(cc)) { + auto ccCuda = std::get(cc); + if (!ccCuda.IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } } auto debug_options = GetDebugOptionsFromFlags(); @@ -2780,13 +2707,16 @@ absl::StatusOr TritonWrapper( // TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction. absl::StatusOr CompileTritonToLLVM( const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, - const se::CudaComputeCapability& cc, + const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const TritonGemmConfig& config, mlir::ModuleOp triton_module, llvm::Module* llvm_module, mlir::MLIRContext& mlir_context) { - if (!cc.IsAtLeastAmpere()) { - return absl::FailedPreconditionError( - "Triton support is only enabled for Ampere GPUs and up."); + if (std::holds_alternative(cc)) { + auto ccCuda = std::get(cc); + if (!ccCuda.IsAtLeastAmpere()) { + return absl::FailedPreconditionError( + "Triton support is only enabled for Ampere GPUs and up."); + } } bool should_verify = @@ -2860,7 +2790,8 @@ absl::StatusOr CompileTritonToLLVM( triton_module->getAttrOfType("triton_gpu.shared") .getInt(); VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B"; - if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) { + if (std::holds_alternative(cc) + && shared_mem_bytes > device_info.shared_memory_per_block_optin()) { return absl::ResourceExhaustedError(absl::StrFormat( "Shared memory size limit exceeded: requested %d, available: %d", shared_mem_bytes, device_info.shared_memory_per_block_optin())); @@ -2869,7 +2800,7 @@ absl::StatusOr CompileTritonToLLVM( TF_ASSIGN_OR_RETURN( std::unique_ptr ll_triton_module, TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module, - GetLibdevicePath(hlo_config))); + GetLibdevicePath(hlo_config, device_info))); VLogModule(5, *ll_triton_module); if (should_verify) { VerifyModule(*ll_triton_module); diff --git a/xla/service/gpu/ir_emitter_triton.h b/xla/service/gpu/ir_emitter_triton.h index 96ca55139bb196..7855f160443cd9 100644 --- a/xla/service/gpu/ir_emitter_triton.h +++ b/xla/service/gpu/ir_emitter_triton.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/service/gpu/hlo_traversal.h" @@ -40,10 +41,13 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/launch_dim.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" namespace xla { namespace gpu { +namespace mt = ::mlir::triton; + struct TritonWrapperResult { int64_t shmem_bytes = 0; std::optional cluster_dim; @@ -89,7 +93,7 @@ using TritonIrEmitter = std::function TritonWrapper( const TritonFusionAnalysis& analysis, absl::string_view fn_name, - const HloComputation* hlo_computation, const se::CudaComputeCapability& cc, + const HloComputation* hlo_computation, const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const TritonGemmConfig& config, llvm::Module* llvm_module, TritonIrEmitter ir_emitter, mlir::MLIRContext& mlir_context); @@ -105,11 +109,28 @@ absl::StatusOr> CreateTritonModule( // Compiles a given Triton module to LLVM IR. absl::StatusOr CompileTritonToLLVM( const HloModuleConfig& hlo_config, absl::string_view hlo_module_name, - const se::CudaComputeCapability& cc, + const se::GpuComputeCapability& cc, const se::DeviceDescription& device_info, const TritonGemmConfig& config, mlir::ModuleOp triton_module, llvm::Module* llvm_module, mlir::MLIRContext& mlir_context); +// Create Triton pipeline. +// +// `out_cluster_info` must be kept alive at least until pm.run() is called. +// It should be read after that. We have to pass the cluster dims to +// LaunchDimensions. Triton currently uses this as an out-parameter to return +// the cluster dims determined based on `config.num_ctas` and a heuristic. There +// are some signs that show that this was intended to be used as an in-out +// parameter which would give a hint to Triton which cluster dims we prefer to +// use, but that's not the case currently. +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + const TritonGemmConfig& config, + mt::nvidia_gpu::ClusterInfo& out_cluster_info); + +std::string GetLibdevicePath(const HloModuleConfig& hlo_config, + const se::DeviceDescription& device_info); + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/ir_emitter_triton_cuda.cc b/xla/service/gpu/ir_emitter_triton_cuda.cc new file mode 100644 index 00000000000000..4e30d95ef18add --- /dev/null +++ b/xla/service/gpu/ir_emitter_triton_cuda.cc @@ -0,0 +1,127 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_module_config.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h" +#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" + +namespace xla { +namespace gpu { + +namespace ma = ::mlir::arith; +namespace mm = ::mlir::math; +namespace ml = ::mlir::LLVM; +namespace mn = ::mlir::NVVM; +namespace mt = ::mlir::triton; + +using ::llvm::SmallVector; +using mlir::ArrayRef; +using mlir::ImplicitLocOpBuilder; +using ::mlir::ShapedType; +using ::mlir::Type; +using ::mlir::Value; +using mlir::ValueRange; + +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + const TritonGemmConfig& config, + mt::nvidia_gpu::ClusterInfo& out_cluster_info) { + auto ccCuda = std::get(cc); + const int ccAsInt = ccCuda.major * 10 + ccCuda.minor; + const int threadsPerWarp = 32; + + // Based on make_ttir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mt::createRewriteTensorPointerPass()); + pm.addPass(mt::createCombineOpsPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mt::createReorderBroadcastPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createSymbolDCEPass()); + + // Based on make_ttgir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::createConvertTritonToTritonGPUPass( + config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); + pm.addPass(mt::gpu::createCoalescePass()); + pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info)); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); + pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mlir::createCSEPass()); + + pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, + config.num_ctas, ccAsInt)); + + if (!ccCuda.IsAtLeastHopper()) { + pm.addPass(mt::gpu::createPrefetchPass()); + } + + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createReduceDataDuplicationPass()); + pm.addPass(mt::gpu::createReorderInstructionsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + if (ccCuda.IsAtLeastHopper()) { + pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt)); + } + pm.addPass(mlir::createCanonicalizerPass()); + + // Based on make_llir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createConvertIndexToLLVMPass()); + pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt)); + pm.addPass(mt::createConvertNVGPUToLLVMPass()); + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. + + return absl::OkStatus(); +} + +std::string GetLibdevicePath(const HloModuleConfig& hlo_config, + const se::DeviceDescription& device_info) { + return nvptx::LibDevicePath( + hlo_config.debug_options().xla_gpu_cuda_data_dir()); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/ir_emitter_triton_rocm.cc b/xla/service/gpu/ir_emitter_triton_rocm.cc new file mode 100644 index 00000000000000..0bd0179ef4de8f --- /dev/null +++ b/xla/service/gpu/ir_emitter_triton_rocm.cc @@ -0,0 +1,122 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/hlo_module_config.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" +#include "triton/Dialect/Triton/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "tsl/platform/rocm_rocdl_path.h" + +namespace xla { +namespace gpu { + +namespace ma = ::mlir::arith; +namespace mm = ::mlir::math; +namespace ml = ::mlir::LLVM; +namespace mt = ::mlir::triton; + +using ::llvm::SmallVector; +using mlir::ArrayRef; +using ::mlir::ShapedType; +using ::mlir::Type; +using ::mlir::Value; +using mlir::ValueRange; + +absl::Status CreateTritonPipeline( + mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, + const TritonGemmConfig& config, + mt::nvidia_gpu::ClusterInfo& out_cluster_info) { + // TODO(ROCm): Check whether value different than 0 can be used. + const int ccAsInt = 0; + // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. + const int threadsPerWarp = 32; + + // Based on make_ttir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mt::createRewriteTensorPointerPass()); + pm.addPass(mt::createCombineOpsPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mt::createReorderBroadcastPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createSymbolDCEPass()); + + // Based on make_ttgir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::createConvertTritonToTritonGPUPass( + config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt)); + pm.addPass(mt::gpu::createCoalescePass()); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeThreadLocalityPass()); + pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt)); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps, + config.num_ctas, ccAsInt)); + pm.addPass(mt::gpu::createPrefetchPass()); + + pm.addPass(mt::gpu::createOptimizeDotOperandsPass()); + pm.addPass(mt::gpu::createRemoveLayoutConversionsPass()); + pm.addPass(mt::gpu::createReduceDataDuplicationPass()); + pm.addPass(mt::gpu::createReorderInstructionsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createCanonicalizerPass()); + + // Based on make_llir() in + // @triton//:third_party/nvidia/backend/compiler.py + pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass()); + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createConvertIndexToLLVMPass()); + pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); + pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass()); + pm.addPass(mlir::createArithToLLVMConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createSymbolDCEPass()); + // Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass. + pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createConvertControlFlowToLLVMPass()); + + // There is no clusters in ROCm for now. + out_cluster_info.clusterDimX = 1; + out_cluster_info.clusterDimY = 1; + out_cluster_info.clusterDimZ = 1; + + return absl::OkStatus(); +} + +std::string GetLibdevicePath(const HloModuleConfig& hlo_config, + const se::DeviceDescription& device_info) { + std::string libdevice_dir = tsl::RocdlRoot(); + auto compute_capability = device_info.rocm_compute_capability(); + const std::string libdevice_path = + amdgpu::LibDevicePath(compute_capability.gcn_arch_name(), libdevice_dir); + return libdevice_path; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 6c6a6c20fbe27f..a38b6a2e6f0f0d 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -941,6 +941,7 @@ void AMDGPUBackendInit(const DebugOptions& debug_options, LLVMInitializeAMDGPUTarget(); LLVMInitializeAMDGPUTargetInfo(); LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmParser(); LLVMInitializeAMDGPUAsmPrinter(); #endif @@ -952,6 +953,19 @@ void AMDGPUBackendInit(const DebugOptions& debug_options, } // namespace namespace amdgpu { + +std::string LibDevicePath(std::string gcn_arch_name, + const std::string& rocdl_dir_path) { + + auto libdevice_dir_paths = GetROCDLPaths(gcn_arch_name, rocdl_dir_path); + for (auto libdevice_dir_path : libdevice_dir_paths) { + if (libdevice_dir_path.find("ocml.bc")) { + return libdevice_dir_path; + } + } + return ""; +} + absl::StatusOr> CompileToHsaco( llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, diff --git a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 38a4687987533a..a5f320fac54c3d 100644 --- a/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -60,6 +60,9 @@ absl::StatusOr CompileToPtx( } // namespace nvptx namespace amdgpu { +// Get path to libdevice file. +std::string LibDevicePath(std::string gcn_arch_name, + const std::string& rocdl_dir_path); // Compiles the argument module and returns it with LLVM AMDGPU backend. // rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries. // The contents of the module may be changed.