Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aliia Khasanova authored and khasanovaa committed May 22, 2024
1 parent 4af164a commit 38c82bd
Show file tree
Hide file tree
Showing 39 changed files with 1,787 additions and 92 deletions.
878 changes: 878 additions & 0 deletions BUILD

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#pragma once
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

#ifdef USE_ROCM
// Below headers will allow registration to ROCm passes
#include "TritonAMDGPUToLLVM/Passes.h"
#include "TritonAMDGPUTransforms/Passes.h"
#include "TritonAMDGPUTransforms/TritonGPUConversion.h"
#endif

#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
Expand Down Expand Up @@ -48,6 +50,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerDecomposeUnsupportedNVIDIAConversions();
mlir::registerLLVMDIScope();

#ifdef USE_ROCM
// TritonAMDGPUToLLVM passes
mlir::triton::registerConvertTritonAMDGPUToLLVM();
mlir::triton::registerConvertBuiltinFuncToLLVM();
Expand All @@ -58,6 +61,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUStreamPipeline();
#endif // USE_ROCM

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
6 changes: 0 additions & 6 deletions include/triton/Conversion/MLIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ inline Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
}

inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

} // namespace type
} // namespace triton
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
axisAnalysisPass(axisAnalysisPass) {}

// True if elements allocated to a thread are contiguous within the axis. This
// is not the case in MMA-like encodings wherea thread might have elements
// (0,0),(0,1) and (8,0),(8,1) for example. The problem with this is that the
// deduplication mechanism assumes that for example constancy=4 and
// elements/thread=4 that if a thread has all elements constant.
bool contiguouslyMapped(Attribute encoding) const {
if (auto slice = dyn_cast<triton::gpu::SliceEncodingAttr>(encoding)) {
return contiguouslyMapped(slice.getParent());
}
return isa<triton::gpu::BlockedEncodingAttr>(encoding);
}

// Try to deduplicate the resultVals based on the
// constancy properties of the result discovered by
// the axis analysis pass. If possible, redundant
Expand All @@ -87,8 +99,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
if (!encoding)
// encoding not available
return resultVals;
if (!dyn_cast<BlockedEncodingAttr>(encoding) &&
!dyn_cast<SliceEncodingAttr>(encoding)) {
if (!contiguouslyMapped(encoding)) {
// TODO: constraining the ecndoing type here is necessary for avoiding
// crashes in the getElemsPerThread call below happening in the
// test_core::test_fp8_dot_acc
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"AMDGCN_ENABLE_DUMP",
"DISABLE_FAST_REDUCTION",
"DISABLE_LLVM_OPT",
"DISABLE_MMA_V3",
"ENABLE_MMA_V3",
"DISABLE_PTXAS_OPT",
"LLVM_IR_ENABLE_DUMP",
"LLVM_ENABLE_TIMING",
Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
if ((inBitWidth == 16 && ouBitWidth == 32) ||
(inBitWidth == 32 && ouBitWidth == 16)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
Expand Down
7 changes: 4 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ struct ArithConstantSplatOpConversion
auto values = mlir::dyn_cast<SplatElementsAttr>(op.getValue());
auto elemType = values.getElementType();
Attribute val;
if (elemType.isBF16() || type::isFloat(elemType)) {
if (isa<FloatType>(elemType)) {
val = values.getValues<FloatAttr>()[0];
} else if (type::isInt(elemType)) {
} else if (isa<IntegerType>(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto typeConverter = getTypeConverter();
auto constOp = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(elemType), val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
rewriter.replaceOp(op, llStruct);
Expand Down
11 changes: 10 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using ttg::SliceEncodingAttr;
// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
int baseVersion = 0;
if (computeCapability < 75) {
if (computeCapability < 80) {
baseVersion = 1;
} else if (computeCapability < 90) {
baseVersion = 2;
Expand Down Expand Up @@ -133,6 +133,15 @@ class BlockedToMMA : public mlir::RewritePattern {
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<tt::FpToFpOp, tt::BitcastOp, ttg::ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand Down
17 changes: 16 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
Expand All @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(src)) {
Type srcType = getElementTypeOrSelf(src->getOperand(0));
if (srcType.isInteger(1))
return failure();
}

// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
foundConvertFromShared = true;
Expand Down
45 changes: 1 addition & 44 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,6 @@ class LayoutRematerialization {
ConvertLayoutOp convertOp);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
DenseMap<Value, Attribute> mappedValues;
// map of the values remat based on encoding.
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
Expand All @@ -189,7 +184,6 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
}

// Remove unneeded values now that we are done with the rematMapping.
Expand Down Expand Up @@ -845,31 +839,6 @@ bool canBeRemat(Operation *op) {
return true;
}

void LayoutRematerialization::updateRematMapping(
SmallVector<std::tuple<Value, Value>> &values) {
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
}
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
Expand All @@ -878,13 +847,6 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
// Keep track of yield operands that need to be duplicated.
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we already have a remat value for this value, use it.
if (hasRematValue(v, layoutIt->second)) {
mapping.map(v, getRematValue(v, layoutIt->second));
continue;
}
if (v.getDefiningOp()) {
opsToRewrite.insert(v.getDefiningOp());
if (auto ifOp = v.getDefiningOp<scf::IfOp>()) {
Expand Down Expand Up @@ -973,8 +935,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
if (slice.count(res)) {
// Why can't we use res instead of ifOp.getResult(oldIdx)?
mapping.map(ifOp.getResult(oldIdx), newIfOp.getResult(newIdx));
addRematValue(ifOp.getResult(oldIdx), layout[res],
newIfOp.getResult(newIdx));
addRematValue(res, layout[res], newIfOp.getResult(newIdx));
++newIdx;
}
++oldIdx;
Expand Down Expand Up @@ -1005,8 +966,6 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
auto cvt = builder.create<ConvertLayoutOp>(op->getLoc(), newType,
newOp->getResult(0));
mapping.map(op->getResult(0), cvt.getResult());
addRematValue(op->getResult(0), layout[op->getResult(0)],
cvt.getResult());
continue;
}
Operation *newOp = builder.clone(*op, mapping);
Expand All @@ -1018,14 +977,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
cast<RankedTensorType>(old.getType()).getShape(),
cast<RankedTensorType>(old.getType()).getElementType(), it->second);
newV.setType(newType);
addRematValue(old, it->second, newV);
}
}
// Check mapping and see if there are existing convertOps on the old Argument
convertOp.replaceAllUsesWith(mapping.lookup(convertOp.getSrc()));
opToDelete.insert(convertOp);

updateRematMapping(replacements);
for (auto &kv : replacements) {
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct FenceInsertionPass
// Only insert fences for compute capability 9.0
if (computeCapability < 90)
return;
if (::triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return;
ModuleOp mod = getOperation();
mod.walk([&](Operation *op) {
Expand Down
7 changes: 4 additions & 3 deletions lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,11 @@ LinearLayout LinearLayout::compose(const LinearLayout &outer) const {
for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) {
bases.push_back({outDim, b});
}
auto newBases = outer.apply(bases);
auto newBasesRange = llvm::make_second_range(newBases);

auto outerBases =
llvm::to_vector(llvm::make_second_range(outer.apply(bases)));
newInDimBases.push_back(
std::vector<int32_t>(newBasesRange.begin(), newBasesRange.end()));
std::vector<int32_t>(outerBases.begin(), outerBases.end()));
}
}
return LinearLayout(std::move(newBases), outer.getOutDimNames());
Expand Down
Loading

0 comments on commit 38c82bd

Please sign in to comment.