Skip to content

Commit

Permalink
Revert "[mlir][Transforms][NFC] Dialect Conversion: Move argument mat…
Browse files Browse the repository at this point in the history
…erialization logic (#96329)"

This reverts commit c01ce79. It depends
on f1e0657 which breaks SCF lowering.
  • Loading branch information
d0k committed Jun 27, 2024
1 parent b5cc19e commit 605098d
Showing 1 changed file with 81 additions and 52 deletions.
133 changes: 81 additions & 52 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}

/// Helper function that computes an insertion point where the given value is
/// defined and can be used without a dominance violation.
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
Block *insertBlock = value.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
if (OpResult inputRes = dyn_cast<OpResult>(value))
insertPt = ++inputRes.getOwner()->getIterator();
return OpBuilder::InsertPoint(insertBlock, insertPt);
}

//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -455,9 +445,11 @@ class BlockTypeConversionRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::BlockTypeConversion;
}

Block *getOrigBlock() const { return origBlock; }

const TypeConverter *getConverter() const { return converter; }
/// Materialize any necessary conversions for converted arguments that have
/// live users, using the provided `findLiveUser` to search for a user that
/// survives the conversion process.
LogicalResult
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);

void commit(RewriterBase &rewriter) override;

Expand Down Expand Up @@ -849,10 +841,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
OpBuilder::InsertPoint ip, Location loc,
Block *insertBlock,
Block::iterator insertPt, Location loc,
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);

//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
Expand Down Expand Up @@ -985,6 +981,49 @@ void BlockTypeConversionRewrite::rollback() {
block->replaceAllUsesWith(origBlock);
}

LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
function_ref<Operation *(Value)> findLiveUser) {
// Process the remapping for each of the original arguments.
for (auto it : llvm::enumerate(origBlock->getArguments())) {
BlockArgument origArg = it.value();
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
builder.setInsertionPointToStart(block);

// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
}
if (!newArg) {
InFlightDiagnostic diag =
emitError(origArg.getLoc())
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
}
rewriterImpl.mapping.map(origArg, newArg);
}
return success();
}

void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
Expand Down Expand Up @@ -1157,10 +1196,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type newOperandType = newOperand.getType();
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
/*origArgType=*/{}, currentTypeConverter);
Value castValue = buildUnresolvedTargetMaterialization(
operandLoc, newOperand, desiredType, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
Expand Down Expand Up @@ -1288,9 +1325,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*origArgType=*/{}, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
Expand All @@ -1315,9 +1351,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Argument,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/replArgs,
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs,
/*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
Expand All @@ -1339,22 +1374,34 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType, Type origArgType,
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
Location loc, ValueRange inputs, Type outputType, Type origArgType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
return inputs.front();

// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
OpBuilder builder(ip.getBlock(), ip.getPoint());
OpBuilder builder(insertBlock, insertPt);
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
origArgType);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
Block *insertBlock = input.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();

return buildUnresolvedMaterialization(
MaterializationKind::Target, insertBlock, insertPt, loc, input,
outputType, /*origArgType=*/{}, converter);
}

//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
Expand Down Expand Up @@ -2468,9 +2515,9 @@ LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)) ||
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();

// Process requested operation replacements.
Expand Down Expand Up @@ -2526,28 +2573,10 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
// Process the remapping for each of the original arguments.
for (Value origArg :
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
// If the type of this argument changed and the argument is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value repl = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(replacementValue),
origArg.getLoc(), /*inputs=*/replacementValue,
/*outputType=*/origArg.getType(), /*origArgType=*/{},
blockTypeConversionRewrite->getConverter());
rewriterImpl.mapping.map(origArg, repl);
}
}
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
findLiveUser)))
return failure();
}
return success();
}
Expand Down

0 comments on commit 605098d

Please sign in to comment.