-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Bufferization] castOrReallocMemRefValue: Use BufferizationOptions #89175
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Matthias Gehre (mgehre-amd) ChangesThis allows to configure both the op used for allocation and copy of memrefs.
Full diff: https://github.com/llvm/llvm-project/pull/89175.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index e98b5728b38ef8..6f19dca2e82224 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -53,12 +53,14 @@ void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue,
/// This function returns `failure()` in case of unsupported casts. E.g., casts
/// with differing element types or memory spaces.
FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
- MemRefType type);
+ MemRefType type,
+ const BufferizationOptions &options);
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are different, a memref.cast is needed.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
- ToMemrefOp toMemref);
+ ToMemrefOp toMemref,
+ const BufferizationOptions &options);
/// Add the canonicalization patterns for bufferization.dealloc to the given
/// pattern set to make them available to other passes (such as
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 2b226c7a1207cf..9f1295222c3525 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,9 +23,9 @@ using namespace mlir::bufferization;
// Helper functions
//===----------------------------------------------------------------------===//
-FailureOr<Value>
-mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
- MemRefType destType) {
+FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
+ OpBuilder &b, Value value, MemRefType destType,
+ const BufferizationOptions &options) {
auto srcType = llvm::cast<MemRefType>(value.getType());
// Element type, rank and memory space must match.
@@ -73,18 +73,23 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
Value size = b.create<memref::DimOp>(loc, value, i);
dynamicOperands.push_back(size);
}
- // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
- // BufferizableOpInterface impl of ToMemrefOp.
- Value copy = b.create<memref::AllocOp>(loc, destType, dynamicOperands);
- b.create<memref::CopyOp>(loc, value, copy);
+
+ FailureOr<Value> copy =
+ options.createAlloc(b, loc, destType, dynamicOperands);
+ if (failed(copy)) {
+ return failure();
+ }
+ if (failed(options.createMemCpy(b, loc, value, *copy))) {
+ return failure();
+ }
return copy;
}
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are different, a memref.cast is needed.
-LogicalResult
-mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
- ToMemrefOp toMemref) {
+LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
+ RewriterBase &rewriter, ToMemrefOp toMemref,
+ const BufferizationOptions &options) {
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
if (!memrefToTensor)
return failure();
@@ -105,7 +110,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
// Ranked memref -> Ranked memref cast.
if (rankedSrcType && rankedDestType) {
FailureOr<Value> replacement = castOrReallocMemRefValue(
- rewriter, memrefToTensor.getMemref(), rankedDestType);
+ rewriter, memrefToTensor.getMemref(), rankedDestType, options);
if (failed(replacement))
return failure();
@@ -792,7 +797,7 @@ struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
PatternRewriter &rewriter) const final {
- return foldToMemrefToTensorPair(rewriter, toMemref);
+ return foldToMemrefToTensorPair(rewriter, toMemref, {});
}
};
@@ -840,7 +845,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options) {
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
- (void)foldToMemrefToTensorPair(rewriter, *this);
+ (void)foldToMemrefToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
// or not. (And not whether the pattern matched or not.)
return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 32f4e6a0fe8901..786a071dccfe67 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -75,7 +75,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
if (!rankedDestType)
return nullptr;
FailureOr<Value> replacement =
- castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
+ castOrReallocMemRefValue(builder, inputs[0], rankedDestType, {});
if (failed(replacement))
return nullptr;
return *replacement;
@@ -512,8 +512,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
// Fold all to_memref(to_tensor(x)) pairs.
for (Operation *op : toMemrefOps) {
rewriter.setInsertionPoint(op);
- (void)bufferization::foldToMemrefToTensorPair(rewriter,
- cast<ToMemrefOp>(op));
+ (void)bufferization::foldToMemrefToTensorPair(
+ rewriter, cast<ToMemrefOp>(op), options);
}
// Remove all dead to_tensor ops.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d928..7f1e009c303a68 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -33,7 +33,7 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
// CHECK-SAME: %[[arg:.*]]: memref<?xf32, strided<[1], offset: ?>>)
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
@@ -48,7 +48,7 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
// CHECK-SAME: %[[arg:.*]]: memref<?xf32, strided<[100], offset: ?>>)
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
@@ -63,7 +63,7 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
// CHECK-SAME: %[[arg:.*]]: memref<?xf32, strided<[1], offset: 25>>)
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]]
-// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
+// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) {{.*}} : memref<?xf32>
// CHECK: memref.copy %[[arg]], %[[alloc]]
// CHECK: return %[[alloc]]
func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
index de75b288855f94..9cf44c335d551e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
@@ -84,7 +84,7 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) {
// Note: This alloc is not needed, but it is inserted before the returned buffer
// is promoted to an out param to reconcile mismatching layout maps on return
// value and function signature.
-// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() : memref<2x5xf32>
+// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<2x5xf32>
// CHECK-NO-LAYOUT: memref.copy %[[subview]], %[[alloc2]]
// CHECK-NO-LAYOUT: memref.copy %[[alloc2]], %[[r]]
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 429c9e4dea9e93..0248afb11f1672 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -52,7 +52,7 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
// CHECK-NO-LAYOUT-MAP-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32>
// CHECK-NO-LAYOUT-MAP: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<20x10xf32>
// CHECK-NO-LAYOUT-MAP: %[[subview:.*]] = memref.subview {{.*}} : memref<20x10xf32> to memref<2x?xf32, strided<[10, 1], offset: ?>>
-// CHECK-NO-LAYOUT-MAP: %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) : memref<2x?xf32>
+// CHECK-NO-LAYOUT-MAP: %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<2x?xf32>
// CHECK-NO-LAYOUT-MAP: memref.copy %[[subview]], %[[alloc_no_layout]]
// TODO: %alloc should be deallocated here, but we currently do not dealloc
// buffers that are inserted due to to_tensor/to_memref canonicalization (when
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0e..113aad67985d70 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -84,7 +84,7 @@ func.func @canonicalize_buffer_cast_of_tensor_load_to_copy(
// CHECK-NOT: bufferization.to_memref
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, strided<[1], offset: ?>>
-// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, strided<[1], offset: 3>>
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref<?xf32, strided<[1], offset: 3>>
// CHECK: memref.copy %[[M]], %[[ALLOC]]
// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: 3>>
// CHECK: return %[[ALLOC]]
|
5158b17
to
90a464a
Compare
90a464a
to
f9969fa
Compare
This allows to configure both the op used for allocation and copy of memrefs.
It also changes the default behavior because the default allocation in
BufferizationOptions
createsmemref.alloc
withalignment = 64
where we used to creatememref.alloca
without any alignment before.Fixes