Skip to content

Commit

Permalink
Revert "[Pipeliner] Handle masking for atomic_rmw (#5231)"
Browse files Browse the repository at this point in the history
This reverts commit 01fb036.
  • Loading branch information
peterbell10 authored Dec 3, 2024
1 parent 01fb036 commit 3f7b21c
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,6 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
storeOp.getMaskMutable().assign(mask);
return op;
}
if (auto atomicRMWOp = dyn_cast<tt::AtomicRMWOp>(op)) {
rewriter.setInsertionPoint(atomicRMWOp);
Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(),
atomicRMWOp.getMask(), pred);
atomicRMWOp.getMaskMutable().assign(mask);
return op;
}

assert("don't know how to predicate this op" && false);
return op;
Expand Down
68 changes: 0 additions & 68 deletions test/TritonGPU/loop-pipeline-hip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,71 +266,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
tt.return
}
}

// -----

// Check that the stream pipeliner updates atomic op in the k-loop correctly
// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw
// CHECK: scf.for
// CHECK: tt.atomic_rmw fadd, acq_rel, gpu
// CHECK: tt.dot
// CHECK: scf.yield

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} {
%cst = arith.constant dense<32> : tensor<32x32xi32, #blocked>
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c31_i32 = arith.constant 31 : i32
%c32_i32 = arith.constant 32 : i32
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked>
%3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked>
%4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
%6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
%7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
%8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked>
%9 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
%10 = tt.addptr %9, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%11 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
%12 = tt.addptr %11, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%13 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>, #blocked>
%14 = tt.addptr %13, %3 : tensor<32x1x!tt.ptr<f16>, #blocked>, tensor<32x1xi32, #blocked>
%15 = tt.broadcast %14 : tensor<32x1x!tt.ptr<f16>, #blocked> -> tensor<32x32x!tt.ptr<f16>, #blocked>
%16 = tt.addptr %15, %7 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%17 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked>
%18 = arith.cmpi slt, %1, %17 : tensor<32x1xi32, #blocked>
%19 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked>
%20 = arith.cmpi slt, %5, %19 : tensor<1x32xi32, #blocked>
%21 = tt.broadcast %18 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
%22 = tt.broadcast %20 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
%23 = arith.andi %21, %22 : tensor<32x32xi1, #blocked>
%24 = arith.addi %arg3, %c31_i32 : i32
%25 = arith.divsi %24, %c32_i32 : i32
%26 = arith.muli %arg4, %c32_i32 : i32
%27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked>
%28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32x!tt.ptr<f16>, #blocked>) : i32 {
%32 = tt.load %arg7 : tensor<32x32x!tt.ptr<f16>, #blocked>
%33 = tt.load %arg8 : tensor<32x32x!tt.ptr<f16>, #blocked>
%34 = triton_gpu.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%35 = triton_gpu.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
%37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
%39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma>
%40 = triton_gpu.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked>
%41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32x!tt.ptr<f16>, #blocked>
}
%29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma>
%30 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr<f16>, #blocked> -> tensor<32x32x!tt.ptr<f16>, #mma>
%31 = triton_gpu.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma>
tt.store %30, %29, %31 : tensor<32x32x!tt.ptr<f16>, #mma>
tt.return
}
}

0 comments on commit 3f7b21c

Please sign in to comment.