Skip to content

Commit

Permalink
Merge pull request triton-lang#23 from dfukalov/dfukalov/work-1
Browse files Browse the repository at this point in the history
[Triton-MLIR][ROCM] Preparing for ROCm support.
  • Loading branch information
rsanthanam-amd authored Nov 4, 2022
2 parents 4dc2396 + 071dc9f commit 786a739
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 30 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ endif()
option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)

# Default build for this branch is ROCm support
option(TRITON_USE_ROCM "Build with ROCm/AMDGPU support" ON)
if(DEFINED ENV{TRITON_USE_ROCM})
set(TRITON_USE_ROCM "$ENV{TRITON_USE_ROCM}" CACHE BOOL "" FORCE)
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Default build type: Release")
Expand All @@ -38,6 +44,15 @@ if(WIN32)
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")

if (TRITON_USE_ROCM)
set(MI_GPU_ARCH $ENV{MI_GPU_ARCH})
if (NOT MI_GPU_ARCH)
set(MI_GPU_ARCH "gfx90a")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ROCM -DMI_GPU_ARCH=${MI_GPU_ARCH} -Wno-unused-result -Wno-attributes")
endif()

if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
endif()
Expand Down Expand Up @@ -209,6 +224,7 @@ target_link_libraries(triton
MLIRExecutionEngine
MLIRMathToLLVM
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
MLIRIR
)

Expand Down
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ And the latest nightly release:
```bash
pip install -U --pre triton
```
# Install from source
```
git clone https://github.com/ROCmSoftwarePlatform/triton.git
cd triton
git checkout triton-mlir
```
# Build
```
cd python
pip3 install cmake; # build time dependency
pip3 install -e .
```
# Run tests:
```
# Run the Python tests
pytest
# Run the ctest
cd build/temp.linux-x86_64-3.7
ctest
# Run the lit tests
lit -v test
```

# Changelog

Expand Down
1 change: 1 addition & 0 deletions include/triton/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::NVVM::NVVMDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::StandardOpsDialect"];
}

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
MLIRPass
MLIRGPUOps
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms
TritonAnalysis
TritonIR
Expand Down
11 changes: 11 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
Expand Down Expand Up @@ -253,6 +254,7 @@ struct FuncOpConversion : public FuncOpConversionBase {

auto ctx = funcOp->getContext();

#ifndef USE_ROCM
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr(NVVMMetadataField::Kernel,
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
Expand All @@ -261,6 +263,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
// for `nvvm.annotation` metadata.
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
rewriter.getIntegerAttr(i32_ty, 32 * NumWarps));
#endif

rewriter.eraseOp(funcOp);
return success();
Expand Down Expand Up @@ -3693,7 +3696,11 @@ class ConvertTritonGPUToLLVM
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);

#ifdef USE_ROCM
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP);
#else
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
#endif

if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
Expand Down Expand Up @@ -3744,7 +3751,11 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
MLIRContext &ctx, mlir::LLVMTypeConverter &typeConverter)
: ConversionTarget(ctx), typeConverter(typeConverter) {
addLegalDialect<LLVM::LLVMDialect>();
#ifdef USE_ROCM
addLegalDialect<ROCDL::ROCDLDialect>();
#else
addLegalDialect<NVVM::NVVMDialect>();
#endif
// addIllegalDialect<triton::TritonDialect>();
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
Expand Down
14 changes: 14 additions & 0 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/tools/sys/getenv.hpp"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"

namespace mlir {
Expand All @@ -33,6 +35,7 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
auto *module = func->getParent();
auto &ctx = func->getContext();

#ifndef USE_ROCM
if (metadata.maxntidx > 0) {
auto i32_ty = llvm::IntegerType::get(ctx, 32);
auto warps =
Expand All @@ -45,14 +48,19 @@ void amendLLVMFunc(llvm::Function *func, const NVVMMetadata &metadata) {
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
}
#endif

if (metadata.is_kernel) {
#if defined(USE_ROCM)
func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
#else // defined(USE_ROCM)
llvm::Metadata *md_args[] = {
llvm::ValueAsMetadata::get(func), llvm::MDString::get(ctx, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(ctx), 1))};
module->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(llvm::MDNode::get(ctx, md_args));
#endif
}
}

Expand Down Expand Up @@ -86,7 +94,13 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
mlir::registerLLVMDialectTranslation(registry);

#ifdef USE_ROCM
mlir::registerROCDLDialectTranslation(registry);
#else
mlir::registerNVVMDialectTranslation(registry);
#endif

context->appendDialectRegistry(registry);

llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
Expand Down
2 changes: 1 addition & 1 deletion python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void init_triton_runtime(py::module &&m) {
py::enum_<backend_t>(m, "backend")
.value("HOST", HOST)
.value("CUDA", CUDA)
// .value("ROCM", ROCM)
.value("ROCM", ROCM)
.export_values();
}

Expand Down
29 changes: 15 additions & 14 deletions python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]


@functools.lru_cache
@functools.lru_cache()
def ptx_get_version(cuda_version) -> int:
'''
Get the highest PTX version supported by the current CUDA driver.
Expand Down Expand Up @@ -990,19 +990,20 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
llvm_ir = make_llvm_ir(module)

assert device >= 0, "device should be provided."
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name

cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name
if torch.version.hip is None:
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name

cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name

assert False

Expand Down
26 changes: 13 additions & 13 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
// XHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
Expand Down Expand Up @@ -263,7 +263,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range
func @basic_make_range() {
// CHECK: nvvm.read.ptx.sreg.tid.x
// XHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
Expand Down Expand Up @@ -303,7 +303,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
func @basic_program_id() {
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
// XHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
return
}
Expand Down Expand Up @@ -549,7 +549,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_blocked_blocked
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
Expand All @@ -566,7 +566,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
Expand Down Expand Up @@ -597,12 +597,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_blocked_blocked_vec
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
Expand All @@ -621,18 +621,18 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
Expand Down Expand Up @@ -685,12 +685,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global internal @global_smem() {addr_space = 3 : i32} : !llvm.array<2560 x i8>
// CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: nvvm.barrier0
// XHECK: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
Expand Down
4 changes: 2 additions & 2 deletions test/Target/tritongpu_to_llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
// CHECK: define void @test_empty_kernel
// CHECK: !nvvm.annotations
// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}
// XHECK: !nvvm.annotations
// XHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128}

module attributes {"triton_gpu.num-warps" = 4 : i32} {

Expand Down

0 comments on commit 786a739

Please sign in to comment.