Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jax-triton-dev authored and karupayun committed Apr 3, 2024
1 parent 6ffd95a commit 2b46efc
Show file tree
Hide file tree
Showing 30 changed files with 1,993 additions and 217 deletions.
972 changes: 972 additions & 0 deletions BUILD

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
axisAnalysisPass(axisAnalysisPass) {}

// True if elements allocated to a thread are contiguous within the axis. This
// is not the case in MMA-like encodings wherea thread might have elements
// (0,0),(0,1) and (8,0),(8,1) for example. The problem with this is that the
// deduplication mechanism assumes that for example constancy=4 and
// elements/thread=4 that if a thread has all elements constant.
bool contiguouslyMapped(Attribute encoding) const {
if (auto slice = encoding.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return contiguouslyMapped(slice.getParent());
}
return encoding.isa<triton::gpu::BlockedEncodingAttr>();
}

// Try to deduplicate the resultVals based on the
// constancy properties of the result discovered by
// the axis analysis pass. If possible, redundant
Expand All @@ -93,8 +105,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
if (!encoding)
// encoding not available
return resultVals;
if (!encoding.dyn_cast<BlockedEncodingAttr>() &&
!encoding.dyn_cast<SliceEncodingAttr>()) {
if (!contiguouslyMapped(encoding)) {
// TODO: constraining the ecndoing type here is necessary for avoiding
// crashes in the getElemsPerThread call below happening in the
// test_core::test_fp8_dot_acc
Expand Down
6 changes: 5 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,8 @@ loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
srcTy.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
auto srcElemTy = srcTy.getElementType();
auto dstElemTy = dstTy.getElementType();
LDBG("loadSharedToDistributed elemTy " << elemTy << " srcElemTy " << srcElemTy
<< " dstElemTy " << dstElemTy);
auto inOrd = triton::gpu::getOrder(srcSharedLayout);
auto outOrd = triton::gpu::getOrder(dstDistributedLayout);
unsigned outVec = inOrd == outOrd
Expand Down Expand Up @@ -1281,7 +1283,7 @@ loadSharedToDistributed(Value dst, ArrayRef<SmallVector<Value>> dstIndices,
auto valVec = load(wordTy, smemAddr);
valVec.setAlignment(minVec * elemTy.getIntOrFloatBitWidth() / 8);
for (unsigned v = 0; v < minVec; ++v) {
Value currVal = extract_element(dstElemTy, valVec, i32_val(v));
Value currVal = extract_element(elemTy, valVec, i32_val(v));
outVals[i * minVec + v] = currVal;
}
}
Expand Down Expand Up @@ -1407,6 +1409,8 @@ static Value packLLElements(Location loc,
<< v.value();
}
if (v.value().getType() != elementTypes[v.index()]) {
LDBG("type " << type << " structType " << structType);
LDBG("value " << v.value());
emitError(loc) << "invalid element type in packLLEElements. Expected "
<< elementTypes[v.index()] << " but got "
<< v.value().getType();
Expand Down
21 changes: 21 additions & 0 deletions include/triton/Dialect/Triton/IR/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,27 @@ template <typename VecT> bool isConsecutive(const VecT &vec) {
return isConsecutive(ArrayRef(vec));
}

// LLVM's STLExtras.h provides a bunch of functions that work over ranges, but
// it's missing min/max_element until
// https://github.com/llvm/llvm-project/commit/fab2bb8b makes it into Triton.
// TODO(jlebar): Remove this once we have the LLVM helpers.
template <typename R> auto min_element(R &&Range) {
return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range));
}
template <typename R, typename Compare>
auto min_element(R &&Range, Compare &&C) {
return std::min_element(llvm::adl_begin(Range), llvm::adl_end(Range),
std::forward<Compare>(C));
}
template <typename R> auto max_element(R &&Range) {
return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range));
}
template <typename R, typename T, typename Compare>
auto max_element(R &&Range, Compare &&C) {
return std::max_element(llvm::adl_begin(Range), llvm::adl_end(Range),
std::forward<Compare>(C));
}

} // namespace triton
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
Expand Down
74 changes: 74 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,26 @@ struct ExternElementwiseOpConversion
}
};

template <typename SourceOp, typename DestOp>
struct ElementwiseOpConversion
: public ElementwiseOpConversionBase<
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
using Base =
ElementwiseOpConversionBase<SourceOp,
ElementwiseOpConversion<SourceOp, DestOp>>;
using Base::Base;
using OpAdaptor = typename Base::OpAdaptor;

// An interface to support variant DestOp builder.
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
adaptor.getAttributes().getValue())};
}
};

struct ElementwiseInlineAsmOpConversion
: public ConvertOpToLLVMPattern<ElementwiseInlineAsmOp> {
using Base = ConvertOpToLLVMPattern<ElementwiseInlineAsmOp>;
Expand Down Expand Up @@ -720,6 +740,60 @@ void mlir::triton::populateClampFOpToLLVMPattern(
void mlir::triton::populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);

POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
POPULATE_UNARY_OP(arith::ExtUIOp, LLVM::ZExtOp)
POPULATE_UNARY_OP(arith::FPToUIOp, LLVM::FPToUIOp)
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
POPULATE_UNARY_OP(math::FloorOp, math::FloorOp)
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
POPULATE_UNARY_OP(math::Log2Op, math::Log2Op)
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
POPULATE_UNARY_OP(math::Exp2Op, math::Exp2Op)
POPULATE_UNARY_OP(math::ErfOp, math::ErfOp)
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
#undef POPULATE_UNARY_OP

#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>( \
typeConverter, axisInfoAnalysis, benefit);

POPULATE_BINARY_OP(arith::SubIOp, LLVM::SubOp) // -
POPULATE_BINARY_OP(arith::AddIOp, LLVM::AddOp) // +
POPULATE_BINARY_OP(arith::MulIOp, LLVM::MulOp) // *
POPULATE_BINARY_OP(arith::DivSIOp, LLVM::SDivOp)
POPULATE_BINARY_OP(arith::DivUIOp, LLVM::UDivOp)
POPULATE_BINARY_OP(arith::RemFOp, LLVM::FRemOp) // %
POPULATE_BINARY_OP(arith::RemSIOp, LLVM::SRemOp)
POPULATE_BINARY_OP(arith::RemUIOp, LLVM::URemOp)
POPULATE_BINARY_OP(arith::AndIOp, LLVM::AndOp) // &
POPULATE_BINARY_OP(arith::OrIOp, LLVM::OrOp) // |
POPULATE_BINARY_OP(arith::XOrIOp, LLVM::XOrOp) // ^
POPULATE_BINARY_OP(arith::ShLIOp, LLVM::ShlOp) // <<
POPULATE_BINARY_OP(arith::ShRSIOp, LLVM::AShrOp) // >>
POPULATE_BINARY_OP(arith::ShRUIOp, LLVM::LShrOp) // >>
// fmin (return non-NaN if either op is non-NaN)
POPULATE_BINARY_OP(arith::MinNumFOp, LLVM::MinNumOp)
// fmax (return non-NaN if either op is non-NaN)
POPULATE_BINARY_OP(arith::MaxNumFOp, LLVM::MaxNumOp)
POPULATE_BINARY_OP(arith::MinSIOp, LLVM::SMinOp) // smin
POPULATE_BINARY_OP(arith::MaxSIOp, LLVM::SMaxOp) // smax
POPULATE_BINARY_OP(arith::MinUIOp, LLVM::UMinOp) // umin
POPULATE_BINARY_OP(arith::MaxUIOp, LLVM::UMaxOp) // umax
#undef POPULATE_BINARY_OP

patterns.add<ElementwiseOpConversion<math::FmaOp, LLVM::FMAOp>>(
typeConverter, axisInfoAnalysis, benefit);

patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<CmpIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<CmpFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using ttg::SliceEncodingAttr;
// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
int baseVersion = 0;
if (computeCapability < 75) {
if (computeCapability < 80) {
baseVersion = 1;
} else if (computeCapability < 90) {
baseVersion = 2;
Expand Down Expand Up @@ -307,8 +307,10 @@ class BlockedToMMA : public mlir::RewritePattern {
} else {

// convert operands
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
// TODO(b/296812125): Fix minBitwidth issue upstream and uncomment.
// int minBitwidth =
// std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
int minBitwidth = 0;
Type minType = IntegerType::get(ctx, minBitwidth);
// convert A operand
auto newAEncoding = ttg::DotOperandEncodingAttr::get(
Expand Down
16 changes: 15 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <memory>
#include <string>

inline bool isPipeliningEnabled() {
const char *s = std::getenv("ENABLE_PIPELINING");
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}

namespace {

Expand Down Expand Up @@ -329,7 +341,9 @@ class TritonGPUOptimizeDotOperandsPass

mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
// TODO(b/291216607): Fix crashes and enable by default.
if (isPipeliningEnabled() &&
triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
Expand Down
Loading

0 comments on commit 2b46efc

Please sign in to comment.