Skip to content

Commit

Permalink
cherry pick #2701, #2736, #2756, #3191 (#3418)
Browse files Browse the repository at this point in the history
Cherry pick a few PRs to make PyTorch 2.3 work with Triton 2.2.x . Also
this makes sure some python bindings needed by user defined triton
kernel exists in the triton branch.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
Co-authored-by: Manman Ren <manman.ren@gmail.com>
Co-authored-by: Manman Ren <mren@meta.com>
  • Loading branch information
4 people authored Mar 19, 2024
1 parent 996b6c0 commit 79c6c9b
Show file tree
Hide file tree
Showing 21 changed files with 704 additions and 937 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ jobs:
echo '::set-output name=matrix-optional::["ubuntu-latest"]'
fi
Integration-Tests:
needs: Runner-Preparation

Expand All @@ -49,7 +48,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: 'true'
submodules: "true"
- name: Set CUDA ENV
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}}
run: |
Expand Down
3 changes: 2 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def build_extension(self, ext):

setup(
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
version="2.2.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""),
version="2.3.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""),
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
Expand All @@ -353,6 +353,7 @@ def build_extension(self, ext):
"triton/_C",
"triton/common",
"triton/compiler",
"triton/compiler/backends",
"triton/language",
"triton/language/extra",
"triton/ops",
Expand Down
87 changes: 71 additions & 16 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

#include <pybind11/numpy.h>
namespace py = pybind11;
using namespace mlir;

PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy);

Expand Down Expand Up @@ -170,7 +171,7 @@ class TritonOpBuilder {
private:
std::unique_ptr<mlir::OpBuilder> builder;
std::unique_ptr<mlir::Location> lastLoc;
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
bool lineInfoEnabled = !::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
};

static std::string locationToString(mlir::Location loc) {
Expand Down Expand Up @@ -347,15 +348,23 @@ void init_triton_ir(py::module &&m) {
[](mlir::Value &self, mlir::Value &newValue) {
self.replaceAllUsesWith(newValue);
})
.def("get_type", &mlir::Value::getType);
.def("get_type", &mlir::Value::getType)
.def("id", [](Value &self) {
// The Value is identified by and compared with
// other Values via the underlying ValueImpl
return (uint64_t)self.getImpl();
});

py::class_<OpResult, Value>(m, "op_result", py::module_local());

py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument",
py::module_local());

py::class_<mlir::Region>(m, "region", py::module_local())
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
.def("size", [](mlir::Region &self) { return self.getBlocks().size(); })
.def("empty", &mlir::Region::empty);
.def("empty", &mlir::Region::empty)
.def("id", [](Region &self) { return (uint64_t)&self; });

py::class_<mlir::Block>(m, "block", py::module_local())
.def("arg",
Expand All @@ -368,6 +377,7 @@ void init_triton_ir(py::module &&m) {
self.addArgument(ty, loc);
})
.def("get_num_arguments", &mlir::Block::getNumArguments)
.def("get_argument", &Block::getArgument)
.def("dump", &mlir::Block::dump)
.def("move_before", &mlir::Block::moveBefore)
.def("insert_before", &mlir::Block::insertBefore)
Expand Down Expand Up @@ -414,7 +424,8 @@ void init_triton_ir(py::module &&m) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::ReturnLike>();
})
.def("erase", [](mlir::Block &self) { self.erase(); });
.def("erase", [](mlir::Block &self) { self.erase(); })
.def("id", [](Block &self) { return (uint64_t)&self; });

// using eattr = ir::attribute_kind_t;
// py::enum_<eattr>(m, "attribute_kind")
Expand Down Expand Up @@ -461,7 +472,9 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpState &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
self->print(os);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.enableDebugInfo();
self->print(os, printingFlags);
return str;
})
.def("append_operand",
Expand Down Expand Up @@ -489,6 +502,35 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp",
py::module_local());

py::class_<Operation, std::unique_ptr<Operation, py::nodelete>>(
m, "operation", py::module_local())
.def("get_name",
[](Operation &self) {
llvm::StringRef opName = self.getName().getStringRef();
return opName.str();
})
.def("get_num_operands", &Operation::getNumOperands)
.def("get_operand", &Operation::getOperand)
.def("get_num_results", &Operation::getNumResults)
.def("get_result", &Operation::getResult)
.def("get_num_regions", &Operation::getNumRegions)
.def("get_region", &Operation::getRegion, ret::reference)
.def("get_block", &Operation::getBlock, ret::reference)
.def("get_str_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<StringAttr>(name);
if (!ret)
return py::none();
return py::str(ret.getValue().str());
})
.def("get_flat_symbol_ref_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
if (!ret)
return py::none();
return py::str(ret.getValue().str());
});

// dynamic_attr is used to transfer ownership of the MLIR context to the
// module
py::class_<mlir::ModuleOp, mlir::OpState>(m, "module", py::module_local(),
Expand All @@ -498,7 +540,9 @@ void init_triton_ir(py::module &&m) {
[](mlir::ModuleOp &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.enableDebugInfo();
self.print(os, printingFlags);
return str;
})
.def("bytecode",
Expand Down Expand Up @@ -532,6 +576,17 @@ void init_triton_ir(py::module &&m) {
if (funcs.size() != 1)
throw std::runtime_error("Expected a single function");
return funcs[0];
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
if (!ret)
return py::none();
return py::int_(ret.getInt());
})
.def("walk",
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
self.walk(fn);
});

m.def("make_attr",
Expand Down Expand Up @@ -1685,9 +1740,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::triton::createReorderBroadcastPass());
})
.def("add_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::triton::createRewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::triton::createRewriteTensorPointerPass(capability));
})
.def("add_tritongpu_ws_feasibility_checking_pass",
[](mlir::PassManager &self, int computeCapability) {
Expand Down Expand Up @@ -1761,9 +1816,9 @@ void init_triton_ir(py::module &&m) {
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
})
.def("add_tritongpu_rewrite_tensor_pointer_pass",
[](mlir::PassManager &self, int computeCapability) {
self.addPass(mlir::createTritonGPURewriteTensorPointerPass(
computeCapability));
[](mlir::PassManager &self, int capability) {
self.addPass(
mlir::createTritonGPURewriteTensorPointerPass(capability));
})
.def("add_tritongpu_decompose_conversions_pass",
[](mlir::PassManager &self) {
Expand Down Expand Up @@ -1794,8 +1849,8 @@ void init_triton_ir(py::module &&m) {
void init_triton_env_vars(py::module &m) {
m.def("get_env_vars", []() -> std::map<std::string, bool> {
std::map<std::string, bool> envVars;
for (const auto &envVar : triton::ENV_VARS) {
envVars[envVar] = triton::tools::getBoolEnv(envVar);
for (const auto &envVar : ::triton::ENV_VARS) {
envVars[envVar] = ::triton::tools::getBoolEnv(envVar);
}
return envVars;
});
Expand Down Expand Up @@ -1896,7 +1951,7 @@ void init_triton_translation(py::module &m) {
"lineno: " + std::to_string(error.getLineNo()));
}
// translate module to PTX
auto ptxCode = triton::translateLLVMIRToPTX(*module, capability,
auto ptxCode = ::triton::translateLLVMIRToPTX(*module, capability,
version, enable_fp_fusion);
return ptxCode;
},
Expand Down Expand Up @@ -1925,7 +1980,7 @@ void init_triton_translation(py::module &m) {
ofs.close();

auto lineInfoOption =
triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO")
? ""
: " -lineinfo";
auto fmadOption = enable_fp_fusion ? "" : " --fmad=false";
Expand Down
10 changes: 5 additions & 5 deletions python/test/unit/hopper/test_persistent_warp_specialized_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import triton
import triton.language as tl
from triton.runtime import driver
from triton.runtime.jit import get_current_device


# kernel used to query max clusters for persistent kernel when NUM_CTAS > 1
Expand Down Expand Up @@ -899,12 +898,13 @@ def process_epilogue(d, bias, w, epilogue):

NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
src = triton.compiler.ASTSource(fn=empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel = triton.compile(src)
null_kernel._init_handles()
device = driver.get_current_device()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
1, 1)
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.function, max_shared_mem, NUM_CTAS, 1,
1)
NUM_SMS = num_clusters

def grid(META):
Expand Down
82 changes: 82 additions & 0 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import triton
import triton.language as tl
from triton.compiler.backends.cuda import CUDABackend
from triton.runtime.driver import driver

import torch


@triton.jit
def add_helper(x, y):
return x + y


@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
n_elements,
out_ptr,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = add_helper(x, y)
tl.store(out_ptr + offsets, output, mask=mask)


def test_module_walk():
"""
Test the MLIR bindings exposed for the out-ot-tree walk.
"""

def walk_fn(op):
name = op.get_name()
for i in range(op.get_num_results()):
op.get_result(i).id()
for i in range(op.get_num_operands()):
op.get_operand(i).id()
for i in range(op.get_num_regions()):
op.get_region(i).id()
block = op.get_block()
if block is not None:
block.id()
for i in range(block.get_num_arguments()):
block.get_argument(i)
if name == "tt.func":
op.get_str_attr("sym_name")
if name == "tt.call":
op.get_flat_symbol_ref_attr("callee")

kernel = add_kernel
args = [
torch.empty((32, 32), device="cuda"), # in_ptr0
torch.empty((32, 32), device="cuda"), # in_ptr1
1024, # n_elements
torch.empty((32, 32), device="cuda"), # out_ptr
16, # BLOCK_SIZE
]
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={i: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs},
constants={i: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=kernel._get_config(*args, ),
)

triton._C.libtriton.ir = triton._C.libtriton.triton.ir
context = triton._C.libtriton.ir.context()

target = driver.get_current_target()
backend = CUDABackend(target)
options = backend.parse_options(dict())

ttir_module = src.make_ir(options)
ttir_module.walk(walk_fn)
Loading

0 comments on commit 79c6c9b

Please sign in to comment.