Skip to content

Commit

Permalink
[ROCm] Triton in XLA for ROCm - ir_emitter_triton related changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed Apr 3, 2024
1 parent 4c8a74b commit 23d442f
Show file tree
Hide file tree
Showing 9 changed files with 330 additions and 111 deletions.
3 changes: 2 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ cc_library(
name = "ir_emitter_triton",
srcs = if_cuda_is_configured(["ir_emitter_triton.cc"]) + if_rocm_hipblaslt([
"ir_emitter_triton.cc",
]),
]) + if_cuda_is_configured(["ir_emitter_triton_cuda.cc"])
+ if_rocm_is_configured(["ir_emitter_triton_rocm.cc"]),
hdrs = if_gpu_is_configured(["ir_emitter_triton.h"]),
deps = [
":hlo_traversal",
Expand Down
10 changes: 5 additions & 5 deletions xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "xla/service/gpu/ir_emitter_triton.h"
#else
#include "absl/status/status.h"
Expand Down Expand Up @@ -98,7 +98,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext());
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
VLOG(3) << fusion.ToString();
std::string suggested_kernel_name = std::string(fusion.name());
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -137,7 +137,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
TF_ASSIGN_OR_RETURN(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitSoftMax,
*ir_emitter_context.mlir_context()));
Expand All @@ -164,7 +164,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
TF_ASSIGN_OR_RETURN(
triton_wrapper_result,
TritonWrapper(analysis, impl_fn_name, hlo_computation,
ir_emitter_context.cuda_compute_capability(),
ir_emitter_context.gpu_compute_capability(),
ir_emitter_context.gpu_device_info(), config,
ir_emitter_context.llvm_module(), &EmitMatMul,
*ir_emitter_context.mlir_context()));
Expand Down Expand Up @@ -212,7 +212,7 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(

return result;
#else
return absl::UnimplementedError("Triton support requires CUDA");
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind()
return EmitterFusionKind::kCustomFusion;
}

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (fusion_backend_config_.kind() == kTritonGemmFusionKind ||
fusion_backend_config_.kind() == kTritonSoftmaxFusionKind) {
return EmitterFusionKind::kTriton;
Expand Down
135 changes: 33 additions & 102 deletions xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ limitations under the License.
#include <variant>
#include <vector>

#include "nvidia/include/NVGPUToLLVM/NVGPUToLLVMPass.h"
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -54,6 +52,7 @@ limitations under the License.
#include "llvm/TargetParser/Triple.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
Expand Down Expand Up @@ -90,6 +89,7 @@ limitations under the License.
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "xla/autotuning.pb.h"
Expand Down Expand Up @@ -144,6 +144,7 @@ limitations under the License.
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"


namespace xla {
namespace gpu {

Expand Down Expand Up @@ -424,12 +425,16 @@ absl::StatusOr<Value> EmitElementwise(ImplicitLocOpBuilder& b,
mlir::getElementTypeOrSelf(inputs[0]).isF64()) {
auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode());
if (dev_fn_id.ok()) {
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(),
llvm::Triple("nvptx64-unknown-unknown")),
/*pure=*/true);
llvm::Triple triple("nvptx64-unknown-unknown");
if (std::holds_alternative<se::RocmComputeCapability>
(device_info.gpu_compute_capability())) {
triple.setTriple("amdgcn-unknown-unknown");
}
return b.create<mt::ExternElementwiseOp>(
inputs[0].getType(), inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(),
hlo.shape().element_type(), triple),
/*pure=*/true);
}
}
const bool is_integer =
Expand Down Expand Up @@ -932,81 +937,6 @@ absl::StatusOr<Value> EmitScope(
return values[instructions.back()];
}

// Create Triton pipeline.
//
// `out_cluster_info` must be kept alive at least until pm.run() is called.
// It should be read after that. We have to pass the cluster dims to
// LaunchDimensions. Triton currently uses this as an out-parameter to return
// the cluster dims determined based on `config.num_ctas` and a heuristic. There
// are some signs that show that this was intended to be used as an in-out
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::CudaComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info) {
const int ccAsInt = cc.major * 10 + cc.minor;
const int threadsPerWarp = 32;

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mt::createCombineOpsPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mt::createReorderBroadcastPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(mlir::createSymbolDCEPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
config.num_warps, threadsPerWarp, config.num_ctas, ccAsInt));
pm.addPass(mt::gpu::createCoalescePass());
pm.addPass(mlir::createTritonNvidiaGPUPlanCTAPass(&out_cluster_info));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeThreadLocalityPass());
pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mlir::createCSEPass());

pm.addPass(mt::gpu::createPipelinePass(config.num_stages, config.num_warps,
config.num_ctas, ccAsInt));

if (!cc.IsAtLeastHopper()) {
pm.addPass(mt::gpu::createPrefetchPass());
}

pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
if (cc.IsAtLeastHopper()) {
pm.addPass(mlir::createTritonNvidiaGPUFenceInsertionPass(ccAsInt));
}
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
pm.addPass(mt::createConvertTritonGPUToLLVMPass(ccAsInt));
pm.addPass(mt::createConvertNVGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.

return absl::OkStatus();
}

// Extract additional attributes from an LLVM function that are not passed
// to the builder directly.
SmallVector<mlir::NamedAttribute> GetExtraAttrs(ml::LLVMFuncOp func) {
Expand Down Expand Up @@ -2653,6 +2583,7 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
module->getContext()->appendDialectRegistry(registry);

std::unique_ptr<llvm::Module> llvmModule =
Expand All @@ -2677,14 +2608,6 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> TranslateLLVMToLLVMIR(
return llvmModule;
}

namespace {

std::string GetLibdevicePath(const HloModuleConfig& hlo_config) {
return nvptx::LibDevicePath(
hlo_config.debug_options().xla_gpu_cuda_data_dir());
}

} // namespace

absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
Expand Down Expand Up @@ -2724,7 +2647,8 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
b.setInsertionPointToStart(&fn.front());

TF_RETURN_IF_ERROR(
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config()),
ir_emitter(b, GetLibdevicePath(hlo_computation->parent()->config(),
device_info),
device_info, analysis, hlo_computation, fn, config));

b.create<mt::ReturnOp>(loc);
Expand All @@ -2746,13 +2670,16 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(

absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, const se::CudaComputeCapability& cc,
const HloComputation* hlo_computation, const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context) {
if (!cc.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
if (!ccCuda.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
}
}

auto debug_options = GetDebugOptionsFromFlags();
Expand Down Expand Up @@ -2780,13 +2707,16 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
// TODO(b/325220878): Replace TritonGemmConfig with a more generic abstraction.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context) {
if (!cc.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
if (!ccCuda.IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
"Triton support is only enabled for Ampere GPUs and up.");
}
}

bool should_verify =
Expand Down Expand Up @@ -2860,7 +2790,8 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
triton_module->getAttrOfType<mlir::IntegerAttr>("triton_gpu.shared")
.getInt();
VLOG(2) << "Shared memory usage: " << shared_mem_bytes << " B";
if (shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
if (std::holds_alternative<se::CudaComputeCapability>(cc)
&& shared_mem_bytes > device_info.shared_memory_per_block_optin()) {
return absl::ResourceExhaustedError(absl::StrFormat(
"Shared memory size limit exceeded: requested %d, available: %d",
shared_mem_bytes, device_info.shared_memory_per_block_optin()));
Expand All @@ -2869,7 +2800,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<llvm::Module> ll_triton_module,
TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module,
GetLibdevicePath(hlo_config)));
GetLibdevicePath(hlo_config, device_info)));
VLogModule(5, *ll_triton_module);
if (should_verify) {
VerifyModule(*ll_triton_module);
Expand Down
25 changes: 23 additions & 2 deletions xla/service/gpu/ir_emitter_triton.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/hlo_traversal.h"
Expand All @@ -40,10 +41,13 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/launch_dim.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"

namespace xla {
namespace gpu {

namespace mt = ::mlir::triton;

struct TritonWrapperResult {
int64_t shmem_bytes = 0;
std::optional<se::ClusterDim> cluster_dim;
Expand Down Expand Up @@ -89,7 +93,7 @@ using TritonIrEmitter = std::function<Status(
// MatMul and SoftMax above are some such IR generators.
absl::StatusOr<TritonWrapperResult> TritonWrapper(
const TritonFusionAnalysis& analysis, absl::string_view fn_name,
const HloComputation* hlo_computation, const se::CudaComputeCapability& cc,
const HloComputation* hlo_computation, const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
llvm::Module* llvm_module, TritonIrEmitter ir_emitter,
mlir::MLIRContext& mlir_context);
Expand All @@ -105,11 +109,28 @@ absl::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> CreateTritonModule(
// Compiles a given Triton module to LLVM IR.
absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const HloModuleConfig& hlo_config, absl::string_view hlo_module_name,
const se::CudaComputeCapability& cc,
const se::GpuComputeCapability& cc,
const se::DeviceDescription& device_info, const TritonGemmConfig& config,
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
mlir::MLIRContext& mlir_context);

// Create Triton pipeline.
//
// `out_cluster_info` must be kept alive at least until pm.run() is called.
// It should be read after that. We have to pass the cluster dims to
// LaunchDimensions. Triton currently uses this as an out-parameter to return
// the cluster dims determined based on `config.num_ctas` and a heuristic. There
// are some signs that show that this was intended to be used as an in-out
// parameter which would give a hint to Triton which cluster dims we prefer to
// use, but that's not the case currently.
absl::Status CreateTritonPipeline(
mlir::OpPassManager& pm, const se::GpuComputeCapability& cc,
const TritonGemmConfig& config,
mt::nvidia_gpu::ClusterInfo& out_cluster_info);

std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info);

} // namespace gpu
} // namespace xla

Expand Down
Loading

0 comments on commit 23d442f

Please sign in to comment.