Skip to content

Commit

Permalink
[ROCm] Introducing dump support for AMDGCN (triton-lang#25)
Browse files Browse the repository at this point in the history
* add `amdgcn` target for tools/aot.py

* clang-format fix

* [ROCm] added AMDGPU kernel call conversion

* [fix] Fixing AMDGPU calling convection
  • Loading branch information
B1tway committed Nov 22, 2022
1 parent dcb5b8d commit 8094861
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 12 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ target_link_libraries(triton
TritonGPUTransforms
TritonLLVMIR
TritonPTX
TritonAMDGCN
${dialect_libs}
${conversion_libs}
# optimizations
Expand Down
19 changes: 19 additions & 0 deletions include/triton/Target/AMDGCN/AMDGCNTranslation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef TRITON_TARGET_AMDGCNTRANSLATION_H
#define TRITON_TARGET_AMDGCNTRANSLATION_H

#include <memory>
#include <string>

namespace llvm {
class Module;
} // namespace llvm

namespace triton {

// Translate LLVM IR to AMDGCN code.
std::string translateLLVMIRToAMDGCN(llvm::Module &module,
const std::string &_proc);

} // namespace triton

#endif
3 changes: 1 addition & 2 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,10 @@ struct FuncOpConversion : public FuncOpConversionBase {

auto ctx = funcOp->getContext();

#ifndef USE_ROCM
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr(NVVMMetadataField::Kernel,
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));

#ifdef USE_ROCM
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
Expand Down
99 changes: 99 additions & 0 deletions lib/Target/AMDGCN/AMDGCNTranslation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include "triton/Target/AMDGCN/AMDGCNTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"

#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Cloning.h"

namespace triton {

static void init_llvm() {
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
}

static std::string llir_to_amdgcn(llvm::Module *module,
const std::string &_proc) {
init_llvm();

llvm::SmallVector<char, 0> buffer;
std::string triple = "amdgcn-amd-amdhsa";
std::string layout = "";
std::string features = "+sramecc,-xnack";
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
llvm::TargetOptions opt;

opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;

llvm::TargetMachine *machine = target->createTargetMachine(
module->getTargetTriple(), _proc, features, opt, llvm::Reloc::PIC_,
llvm::None, llvm::CodeGenOpt::None);

// set data layout
if (layout.empty())
module->setDataLayout(machine->createDataLayout());
else
module->setDataLayout(layout);
// emit machine code
for (llvm::Function &f : module->functions()) {
f.addFnAttr(llvm::Attribute::AlwaysInline);
}

llvm::legacy::PassManager pass;
llvm::raw_svector_ostream stream(buffer);

// emit
machine->addPassesToEmitFile(pass, stream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(*module);

std::string amdgcn(buffer.begin(), buffer.end());

return amdgcn;
}

std::string translateLLVMIRToAMDGCN(llvm::Module &module,
const std::string &_proc) {
auto gcnCode = llir_to_amdgcn(&module, _proc);
return gcnCode;
}

} // namespace triton
12 changes: 12 additions & 0 deletions lib/Target/AMDGCN/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_translation_library(TritonAMDGCN
AMDGCNTranslation.cpp

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
)
1 change: 1 addition & 0 deletions lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(AMDGCN)
add_subdirectory(LLVMIR)
add_subdirectory(PTX)
18 changes: 18 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Target/AMDGCN/AMDGCNTranslation.h"
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "triton/Target/PTX/PTXTranslation.h"
#include "triton/tools/sys/getenv.hpp"
Expand Down Expand Up @@ -1273,6 +1274,23 @@ void init_triton_translation(py::module &m) {
},
ret::take_ownership);

m.def(
"translate_llvmir_to_amdgcn",
[](const std::string llvmIR, int gfx_number) -> std::string {
// create LLVM module from C++
llvm::LLVMContext context;
std::unique_ptr<llvm::MemoryBuffer> buffer =
llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str());
llvm::SMDiagnostic error;
std::unique_ptr<llvm::Module> module =
llvm::parseIR(buffer->getMemBufferRef(), error, context);
// translate module to AMDGCN
std::string target = "gfx" + std::to_string(gfx_number);
auto gcnCode = triton::translateLLVMIRToAMDGCN(*module, target);
return gcnCode;
},
ret::take_ownership);

m.def("compile_ptx_to_cubin",
[](const std::string &ptxCode, const std::string &ptxasPath,
int capability) -> py::object {
Expand Down
2 changes: 2 additions & 0 deletions python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,8 @@ def optimize_tritongpu_ir(mod, num_stages):
def make_llvm_ir(mod):
return _triton.translate_triton_gpu_to_llvmir(mod)

def make_amdgcn(mod: Any, gfx_number: int):
return _triton.translate_llvmir_to_amdgcn(mod, gfx_number)

def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
'''
Expand Down
25 changes: 15 additions & 10 deletions python/triton/tools/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if __name__ == '__main__':

# valid source and target formats
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx']
VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn']

# set up the argument parser
# TODO: conditional requirements
Expand All @@ -16,7 +16,7 @@
help="Target format, one of: " + ', '.join(VALID_FORMATS))
parser.add_argument('--sm', type=int, help="Compute capability to compile for")
parser.add_argument('--ptx-version', type=int, help="PTX version to compile for")

parser.add_argument('--gfx', type=int, help="AMDGPU target to compile for")
# parse the args
args = parser.parse_args()

Expand Down Expand Up @@ -50,12 +50,17 @@
print(module)
exit(0)

if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")

# llvm-ir -> ptx
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)
assert args.target == 'ptx'
if args.target == 'ptx':
if not args.sm:
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
if not args.ptx_version:
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
# llvm-ir -> ptx
module = triton.compiler.make_ptx(module, compute_capability=args.sm, ptx_version=args.ptx_version)

if args.target == 'amdgcn':
if not args.gfx:
raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation")
# llvm-ir -> amdgcn
module = triton.compiler.make_amdgcn(module, args.gfx)
print(module)

0 comments on commit 8094861

Please sign in to comment.