Skip to content

Commit

Permalink
Fix yield conversion of scf.if/scf.for to emitc (#401)
Browse files Browse the repository at this point in the history
* Fix conversion for scf.for and scf.if
  • Loading branch information
josel-amd authored Nov 22, 2024
1 parent 72cbeca commit 7326995
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 25 deletions.
63 changes: 45 additions & 18 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,31 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}

static void lowerYield(SmallVector<Value> &resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);
SmallVector<Value> yieldOperands;
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);

return success();
}

LogicalResult
Expand All @@ -118,22 +124,32 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());

// Propagate any attributes from the ODS forOp to the lowered emitc::for op.
loweredFor->setAttrs(forOp->getAttrs());

Block *loweredBody = loweredFor.getBody();

// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

// Convert the original region types into the new types by adding unrealized
// casts in the beginning of the loop. This performs the conversion in place.
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
*getTypeConverter(), nullptr))) {
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
}

// Register the replacements for the block arguments and inline the body of
// the scf.for loop into the body of the emitc::for loop.
Block *scfBody = &(forOp.getRegion().front());
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(resultVariables.begin(), resultVariables.end());
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);

auto result = lowerYield(forOp, resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

Block *adaptorBody = &(adaptor.getRegion().front());
rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues);
lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));
if (failed(result)) {
return result;
}

rewriter.replaceOp(forOp, resultVariables);
return success();
Expand Down Expand Up @@ -169,11 +185,16 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
Region &loweredRegion) {
auto lowerRegion = [&resultVariables, &rewriter,
&ifOp](Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
auto result = lowerYield(ifOp, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
if (failed(result)) {
return result;
}
return success();
};

Region &thenRegion = adaptor.getThenRegion();
Expand All @@ -185,11 +206,17 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
auto result = lowerRegion(thenRegion, loweredThenRegion);
if (failed(result)) {
return result;
}

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(elseRegion, loweredElseRegion);
auto result = lowerRegion(elseRegion, loweredElseRegion);
if (failed(result)) {
return result;
}
}

rewriter.replaceOp(ifOp, resultVariables);
Expand Down
58 changes: 51 additions & 7 deletions mlir/test/Conversion/SCFToEmitC/for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,55 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
// CHECK-NEXT: return %[[VAL_4]] : f32
// CHECK-NEXT: }

func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) {
scf.for %i0 = %arg0 to %arg1 step %arg2 {
%c1 = arith.constant 1 : index
} {test.value = 5 : index}
return
func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
%zero = arith.constant 0 : index
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
scf.yield %acc : index
}
return %r : index
}
// CHECK-LABEL: func.func @loop_with_attr
// CHECK: {test.value = 5 : index}

// CHECK-LABEL: func.func @for_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: emitc.assign %[[VAL_4]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_8]] : index
// CHECK: }


func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
%zero = arith.constant 0 : index
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
%sn = arith.addi %acc, %acc : index
scf.yield %sn: index
}
return %r : index
}

// CHECK-LABEL: func.func @for_yield_update_loop_carried_var(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_9]] : index
// CHECK: }
27 changes: 27 additions & 0 deletions mlir/test/Conversion/SCFToEmitC/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,30 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }


func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%x = scf.if %arg0 -> (index) {
scf.yield %0 : index
} else {
scf.yield %1 : index
}
return
}

// CHECK: func.func @test_if_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.if %[[ARG_0]] {
// CHECK: emitc.assign %[[VAL_0]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
// CHECK: } else {
// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
// CHECK: }
// CHECK: return
// CHECK: }

0 comments on commit 7326995

Please sign in to comment.