diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index 61e1a97a2961..072b23c6fb4a 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -1,6 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS Combine.td) +mlir_tablegen(TritonCombine.inc -gen-rewriters) +add_public_tablegen_target(TritonCombineIncGen) + add_mlir_dialect_library(TritonTransforms Combine.cpp DEPENDS TritonTransformsIncGen + TritonCombineIncGen ) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index ca5841aadc3f..69f0ced6a019 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -13,143 +13,43 @@ using namespace mlir; namespace { -// dot(a, b, 0) + c => dot(a, b, c) -class CombineDotOp : public mlir::RewritePattern { -public: - CombineDotOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, - context) {} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - if (llvm::isa(op)) { - if (isCandidate(op->getOperand(0)).succeeded()) { - auto dotOp = op->getOperand(0).getDefiningOp(); - rewriter.replaceOpWithNewOp( - op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(), - op->getOperand(1), dotOp.allowTF32()); - return mlir::success(); - } else if (isCandidate(op->getOperand(1)).succeeded()) { - auto dotOp = op->getOperand(1).getDefiningOp(); - rewriter.replaceOpWithNewOp( - op, dotOp->getResultTypes().front(), dotOp.a(), dotOp.b(), - op->getOperand(0), dotOp.allowTF32()); - return mlir::success(); - } - } - return mlir::failure(); - } - -private: - // Is this value a dot and has 0 as `c`. - mlir::LogicalResult isCandidate(mlir::Value val) const { - if (auto dot = val.getDefiningOp()) { - if (isZero(dot.c())) - return mlir::success(); - } - return mlir::failure(); - } - bool isZero(mlir::Value val) const { - if (mlir::matchPattern(val, mlir::m_Zero()) || - mlir::matchPattern(val, mlir::m_AnyZeroFloat())) +bool isZero(mlir::Value val) { + if (mlir::matchPattern(val, mlir::m_Zero()) || + mlir::matchPattern(val, mlir::m_AnyZeroFloat())) + return true; + // broadcast(constant_0) + if (auto bc = val.getDefiningOp()) { + if (mlir::matchPattern(bc.src(), mlir::m_Zero()) || + mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat())) return true; - // broadcast(constant_0) - if (auto bc = val.getDefiningOp()) { - if (mlir::matchPattern(bc.src(), mlir::m_Zero()) || - mlir::matchPattern(bc.src(), mlir::m_AnyZeroFloat())) - return true; - } - return false; } -}; - -// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1)) -// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect -// (ref: ArithmeticCanonicalization.td) -class CombineGEPOp : public mlir::RewritePattern { -public: - CombineGEPOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, - context) {} + return false; +} - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - if (llvm::isa(op)) { - if (auto gep2 = op->getOperand(0).getDefiningOp()) { - auto loc = op->getLoc(); - mlir::Value newIdx = rewriter.create( - loc, op->getOperand(1), gep2->getOperand(1)); - rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), gep2->getOperand(0), newIdx); - return mlir::success(); - } - } - return mlir::failure(); +bool isBroadcastConstantCombinable(Attribute value) { + if (auto denseValue = value.dyn_cast()) { + return denseValue.isSplat(); } -}; + return value.isa(); +} -// select(cond, load(ptrs, broadcast(cond), ???), other) -// => load(ptrs, broadcast(cond), other) -class CombineSelectMaskedLoadOp : public mlir::RewritePattern { -public: - CombineSelectMaskedLoadOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, - context) {} +DenseElementsAttr getConstantValue(Builder &builder, Attribute value, + Value bcast_res) { - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - if (llvm::isa(op)) { - if (auto load = op->getOperand(1).getDefiningOp()) { - mlir::Value cond = op->getOperand(0); - if (auto bc = load.mask().getDefiningOp()) { - if (bc.src().getDefiningOp() == cond.getDefiningOp()) { - rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), load.ptr(), load.mask(), - op->getOperand(2), load.cache(), load.evict(), - load.isVolatile()); - return mlir::success(); - } - } - } - } - return mlir::failure(); + Type resType = bcast_res.getType(); + DenseElementsAttr res; + if (auto denseValue = value.dyn_cast()) { + res = + DenseElementsAttr::get(resType, denseValue.getSplatValue()); + } else { + res = DenseElementsAttr::get(resType, value); } -}; + return res; +} -// broadcast(cst) => cst -// TODO: move this to .td file -class CombineBroadcastConstantOp : public mlir::RewritePattern { -public: - CombineBroadcastConstantOp(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::RewritePattern::MatchAnyOpTypeTag(), 1, - context) {} +#include "TritonCombine.inc" - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (auto broadcast = llvm::dyn_cast(op)) { - if (auto cst = broadcast.src().getDefiningOp()) { - Attribute value = cst.getValue(); - Type resType = broadcast.getResult().getType(); - if (auto denseValue = value.dyn_cast()) { - if (!denseValue.isSplat()) - return failure(); - value = DenseElementsAttr::get(resType, - denseValue.getSplatValue()); - } else { - if (!value.isa()) - return failure(); - value = DenseElementsAttr::get(resType, value); - } - rewriter.replaceOpWithNewOp(op, value, resType); - return success(); - } - } - return failure(); - } -}; } // anonymous namespace #define GEN_PASS_CLASSES @@ -162,11 +62,15 @@ class CombineOpsPass : public TritonCombineOpsBase { mlir::RewritePatternSet patterns(context); mlir::ModuleOp m = getOperation(); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - // patterns.add(context); + // Dot Add %{ + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + // %} + patterns.add(context); + patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td new file mode 100644 index 000000000000..6decc9539f54 --- /dev/null +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -0,0 +1,53 @@ +#ifndef TRITON_PATTERNS +#define TRITON_PATTERNS + +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" +include "triton/Dialect/Triton/IR/TritonOps.td" + + +// AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) +// AddFOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) + +// AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +// AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) +def CombineDotAddIPattern : Pat< + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), + (TT_DotOp $a, $b, $d, $allowTF32), + [(Constraint> $c)]>; +def CombineDotAddFPattern : Pat< + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), + (TT_DotOp $a, $b, $d, $allowTF32), + [(Constraint> $c)]>; + +def CombineDotAddIRevPattern : Pat< + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), + (TT_DotOp $a, $b, $d, $allowTF32), + [(Constraint> $c)]>; +def CombineDotAddFRevPattern : Pat< + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), + (TT_DotOp $a, $b, $d, $allowTF32), + [(Constraint> $c)]>; + + +// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1)) +// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect +// (ref: ArithmeticCanonicalization.td) +def CombineGEPPattern : Pat< + (TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1), + (TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>; + +// select(cond, load(ptrs, broadcast(cond), ???), other) +// => load(ptrs, broadcast(cond), other) +def CombineSelectMaskedLoadPattern : Pat< + (SelectOp $cond, (TT_LoadOp $ptrs, (TT_BroadcastOp:$bcast_res $cond), $other, $cache, $evict, $isVolatile), $falseValue), + (TT_LoadOp $ptrs, $bcast_res, $falseValue, $cache, $evict, $isVolatile)>; + +// broadcast(cst) => cst +def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; +def CombineBroadcastConstantPattern : Pat< + (TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)), + (Arith_ConstantOp (getConstantValue $value, $bcast_res)), + [(Constraint> $value)]>; + +#endif diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir new file mode 100644 index 000000000000..a85d57e655bd --- /dev/null +++ b/test/Triton/combine.mlir @@ -0,0 +1,68 @@ +// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine +// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s + +// CHECK-LABEL: @test_combine_dot_add_pattern +func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> + // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> + // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> + %a = arith.constant dense<1.0> : tensor<128x128xf32> + %b = arith.constant dense<2.0> : tensor<128x128xf32> + %zero = arith.constant dense<0.0> : tensor<128x128xf32> + %d = arith.constant dense<3.0> : tensor<128x128xf32> + + %dot_out = tt.dot %a, %b, %zero {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + + // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> + + // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> + + return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> +} + +// CHECK-LABEL: @test_combine_gep_pattern +func @test_combine_gep_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { + %off0 = arith.constant 10 : i32 + %off1 = arith.constant 15 : i32 + + // 10 + 15 = 25 + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32> + + %base_ = tt.broadcast %base : (!tt.ptr) -> tensor<8x!tt.ptr> + + // CHECK-NEXT: %[[tmp0:.*]] = tt.broadcast %{{.*}} : (!tt.ptr) -> tensor<8x!tt.ptr> + + %idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32> + %idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32> + + // CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr> + %ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr> + %ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr> + + return %ptr1 : tensor<8x!tt.ptr> +} + +// CHECK-LABEL: @test_combine_select_masked_load_pattern +func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> tensor<8xf32> { + %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> + %false_val = arith.constant dense<0.0> : tensor<8xf32> + + // CHECK: %[[res:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %0 = select %cond, %x, %false_val : tensor<8xf32> + + // CHECK: return %[[res]] : tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK-LABEL: @test_combine_broadcast_constant_pattern +func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : tensor<8x2xf32> + %const = arith.constant dense<1.0> : tensor<8xf32> + %bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32> + + // CHECK-NEXT: return %[[cst]] : tensor<8x2xf32> + return %bst_out : tensor<8x2xf32> +}