Skip to content

Commit

Permalink
[mlir][Transforms][NFC] Dialect Conversion: Move argument materializa…
Browse files Browse the repository at this point in the history
…tion logic (#98805)

This commit moves the argument materialization logic from
`legalizeConvertedArgumentTypes` to
`legalizeUnresolvedMaterializations`.

Before this change:
- Argument materializations were created in
`legalizeConvertedArgumentTypes` (which used to call
`materializeLiveConversions`).

After this change:
- `legalizeConvertedArgumentTypes` creates a "placeholder"
`unrealized_conversion_cast`.
- The placeholder `unrealized_conversion_cast` is replaced with an
argument materialization (using the type converter) in
`legalizeUnresolvedMaterializations`.
- All argument and target materializations now take place in the same
location (`legalizeUnresolvedMaterializations`).

This commit brings us closer towards creating all source/target/argument
materializations in one central step, which can then be made optional
(and delegated to the user) in the future. (There is one more source
materialization step that has not been moved yet.)

This commit also consolidates all `build*UnresolvedMaterialization`
functions into a single `buildUnresolvedMaterialization` function.

This is a re-upload of #96329.
  • Loading branch information
matthias-springer authored Aug 3, 2024
1 parent 3eaca31 commit 2fc71e4
Showing 1 changed file with 54 additions and 84 deletions.
138 changes: 54 additions & 84 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}

/// Helper function that computes an insertion point where the given value is
/// defined and can be used without a dominance violation.
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
Block *insertBlock = value.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
if (OpResult inputRes = dyn_cast<OpResult>(value))
insertPt = ++inputRes.getOwner()->getIterator();
return OpBuilder::InsertPoint(insertBlock, insertPt);
}

//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -444,11 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::BlockTypeConversion;
}

/// Materialize any necessary conversions for converted arguments that have
/// live users, using the provided `findLiveUser` to search for a user that
/// survives the conversion process.
LogicalResult
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
Block *getOrigBlock() const { return origBlock; }

const TypeConverter *getConverter() const { return converter; }

void commit(RewriterBase &rewriter) override;

Expand Down Expand Up @@ -829,15 +837,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
Block *insertBlock,
Block::iterator insertPt, Location loc,
OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
const TypeConverter *converter);

Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);

//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -969,49 +972,6 @@ void BlockTypeConversionRewrite::rollback() {
block->replaceAllUsesWith(origBlock);
}

LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
function_ref<Operation *(Value)> findLiveUser) {
// Process the remapping for each of the original arguments.
for (auto it : llvm::enumerate(origBlock->getArguments())) {
BlockArgument origArg = it.value();
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
builder.setInsertionPointToStart(block);

// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
}
if (!newArg) {
InFlightDiagnostic diag =
emitError(origArg.getLoc())
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
}
rewriterImpl.mapping.map(origArg, newArg);
}
return success();
}

void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
Expand Down Expand Up @@ -1184,8 +1144,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type newOperandType = newOperand.getType();
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedTargetMaterialization(
operandLoc, newOperand, desiredType, currentTypeConverter);
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
Expand Down Expand Up @@ -1298,8 +1260,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
/*outputType=*/origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
Expand All @@ -1323,8 +1286,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
MaterializationKind::Argument,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/replArgs, origArgType, converter);
mapping.map(origArg, argMat);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);

Expand All @@ -1342,7 +1306,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
legalOutputType = replArgs[0].getType();
}
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedTargetMaterialization(
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat),
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, targetMat);
}
Expand All @@ -1365,34 +1330,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType, const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
return inputs.front();

// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
OpBuilder builder(outputType.getContext());
builder.setInsertionPoint(insertBlock, insertPt);
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
Block *insertBlock = input.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();

return buildUnresolvedMaterialization(MaterializationKind::Target,
insertBlock, insertPt, loc, input,
outputType, converter);
}

//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
Expand Down Expand Up @@ -2504,9 +2456,9 @@ LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)) ||
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();

// Process requested operation replacements.
Expand Down Expand Up @@ -2562,10 +2514,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
findLiveUser)))
return failure();
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
// Process the remapping for each of the original arguments.
for (Value origArg :
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
// If the type of this argument changed and the argument is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value repl = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(replacementValue),
origArg.getLoc(), /*inputs=*/replacementValue,
/*outputType=*/origArg.getType(),
blockTypeConversionRewrite->getConverter());
rewriterImpl.mapping.map(origArg, repl);
}
}
}
return success();
}
Expand Down

0 comments on commit 2fc71e4

Please sign in to comment.