-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect conversion: add originalType
param to materializations
#112128
[mlir][Transforms] Dialect conversion: add originalType
param to materializations
#112128
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesThis commit adds an Note:
For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists. This commit also puts the This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new Note for LLVM integration: For all argument/source/target materialization functions, move the Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff 31 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
}
static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
- BoxProcType type,
+ mlir::Location loc, BoxProcType type,
mlir::ValueRange inputs,
- mlir::Location loc) {
+ mlir::Type originalType) {
assert(inputs.size() == 1);
return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
/// All of the following materializations require function objects that are
/// convertible to the following form:
- /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+ /// `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
+ ///
+ /// The type that is provided as the 5-th argument is the original type of
+ /// value. For more details, see the documentation below.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
+ ///
+ /// Note: The original type matches the result type `T` for argument
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
/// converting a legal replacement value back to an illegal source type.
/// This is used when some uses of the original, illegal value must persist
/// beyond the main conversion.
+ ///
+ /// Note: The original type matches the result type `T` for source
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
+ ///
+ /// Note: For target materializations, the original type can be
+ /// different from the type of the input. For example, let's assume that a
+ /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+ /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+ /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+ /// that the legalized type of "t1" is "t3", which may be different from
+ /// "t2". In this example, the target materialization callback will be
+ /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+ /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+ /// that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType,
- ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(sourceMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(targetMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
/// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
Type, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a conversion.
+ ///
+ /// Arguments: builder, location, result type, inputs, original type
using MaterializationCallbackFn = std::function<std::optional<Value>(
- OpBuilder &, Type, ValueRange, Location)>;
+ OpBuilder &, Location, Type, ValueRange, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
Value
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType,
- ValueRange inputs) const;
+ ValueRange inputs, Type originalType) const;
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ OpBuilder &builder, Location loc, Type resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc);
+ return callback(builder, loc, derivedType, inputs, originalType);
return std::nullopt;
};
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
/// materializations for 1:N type conversions, which materialize one value in
/// a source type as N values in target types.
using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
+ std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+ TypeRange, Value, Type)>;
/// Creates the mapping of the given range of original types to target types
/// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
/// returns `std::nullopt`.
std::optional<SmallVector<Value>>
materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
+ TypeRange resultTypes, Value input,
+ Type originalType) const;
/// Adds a 1:N target materialization to the converter. Such materializations
/// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
- auto addUnrealizedCast = [](OpBuilder &builder, Type type,
- ValueRange inputs, Location loc) {
+ auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+ ValueRange inputs, Type originalType) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
Location loc = arg.getLoc();
Value newArg = block.insertArgument(argNum, newTy, loc);
Value convertedValue = converter.materializeSourceConversion(
- builder, op->getLoc(), ty, newArg);
+ builder, op->getLoc(), ty, newArg, ty);
if (!convertedValue) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+ MemRefType resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addSourceMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
- addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addTargetMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
- rewriter, loc, indexType, dimSize);
+ rewriter, loc, indexType, dimSize, dimSize.getType());
assert(dimSize && "Invalid memref element type");
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
// All other types legal
return type;
});
- converter.addTargetMaterialization(
- [](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
- });
+ converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+ ValueRange input, Type originalType) {
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
+ });
}
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
- ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+ TensorType type, ValueRange inputs,
+ Type originalType) {
assert(inputs.size() == 1);
assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
- addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
- ValueRange inputs, Location loc) -> Value {
+ addTargetMaterialization([](OpBuilder &builder, Location loc,
+ BaseMemRefType type, ValueRange inputs,
+ Type originalType) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
namespace {
std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
- Type resultType,
+ Location loc, Type resultType,
ValueRange inputs,
- Location loc) {
+ Type originalType) {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
if (input.getType() != type) {
Value newInput = converter.materializeSourceConversion(
- rewriter, input.getLoc(), type, input);
+ rewriter, input.getLoc(), type, input, type);
if (!newInput) {
return emitDefiniteFailure() << "Failed to materialize conversion of "
<< input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
Value convertedOutput = newOutput;
if (output.getType() != newOutput.getType()) {
convertedOutput = converter.materializeTargetConversion(
- rewriter, output.getLoc(), output.getType(), newOutput);
+ rewriter, output.getLoc(), output.getType(), newOutput,
+ output.getType());
if (!convertedOutput) {
return emitDefiniteFailure()
<< "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]
|
@llvm/pr-subscribers-mlir-tensor Author: Matthias Springer (matthias-springer) ChangesThis commit adds an Note:
For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists. This commit also puts the This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new Note for LLVM integration: For all argument/source/target materialization functions, move the Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff 31 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
}
static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
- BoxProcType type,
+ mlir::Location loc, BoxProcType type,
mlir::ValueRange inputs,
- mlir::Location loc) {
+ mlir::Type originalType) {
assert(inputs.size() == 1);
return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
/// All of the following materializations require function objects that are
/// convertible to the following form:
- /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+ /// `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
+ ///
+ /// The type that is provided as the 5-th argument is the original type of
+ /// value. For more details, see the documentation below.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
+ ///
+ /// Note: The original type matches the result type `T` for argument
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
/// converting a legal replacement value back to an illegal source type.
/// This is used when some uses of the original, illegal value must persist
/// beyond the main conversion.
+ ///
+ /// Note: The original type matches the result type `T` for source
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
+ ///
+ /// Note: For target materializations, the original type can be
+ /// different from the type of the input. For example, let's assume that a
+ /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+ /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+ /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+ /// that the legalized type of "t1" is "t3", which may be different from
+ /// "t2". In this example, the target materialization callback will be
+ /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+ /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+ /// that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType,
- ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(sourceMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(targetMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
/// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
Type, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a conversion.
+ ///
+ /// Arguments: builder, location, result type, inputs, original type
using MaterializationCallbackFn = std::function<std::optional<Value>(
- OpBuilder &, Type, ValueRange, Location)>;
+ OpBuilder &, Location, Type, ValueRange, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
Value
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType,
- ValueRange inputs) const;
+ ValueRange inputs, Type originalType) const;
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ OpBuilder &builder, Location loc, Type resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc);
+ return callback(builder, loc, derivedType, inputs, originalType);
return std::nullopt;
};
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
/// materializations for 1:N type conversions, which materialize one value in
/// a source type as N values in target types.
using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
+ std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+ TypeRange, Value, Type)>;
/// Creates the mapping of the given range of original types to target types
/// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
/// returns `std::nullopt`.
std::optional<SmallVector<Value>>
materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
+ TypeRange resultTypes, Value input,
+ Type originalType) const;
/// Adds a 1:N target materialization to the converter. Such materializations
/// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
- auto addUnrealizedCast = [](OpBuilder &builder, Type type,
- ValueRange inputs, Location loc) {
+ auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+ ValueRange inputs, Type originalType) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
Location loc = arg.getLoc();
Value newArg = block.insertArgument(argNum, newTy, loc);
Value convertedValue = converter.materializeSourceConversion(
- builder, op->getLoc(), ty, newArg);
+ builder, op->getLoc(), ty, newArg, ty);
if (!convertedValue) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+ MemRefType resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addSourceMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
- addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addTargetMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
- rewriter, loc, indexType, dimSize);
+ rewriter, loc, indexType, dimSize, dimSize.getType());
assert(dimSize && "Invalid memref element type");
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
// All other types legal
return type;
});
- converter.addTargetMaterialization(
- [](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
- });
+ converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+ ValueRange input, Type originalType) {
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
+ });
}
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
- ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+ TensorType type, ValueRange inputs,
+ Type originalType) {
assert(inputs.size() == 1);
assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
- addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
- ValueRange inputs, Location loc) -> Value {
+ addTargetMaterialization([](OpBuilder &builder, Location loc,
+ BaseMemRefType type, ValueRange inputs,
+ Type originalType) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
namespace {
std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
- Type resultType,
+ Location loc, Type resultType,
ValueRange inputs,
- Location loc) {
+ Type originalType) {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
if (input.getType() != type) {
Value newInput = converter.materializeSourceConversion(
- rewriter, input.getLoc(), type, input);
+ rewriter, input.getLoc(), type, input, type);
if (!newInput) {
return emitDefiniteFailure() << "Failed to materialize conversion of "
<< input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
Value convertedOutput = newOutput;
if (output.getType() != newOutput.getType()) {
convertedOutput = converter.materializeTargetConversion(
- rewriter, output.getLoc(), output.getType(), newOutput);
+ rewriter, output.getLoc(), output.getType(), newOutput,
+ output.getType());
if (!convertedOutput) {
return emitDefiniteFailure()
<< "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]
|
@llvm/pr-subscribers-mlir-sparse Author: Matthias Springer (matthias-springer) ChangesThis commit adds an Note:
For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists. This commit also puts the This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new Note for LLVM integration: For all argument/source/target materialization functions, move the Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff 31 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
}
static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
- BoxProcType type,
+ mlir::Location loc, BoxProcType type,
mlir::ValueRange inputs,
- mlir::Location loc) {
+ mlir::Type originalType) {
assert(inputs.size() == 1);
return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
/// All of the following materializations require function objects that are
/// convertible to the following form:
- /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+ /// `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
+ ///
+ /// The type that is provided as the 5-th argument is the original type of
+ /// value. For more details, see the documentation below.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
+ ///
+ /// Note: The original type matches the result type `T` for argument
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
/// converting a legal replacement value back to an illegal source type.
/// This is used when some uses of the original, illegal value must persist
/// beyond the main conversion.
+ ///
+ /// Note: The original type matches the result type `T` for source
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
+ ///
+ /// Note: For target materializations, the original type can be
+ /// different from the type of the input. For example, let's assume that a
+ /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+ /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+ /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+ /// that the legalized type of "t1" is "t3", which may be different from
+ /// "t2". In this example, the target materialization callback will be
+ /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+ /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+ /// that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType,
- ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(sourceMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(targetMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
/// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
Type, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a conversion.
+ ///
+ /// Arguments: builder, location, result type, inputs, original type
using MaterializationCallbackFn = std::function<std::optional<Value>(
- OpBuilder &, Type, ValueRange, Location)>;
+ OpBuilder &, Location, Type, ValueRange, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
Value
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType,
- ValueRange inputs) const;
+ ValueRange inputs, Type originalType) const;
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ OpBuilder &builder, Location loc, Type resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc);
+ return callback(builder, loc, derivedType, inputs, originalType);
return std::nullopt;
};
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
/// materializations for 1:N type conversions, which materialize one value in
/// a source type as N values in target types.
using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
+ std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+ TypeRange, Value, Type)>;
/// Creates the mapping of the given range of original types to target types
/// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
/// returns `std::nullopt`.
std::optional<SmallVector<Value>>
materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
+ TypeRange resultTypes, Value input,
+ Type originalType) const;
/// Adds a 1:N target materialization to the converter. Such materializations
/// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
- auto addUnrealizedCast = [](OpBuilder &builder, Type type,
- ValueRange inputs, Location loc) {
+ auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+ ValueRange inputs, Type originalType) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
Location loc = arg.getLoc();
Value newArg = block.insertArgument(argNum, newTy, loc);
Value convertedValue = converter.materializeSourceConversion(
- builder, op->getLoc(), ty, newArg);
+ builder, op->getLoc(), ty, newArg, ty);
if (!convertedValue) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+ MemRefType resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addSourceMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
- addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addTargetMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
- rewriter, loc, indexType, dimSize);
+ rewriter, loc, indexType, dimSize, dimSize.getType());
assert(dimSize && "Invalid memref element type");
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
// All other types legal
return type;
});
- converter.addTargetMaterialization(
- [](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
- });
+ converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+ ValueRange input, Type originalType) {
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
+ });
}
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
- ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+ TensorType type, ValueRange inputs,
+ Type originalType) {
assert(inputs.size() == 1);
assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
- addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
- ValueRange inputs, Location loc) -> Value {
+ addTargetMaterialization([](OpBuilder &builder, Location loc,
+ BaseMemRefType type, ValueRange inputs,
+ Type originalType) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
namespace {
std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
- Type resultType,
+ Location loc, Type resultType,
ValueRange inputs,
- Location loc) {
+ Type originalType) {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
if (input.getType() != type) {
Value newInput = converter.materializeSourceConversion(
- rewriter, input.getLoc(), type, input);
+ rewriter, input.getLoc(), type, input, type);
if (!newInput) {
return emitDefiniteFailure() << "Failed to materialize conversion of "
<< input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
Value convertedOutput = newOutput;
if (output.getType() != newOutput.getType()) {
convertedOutput = converter.materializeTargetConversion(
- rewriter, output.getLoc(), output.getType(), newOutput);
+ rewriter, output.getLoc(), output.getType(), newOutput,
+ output.getType());
if (!convertedOutput) {
return emitDefiniteFailure()
<< "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]
|
@llvm/pr-subscribers-mlir-scf Author: Matthias Springer (matthias-springer) ChangesThis commit adds an Note:
For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists. This commit also puts the This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new Note for LLVM integration: For all argument/source/target materialization functions, move the Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff 31 Files Affected:
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
}
static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
- BoxProcType type,
+ mlir::Location loc, BoxProcType type,
mlir::ValueRange inputs,
- mlir::Location loc) {
+ mlir::Type originalType) {
assert(inputs.size() == 1);
return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
/// All of the following materializations require function objects that are
/// convertible to the following form:
- /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+ /// `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
+ ///
+ /// The type that is provided as the 5-th argument is the original type of
+ /// value. For more details, see the documentation below.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
+ ///
+ /// Note: The original type matches the result type `T` for argument
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
/// converting a legal replacement value back to an illegal source type.
/// This is used when some uses of the original, illegal value must persist
/// beyond the main conversion.
+ ///
+ /// Note: The original type matches the result type `T` for source
+ /// materializations.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
+ ///
+ /// Note: For target materializations, the original type can be
+ /// different from the type of the input. For example, let's assume that a
+ /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+ /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+ /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+ /// that the legalized type of "t1" is "t3", which may be different from
+ /// "t2". In this example, the target materialization callback will be
+ /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+ /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+ /// that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
+ std::decay_t<FnT>>::template arg_t<2>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType,
- ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(sourceMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
+ Type resultType, ValueRange inputs,
+ Type originalType) const {
return materializeConversion(targetMaterializations, builder, loc,
- resultType, inputs);
+ resultType, inputs, originalType);
}
/// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
Type, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a conversion.
+ ///
+ /// Arguments: builder, location, result type, inputs, original type
using MaterializationCallbackFn = std::function<std::optional<Value>(
- OpBuilder &, Type, ValueRange, Location)>;
+ OpBuilder &, Location, Type, ValueRange, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
Value
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType,
- ValueRange inputs) const;
+ ValueRange inputs, Type originalType) const;
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ OpBuilder &builder, Location loc, Type resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc);
+ return callback(builder, loc, derivedType, inputs, originalType);
return std::nullopt;
};
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
/// materializations for 1:N type conversions, which materialize one value in
/// a source type as N values in target types.
using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
+ std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+ TypeRange, Value, Type)>;
/// Creates the mapping of the given range of original types to target types
/// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
/// returns `std::nullopt`.
std::optional<SmallVector<Value>>
materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
+ TypeRange resultTypes, Value input,
+ Type originalType) const;
/// Adds a 1:N target materialization to the converter. Such materializations
/// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
- auto addUnrealizedCast = [](OpBuilder &builder, Type type,
- ValueRange inputs, Location loc) {
+ auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+ ValueRange inputs, Type originalType) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
};
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
Location loc = arg.getLoc();
Value newArg = block.insertArgument(argNum, newTy, loc);
Value convertedValue = converter.materializeSourceConversion(
- builder, op->getLoc(), ty, newArg);
+ builder, op->getLoc(), ty, newArg, ty);
if (!convertedValue) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+ ValueRange inputs, Type originalType) -> std::optional<Value> {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+ MemRefType resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addSourceMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
- addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs,
- Location loc) -> std::optional<Value> {
+ addTargetMaterialization([&](OpBuilder &builder, Location loc,
+ Type resultType, ValueRange inputs,
+ Type originalType) -> std::optional<Value> {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
Type indexType = getIndexType();
if (dimSize.getType() != indexType)
dimSize = typeConverter->materializeTargetConversion(
- rewriter, loc, indexType, dimSize);
+ rewriter, loc, indexType, dimSize, dimSize.getType());
assert(dimSize && "Invalid memref element type");
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
// All other types legal
return type;
});
- converter.addTargetMaterialization(
- [](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
- });
+ converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+ ValueRange input, Type originalType) {
+ auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+ extFOp.setFastmath(arith::FastMathFlags::contract);
+ return extFOp;
+ });
}
void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
- ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+ TensorType type, ValueRange inputs,
+ Type originalType) {
assert(inputs.size() == 1);
assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
- addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
- ValueRange inputs, Location loc) -> Value {
+ addTargetMaterialization([](OpBuilder &builder, Location loc,
+ BaseMemRefType type, ValueRange inputs,
+ Type originalType) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
namespace {
std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
- Type resultType,
+ Location loc, Type resultType,
ValueRange inputs,
- Location loc) {
+ Type originalType) {
if (inputs.size() != 1)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
if (input.getType() != type) {
Value newInput = converter.materializeSourceConversion(
- rewriter, input.getLoc(), type, input);
+ rewriter, input.getLoc(), type, input, type);
if (!newInput) {
return emitDefiniteFailure() << "Failed to materialize conversion of "
<< input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
Value convertedOutput = newOutput;
if (output.getType() != newOutput.getType()) {
convertedOutput = converter.materializeTargetConversion(
- rewriter, output.getLoc(), output.getType(), newOutput);
+ rewriter, output.getLoc(), output.getType(), newOutput,
+ output.getType());
if (!convertedOutput) {
return emitDefiniteFailure()
<< "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]
|
originalType
param to materializationoriginalType
param to materializations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see uses of materializeTargetConversion
within normal conversion patterns where its non-obvious to me as to how that method should now be used.
With your example of "v2" replacing "v1", the adaptor of the conversion pattern would then receive "v2", is that right?
Isn't t1 then also inaccessible to the conversion pattern? Or can the case you describe not happen during conversion pattern application OR should materializeTargetConversion
not actually ever be called by conversion patterns?
Ideally, patterns should not call this function at all. It's not really an API violation. I'm just saying that it should not be necessary. The uses that I've seen in upstream MLIR are workarounds around 1:N limitations and they are going to disappear soon. Worst case, users can always pass
Almost. The adaptor receives "v2" only if the type of "v2" is valid according to the pattern's type converter. Otherwise, we will insert another target materialization from "v2" -> "t3" and that value is then passed to the adaptor. It may not be possible to compute this target materialization without having access to the original type. That's what's happening during the MemRef->LLVM lowering that I mentioned in the commit message. The logic for this is in |
2c438b0
to
d245912
Compare
d245912
to
80bc49f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now, thank you!
Ideally, patterns should not call this function at all. It's not really an API violation. I'm just saying that it should not be necessary. The uses that I've seen in upstream MLIR are workarounds around 1:N limitations and they are going to disappear soon.
Once your 1:N changes are merged, I think I'd be absolutely amazing to get rid of the API entirely then. I know at least from my own code and other downstream code that it is used rather defensively. But if it is really not needed then it'd help to conceptually simplify dialect conversion by not making these functions public.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Interesting corner case -- glad you caught it. This makes sense to fix even without the merging. But I am looking forward to the merging anyways :)
…terializations (llvm#112128) This commit adds an optional `originalType` parameter to target materialization functions. Without this parameter, target materializations are underspecified. Note: `originalType` is only needed for target materializations. Source/argument materializations do not have it. Consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the `originalType` parameter is added. This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.
…terializations (llvm#112128) This commit adds an optional `originalType` parameter to target materialization functions. Without this parameter, target materializations are underspecified. Note: `originalType` is only needed for target materializations. Source/argument materializations do not have it. Consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the `originalType` parameter is added. This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.
…terializations (llvm#112128) This commit adds an optional `originalType` parameter to target materialization functions. Without this parameter, target materializations are underspecified. Note: `originalType` is only needed for target materializations. Source/argument materializations do not have it. Consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the `originalType` parameter is added. This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.
This commit adds an optional
originalType
parameter to target materialization functions. Without this parameter, target materializations are underspecified.Note:
originalType
is only needed for target materializations. Source/argument materializations do not have it.Consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the
originalType
parameter is added.This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new
originalType
parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an!llvm.struct
that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type =!llvm.struct<...>
) from just the bare pointer (inputs =!llvm.ptr
). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.