diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index cc1a22d0d48a18..d8e1cc0ecef88e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, } void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { - // Initialize the iteration argument to the loop initiale values. + // Initialize the iteration argument to the loop initial values. for (auto [arg, operand] : llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) { setValueMapping(arg, operand.get(), 0); @@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { - setValueMapping(op->getResult(destId), newOp->getResult(destId), - i - stages[op]); + Value source = newOp->getResult(destId); // If the value is a loop carried dependency update the loop argument - // mapping. for (OpOperand &operand : yield->getOpOperands()) { if (operand.get() != op->getResult(destId)) continue; + if (predicates[predicateIdx] && + !forOp.getResult(operand.getOperandNumber()).use_empty()) { + // If the value is used outside the loop, we need to make sure we + // return the correct version of it. + Value prevValue = valueMapping + [forOp.getRegionIterArgs()[operand.getOperandNumber()]] + [i - stages[op]]; + source = rewriter.create( + loc, predicates[predicateIdx], source, prevValue); + } setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], - newOp->getResult(destId), i - stages[op] + 1); + source, i - stages[op] + 1); } + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); } } } diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index 46e7feca4329ee..9687f80f5ddfc8 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref, %result: memref) { // ----- // NOEPILOGUE-LABEL: stage_0_value_escape( -func.func @stage_0_value_escape(%A: memref, %result: memref) { +func.func @stage_0_value_escape(%A: memref, %result: memref, %ub: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index %cf = arith.constant 1.0 : f32 -// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index -// NOEPILOGUE: %[[A:.+]] = arith.addf -// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]], -// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index -// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32 -// NOEPILOGUE: scf.yield %[[S]] - %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) { +// NOEPILOGUE: %[[UB:[^,]+]]: index) +// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index +// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index +// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00 +// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]] +// NOEPILOGUE: scf.if +// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]] +// NOEPILOGUE: %[[A:.+]] = arith.addf +// NOEPILOGUE: scf.yield %[[A]] +// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]] +// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]], +// NOEPILOGUE: %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index +// NOEPILOGUE: %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index +// NOEPILOGUE: %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32 +// NOEPILOGUE: scf.yield %[[S1]] + %r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) { %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref