Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix yield conversion of scf.if/scf.for to emitc #401

Merged
merged 8 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -79,22 +80,36 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
ConversionPatternRewriter &rewriter, Location loc) {
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc,
const TypeConverter *typeConverter = nullptr) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}

static void lowerYield(SmallVector<Value> &resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
static void lowerYield(ValueRange resultVariables,
ConversionPatternRewriter &rewriter, scf::YieldOp yield,
const TypeConverter *typeConverter) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);
SmallVector<Value> yieldOperands;
for (auto originalOperand : yield.getOperands()) {
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
Value operand = originalOperand;

if (typeConverter && !typeConverter->isLegal(operand.getType())) {
Type resultType = typeConverter->convertType(operand.getType());
auto castToTarget =
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
rewriter.create<UnrealizedConversionCastOp>(loc, resultType, operand);
operand = castToTarget.getResult(0);
}

yieldOperands.push_back(operand);
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);
Expand All @@ -118,22 +133,29 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());

// Propagate any attributes from the ODS forOp to the lowered emitc::for op.
loweredFor->setAttrs(forOp->getAttrs());

Block *loweredBody = loweredFor.getBody();

// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

// Convert the original region types into the new types by adding unrealized
// casts in the beginning of the loop. This performs the conversion in place.
josel-amd marked this conversation as resolved.
Show resolved Hide resolved
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
*getTypeConverter(), nullptr))) {
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
}

// Register the replacements for the block arguments and inline the body of
// the scf.for loop into the body of the emitc::for loop.
Block *scfBody = &(forOp.getRegion().front());
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(resultVariables.begin(), resultVariables.end());
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);

Block *adaptorBody = &(adaptor.getRegion().front());
rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues);
lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));
cast<scf::YieldOp>(loweredBody->getTerminator()),
getTypeConverter());

rewriter.replaceOp(forOp, resultVariables);
return success();
Expand Down Expand Up @@ -169,11 +191,12 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
Region &loweredRegion) {
auto lowerRegion = [&resultVariables, &rewriter,
this](Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator),
getTypeConverter());
};

Region &thenRegion = adaptor.getThenRegion();
Expand Down
60 changes: 53 additions & 7 deletions mlir/test/Conversion/SCFToEmitC/for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,57 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
// CHECK-NEXT: return %[[VAL_4]] : f32
// CHECK-NEXT: }

func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) {
scf.for %i0 = %arg0 to %arg1 step %arg2 {
%c1 = arith.constant 1 : index
} {test.value = 5 : index}
return
func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
%zero = arith.constant 0 : index
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
scf.yield %acc : index
}
return %r : index
}
// CHECK-LABEL: func.func @loop_with_attr
// CHECK: {test.value = 5 : index}

// CHECK-LABEL: func.func @for_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_7]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_8]] : index
// CHECK: }


func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
%zero = arith.constant 0 : index
%r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index {
%sn = arith.addi %acc, %acc : index
scf.yield %sn: index
}
return %r : index
}

// CHECK-LABEL: func.func @for_yield_update_loop_carried_var(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index {
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_9]] : index
// CHECK: }
27 changes: 27 additions & 0 deletions mlir/test/Conversion/SCFToEmitC/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,30 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }


func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%x = scf.if %arg0 -> (index) {
scf.yield %0 : index
} else {
scf.yield %1 : index
}
return
}

// CHECK: func.func @test_if_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_0:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.if %[[ARG_0]] {
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_0]] : !emitc.size_t
// CHECK: } else {
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_2]] : !emitc.size_t to %[[VAL_0]] : !emitc.size_t
// CHECK: }
// CHECK: return
// CHECK: }