diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 31456f23ae41..7f0a6aed9617 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -211,6 +211,7 @@ def TT_AddPtrOp : TT_Op<"addptr", let results = (outs TT_PtrLike:$result); let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result) `,` type($offset)"; + let hasFolder = 1; } def TT_AdvanceOp : TT_Op<"advance", diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 1ac2d8cb53f8..60f9bdf95771 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -848,6 +848,15 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, builder.getDenseI32ArrayAttr(order)); } +//-- AddPtrOp -- +OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { + // addptr(ptr, 0) -> ptr + if (matchPattern(adaptor.getOffset(), m_Zero())) { + return getPtr(); + } + return {}; +} + //-- AdvanceOp -- OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { // advance(ptr, 0, 0) -> ptr diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index fd31e5c782fe..9b7804d44e2c 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -11,6 +11,8 @@ tt.func @dead_load(%ptr: tensor<32x128x!tt.ptr>) { tt.return } +// ----- + // CHECK-LABEL: make_range tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { // CHECK-DAG: %[[c:.*]] = arith.constant dense<0> : tensor<128x1xi32> @@ -25,6 +27,32 @@ tt.func @make_range() -> (tensor<128x1xi32>, tensor<1xi32>) { tt.return %c, %d : tensor<128x1xi32>, tensor<1xi32> } +// ----- + +// CHECK-LABEL: fold_addptr +tt.func @fold_addptr(%arg: tensor<64x64x!tt.ptr>) -> (tensor<64x64x!tt.ptr>) { + // CHECK-NOT: tt.addptr + // CHECK-NOT: arith.constant + // CHECK: tt.return %arg + %c0_i32 = arith.constant dense<0> : tensor<64x64xi32> + %0 = tt.addptr %arg, %c0_i32 : tensor<64x64x!tt.ptr>, tensor<64x64xi32> + tt.return %0 : tensor<64x64x!tt.ptr> +} + +// ----- + +// CHECK-LABEL: fold_addptr_scalar +tt.func @fold_addptr_scalar(%arg: !tt.ptr) -> (!tt.ptr) { + // CHECK-NOT: tt.addptr + // CHECK-NOT: arith.constant + // CHECK: tt.return %arg + %c0_i32 = arith.constant 0 : i32 + %0 = tt.addptr %arg, %c0_i32 : !tt.ptr, i32 + tt.return %0 : !tt.ptr +} + +// ----- + // CHECK-LABEL: fold_advance tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr>) { %c0_i32 = arith.constant 0 : i32 @@ -34,7 +62,6 @@ tt.func @fold_advance(%arg: !tt.ptr>) -> (!tt.ptr> } - // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 0b30ccb4191b..f541cacd0e8b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -617,7 +617,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: async_following_sync tt.func @async_following_sync(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) { - %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %cst = arith.constant dense<64> : tensor<64x16xi32, #blocked> %c0_i32 = arith.constant 0 : i32 %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>