Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg committed Sep 13, 2024
1 parent fd02f65 commit b39a103
Show file tree
Hide file tree
Showing 49 changed files with 2,330 additions and 203 deletions.
904 changes: 904 additions & 0 deletions BUILD

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions include/triton/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
for (auto funcOp : llvm::reverse(sortedFuncs)) {
initialize(funcOp);
funcOp.walk([&](CallOpInterface callOp) {
auto callee =
dyn_cast<FunctionOpInterface>(callOp.resolveCallable(&symbolTable));
auto callee = dyn_cast<FunctionOpInterface>(
callOp.resolveCallableInTable(&symbolTable));
update(callOp, callee);
});
}
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ template <typename T> class CallGraph {
moduleOp.walk([&](Operation *op) {
auto caller = op->getParentOfType<FunctionOpInterface>();
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto *callee = callOp.resolveCallable(&symbolTable);
auto *callee = callOp.resolveCallableInTable(&symbolTable);
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callee);
if (funcOp) {
graph[caller].emplace_back(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<I
let results = (outs TT_FpIntTensor:$d);

let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";

let extraClassDeclaration = [{
bool needsPartialAccumulator();
}];
}

def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
Expand Down
5 changes: 5 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,11 @@ bool supportMMA(triton::DotOp op, int version) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
return false;
auto retType = op.getType();
RankedTensorType typeA = op.getA().getType();
int k = typeA.getShape().back();
// If k size is smaller than the native mma size, we cannot use MMA.
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
return false;
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
auto mod = op->getParentOfType<ModuleOp>();
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
if ((inBitWidth == 16 && ouBitWidth == 32) ||
(inBitWidth == 32 && ouBitWidth == 16)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
Expand Down Expand Up @@ -610,10 +611,9 @@ struct IndexCastOpLowering
if (targetBits == sourceBits)
return {operands[0][0]};
if (targetBits < sourceBits)
return {rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, elemTy,
operands[0][0])};
return {
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, elemTy, operands[0][0])};
return {
rewriter.create<LLVM::TruncOp>(op.getLoc(), elemTy, operands[0][0])};
return {rewriter.create<LLVM::SExtOp>(op.getLoc(), elemTy, operands[0][0])};
}
};

Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ struct ArithConstantSplatOpConversion
// LLVM IR.
if (type::isFloat8(elemType))
elemType = rewriter.getIntegerType(8);
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto typeConverter = getTypeConverter();
auto constOp = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(elemType), val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
rewriter.replaceOp(op, llStruct);
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

Expand All @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand Down
7 changes: 6 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ namespace gpu {
namespace {
bool dotSupportsAccInitFlag(Operation *op) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
return isa<triton::nvidia_gpu::WarpGroupDotOp>(op);
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
// Partial accumulation would require a select op to handle the
// initialization that would degrade the performance.
return !wgDotOp.needsPartialAccumulator();
}
return false;
}

std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
Expand All @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(src)) {
Type srcType = getElementTypeOrSelf(src->getOperand(0));
if (srcType.isInteger(1))
return failure();
}

// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMemorySpace()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -187,6 +193,15 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ void WarpGroupDotOp::getEffects(
mlir::triton::gpu::SharedMemory::get());
}

bool WarpGroupDotOp::needsPartialAccumulator() {
const auto &a = getA();
const auto &d = getD();
auto aTensorTy = cast<TensorOrMemDesc>(a.getType());
auto aElTy = cast<TensorOrMemDesc>(a.getType()).getElementType();
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
bool accFP32 = cast<TensorOrMemDesc>(d.getType()).getElementType().isF32();
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();
return isFP8 && accFP32 && maxNumImpreciseAcc <= aTensorTy.getShape()[1];
}

// -- WarpGroupDotWaitOp --
LogicalResult WarpGroupDotWaitOp::inferReturnTypes(
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
Expand Down
77 changes: 77 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# NOTE: Do not depend on any targets from this directory,
# but use //third_party/py/triton instead.

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//third_party/py/triton:__pkg__",
"@triton//python:__subpackages__",
],
)

cc_library(
name = "passes",
hdrs = ["src/passes.h"],
includes = ["src"],
visibility = ["@triton//third_party:__subpackages__"],
)

pybind_extension(
name = "libtriton",
srcs = [
"src/interpreter.cc",
"src/ir.cc",
"src/llvm.cc",
"src/main.cc",
"src/passes.cc",
],
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
deps = [
":passes",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:InstCombine",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMIRTransforms",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"//:TritonAnalysis",
"//:TritonDialects",
"//:TritonGPUToLLVM",
"//:TritonGPUTransforms",
"//:TritonHSACO",
"//:TritonLLVMIR",
"//:TritonNvidiaGPUTransforms",
"//:TritonPTX",
"//:TritonToTritonGPU",
"//:TritonTools",
"//:TritonTransforms",
"@triton//third_party/nvidia:triton_nvidia",
],
)

filegroup(
name = "files",
srcs = glob(
include = ["triton/**/*.py"],
),
)
6 changes: 3 additions & 3 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp
#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
Expand Down Expand Up @@ -346,8 +346,8 @@ void init_triton_llvm(py::module &&m) {
// and break the lowering of some target specific intrinsics.
std::unique_ptr<TargetMachine> targetMachine = nullptr;
if (!arch.empty() && pluginFile.empty())
targetMachine = std::move(
createTargetMachine(mod, arch, enable_fp_fusion, features));
targetMachine =
createTargetMachine(mod, arch, enable_fp_fusion, features);
PassBuilder pb(/*targetMachine=*/targetMachine.get(), tuningOptions,
std::nullopt, instrCbPtr);

Expand Down
26 changes: 26 additions & 0 deletions python/test/regression/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
srcs = ["conftest.py"],
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["test_*.py"],
exclude = [
"test_performance.py", #TODO(b/321005767): fix failing test
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
12 changes: 12 additions & 0 deletions python/test/regression/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# content of conftest.py

import pytest


def pytest_addoption(parser):
parser.addoption("--device", action="store", default='cuda')


@pytest.fixture
def device(request):
return request.config.getoption("--device")
Loading

0 comments on commit b39a103

Please sign in to comment.