Skip to content

Commit

Permalink
[mlir][Transforms] Dialect conversion: Simplify handling of dropped a…
Browse files Browse the repository at this point in the history
…rguments

This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.

When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new
BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed.

To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments.

Other minor improvements:

Improve variable name: origOutputType -> origArgType. Add an assertion to check that this field is only used for argument materializations.
Add more comments to applySignatureConversion.
Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes.

This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion.

This is a re-upload of #96207.
  • Loading branch information
matthias-springer committed Jul 13, 2024
1 parent 5773176 commit b0b7813
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 122 deletions.
173 changes: 55 additions & 118 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};

/// This structure contains the information pertaining to an argument that has
/// been converted.
struct ConvertedArgInfo {
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
Value castValue = nullptr)
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}

/// The start index of in the new argument list that contains arguments that
/// replace the original.
unsigned newArgIdx;

/// The number of arguments that replaced the original argument.
unsigned newArgSize;

/// The cast value that was created to cast from the new arguments to the
/// old. This only used if 'newArgSize' > 1.
Value castValue;
};

/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
const TypeConverter *converter)
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block, Block *origBlock,
const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
origBlock(origBlock), converter(converter) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
Expand All @@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
/// The original block that was requested to have its signature converted.
Block *origBlock;

/// The conversion information for each of the arguments. The information is
/// std::nullopt if the argument was dropped during conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;

/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
Expand Down Expand Up @@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
/// The type of materialization.
enum MaterializationKind {
/// This materialization materializes a conversion for an illegal block
/// argument type, to a legal one.
/// argument type, to the original one.
Argument,

/// This materialization materializes a conversion from an illegal type to a
/// legal one.
Target
Target,

/// This materialization materializes a conversion from a legal type back to
/// an illegal one.
Source
};

/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
Expand Down Expand Up @@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;
};
} // namespace
Expand Down Expand Up @@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
const TypeConverter *converter);

Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type outputType,
const TypeConverter *converter);

Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
Expand Down Expand Up @@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);

// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
// Handle the case of a 1->0 value mapping.
if (!info) {
if (Value newArg =
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}

// Otherwise this is a 1->1+ value mapping.
Value castValue = info->castValue;
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");

// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
castValue, origArg.getType()));
}
}
}

void BlockTypeConversionRewrite::rollback() {
Expand All @@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
bool isDroppedArg = replacementValue == origArg;
if (!isDroppedArg)
builder.setInsertionPointAfterValue(replacementValue);
assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(),
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
Expand All @@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
if (!isDroppedArg)
diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
Expand Down Expand Up @@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);

// Remap each of the original arguments as determined by the signature
// conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
argInfo.resize(origArgCount);

for (unsigned i = 0; i != origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
continue;
BlockArgument origArg = block->getArgument(i);
Type origArgType = origArg.getType();

std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
signatureConversion.getInputMapping(i);
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}

// If inputMap->replacementValue is not nullptr, then the argument is
// dropped and a replacement value is provided to be the remappedValue.
if (inputMap->replacementValue) {
if (Value repl = inputMap->replacementValue) {
// This block argument was dropped and a replacement value was provided.
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}

// Otherwise, this is a 1->1+ mapping.
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
// dialect conversion. Therefore, we need an argument materialization to
// turn the replacement block arguments into a single SSA value that can be
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value newArg;
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
mapping.map(origArg, argMat);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);

// If this is a 1->1 mapping and the types of new and replacement arguments
// match (i.e. it's an identity map), then the argument is mapped to its
// original type.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
mapping.map(origArg, newArg);
} else {
// Build argument materialization: new block arguments -> old block
// argument type.
Value argMat = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
mapping.map(origArg, argMat);

// Build target materialization: old block argument type -> legal type.
// Note: This function returns an "empty" type if no valid conversion to
// a legal type exists. In that case, we continue the conversion with the
// original block argument type.
Type legalOutputType = converter->convertType(origArg.getType());
if (legalOutputType && legalOutputType != origArg.getType()) {
newArg = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, newArg);
} else {
newArg = argMat;
}
Type legalOutputType;
if (converter)
legalOutputType = converter->convertType(origArgType);
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, targetMat);
}

appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}

appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);

// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
Expand Down Expand Up @@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
Block *block, Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
Expand Down Expand Up @@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
}
if (newMaterialization) {
assert(newMaterialization.getType() == outputType &&
Expand All @@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(

InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
"from "
<< inputOperands.getTypes() << " to " << outputType
"from ("
<< inputOperands.getTypes() << ") to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Transforms/test-legalize-type-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@


func.func @test_invalid_arg_materialization(
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
%arg0: i16) {
// expected-note@below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}

Expand Down Expand Up @@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
^bb0(%arg0: f32):
// expected-note@below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
Expand Down

0 comments on commit b0b7813

Please sign in to comment.