Skip to content

Commit

Permalink
OpenXLA-specific changes.
Browse files Browse the repository at this point in the history
- Disable short pointer.
- Remove remat_fast_load
  • Loading branch information
vwbaker committed Oct 6, 2023
1 parent 66f3e32 commit a0475e8
Show file tree
Hide file tree
Showing 34 changed files with 543 additions and 152 deletions.
454 changes: 400 additions & 54 deletions BUILD

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
MutableOperandRange getArgOperandsMutable() {
mlir::MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}

Expand Down
1 change: 0 additions & 1 deletion include/triton/Target/PTX/TmaMetadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#ifndef TRITON_TARGET_PTX_TMAMETADATA_H
#define TRITON_TARGET_PTX_TMAMETADATA_H

#include "python/triton/third_party/cuda/include/cuda.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
Expand Down
51 changes: 51 additions & 0 deletions include/triton/Tools/cuda_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "cuda.h"

// Compatibility with CUDA 11
#ifndef CU_TENSOR_MAP_NUM_QWORDS
#define CU_TENSOR_MAP_NUM_QWORDS 16

typedef struct CUtensorMap_st {
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
} CUtensorMap;

typedef enum CUtensorMapDataType_enum {
CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0,
CU_TENSOR_MAP_DATA_TYPE_UINT16,
CU_TENSOR_MAP_DATA_TYPE_UINT32,
CU_TENSOR_MAP_DATA_TYPE_INT32,
CU_TENSOR_MAP_DATA_TYPE_UINT64,
CU_TENSOR_MAP_DATA_TYPE_INT64,
CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
CU_TENSOR_MAP_DATA_TYPE_FLOAT64,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ,
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32,
CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
} CUtensorMapDataType;

typedef enum CUtensorMapInterleave_enum {
CU_TENSOR_MAP_INTERLEAVE_NONE = 0,
CU_TENSOR_MAP_INTERLEAVE_16B,
CU_TENSOR_MAP_INTERLEAVE_32B
} CUtensorMapInterleave;

typedef enum CUtensorMapSwizzle_enum {
CU_TENSOR_MAP_SWIZZLE_NONE = 0,
CU_TENSOR_MAP_SWIZZLE_32B,
CU_TENSOR_MAP_SWIZZLE_64B,
CU_TENSOR_MAP_SWIZZLE_128B
} CUtensorMapSwizzle;

typedef enum CUtensorMapL2promotion_enum {
CU_TENSOR_MAP_L2_PROMOTION_NONE = 0,
CU_TENSOR_MAP_L2_PROMOTION_L2_64B,
CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
} CUtensorMapL2promotion;

typedef enum CUtensorMapFloatOOBfill_enum {
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE = 0,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
} CUtensorMapFloatOOBfill;
#endif
6 changes: 4 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,11 @@ struct ConvertLayoutOpConversion
Value _16 = i32_val(16);
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
multiDimWarpId[0] =
urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0]));
urem(multiDimWarpId[0],
i32_val(ceil<unsigned>(shapePerCTA[0], instrShape[0])));
multiDimWarpId[1] =
urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1]));
urem(multiDimWarpId[1],
i32_val(ceil<unsigned>(shapePerCTA[1], instrShape[1])));

Value mmaGrpId = udiv(laneId, _4);
Value mmaGrpIdP8 = add(mmaGrpId, _8);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1230,8 +1230,8 @@ void populateElementwiseOpToLLVMPatterns(
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
POPULATE_BINARY_OP(arith::MinFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaxFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinimumFOp, LLVM::MinNumOp) // fmin
POPULATE_BINARY_OP(arith::MaximumFOp, LLVM::MaxNumOp) // fmax
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "triton/Tools/cuda_compat.h"

#include <numeric>

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc,
Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Value val, Value pred) {
MLIRContext *ctx = rewriter.getContext();
unsigned bits = val.getType().getIntOrFloatBitWidth();
unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth());
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");

PTXBuilder builder;
Expand All @@ -257,7 +257,7 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
auto ptrTy = ptr.getType().cast<LLVMPointerType>();
assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared");
auto elemTy = ptrTy.getElementType();
unsigned bitwidth = elemTy.getIntOrFloatBitWidth();
unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth());

const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r");

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \
} while (0)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)

// Types
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
// Floating point
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
// MaxMin
GenericOpPattern<arith::MaxFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinFOp>,
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
// Floating point
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
Expand Down Expand Up @@ -728,8 +728,8 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
ConversionPatternRewriter &rewriter) const override {
auto newOp =
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());

// Now, update all the types.

Expand All @@ -738,7 +738,7 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
// The entry block may have a special conversion if `entryConversion` is
// provided. On success, the new entry block to the region is returned for
// convenience. Otherwise, failure is returned.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
*getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
}
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ class RewriteTensorPointerPass
Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op,
std::stack<Operation *> &eraser) {
// Generate new iteration operands and set rewrited information
SmallVector<Value> oldIterOperands = op.getIterOperands();
SmallVector<Value> newIterOperands = op.getIterOperands();
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
SmallVector<Value> oldIterOperands = llvm::to_vector(op.getInitArgs());
SmallVector<Value> newIterOperands = llvm::to_vector(op.getInitArgs());
for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size;
++i, ++oldI) {
if (!triton::isTensorPointerType(newIterOperands[i].getType()))
continue;
Expand All @@ -350,7 +350,7 @@ class RewriteTensorPointerPass
// mapping. It may refer to a value in the old loop, but we will rewrite it
// later
IRMapping mapping;
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size();
++i, ++oldI) {
auto oldRegionIterArg = op.getRegionIterArg(oldI);
if (triton::isTensorPointerType(oldRegionIterArg.getType())) {
Expand All @@ -377,7 +377,7 @@ class RewriteTensorPointerPass
}

// Replace later usages
assert(op.getNumResults() == op.getNumIterOperands());
assert(op.getNumResults() == op.getInitArgs().size());
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
auto oldResult = op.getResult(oldI);
if (triton::isTensorPointerType(oldResult.getType())) {
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::getBackwardSlice(x, &slice, {{bwdFilter}});
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
Expand Down Expand Up @@ -298,8 +298,10 @@ class BlockedToMMA : public mlir::RewritePattern {
} else {

// convert operands
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
// TODO(b/296812125): Fix minBitwidth issue upstream and uncomment.
// int minBitwidth =
// std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
int minBitwidth = 0;
Type minType = IntegerType::get(ctx, minBitwidth);
// convert A operand
auto newAEncoding = ttg::DotOperandEncodingAttr::get(
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ class TritonGPUOptimizeDotOperandsPass

mlir::RewritePatternSet patterns(context);
patterns.add<ConvertTransConvert>(context);
if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
patterns.add<MoveOpAfterLayoutConversion>(context);
// TODO(b/283035396): Fix CUDA_ERROR_MISALIGNED_ADDRESS and uncomment.
// if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
// patterns.add<MoveOpAfterLayoutConversion>(context);
patterns.add<FuseTransHopper>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
Expand Down
7 changes: 3 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ LogicalResult LoopPipeliner::collectOps(SetVector<Operation *> &ops) {
void LoopPipeliner::collectValueDep(Value v, int stage,
SetVector<Value> &deps) {
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
if (v.getParentRegion() != &forOp.getRegion())
return;

// Since we only need to peel the loop numStages-1 times, don't worry
Expand Down Expand Up @@ -671,7 +671,7 @@ void LoopPipeliner::createBufferTypes() {
}

void LoopPipeliner::createOrderedDeps() {
for (Operation &op : forOp.getLoopBody().front()) {
for (Operation &op : *forOp.getBody()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0)))
Expand Down Expand Up @@ -1007,8 +1007,7 @@ SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
// We need this to update operands for yield
// original block arg => new arg's idx
SmallVector<Value> newLoopArgs;
for (auto v : forOp.getIterOperands())
newLoopArgs.push_back(v);
for (auto v : forOp.getInitArgs()) newLoopArgs.push_back(v);

bufferIdx = newLoopArgs.size();
for (auto loadOp : validLoads)
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ scf::ForOp Prefetcher::createNewForOp() {
OpBuilder builder(forOp);

SmallVector<Value> loopArgs;
for (auto v : forOp.getIterOperands())
loopArgs.push_back(v);
for (auto v : forOp.getInitArgs()) loopArgs.push_back(v);
for (Value dot : dots) {
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getA()]);
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,16 +760,16 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,

// Create a new loop before the existing one, with the extra operands.
rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getIterOperands());
auto operands = llvm::to_vector<4>(loop.getInitArgs());
operands.append(newIterOperands.begin(), newIterOperands.end());
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands);
newLoop.getBody()->erase();

newLoop.getLoopBody().getBlocks().splice(
newLoop.getLoopBody().getBlocks().begin(),
loop.getLoopBody().getBlocks());
newLoop.getRegion().getBlocks().splice(
newLoop.getRegion().getBlocks().begin(),
loop.getRegion().getBlocks());
for (Value operand : newIterOperands)
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());

Expand Down Expand Up @@ -805,8 +805,8 @@ static void rewriteSlice(SetVector<Value> &slice,
if (slice.count(arg)) {
OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg);
argMapping.push_back(
std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal),
forOp.getNumIterOperands() + newOperands.size()));
std::make_pair(forOp.getResultForOpOperand(initVal).getResultNumber(),
forOp.getInitArgs().size() + newOperands.size()));
newOperands.push_back(mapping.lookup(initVal.get()));
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ struct ForOpDeadArgElimination : public OpRewritePattern<scf::ForOp> {
Value yieldOperand =
forOwner.getBody()->getTerminator()->getOperand(iterIdx);
markLive(yieldOperand);
markLive(forOwner.getIterOperands()[iterIdx]);
markLive(forOwner.getInitArgs()[iterIdx]);
}
}
SmallVector<unsigned> deadArg;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinimumFOp,
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
Expand Down
14 changes: 7 additions & 7 deletions lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,9 @@ class TritonGPURewriteTensorPointerPass
std::stack<Operation *> &eraser,
DenseSet<Value> &valueToRemove) {
// Generate new iteration operands and set rewrited information
SmallVector<Value> oldIterOperands = op.getIterOperands();
SmallVector<Value> newIterOperands = op.getIterOperands();
for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
SmallVector<Value> oldIterOperands = llvm::to_vector(op.getInitArgs());
SmallVector<Value> newIterOperands = llvm::to_vector(op.getInitArgs());
for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size;
++i, ++oldI) {
if (!tt::isTensorPointerType(newIterOperands[i].getType()))
continue;
Expand All @@ -550,7 +550,7 @@ class TritonGPURewriteTensorPointerPass
// mapping. It may refer to a value in the old loop, but we will rewrite it
// later
IRMapping mapping;
for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size();
++i, ++oldI) {
auto oldRegionIterArg = op.getRegionIterArg(oldI);
if (tt::isTensorPointerType(oldRegionIterArg.getType()) &&
Expand Down Expand Up @@ -586,7 +586,7 @@ class TritonGPURewriteTensorPointerPass
valueToRemove.insert(v);

// Replace later usages
assert(op.getNumResults() == op.getNumIterOperands());
assert(op.getNumResults() == op.getInitArgs().size());
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
auto oldResult = op.getResult(oldI);
if (tt::isTensorPointerType(oldResult.getType()) &&
Expand Down Expand Up @@ -787,8 +787,8 @@ class TritonGPURewriteTensorPointerPass
}
}
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
SmallVector<Value> iterOperands = forOp.getIterOperands();
for (unsigned i = 0, size = forOp.getNumIterOperands(); i < size; ++i) {
SmallVector<Value> iterOperands = llvm::to_vector(forOp.getInitArgs());
for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) {
if (tt::isTensorPointerType(iterOperands[i].getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]);
if (shouldRemove(makeTensorPtrOp, computeCapability))
Expand Down
Loading

0 comments on commit a0475e8

Please sign in to comment.