Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Bufferization] castOrReallocMemRefValue: Use BufferizationOptions #89175

Merged
merged 2 commits into from
Apr 18, 2024

Conversation

mgehre-amd
Copy link
Contributor

This allows to configure both the op used for allocation and copy of memrefs.
It also changes the default behavior because the default allocation in BufferizationOptions creates memref.alloc with alignment = 64 where we used to create memref.alloca without any alignment before.
Fixes

// TODO: Use alloc/memcpy callback from BufferizationOptions if called via
// BufferizableOpInterface impl of ToMemrefOp.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Apr 18, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Gehre (mgehre-amd)

Changes

This allows to configure both the op used for allocation and copy of memrefs.
It also changes the default behavior because the default allocation in BufferizationOptions creates memref.alloc with alignment = 64 where we used to create memref.alloca without any alignment before.
Fixes

// TODO: Use alloc/memcpy callback from BufferizationOptions if called via
// BufferizableOpInterface impl of ToMemrefOp.

Full diff: https://github.com/llvm/llvm-project/pull/89175.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+4-2)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+18-13)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index e98b5728b38ef8..6f19dca2e82224 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -53,12 +53,14 @@ void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue,
 /// This function returns `failure()` in case of unsupported casts. E.g., casts
 /// with differing element types or memory spaces.
 FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
-                                          MemRefType type);
+                                          MemRefType type,
+                                          const BufferizationOptions &options);
 
 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
 /// to_memref op are different, a memref.cast is needed.
 LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
-                                       ToMemrefOp toMemref);
+                                       ToMemrefOp toMemref,
+                                       const BufferizationOptions &options);
 
 /// Add the canonicalization patterns for bufferization.dealloc to the given
 /// pattern set to make them available to other passes (such as
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 2b226c7a1207cf..9f1295222c3525 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,9 +23,9 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
-FailureOr<Value>
-mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
-                                              MemRefType destType) {
+FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
+    OpBuilder &b, Value value, MemRefType destType,
+    const BufferizationOptions &options) {
   auto srcType = llvm::cast<MemRefType>(value.getType());
 
   // Element type, rank and memory space must match.
@@ -73,18 +73,23 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
     Value size = b.create<memref::DimOp>(loc, value, i);
     dynamicOperands.push_back(size);
   }
-  // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
-  // BufferizableOpInterface impl of ToMemrefOp.
-  Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
-  b.create<memref::CopyOp>(loc, value, copy);
+
+  FailureOr<Value> copy =
+      options.createAlloc(b, loc, destType, dynamicOperands);
+  if (failed(copy)) {
+    return failure();
+  }
+  if (failed(options.createMemCpy(b, loc, value, *copy))) {
+    return failure();
+  }
   return copy;
 }
 
 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
 /// to_memref op are different, a memref.cast is needed.
-LogicalResult
-mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
-                                              ToMemrefOp toMemref) {
+LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
+    RewriterBase &rewriter, ToMemrefOp toMemref,
+    const BufferizationOptions &options) {
   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
   if (!memrefToTensor)
     return failure();
@@ -105,7 +110,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
   // Ranked memref -> Ranked memref cast.
   if (rankedSrcType && rankedDestType) {
     FailureOr<Value> replacement = castOrReallocMemRefValue(
-        rewriter, memrefToTensor.getMemref(), rankedDestType);
+        rewriter, memrefToTensor.getMemref(), rankedDestType, options);
     if (failed(replacement))
       return failure();
 
@@ -792,7 +797,7 @@ struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
 
   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
                                 PatternRewriter &rewriter) const final {
-    return foldToMemrefToTensorPair(rewriter, toMemref);
+    return foldToMemrefToTensorPair(rewriter, toMemref, {});
   }
 };
 
@@ -840,7 +845,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
                                     const BufferizationOptions &options) {
   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
-  (void)foldToMemrefToTensorPair(rewriter, *this);
+  (void)foldToMemrefToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
   // or not. (And not whether the pattern matched or not.)
   return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 32f4e6a0fe8901..786a071dccfe67 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -75,7 +75,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
       if (!rankedDestType)
         return nullptr;
       FailureOr<Value> replacement =
-          castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
+          castOrReallocMemRefValue(builder, inputs[0], rankedDestType, {});
       if (failed(replacement))
         return nullptr;
       return *replacement;
@@ -512,8 +512,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   // Fold all to_memref(to_tensor(x)) pairs.
   for (Operation *op : toMemrefOps) {
     rewriter.setInsertionPoint(op);
-    (void)bufferization::foldToMemrefToTensorPair(rewriter,
-                                                  cast<ToMemrefOp>(op));
+    (void)bufferization::foldToMemrefToTensorPair(
+        rewriter, cast<ToMemrefOp>(op), options);
   }
 
   // Remove all dead to_tensor ops.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d928..7f1e009c303a68 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -33,7 +33,7 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, strided<[1], offset: ?>>)
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
@@ -48,7 +48,7 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, strided<[100], offset: ?>>)
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
@@ -63,7 +63,7 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //  CHECK-SAME:     %[[arg:.*]]: memref<?xf32, strided<[1], offset: 25>>)
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
index de75b288855f94..9cf44c335d551e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
@@ -84,7 +84,7 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
 // Note: This alloc is not needed, but it is inserted before the returned buffer
 // is promoted to an out param to reconcile mismatching layout maps on return
 // value and function signature.
-//       CHECK-NO-LAYOUT:   %[[alloc2:.*]] = memref.alloc() : memref<2x5xf32>
+//       CHECK-NO-LAYOUT:   %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<2x5xf32>
 //       CHECK-NO-LAYOUT:   memref.copy %[[subview]], %[[alloc2]]
 //       CHECK-NO-LAYOUT:   memref.copy %[[alloc2]], %[[r]]
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 429c9e4dea9e93..0248afb11f1672 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -52,7 +52,7 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
 // CHECK-NO-LAYOUT-MAP-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32>
 //       CHECK-NO-LAYOUT-MAP:   %[[alloc:.*]] = memref.alloc() {{.*}} : memref<20x10xf32>
 //       CHECK-NO-LAYOUT-MAP:   %[[subview:.*]] = memref.subview {{.*}} : memref<20x10xf32> to memref<2x?xf32, strided<[10, 1], offset: ?>>
-//       CHECK-NO-LAYOUT-MAP:   %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) : memref<2x?xf32>
+//       CHECK-NO-LAYOUT-MAP:   %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<2x?xf32>
 //       CHECK-NO-LAYOUT-MAP:   memref.copy %[[subview]], %[[alloc_no_layout]]
 // TODO: %alloc should be deallocated here, but we currently do not dealloc
 // buffers that are inserted due to to_tensor/to_memref canonicalization (when
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0e..113aad67985d70 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -84,7 +84,7 @@ func.func @canonicalize_buffer_cast_of_tensor_load_to_copy(
 //  CHECK-NOT: bufferization.to_memref
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
 //      CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, strided<[1], offset: ?>>
-//      CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, strided<[1], offset: 3>>
+//      CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref<?xf32, strided<[1], offset: 3>>
 //      CHECK: memref.copy %[[M]], %[[ALLOC]]
 // CHECK-SAME:   memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: 3>>
 //      CHECK: return %[[ALLOC]]

@mgehre-amd mgehre-amd force-pushed the mgehre.bufferize_realloc branch from 5158b17 to 90a464a Compare April 18, 2024 07:39
@mgehre-amd mgehre-amd force-pushed the mgehre.bufferize_realloc branch from 90a464a to f9969fa Compare April 18, 2024 10:05
@mgehre-amd mgehre-amd merged commit c515c78 into llvm:main Apr 18, 2024
3 of 4 checks passed
@mgehre-amd mgehre-amd deleted the mgehre.bufferize_realloc branch April 18, 2024 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants