diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 51490c79ce49042..41c69eed208608c 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -79,25 +79,31 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, // Create a series of assign ops assigning given values to given variables at // the current insertion point of given rewriter. -static void assignValues(ValueRange values, SmallVector &variables, +static void assignValues(ValueRange values, ValueRange variables, ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) rewriter.create(loc, var, value); } -static void lowerYield(SmallVector &resultVariables, - ConversionPatternRewriter &rewriter, - scf::YieldOp yield) { +static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + scf::YieldOp yield) { Location loc = yield.getLoc(); - ValueRange operands = yield.getOperands(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); - assignValues(operands, resultVariables, rewriter, loc); + SmallVector yieldOperands; + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); + } + + assignValues(yieldOperands, resultVariables, rewriter, loc); rewriter.create(loc); rewriter.eraseOp(yield); + + return success(); } LogicalResult @@ -118,22 +124,32 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, emitc::ForOp loweredFor = rewriter.create( loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); - // Propagate any attributes from the ODS forOp to the lowered emitc::for op. - loweredFor->setAttrs(forOp->getAttrs()); - Block *loweredBody = loweredFor.getBody(); // Erase the auto-generated terminator for the lowered for op. rewriter.eraseOp(loweredBody->getTerminator()); + // Convert the original region types into the new types by adding unrealized + // casts in the beginning of the loop. This performs the conversion in place. + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); + } + + // Register the replacements for the block arguments and inline the body of + // the scf.for loop into the body of the emitc::for loop. + Block *scfBody = &(forOp.getRegion().front()); SmallVector replacingValues; replacingValues.push_back(loweredFor.getInductionVar()); replacingValues.append(resultVariables.begin(), resultVariables.end()); + rewriter.mergeBlocks(scfBody, loweredBody, replacingValues); + + auto result = lowerYield(forOp, resultVariables, rewriter, + cast(loweredBody->getTerminator())); - Block *adaptorBody = &(adaptor.getRegion().front()); - rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues); - lowerYield(resultVariables, rewriter, - cast(loweredBody->getTerminator())); + if (failed(result)) { + return result; + } rewriter.replaceOp(forOp, resultVariables); return success(); @@ -169,11 +185,16 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, // emitc::if regions, but the scf::yield is replaced not only with an // emitc::yield, but also with a sequence of emitc::assign ops that set the // yielded values into the result variables. - auto lowerRegion = [&resultVariables, &rewriter](Region ®ion, - Region &loweredRegion) { + auto lowerRegion = [&resultVariables, &rewriter, + &ifOp](Region ®ion, Region &loweredRegion) { rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); Operation *terminator = loweredRegion.back().getTerminator(); - lowerYield(resultVariables, rewriter, cast(terminator)); + auto result = lowerYield(ifOp, resultVariables, rewriter, + cast(terminator)); + if (failed(result)) { + return result; + } + return success(); }; Region &thenRegion = adaptor.getThenRegion(); @@ -185,11 +206,17 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, rewriter.create(loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); - lowerRegion(thenRegion, loweredThenRegion); + auto result = lowerRegion(thenRegion, loweredThenRegion); + if (failed(result)) { + return result; + } if (hasElseBlock) { Region &loweredElseRegion = loweredIf.getElseRegion(); - lowerRegion(elseRegion, loweredElseRegion); + auto result = lowerRegion(elseRegion, loweredElseRegion); + if (failed(result)) { + return result; + } } rewriter.replaceOp(ifOp, resultVariables); diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index b422aaa4545d9b2..79a53ec8fd4c08d 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -99,11 +99,55 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 // CHECK-NEXT: return %[[VAL_4]] : f32 // CHECK-NEXT: } -func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) { - scf.for %i0 = %arg0 to %arg1 step %arg2 { - %c1 = arith.constant 1 : index - } {test.value = 5 : index} - return +func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + scf.yield %acc : index + } + return %r : index } -// CHECK-LABEL: func.func @loop_with_attr -// CHECK: {test.value = 5 : index} + +// CHECK-LABEL: func.func @for_yield_index( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] { +// CHECK: emitc.assign %[[VAL_4]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: } +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index +// CHECK: return %[[VAL_8]] : index +// CHECK: } + + +func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + %sn = arith.addi %acc, %acc : index + scf.yield %sn: index + } + return %r : index + } + +// CHECK-LABEL: func.func @for_yield_update_loop_carried_var( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] { +// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index +// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t +// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: } +// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index +// CHECK: return %[[VAL_9]] : index +// CHECK: } diff --git a/mlir/test/Conversion/SCFToEmitC/if.mlir b/mlir/test/Conversion/SCFToEmitC/if.mlir index afc9abc761eb4c1..eba1dda213e7062 100644 --- a/mlir/test/Conversion/SCFToEmitC/if.mlir +++ b/mlir/test/Conversion/SCFToEmitC/if.mlir @@ -68,3 +68,30 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + + +func.func @test_if_yield_index(%arg0: i1, %arg1: f32) { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %x = scf.if %arg0 -> (index) { + scf.yield %0 : index + } else { + scf.yield %1 : index + } + return +} + +// CHECK: func.func @test_if_yield_index( +// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t +// CHECK: emitc.if %[[ARG_0]] { +// CHECK: emitc.assign %[[VAL_0]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t +// CHECK: } else { +// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t +// CHECK: } +// CHECK: return +// CHECK: }