Skip to content

Commit

Permalink
[mlir][SCF]-Fix loop coalescing with iteration arguements
Browse files Browse the repository at this point in the history
Fix a bug found when coalescing loops which have iteration
arguments, such that the inner loop's terminator may have
operands of the inner loop iteration arguments which are about
to be replaced by the outer loop's iteration arguments.

The current flow leads to crush within the IR code.
  • Loading branch information
amirBish committed Aug 22, 2024
1 parent 8ac140f commit eb2c431
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,18 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,

Operation *innerTerminator = innerLoop.getBody()->getTerminator();
auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
for (Value &yieldedVal : yieldedVals) {
// The yielded value may be an iteration argument of the inner loop
// which is about to be inlined.
auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
if (iter != innerLoop.getRegionIterArgs().end()) {
unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
// `outerLoop` iter args identical to the `innerLoop` init args.
assert(iterArgIndex < innerLoop.getInitArgs().size());
yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
}
}
rewriter.eraseOp(innerTerminator);

SmallVector<Value> innerBlockArgs;
Expand Down
120 changes: 120 additions & 0 deletions mlir/test/Dialect/Affine/loop-coalescing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,126 @@ func.func @unnormalized_loops() {
return
}

func.func @noramalized_loops_with_yielded_iter_args() {
// CHECK: %[[orig_lb:.*]] = arith.constant 0
// CHECK: %[[orig_step:.*]] = arith.constant 1
// CHECK: %[[orig_ub_k:.*]] = arith.constant 3
// CHECK: %[[orig_ub_i:.*]] = arith.constant 42
// CHECK: %[[orig_ub_j:.*]] = arith.constant 56
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c42 = arith.constant 42 : index
%c56 = arith.constant 56 : index
// The range of the new scf.
// CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
// CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]

// Updated loop bounds.
// CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
%2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
// Inner loops must have been removed.
// CHECK-NOT: scf.for

// Reconstruct original IVs from the linearized one.
// CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
// CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
// CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
// CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
%1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
%0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
// CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
"use"(%i, %j, %k) : (index, index, index) -> ()
// CHECK: scf.yield %[[VAL_1]] : index
scf.yield %arg2 : index
}
scf.yield %0#0 : index
}
scf.yield %1#0 : index
}
return
}

func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
// CHECK: %[[orig_lb:.*]] = arith.constant 0
// CHECK: %[[orig_step:.*]] = arith.constant 1
// CHECK: %[[orig_ub_k:.*]] = arith.constant 3
// CHECK: %[[orig_ub_i:.*]] = arith.constant 42
// CHECK: %[[orig_ub_j:.*]] = arith.constant 56
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c42 = arith.constant 42 : index
%c56 = arith.constant 56 : index
// The range of the new scf.
// CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
// CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]

// Updated loop bounds.
// CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) {
%2:2 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0, %arg1 = %c0) -> (index, index) {
// Inner loops must have been removed.
// CHECK-NOT: scf.for

// Reconstruct original IVs from the linearized one.
// CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
// CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
// CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
// CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
%1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
%0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
// CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
"use"(%i, %j, %k) : (index, index, index) -> ()
// CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index
scf.yield %arg5, %arg4 : index, index
}
scf.yield %0#0, %0#1 : index, index
}
scf.yield %1#0, %1#1 : index, index
}
return
}

func.func @noramalized_loops_with_yielded_non_iter_args() {
// CHECK: %[[orig_lb:.*]] = arith.constant 0
// CHECK: %[[orig_step:.*]] = arith.constant 1
// CHECK: %[[orig_ub_k:.*]] = arith.constant 3
// CHECK: %[[orig_ub_i:.*]] = arith.constant 42
// CHECK: %[[orig_ub_j:.*]] = arith.constant 56
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c42 = arith.constant 42 : index
%c56 = arith.constant 56 : index
// The range of the new scf.
// CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
// CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]

// Updated loop bounds.
// CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
%2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) {
// Inner loops must have been removed.
// CHECK-NOT: scf.for

// Reconstruct original IVs from the linearized one.
// CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
// CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
// CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
// CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
%1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
%0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
// CHECK: %[[res:.*]] = "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
%res = "use"(%i, %j, %k) : (index, index, index) -> (index)
// CHECK: scf.yield %[[res]] : index
scf.yield %res : index
}
scf.yield %0#0 : index
}
scf.yield %1#0 : index
}
return
}

// Check with parametric loop bounds and steps, capture the bounds here.
// CHECK-LABEL: @parametric
// CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]:
Expand Down

0 comments on commit eb2c431

Please sign in to comment.