Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: add originalType param to materializations #112128

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 13, 2024

This commit adds an optional originalType parameter to target materialization functions. Without this parameter, target materializations are underspecified.

Note: originalType is only needed for target materializations. Source/argument materializations do not have it.

Consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter is added.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new originalType parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an !llvm.struct that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = !llvm.struct<...>) from just the bare pointer (inputs = !llvm.ptr). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

This commit adds an originalType parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: originalType is only needed for target materializations. For source/argument materializations, originalType always matches outputType. However, to keep the code base simple (i.e., reuse MaterializationCallbackFn for all three materializations), originalType is passed to all three materializations, even though it is only really needed for target materializations.

originalType is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the Location parameter right after the OpBuilder parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new originalType parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an !llvm.struct that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = !llvm.struct&lt;...&gt;) from just the bare pointer (inputs = !llvm.ptr). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the Location parameter to the second position and add a Type originalType parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match inputs.front().getType(). If the original type cannot be recovered (which is unlikely), pass Type().


Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff

31 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+2-2)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+40-16)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+4-3)
  • (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2)
  • (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (+1-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-11)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+10-10)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+6-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+22-20)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp (+18-16)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+33-19)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+9-5)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+10-7)
  • (modified) mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp (+3-2)
  • (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+6-4)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+11-9)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-3)
  • (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (+3-2)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
   }
 
   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
-                                          BoxProcType type,
+                                          mlir::Location loc, BoxProcType type,
                                           mlir::ValueRange inputs,
-                                          mlir::Location loc) {
+                                          mlir::Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
                                      inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
 
   /// All of the following materializations require function objects that are
   /// convertible to the following form:
-  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  ///   `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// The type that is provided as the 5-th argument is the original type of
+  /// value. For more details, see the documentation below.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: The original type matches the result type `T` for argument
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
   /// converting a legal replacement value back to an illegal source type.
   /// This is used when some uses of the original, illegal value must persist
   /// beyond the main conversion.
+  ///
+  /// Note: The original type matches the result type `T` for source
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, the original type can be
+  /// different from the type of the input. For example, let's assume that a
+  /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+  /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+  /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+  /// that the legalized type of "t1" is "t3", which may be different from
+  /// "t2". In this example, the target materialization callback will be
+  /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+  /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+  /// that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
+                                      Type resultType, ValueRange inputs,
+                                      Type originalType) const {
     return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
 
   /// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
       Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
+  ///
+  /// Arguments: builder, location, result type, inputs, original type
   using MaterializationCallbackFn = std::function<std::optional<Value>(
-      OpBuilder &, Type, ValueRange, Location)>;
+      OpBuilder &, Location, Type, ValueRange, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
   Value
   materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
                         OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
+                        ValueRange inputs, Type originalType) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc) -> std::optional<Value> {
+               OpBuilder &builder, Location loc, Type resultType,
+               ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc);
+        return callback(builder, loc, derivedType, inputs, originalType);
       return std::nullopt;
     };
   }
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// materializations for 1:N type conversions, which materialize one value in
   /// a source type as N values in target types.
   using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+                                                      TypeRange, Value, Type)>;
 
   /// Creates the mapping of the given range of original types to target types
   /// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// returns `std::nullopt`.
   std::optional<SmallVector<Value>>
   materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
+                              TypeRange resultTypes, Value input,
+                              Type originalType) const;
 
   /// Adds a 1:N target materialization to the converter. Such materializations
   /// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
+    auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) {
       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
       return std::optional<Value>(cast.getResult(0));
     };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
     Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
     Value convertedValue = converter.materializeSourceConversion(
-        builder, op->getLoc(), ty, newArg);
+        builder, op->getLoc(), ty, newArg, ty);
     if (!convertedValue) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // insert a target materialization from the original block argument type to
   // a legal type.
   addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
+      [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+          ValueRange inputs, Type originalType) -> std::optional<Value> {
         if (inputs.size() == 1) {
           // Bare pointers are not supported for unranked memrefs because a
           // memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
             .getResult(0);
       });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
+  addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+                                 MemRefType resultType, ValueRange inputs,
+                                 Type originalType) -> std::optional<Value> {
     Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addSourceMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addTargetMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
             dimSize = typeConverter->materializeTargetConversion(
-                rewriter, loc, indexType, dimSize);
+                rewriter, loc, indexType, dimSize, dimSize.getType());
           assert(dimSize && "Invalid memref element type");
         }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
     // All other types legal
     return type;
   });
-  converter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+                                        ValueRange input, Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
-                                 ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+                                 TensorType type, ValueRange inputs,
+                                 Type originalType) {
   assert(inputs.size() == 1);
   assert(isa<BaseMemRefType>(inputs[0].getType()));
   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
   addArgumentMaterialization(materializeToTensor);
   addSourceMaterialization(materializeToTensor);
-  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
-                              ValueRange inputs, Location loc) -> Value {
+  addTargetMaterialization([](OpBuilder &builder, Location loc,
+                              BaseMemRefType type, ValueRange inputs,
+                              Type originalType) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
 namespace {
 
 std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
-                                                 Type resultType,
+                                                 Location loc, Type resultType,
                                                  ValueRange inputs,
-                                                 Location loc) {
+                                                 Type originalType) {
   if (inputs.size() != 1)
     return std::nullopt;
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          rewriter, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input, type);
       if (!newInput) {
         return emitDefiniteFailure() << "Failed to materialize conversion of "
                                      << input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          rewriter, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput,
+          output.getType());
       if (!convertedOutput) {
         return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Matthias Springer (matthias-springer)

Changes

This commit adds an originalType parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: originalType is only needed for target materializations. For source/argument materializations, originalType always matches outputType. However, to keep the code base simple (i.e., reuse MaterializationCallbackFn for all three materializations), originalType is passed to all three materializations, even though it is only really needed for target materializations.

originalType is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the Location parameter right after the OpBuilder parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new originalType parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an !llvm.struct that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = !llvm.struct&lt;...&gt;) from just the bare pointer (inputs = !llvm.ptr). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the Location parameter to the second position and add a Type originalType parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match inputs.front().getType(). If the original type cannot be recovered (which is unlikely), pass Type().


Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff

31 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+2-2)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+40-16)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+4-3)
  • (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2)
  • (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (+1-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-11)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+10-10)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+6-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+22-20)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp (+18-16)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+33-19)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+9-5)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+10-7)
  • (modified) mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp (+3-2)
  • (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+6-4)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+11-9)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-3)
  • (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (+3-2)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
   }
 
   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
-                                          BoxProcType type,
+                                          mlir::Location loc, BoxProcType type,
                                           mlir::ValueRange inputs,
-                                          mlir::Location loc) {
+                                          mlir::Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
                                      inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
 
   /// All of the following materializations require function objects that are
   /// convertible to the following form:
-  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  ///   `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// The type that is provided as the 5-th argument is the original type of
+  /// value. For more details, see the documentation below.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: The original type matches the result type `T` for argument
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
   /// converting a legal replacement value back to an illegal source type.
   /// This is used when some uses of the original, illegal value must persist
   /// beyond the main conversion.
+  ///
+  /// Note: The original type matches the result type `T` for source
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, the original type can be
+  /// different from the type of the input. For example, let's assume that a
+  /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+  /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+  /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+  /// that the legalized type of "t1" is "t3", which may be different from
+  /// "t2". In this example, the target materialization callback will be
+  /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+  /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+  /// that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
+                                      Type resultType, ValueRange inputs,
+                                      Type originalType) const {
     return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
 
   /// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
       Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
+  ///
+  /// Arguments: builder, location, result type, inputs, original type
   using MaterializationCallbackFn = std::function<std::optional<Value>(
-      OpBuilder &, Type, ValueRange, Location)>;
+      OpBuilder &, Location, Type, ValueRange, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
   Value
   materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
                         OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
+                        ValueRange inputs, Type originalType) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc) -> std::optional<Value> {
+               OpBuilder &builder, Location loc, Type resultType,
+               ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc);
+        return callback(builder, loc, derivedType, inputs, originalType);
       return std::nullopt;
     };
   }
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// materializations for 1:N type conversions, which materialize one value in
   /// a source type as N values in target types.
   using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+                                                      TypeRange, Value, Type)>;
 
   /// Creates the mapping of the given range of original types to target types
   /// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// returns `std::nullopt`.
   std::optional<SmallVector<Value>>
   materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
+                              TypeRange resultTypes, Value input,
+                              Type originalType) const;
 
   /// Adds a 1:N target materialization to the converter. Such materializations
   /// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
+    auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) {
       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
       return std::optional<Value>(cast.getResult(0));
     };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
     Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
     Value convertedValue = converter.materializeSourceConversion(
-        builder, op->getLoc(), ty, newArg);
+        builder, op->getLoc(), ty, newArg, ty);
     if (!convertedValue) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // insert a target materialization from the original block argument type to
   // a legal type.
   addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
+      [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+          ValueRange inputs, Type originalType) -> std::optional<Value> {
         if (inputs.size() == 1) {
           // Bare pointers are not supported for unranked memrefs because a
           // memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
             .getResult(0);
       });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
+  addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+                                 MemRefType resultType, ValueRange inputs,
+                                 Type originalType) -> std::optional<Value> {
     Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addSourceMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addTargetMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
             dimSize = typeConverter->materializeTargetConversion(
-                rewriter, loc, indexType, dimSize);
+                rewriter, loc, indexType, dimSize, dimSize.getType());
           assert(dimSize && "Invalid memref element type");
         }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
     // All other types legal
     return type;
   });
-  converter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+                                        ValueRange input, Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
-                                 ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+                                 TensorType type, ValueRange inputs,
+                                 Type originalType) {
   assert(inputs.size() == 1);
   assert(isa<BaseMemRefType>(inputs[0].getType()));
   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
   addArgumentMaterialization(materializeToTensor);
   addSourceMaterialization(materializeToTensor);
-  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
-                              ValueRange inputs, Location loc) -> Value {
+  addTargetMaterialization([](OpBuilder &builder, Location loc,
+                              BaseMemRefType type, ValueRange inputs,
+                              Type originalType) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
 namespace {
 
 std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
-                                                 Type resultType,
+                                                 Location loc, Type resultType,
                                                  ValueRange inputs,
-                                                 Location loc) {
+                                                 Type originalType) {
   if (inputs.size() != 1)
     return std::nullopt;
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          rewriter, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input, type);
       if (!newInput) {
         return emitDefiniteFailure() << "Failed to materialize conversion of "
                                      << input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          rewriter, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput,
+          output.getType());
       if (!convertedOutput) {
         return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2024

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

Changes

This commit adds an originalType parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: originalType is only needed for target materializations. For source/argument materializations, originalType always matches outputType. However, to keep the code base simple (i.e., reuse MaterializationCallbackFn for all three materializations), originalType is passed to all three materializations, even though it is only really needed for target materializations.

originalType is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the Location parameter right after the OpBuilder parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new originalType parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an !llvm.struct that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = !llvm.struct&lt;...&gt;) from just the bare pointer (inputs = !llvm.ptr). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the Location parameter to the second position and add a Type originalType parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match inputs.front().getType(). If the original type cannot be recovered (which is unlikely), pass Type().


Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff

31 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+2-2)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+40-16)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+4-3)
  • (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2)
  • (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (+1-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-11)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+10-10)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+6-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+22-20)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp (+18-16)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+33-19)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+9-5)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+10-7)
  • (modified) mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp (+3-2)
  • (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+6-4)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+11-9)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-3)
  • (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (+3-2)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
   }
 
   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
-                                          BoxProcType type,
+                                          mlir::Location loc, BoxProcType type,
                                           mlir::ValueRange inputs,
-                                          mlir::Location loc) {
+                                          mlir::Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
                                      inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
 
   /// All of the following materializations require function objects that are
   /// convertible to the following form:
-  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  ///   `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// The type that is provided as the 5-th argument is the original type of
+  /// value. For more details, see the documentation below.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: The original type matches the result type `T` for argument
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
   /// converting a legal replacement value back to an illegal source type.
   /// This is used when some uses of the original, illegal value must persist
   /// beyond the main conversion.
+  ///
+  /// Note: The original type matches the result type `T` for source
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, the original type can be
+  /// different from the type of the input. For example, let's assume that a
+  /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+  /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+  /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+  /// that the legalized type of "t1" is "t3", which may be different from
+  /// "t2". In this example, the target materialization callback will be
+  /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+  /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+  /// that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
+                                      Type resultType, ValueRange inputs,
+                                      Type originalType) const {
     return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
 
   /// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
       Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
+  ///
+  /// Arguments: builder, location, result type, inputs, original type
   using MaterializationCallbackFn = std::function<std::optional<Value>(
-      OpBuilder &, Type, ValueRange, Location)>;
+      OpBuilder &, Location, Type, ValueRange, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
   Value
   materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
                         OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
+                        ValueRange inputs, Type originalType) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc) -> std::optional<Value> {
+               OpBuilder &builder, Location loc, Type resultType,
+               ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc);
+        return callback(builder, loc, derivedType, inputs, originalType);
       return std::nullopt;
     };
   }
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// materializations for 1:N type conversions, which materialize one value in
   /// a source type as N values in target types.
   using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+                                                      TypeRange, Value, Type)>;
 
   /// Creates the mapping of the given range of original types to target types
   /// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// returns `std::nullopt`.
   std::optional<SmallVector<Value>>
   materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
+                              TypeRange resultTypes, Value input,
+                              Type originalType) const;
 
   /// Adds a 1:N target materialization to the converter. Such materializations
   /// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
+    auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) {
       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
       return std::optional<Value>(cast.getResult(0));
     };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
     Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
     Value convertedValue = converter.materializeSourceConversion(
-        builder, op->getLoc(), ty, newArg);
+        builder, op->getLoc(), ty, newArg, ty);
     if (!convertedValue) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // insert a target materialization from the original block argument type to
   // a legal type.
   addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
+      [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+          ValueRange inputs, Type originalType) -> std::optional<Value> {
         if (inputs.size() == 1) {
           // Bare pointers are not supported for unranked memrefs because a
           // memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
             .getResult(0);
       });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
+  addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+                                 MemRefType resultType, ValueRange inputs,
+                                 Type originalType) -> std::optional<Value> {
     Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addSourceMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addTargetMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
             dimSize = typeConverter->materializeTargetConversion(
-                rewriter, loc, indexType, dimSize);
+                rewriter, loc, indexType, dimSize, dimSize.getType());
           assert(dimSize && "Invalid memref element type");
         }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
     // All other types legal
     return type;
   });
-  converter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+                                        ValueRange input, Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
-                                 ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+                                 TensorType type, ValueRange inputs,
+                                 Type originalType) {
   assert(inputs.size() == 1);
   assert(isa<BaseMemRefType>(inputs[0].getType()));
   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
   addArgumentMaterialization(materializeToTensor);
   addSourceMaterialization(materializeToTensor);
-  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
-                              ValueRange inputs, Location loc) -> Value {
+  addTargetMaterialization([](OpBuilder &builder, Location loc,
+                              BaseMemRefType type, ValueRange inputs,
+                              Type originalType) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
 namespace {
 
 std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
-                                                 Type resultType,
+                                                 Location loc, Type resultType,
                                                  ValueRange inputs,
-                                                 Location loc) {
+                                                 Type originalType) {
   if (inputs.size() != 1)
     return std::nullopt;
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          rewriter, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input, type);
       if (!newInput) {
         return emitDefiniteFailure() << "Failed to materialize conversion of "
                                      << input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          rewriter, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput,
+          output.getType());
       if (!convertedOutput) {
         return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2024

@llvm/pr-subscribers-mlir-scf

Author: Matthias Springer (matthias-springer)

Changes

This commit adds an originalType parameter to all materialization functions. Without this parameter, target materializations are underspecified.

Note: originalType is only needed for target materializations. For source/argument materializations, originalType always matches outputType. However, to keep the code base simple (i.e., reuse MaterializationCallbackFn for all three materializations), originalType is passed to all three materializations, even though it is only really needed for target materializations.

originalType is the original type of an SSA value. For argument materializations, it matches the original argument type (which is also the output type). For source materializations, it also matches the output type.

For target materializations, consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the originalType parameter exists.

This commit also puts the Location parameter right after the OpBuilder parameter to be consistent with MLIR conventions.

This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new originalType parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an !llvm.struct that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = !llvm.struct&lt;...&gt;) from just the bare pointer (inputs = !llvm.ptr). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.

Note for LLVM integration: For all argument/source/target materialization functions, move the Location parameter to the second position and add a Type originalType parameter to the lambda. No changes are needed to the body of the lambda. When an argument/source materialization is called in your code base, pass the output type as original type. When a target materialization is called, try to pass the original type of the SSA value, which may match inputs.front().getType(). If the original type cannot be recovered (which is unlikely), pass Type().


Patch is 56.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112128.diff

31 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+2-2)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+40-16)
  • (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+4-3)
  • (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2)
  • (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (+1-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-11)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExtendToSupportedTypes.cpp (+7-6)
  • (modified) mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp (+10-10)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+3-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+6-6)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (+22-20)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp (+18-16)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+33-19)
  • (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+9-5)
  • (modified) mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp (+10-7)
  • (modified) mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp (+3-2)
  • (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+6-4)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+11-9)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+3-3)
  • (modified) mlir/test/lib/Transforms/TestDialectConversion.cpp (+3-2)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index c536fd19fcc69a..f1b057eedb2340 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -173,9 +173,9 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
   }
 
   static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
-                                          BoxProcType type,
+                                          mlir::Location loc, BoxProcType type,
                                           mlir::ValueRange inputs,
-                                          mlir::Location loc) {
+                                          mlir::Type originalType) {
     assert(inputs.size() == 1);
     return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
                                      inputs[0]);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..f22599c4d4aabf 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -170,7 +170,7 @@ class TypeConverter {
 
   /// All of the following materializations require function objects that are
   /// convertible to the following form:
-  ///   `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+  ///   `std::optional<Value>(OpBuilder &, Location, T, ValueRange, Type)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
   /// creating an operation, using the OpBuilder and Location provided, that
   /// "casts" a range of values into a single value of the given type `T`. It
@@ -178,13 +178,19 @@ class TypeConverter {
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// The type that is provided as the 5-th argument is the original type of
+  /// value. For more details, see the documentation below.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: The original type matches the result type `T` for argument
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -194,8 +200,11 @@ class TypeConverter {
   /// converting a legal replacement value back to an illegal source type.
   /// This is used when some uses of the original, illegal value must persist
   /// beyond the main conversion.
+  ///
+  /// Note: The original type matches the result type `T` for source
+  /// materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -203,8 +212,19 @@ class TypeConverter {
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, the original type can be
+  /// different from the type of the input. For example, let's assume that a
+  /// conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2"
+  /// (type "t2"). Then a different conversion pattern "P2" matches an op that
+  /// has "v1" as an operand. Let's furthermore assume that "P2" determines
+  /// that the legalized type of "t1" is "t3", which may be different from
+  /// "t2". In this example, the target materialization callback will be
+  /// invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note
+  /// that the original type "t1" cannot be recovered from just "t3" and "v2";
+  /// that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
-                              std::decay_t<FnT>>::template arg_t<1>>
+                              std::decay_t<FnT>>::template arg_t<2>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -303,20 +323,22 @@ class TypeConverter {
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
+                                      Type resultType, ValueRange inputs,
+                                      Type originalType) const {
     return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType) const {
     return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
+                                 resultType, inputs, originalType);
   }
 
   /// Convert an attribute present `attr` from within the type `type` using
@@ -334,8 +356,10 @@ class TypeConverter {
       Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
+  ///
+  /// Arguments: builder, location, result type, inputs, original type
   using MaterializationCallbackFn = std::function<std::optional<Value>(
-      OpBuilder &, Type, ValueRange, Location)>;
+      OpBuilder &, Location, Type, ValueRange, Type)>;
 
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
@@ -346,7 +370,7 @@ class TypeConverter {
   Value
   materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
                         OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
+                        ValueRange inputs, Type originalType) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
@@ -394,10 +418,10 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               OpBuilder &builder, Type resultType, ValueRange inputs,
-               Location loc) -> std::optional<Value> {
+               OpBuilder &builder, Location loc, Type resultType,
+               ValueRange inputs, Type originalType) -> std::optional<Value> {
       if (T derivedType = dyn_cast<T>(resultType))
-        return callback(builder, derivedType, inputs, loc);
+        return callback(builder, loc, derivedType, inputs, originalType);
       return std::nullopt;
     };
   }
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..90f796fce576a9 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -44,8 +44,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// materializations for 1:N type conversions, which materialize one value in
   /// a source type as N values in target types.
   using OneToNMaterializationCallbackFn =
-      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
-                                                      Value, Location)>;
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, Location,
+                                                      TypeRange, Value, Type)>;
 
   /// Creates the mapping of the given range of original types to target types
   /// of the conversion and stores that mapping in the given (signature)
@@ -63,7 +63,8 @@ class OneToNTypeConverter : public TypeConverter {
   /// returns `std::nullopt`.
   std::optional<SmallVector<Value>>
   materializeTargetConversion(OpBuilder &builder, Location loc,
-                              TypeRange resultTypes, Value input) const;
+                              TypeRange resultTypes, Value input,
+                              Type originalType) const;
 
   /// Adds a 1:N target materialization to the converter. Such materializations
   /// build IR that converts N values with target types into 1 value of the
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 77603739137614..5b067647fff726 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -281,8 +281,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
     // in patterns for other dialects.
-    auto addUnrealizedCast = [](OpBuilder &builder, Type type,
-                                ValueRange inputs, Location loc) {
+    auto addUnrealizedCast = [](OpBuilder &builder, Location loc, Type type,
+                                ValueRange inputs, Type originalType) {
       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
       return std::optional<Value>(cast.getResult(0));
     };
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
index d8f3e995109538..b0fc27e59f7501 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp
@@ -46,7 +46,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
     Location loc = arg.getLoc();
     Value newArg = block.insertArgument(argNum, newTy, loc);
     Value convertedValue = converter.materializeSourceConversion(
-        builder, op->getLoc(), ty, newArg);
+        builder, op->getLoc(), ty, newArg, ty);
     if (!convertedValue) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("failed to cast new argument {0} to type {1})",
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 5a92fa839e9847..66a0ce74f5841c 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,8 +159,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // insert a target materialization from the original block argument type to
   // a legal type.
   addArgumentMaterialization(
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc) -> std::optional<Value> {
+      [&](OpBuilder &builder, Location loc, UnrankedMemRefType resultType,
+          ValueRange inputs, Type originalType) -> std::optional<Value> {
         if (inputs.size() == 1) {
           // Bare pointers are not supported for unranked memrefs because a
           // memref descriptor cannot be built just from a bare pointer.
@@ -174,9 +174,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
             .getResult(0);
       });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs,
-                                 Location loc) -> std::optional<Value> {
+  addArgumentMaterialization([&](OpBuilder &builder, Location loc,
+                                 MemRefType resultType, ValueRange inputs,
+                                 Type originalType) -> std::optional<Value> {
     Value desc;
     if (inputs.size() == 1) {
       // This is a bare pointer. We allow bare pointers only for function entry
@@ -201,18 +201,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addSourceMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc) -> std::optional<Value> {
+  addTargetMaterialization([&](OpBuilder &builder, Location loc,
+                               Type resultType, ValueRange inputs,
+                               Type originalType) -> std::optional<Value> {
     if (inputs.size() != 1)
       return std::nullopt;
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..46acfdab96e648 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1185,7 +1185,7 @@ struct MemRefReshapeOpLowering
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
             dimSize = typeConverter->materializeTargetConversion(
-                rewriter, loc, indexType, dimSize);
+                rewriter, loc, indexType, dimSize, dimSize.getType());
           assert(dimSize && "Invalid memref element type");
         }
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17b..d57960169de217 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -97,12 +97,12 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
     // All other types legal
     return type;
   });
-  converter.addTargetMaterialization(
-      [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
-      });
+  converter.addTargetMaterialization([](OpBuilder &b, Location loc, Type target,
+                                        ValueRange input, Type originalType) {
+    auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
+    extFOp.setFastmath(arith::FastMathFlags::contract);
+    return extFOp;
+  });
 }
 
 void mlir::arith::populateEmulateUnsupportedFloatsPatterns(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 875d8c40e92cc1..3378fe3ee6680d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -42,8 +42,9 @@ using namespace mlir::bufferization;
 // BufferizeTypeConverter
 //===----------------------------------------------------------------------===//
 
-static Value materializeToTensor(OpBuilder &builder, TensorType type,
-                                 ValueRange inputs, Location loc) {
+static Value materializeToTensor(OpBuilder &builder, Location loc,
+                                 TensorType type, ValueRange inputs,
+                                 Type originalType) {
   assert(inputs.size() == 1);
   assert(isa<BaseMemRefType>(inputs[0].getType()));
   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@@ -63,8 +64,9 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
   });
   addArgumentMaterialization(materializeToTensor);
   addSourceMaterialization(materializeToTensor);
-  addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
-                              ValueRange inputs, Location loc) -> Value {
+  addTargetMaterialization([](OpBuilder &builder, Location loc,
+                              BaseMemRefType type, ValueRange inputs,
+                              Type originalType) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 83de9b37974f67..1315805caa675f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -17,9 +17,9 @@ using namespace mlir;
 namespace {
 
 std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
-                                                 Type resultType,
+                                                 Location loc, Type resultType,
                                                  ValueRange inputs,
-                                                 Location loc) {
+                                                 Type originalType) {
   if (inputs.size() != 1)
     return std::nullopt;
 
diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
index 2728936bf33fd3..3b472293ef88b6 100644
--- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
+++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp
@@ -161,7 +161,7 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
     if (input.getType() != type) {
       Value newInput = converter.materializeSourceConversion(
-          rewriter, input.getLoc(), type, input);
+          rewriter, input.getLoc(), type, input, type);
       if (!newInput) {
         return emitDefiniteFailure() << "Failed to materialize conversion of "
                                      << input << " to type " << type;
@@ -180,7 +180,8 @@ transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
     Value convertedOutput = newOutput;
     if (output.getType() != newOutput.getType()) {
       convertedOutput = converter.materializeTargetConversion(
-          rewriter, output.getLoc(), output.getType(), newOutput);
+          rewriter, output.getLoc(), output.getType(), newOutput,
+          output.getType());
       if (!convertedOutput) {
         return emitDefiniteFailure()
                << "Failed to materialize conversion of " << newOutput
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..557ef265c5b30c 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeC...
[truncated]

@matthias-springer matthias-springer changed the title [mlir][Transforms] Dialect conversion: add originalType param to materialization [mlir][Transforms] Dialect conversion: add originalType param to materializations Oct 13, 2024
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see uses of materializeTargetConversion within normal conversion patterns where its non-obvious to me as to how that method should now be used.

With your example of "v2" replacing "v1", the adaptor of the conversion pattern would then receive "v2", is that right?
Isn't t1 then also inaccessible to the conversion pattern? Or can the case you describe not happen during conversion pattern application OR should materializeTargetConversion not actually ever be called by conversion patterns?

mlir/include/mlir/Transforms/DialectConversion.h Outdated Show resolved Hide resolved
mlir/include/mlir/Transforms/DialectConversion.h Outdated Show resolved Hide resolved
@matthias-springer
Copy link
Member Author

matthias-springer commented Oct 13, 2024

I see uses of materializeTargetConversion within normal conversion patterns where its non-obvious to me as to how that method should now be used.

Ideally, patterns should not call this function at all. It's not really an API violation. I'm just saying that it should not be necessary. The uses that I've seen in upstream MLIR are workarounds around 1:N limitations and they are going to disappear soon. Worst case, users can always pass originalType = Type() and ignore this new parameter. The dialect conversion itself does not require this parameter, only the target materializations may need it; but these are also written by the user. Basically, users can choose to use this additional information or ignore it.

With your example of "v2" replacing "v1", the adaptor of the conversion pattern would then receive "v2", is that right?

Almost. The adaptor receives "v2" only if the type of "v2" is valid according to the pattern's type converter. Otherwise, we will insert another target materialization from "v2" -> "t3" and that value is then passed to the adaptor. It may not be possible to compute this target materialization without having access to the original type. That's what's happening during the MemRef->LLVM lowering that I mentioned in the commit message.

The logic for this is in ConversionPatternRewriterImpl::remapValues. The implementation has access to the original types; the same is the case for the patterns themselves: they can always get the original types from op instead of adaptor. However, they may not be able to get original types for ops other than the matched op. (Because they may have to traverse "old" IR which is problematic in a dialect conversion.)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/type_conv_original_type branch from 2c438b0 to d245912 Compare October 13, 2024 12:04
@matthias-springer matthias-springer force-pushed the users/matthias-springer/type_conv_original_type branch from d245912 to 80bc49f Compare October 13, 2024 12:07
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thank you!

Ideally, patterns should not call this function at all. It's not really an API violation. I'm just saying that it should not be necessary. The uses that I've seen in upstream MLIR are workarounds around 1:N limitations and they are going to disappear soon.

Once your 1:N changes are merged, I think I'd be absolutely amazing to get rid of the API entirely then. I know at least from my own code and other downstream code that it is used rather defensively. But if it is really not needed then it'd help to conceptually simplify dialect conversion by not making these functions public.

Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Interesting corner case -- glad you caught it. This makes sense to fix even without the merging. But I am looking forward to the merging anyways :)

@matthias-springer matthias-springer merged commit 0d906a4 into main Oct 15, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/type_conv_original_type branch October 15, 2024 06:52
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
…terializations (llvm#112128)

This commit adds an optional `originalType` parameter to target
materialization functions. Without this parameter, target
materializations are underspecified.

Note: `originalType` is only needed for target materializations.
Source/argument materializations do not have it.

Consider the following example: Let's assume that a conversion pattern
"P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then
a different conversion pattern "P2" matches an op that has "v1" as an
operand. Let's furthermore assume that "P2" determines that the
legalized type of "t1" is "t3", which may be different from "t2". In
this example, the target materialization callback will be invoked with:
outputType = "t3", inputs = "v2", originalType = "t1". Note that the
original type "t1" cannot be recovered from just "t3" and "v2"; that's
why the `originalType` parameter is added.

This change is in preparation of merging the 1:1 and 1:N dialect
conversion drivers. As part of that change, argument materializations
will be removed (as they are no longer needed; they were just a
workaround because of missing 1:N support in the dialect conversion).
The new `originalType` parameter is needed when lowering MemRef to LLVM.
During that lowering, MemRef function block arguments are replaced with
the elements that make up a MemRef descriptor. The type converter is set
up in such a way that the legalized type of a MemRef type is an
`!llvm.struct` that represents the MemRef descriptor. When the bare
pointer calling convention is enabled, the function block arguments
consist of just an LLVM pointer. In such a case, a target
materialization will be invoked to construct a MemRef descriptor (output
type = `!llvm.struct<...>`) from just the bare pointer (inputs =
`!llvm.ptr`). The original MemRef type is required to construct the
MemRef descriptor, as static sizes/strides/offset cannot be inferred
from just the bare pointer.
bricknerb pushed a commit to bricknerb/llvm-project that referenced this pull request Oct 17, 2024
…terializations (llvm#112128)

This commit adds an optional `originalType` parameter to target
materialization functions. Without this parameter, target
materializations are underspecified.

Note: `originalType` is only needed for target materializations.
Source/argument materializations do not have it.

Consider the following example: Let's assume that a conversion pattern
"P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then
a different conversion pattern "P2" matches an op that has "v1" as an
operand. Let's furthermore assume that "P2" determines that the
legalized type of "t1" is "t3", which may be different from "t2". In
this example, the target materialization callback will be invoked with:
outputType = "t3", inputs = "v2", originalType = "t1". Note that the
original type "t1" cannot be recovered from just "t3" and "v2"; that's
why the `originalType` parameter is added.

This change is in preparation of merging the 1:1 and 1:N dialect
conversion drivers. As part of that change, argument materializations
will be removed (as they are no longer needed; they were just a
workaround because of missing 1:N support in the dialect conversion).
The new `originalType` parameter is needed when lowering MemRef to LLVM.
During that lowering, MemRef function block arguments are replaced with
the elements that make up a MemRef descriptor. The type converter is set
up in such a way that the legalized type of a MemRef type is an
`!llvm.struct` that represents the MemRef descriptor. When the bare
pointer calling convention is enabled, the function block arguments
consist of just an LLVM pointer. In such a case, a target
materialization will be invoked to construct a MemRef descriptor (output
type = `!llvm.struct<...>`) from just the bare pointer (inputs =
`!llvm.ptr`). The original MemRef type is required to construct the
MemRef descriptor, as static sizes/strides/offset cannot be inferred
from just the bare pointer.
EricWF pushed a commit to efcs/llvm-project that referenced this pull request Oct 22, 2024
…terializations (llvm#112128)

This commit adds an optional `originalType` parameter to target
materialization functions. Without this parameter, target
materializations are underspecified.

Note: `originalType` is only needed for target materializations.
Source/argument materializations do not have it.

Consider the following example: Let's assume that a conversion pattern
"P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then
a different conversion pattern "P2" matches an op that has "v1" as an
operand. Let's furthermore assume that "P2" determines that the
legalized type of "t1" is "t3", which may be different from "t2". In
this example, the target materialization callback will be invoked with:
outputType = "t3", inputs = "v2", originalType = "t1". Note that the
original type "t1" cannot be recovered from just "t3" and "v2"; that's
why the `originalType` parameter is added.

This change is in preparation of merging the 1:1 and 1:N dialect
conversion drivers. As part of that change, argument materializations
will be removed (as they are no longer needed; they were just a
workaround because of missing 1:N support in the dialect conversion).
The new `originalType` parameter is needed when lowering MemRef to LLVM.
During that lowering, MemRef function block arguments are replaced with
the elements that make up a MemRef descriptor. The type converter is set
up in such a way that the legalized type of a MemRef type is an
`!llvm.struct` that represents the MemRef descriptor. When the bare
pointer calling convention is enabled, the function block arguments
consist of just an LLVM pointer. In such a case, a target
materialization will be invoked to construct a MemRef descriptor (output
type = `!llvm.struct<...>`) from just the bare pointer (inputs =
`!llvm.ptr`). The original MemRef type is required to construct the
MemRef descriptor, as static sizes/strides/offset cannot be inferred
from just the bare pointer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants