Skip to content

Commit

Permalink
Use delayed conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd committed Nov 21, 2024
1 parent 5ed59a8 commit b3f9b8e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
53 changes: 31 additions & 22 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,38 +81,34 @@ createVariablesForResults(T op, const TypeConverter *typeConverter,
// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc,
const TypeConverter *typeConverter = nullptr) {
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}

static void lowerYield(ValueRange resultVariables,
ConversionPatternRewriter &rewriter, scf::YieldOp yield,
const TypeConverter *typeConverter) {
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
Location loc = yield.getLoc();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

SmallVector<Value> yieldOperands;
for (auto originalOperand : yield.getOperands()) {
Value operand = originalOperand;

if (typeConverter && !typeConverter->isLegal(operand.getType())) {
Type resultType = typeConverter->convertType(operand.getType());
auto castToTarget =
rewriter.create<UnrealizedConversionCastOp>(loc, resultType, operand);
operand = castToTarget.getResult(0);
Value remappedValue = rewriter.getRemappedValue(originalOperand);
if (!remappedValue) {
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
}

yieldOperands.push_back(operand);
yieldOperands.push_back(remappedValue);
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);

return success();
}

LogicalResult
Expand Down Expand Up @@ -153,9 +149,12 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
replacingValues.append(resultVariables.begin(), resultVariables.end());
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);

lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()),
getTypeConverter());
auto result = lowerYield(forOp, resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

if (failed(result)) {
return result;
}

rewriter.replaceOp(forOp, resultVariables);
return success();
Expand Down Expand Up @@ -192,11 +191,15 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter,
this](Region &region, Region &loweredRegion) {
&ifOp](Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator),
getTypeConverter());
auto result = lowerYield(ifOp, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
if (failed(result)) {
return result;
}
return success();
};

Region &thenRegion = adaptor.getThenRegion();
Expand All @@ -208,11 +211,17 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
auto result = lowerRegion(thenRegion, loweredThenRegion);
if (failed(result)) {
return result;
}

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(elseRegion, loweredElseRegion);
auto result = lowerRegion(elseRegion, loweredElseRegion);
if (failed(result)) {
return result;
}
}

rewriter.replaceOp(ifOp, resultVariables);
Expand Down
4 changes: 1 addition & 3 deletions mlir/test/Conversion/SCFToEmitC/for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index
// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] {
// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_7]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: emitc.assign %[[VAL_4]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t
// CHECK: }
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index
// CHECK: return %[[VAL_8]] : index
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Conversion/SCFToEmitC/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ func.func @test_if_yield_index(%arg0: i1, %arg1: f32) {
// CHECK: func.func @test_if_yield_index(
// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_0:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t
// CHECK: emitc.if %[[ARG_0]] {
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_0]] : !emitc.size_t
// CHECK: emitc.assign %[[VAL_0]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
// CHECK: } else {
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t
// CHECK: emitc.assign %[[VAL_2]] : !emitc.size_t to %[[VAL_0]] : !emitc.size_t
// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t
// CHECK: }
// CHECK: return
// CHECK: }

0 comments on commit b3f9b8e

Please sign in to comment.