Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic #98805

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 54 additions & 84 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}

/// Helper function that computes an insertion point where the given value is
/// defined and can be used without a dominance violation.
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
Block *insertBlock = value.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
if (OpResult inputRes = dyn_cast<OpResult>(value))
insertPt = ++inputRes.getOwner()->getIterator();
return OpBuilder::InsertPoint(insertBlock, insertPt);
}

//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -444,11 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::BlockTypeConversion;
}

/// Materialize any necessary conversions for converted arguments that have
/// live users, using the provided `findLiveUser` to search for a user that
/// survives the conversion process.
LogicalResult
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
Block *getOrigBlock() const { return origBlock; }

const TypeConverter *getConverter() const { return converter; }

void commit(RewriterBase &rewriter) override;

Expand Down Expand Up @@ -829,15 +837,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
Block *insertBlock,
Block::iterator insertPt, Location loc,
OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
const TypeConverter *converter);

Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);

//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -969,49 +972,6 @@ void BlockTypeConversionRewrite::rollback() {
block->replaceAllUsesWith(origBlock);
}

LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
function_ref<Operation *(Value)> findLiveUser) {
// Process the remapping for each of the original arguments.
for (auto it : llvm::enumerate(origBlock->getArguments())) {
BlockArgument origArg = it.value();
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
builder.setInsertionPointToStart(block);

// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
}
if (!newArg) {
InFlightDiagnostic diag =
emitError(origArg.getLoc())
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
}
rewriterImpl.mapping.map(origArg, newArg);
}
return success();
}

void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
Expand Down Expand Up @@ -1184,8 +1144,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type newOperandType = newOperand.getType();
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedTargetMaterialization(
operandLoc, newOperand, desiredType, currentTypeConverter);
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
Expand Down Expand Up @@ -1298,8 +1260,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
/*outputType=*/origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
Expand All @@ -1323,8 +1286,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
MaterializationKind::Argument,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/replArgs, origArgType, converter);
mapping.map(origArg, argMat);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);

Expand All @@ -1342,7 +1306,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
legalOutputType = replArgs[0].getType();
}
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedTargetMaterialization(
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat),
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, targetMat);
}
Expand All @@ -1365,34 +1330,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType, const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
return inputs.front();

// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
OpBuilder builder(outputType.getContext());
builder.setInsertionPoint(insertBlock, insertPt);
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
Block *insertBlock = input.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();

return buildUnresolvedMaterialization(MaterializationKind::Target,
insertBlock, insertPt, loc, input,
outputType, converter);
}

//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
Expand Down Expand Up @@ -2504,9 +2456,9 @@ LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)) ||
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();

// Process requested operation replacements.
Expand Down Expand Up @@ -2562,10 +2514,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
findLiveUser)))
return failure();
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
// Process the remapping for each of the original arguments.
for (Value origArg :
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
// If the type of this argument changed and the argument is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value repl = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(replacementValue),
origArg.getLoc(), /*inputs=*/replacementValue,
/*outputType=*/origArg.getType(),
blockTypeConversionRewrite->getConverter());
rewriterImpl.mapping.map(origArg, repl);
}
}
}
return success();
}
Expand Down
Loading