diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 40f82557d2eb8a..9545610f10be7c 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -864,6 +864,18 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter, Operation *innerTerminator = innerLoop.getBody()->getTerminator(); auto yieldedVals = llvm::to_vector(innerTerminator->getOperands()); + assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs())); + for (Value &yieldedVal : yieldedVals) { + // The yielded value may be an iteration argument of the inner loop + // which is about to be inlined. + auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal); + if (iter != innerLoop.getRegionIterArgs().end()) { + unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin(); + // `outerLoop` iter args identical to the `innerLoop` init args. + assert(iterArgIndex < innerLoop.getInitArgs().size()); + yieldedVal = innerLoop.getInitArgs()[iterArgIndex]; + } + } rewriter.eraseOp(innerTerminator); SmallVector innerBlockArgs; diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir index 0235000aeac538..45dd299295f640 100644 --- a/mlir/test/Dialect/Affine/loop-coalescing.mlir +++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir @@ -114,6 +114,126 @@ func.func @unnormalized_loops() { return } +func.func @noramalized_loops_with_yielded_iter_args() { + // CHECK: %[[orig_lb:.*]] = arith.constant 0 + // CHECK: %[[orig_step:.*]] = arith.constant 1 + // CHECK: %[[orig_ub_k:.*]] = arith.constant 3 + // CHECK: %[[orig_ub_i:.*]] = arith.constant 42 + // CHECK: %[[orig_ub_j:.*]] = arith.constant 56 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c42 = arith.constant 42 : index + %c56 = arith.constant 56 : index + // The range of the new scf. + // CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]] + // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]] + + // Updated loop bounds. + // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) { + %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) { + // Inner loops must have been removed. + // CHECK-NOT: scf.for + + // Reconstruct original IVs from the linearized one. + // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]] + // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]] + // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]] + // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]] + %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){ + %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) { + // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]]) + "use"(%i, %j, %k) : (index, index, index) -> () + // CHECK: scf.yield %[[VAL_1]] : index + scf.yield %arg2 : index + } + scf.yield %0#0 : index + } + scf.yield %1#0 : index + } + return +} + +func.func @noramalized_loops_with_shuffled_yielded_iter_args() { + // CHECK: %[[orig_lb:.*]] = arith.constant 0 + // CHECK: %[[orig_step:.*]] = arith.constant 1 + // CHECK: %[[orig_ub_k:.*]] = arith.constant 3 + // CHECK: %[[orig_ub_i:.*]] = arith.constant 42 + // CHECK: %[[orig_ub_j:.*]] = arith.constant 56 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c42 = arith.constant 42 : index + %c56 = arith.constant 56 : index + // The range of the new scf. + // CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]] + // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]] + + // Updated loop bounds. + // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) { + %2:2 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0, %arg1 = %c0) -> (index, index) { + // Inner loops must have been removed. + // CHECK-NOT: scf.for + + // Reconstruct original IVs from the linearized one. + // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]] + // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]] + // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]] + // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]] + %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){ + %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) { + // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]]) + "use"(%i, %j, %k) : (index, index, index) -> () + // CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index + scf.yield %arg5, %arg4 : index, index + } + scf.yield %0#0, %0#1 : index, index + } + scf.yield %1#0, %1#1 : index, index + } + return +} + +func.func @noramalized_loops_with_yielded_non_iter_args() { + // CHECK: %[[orig_lb:.*]] = arith.constant 0 + // CHECK: %[[orig_step:.*]] = arith.constant 1 + // CHECK: %[[orig_ub_k:.*]] = arith.constant 3 + // CHECK: %[[orig_ub_i:.*]] = arith.constant 42 + // CHECK: %[[orig_ub_j:.*]] = arith.constant 56 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c42 = arith.constant 42 : index + %c56 = arith.constant 56 : index + // The range of the new scf. + // CHECK: %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]] + // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]] + + // Updated loop bounds. + // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) { + %2:1 = scf.for %i = %c0 to %c42 step %c1 iter_args(%arg0 = %c0) -> (index) { + // Inner loops must have been removed. + // CHECK-NOT: scf.for + + // Reconstruct original IVs from the linearized one. + // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]] + // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]] + // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]] + // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]] + %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){ + %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) { + // CHECK: %[[res:.*]] = "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]]) + %res = "use"(%i, %j, %k) : (index, index, index) -> (index) + // CHECK: scf.yield %[[res]] : index + scf.yield %res : index + } + scf.yield %0#0 : index + } + scf.yield %1#0 : index + } + return +} + // Check with parametric loop bounds and steps, capture the bounds here. // CHECK-LABEL: @parametric // CHECK-SAME: %[[orig_lb1:[A-Za-z0-9]+]]: