Skip to content

Commit

Permalink
[ROCm] Triton in XLA for ROCm - ir_emitter_triton related changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed Mar 18, 2024
1 parent f4b77bb commit a7ea1f4
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 16 deletions.
10 changes: 5 additions & 5 deletions xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "xla/service/gpu/ir_emitter_triton.h"
#else
#include "absl/status/status.h"
Expand Down Expand Up @@ -98,7 +98,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext());
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
VLOG(3) << fusion.ToString();
std::string suggested_kernel_name = std::string(fusion.name());
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -138,7 +138,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonSoftmaxFusionKind,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitSoftMax,
*ir_emitter_context.mlir_context()));
Expand All @@ -165,7 +165,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
kTritonGemmFusionKind,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitMatMul,
*ir_emitter_context.mlir_context()));
Expand Down Expand Up @@ -212,7 +212,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(

return result;
#else
return absl::UnimplementedError("Triton support requires CUDA");
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind()
return EmitterFusionKind::kCustomFusion;
}

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (fusion_backend_config_.kind() == kTritonGemmFusionKind ||
fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) {
return EmitterFusionKind::kTriton;
Expand Down
97 changes: 89 additions & 8 deletions xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ limitations under the License.
#include <utility>
#include <vector>

#if GOOGLE_CUDA
#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#endif
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -50,6 +52,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Triple.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand Down Expand Up @@ -83,6 +86,7 @@ limitations under the License.
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "xla/autotuning.pb.h"
Expand Down Expand Up @@ -122,14 +126,21 @@ limitations under the License.
#include "tsl/platform/path.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
#ifdef TENSORFLOW_USE_ROCM
#include "tsl/platform/rocm_rocdl_path.h"
#endif
#include "tsl/platform/tensor_float_32_utils.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#ifndef TENSORFLOW_USE_ROCM
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#else
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#endif

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -410,10 +421,17 @@ Value EmitElementwise(ImplicitLocOpBuilder& b, absl::string_view libdevice_path,
if (dev_fn_id.ok()) {
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
#ifdef TENSORFLOW_USE_ROCM
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(),
llvm::Triple("amdgcn-unknown-unknown")),
/*pure=*/true);
#else
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(),
llvm::Triple("nvptx64-unknown-unknown")),
/*pure=*/true);
#endif
}
}
const bool is_integer =
Expand Down Expand Up @@ -744,11 +762,23 @@ absl::StatusOr<Value> EmitScope(
// are some signs that show that this was intended to be used as an in-out
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
#if GOOGLE_CUDA
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::CudaComputeCapability& cc,
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
const int ccAsInt = cc.major * 10 + cc.minor;
#else
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const TritonGemmConfig& config) {
#endif

#ifndef TENSORFLOW_USE_ROCM
auto cc = std::get<se::CudaComputeCapability>(gpu_version);
cont int ccAsInt = cc.major * 10 + cc.minor;
#else
cont int ccAsInt = 0;
#endif
const int threadsPerWarp = 32;

// Based on make_ttir() in
Expand All @@ -767,31 +797,40 @@ absl::Status CreateTritonPipeline(
pm.addPass(mt::createConvertTritonToTritonGPUPass(
config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createCoalescePass());
#ifndef TENSORFLOW_USE_ROCM
pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info));
#endif
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeThreadLocalityPass());
pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mlir::createCSEPass());

#ifndef TENSORFLOW_USE_ROCM
if (cc.IsAtLeastAmpere()) {
pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));
}
if (!cc.IsAtLeastHopper()) {
pm.addPass(mt::gpu::createPrefetchPass());
}
#else
pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createPrefetchPass());
#endif

pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
#ifndef TENSORFLOW_USE_ROCM
if (cc.IsAtLeastHopper()) {
pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt));
}
#endif
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
Expand All @@ -800,13 +839,21 @@ absl::Status CreateTritonPipeline(
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
#ifndef TENSORFLOW_USE_ROCM
pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt));
pm.addPass(mt::createConvertNVGPUToLLVMPass());
#else
pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
#endif
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.
#ifdef TENSORFLOW_USE_ROCM
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertControlFlowToLLVMPass());
#endif

return absl::OkStatus();
}
Expand Down Expand Up @@ -2172,6 +2219,7 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
module->getContext()->appendDialectRegistry(registry);

std::unique_ptr<llvm::Module> llvmModule =
Expand All @@ -2198,9 +2246,18 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(

namespace {

std::string GetLibdevicePath(const HloModuleConfig& hlo_config) {
std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info) {
#ifdef TENSORFLOW_USE_ROCM
std::string libdevice_dir = tsl::RocdlRoot();
auto compute_capability = device_info.rocm_compute_capability();
const std::string libdevice_path =
amdgpu::LibDevicePath(compute_capability.gcn_arch_name(), libdevice_dir);
return libdevice_path;
#else
return nvptx::LibDevicePath(
hlo_config.debug_options().xla_gpu_cuda_data_dir());
#endif
}

} // namespace
Expand Down Expand Up @@ -2241,7 +2298,8 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
b.setInsertionPointToStart(&fn.front());

TF_RETURN_IF_ERROR(
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config()),
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config(),
device_info),
device_info, analysis, hlo_computation, fn, config));

b.create<mt::ReturnOp>(loc);
Expand All @@ -2259,7 +2317,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, absl::string_view fusion_kind,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context) {
Expand Down Expand Up @@ -2325,7 +2383,7 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
// TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context) {
Expand Down Expand Up @@ -2369,8 +2427,13 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
}
}


#if GOOGLE_CUDA
mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
if (!CreateTritonPipeline(pm, cc, config, /*out*/ cluster_info).ok()) {
#else
if (!CreateTritonPipeline(pm, cc, config).ok()) {
#endif
return Internal("Failed to create Triton pipeline.");
}
if (log_stream.has_value()) {
Expand All @@ -2383,6 +2446,18 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
// llvm::Linker::linkModules() segfaults if we don't strip locations.
pm.addPass(mlir::createStripDebugInfoPass());

// TODO(ROCm): Check why call to loadAllAvailableDialects is necessary here.
#ifdef TENSORFLOW_USE_ROCM
{
mlir::DialectRegistry registry;
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
triton_module->getContext()->appendDialectRegistry(registry);
triton_module->getContext()->loadAllAvailableDialects();
}
#endif // TENSORFLOW_USE_ROCM

bool succeeded = mlir::succeeded(pm.run(triton_module));

if (log_stream.has_value()) {
Expand All @@ -2397,16 +2472,18 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
triton_module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
.getInt();
VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B";
#ifndef TENSORFLOW_USE_ROCM
if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
return absl::ResourceExhaustedError(absl::StrFormat(
"Shared memory size limit exceeded: requested %d, available: %d",
shared_mem_bytes, device_info.shared_memory_per_block_optin()));
}
#endif

TF_ASSIGN_OR_RETURN(
std::unique_ptr<llvm::Module> ll_triton_module,
TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module,
GetLibdevicePath(hlo_config)));
GetLibdevicePath(hlo_config, device_info)));
VLogModule(5, *ll_triton_module);
if (should_verify) {
VerifyModule(*ll_triton_module);
Expand All @@ -2425,6 +2502,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
VerifyModule(*llvm_module);
}

#if GOOGLE_CUDA
// `cluster_info` must be read after pm.run().
std::optional<se::ClusterDim> cluster_dim;
if (config.num_ctas > 1) {
Expand All @@ -2443,6 +2521,9 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
cluster_info.clusterDimZ == 1);
}
return {{shared_mem_bytes, cluster_dim}};
#else
return {{shared_mem_bytes}};
#endif
}

} // namespace gpu
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/ir_emitter_triton.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ using TritonIrEmitter = std::function<Status(
absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, absl::string_view fusion_kind,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context);
Expand All @@ -99,7 +99,7 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
// Compiles a given Triton module to LLVM IR.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context);
Expand Down
14 changes: 14 additions & 0 deletions xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ void AMDGPUBackendInit(const DebugOptions& debug_options,
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmParser();
LLVMInitializeAMDGPUAsmPrinter();
#endif

Expand All @@ -952,6 +953,19 @@ void AMDGPUBackendInit(const DebugOptions& debug_options,
} // namespace

namespace amdgpu {

std::string LibDevicePath(std::string gcn_arch_name,
const std::string& rocdl_dir_path) {

auto libdevice_dir_paths = GetROCDLPaths(gcn_arch_name, rocdl_dir_path);
for (auto libdevice_dir_path : libdevice_dir_paths) {
if (libdevice_dir_path.find("ocml.bc")) {
return libdevice_dir_path;
}
}
return "";
}

absl::StatusOr<std::vector<uint8_t>> CompileToHsaco(
llvm::Module* module, se::GpuComputeCapability gpu_version,
const DebugOptions& debug_options,
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ absl::StatusOr<std::string> CompileToPtx(
} // namespace nvptx

namespace amdgpu {
// Get path to libdevice file.
std::string LibDevicePath(std::string gcn_arch_name,
const std::string& rocdl_dir_path);
// Compiles the argument module and returns it with LLVM AMDGPU backend.
// rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries.
// The contents of the module may be changed.
Expand Down

0 comments on commit a7ea1f4

Please sign in to comment.