Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix yield conversion of scf.if/scf.for to emitc #401

Merged
merged 8 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 49 additions & 18 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,35 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}

static void lowerYield(SmallVector<Value> &resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);
SmallVector<Value> yieldOperands;
for (auto originalOperand : yield.getOperands()) {
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
Value remappedValue = rewriter.getRemappedValue(originalOperand);
if (!remappedValue) {
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
}
yieldOperands.push_back(remappedValue);
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);

return success();
}

LogicalResult
Expand All @@ -118,22 +128,32 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());

// Propagate any attributes from the ODS forOp to the lowered emitc::for op.
loweredFor->setAttrs(forOp->getAttrs());

Block *loweredBody = loweredFor.getBody();

// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

// Convert the original region types into the new types by adding unrealized
// casts in the beginning of the loop. This performs the conversion in place.
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
*getTypeConverter(), nullptr))) {
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
}

// Register the replacements for the block arguments and inline the body of
// the scf.for loop into the body of the emitc::for loop.
Block *scfBody = &(forOp.getRegion().front());
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(resultVariables.begin(), resultVariables.end());
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);

auto result = lowerYield(forOp, resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

Block *adaptorBody = &(adaptor.getRegion().front());
rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues);
lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));
if (failed(result)) {
return result;
}

rewriter.replaceOp(forOp, resultVariables);
return success();
Expand Down Expand Up @@ -169,11 +189,16 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
Region &loweredRegion) {
auto lowerRegion = [&resultVariables, &rewriter,
&ifOp](Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
auto result = lowerYield(ifOp, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
if (failed(result)) {
return result;
}
return success();
};

Region &thenRegion = adaptor.getThenRegion();
Expand All @@ -185,11 +210,17 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
auto result = lowerRegion(thenRegion, loweredThenRegion);
if (failed(result)) {
return result;
}

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(elseRegion, loweredElseRegion);
auto result = lowerRegion(elseRegion, loweredElseRegion);
if (failed(result)) {
return result;
}
}

rewriter.replaceOp(ifOp, resultVariables);
Expand Down
58 changes: 51 additions & 7 deletions mlir/test/Conversion/SCFToEmitC/for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,55 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
// CHECK-NEXT: return %[[VAL_4]] : f32
// CHECK-NEXT: }

func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) {
scf.for %i0 = %arg0 to %arg1 step %arg2 {
%c1 = arith.constant 1 : index
} {test.value = 5 : index}
return
func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
%zero = arith.constant 0 : index
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
scf.yield %acc : index
}
return %r : index
}
// CHECK-LABEL: func.func @loop_with_attr
// CHECK: {test.value = 5 : index}

// CHECK-LABEL: func.func @for_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: emitc.assign %[[VAL_4]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_8]] : index
// CHECK: }


func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
%zero = arith.constant 0 : index
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
%sn = arith.addi %acc, %acc : index
scf.yield %sn: index
}
return %r : index
}

// CHECK-LABEL: func.func @for_yield_update_loop_carried_var(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_9]] : index
// CHECK: }
27 changes: 27 additions & 0 deletions mlir/test/Conversion/SCFToEmitC/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,30 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }


func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%x = scf.if %arg0 -> (index) {
scf.yield %0 : index
} else {
scf.yield %1 : index
}
return
}

// CHECK: func.func @test_if_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.if %[[ARG_0]] {
// CHECK: emitc.assign %[[VAL_0]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
// CHECK: } else {
// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
// CHECK: }
// CHECK: return
// CHECK: }