-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Implement folding logic for size 0 tensor and memref ops #90814
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir-tensor Author: Spenser Bauman (sabauma) ChangesImplement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases:
Full diff: https://github.com/llvm/llvm-project/pull/90814.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..675aeacd8f0e23 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
return success();
}
};
+
+struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+ static bool isEmptyMemRef(BaseMemRefType type) {
+ return type.hasRank() &&
+ llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
+ }
+
+ LogicalResult matchAndRewrite(CopyOp copyOp,
+ PatternRewriter& rewriter) const override {
+ if (isEmptyMemRef(copyOp.getSource().getType()) ||
+ isEmptyMemRef(copyOp.getTarget().getType())) {
+ rewriter.eraseOp(copyOp);
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldCopyOfCast, FoldSelfCopy>(context);
+ results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
}
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..ef8a078078c864 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2606,6 +2606,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
+ if (llvm::any_of(getMixedSizes(),
+ [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+ return getDest();
return OpFoldResult();
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f442a61dc31ed1..c4ff6480a4ce5e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
// -----
+// CHECK-LABEL: func @empty_copy
+// CHECK-NEXT: return
+func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
+ memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
+ memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
+ return
+}
+
+// -----
+
func.func @scopeMerge() {
memref.alloca_scope {
%cnt = "test.count"() : () -> index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..e8adb7653c3e23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
// -----
+// CHECK-LABEL: func @empty_insert_slice
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
+// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
+// CHECK-NOT: tensor.extract_slice
+// CHECK: return %[[ARG1]] : tensor<3x3xi8>
+func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
+ %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
+ return %0 : tensor<3x3xi8>
+}
+
+// -----
+
// CHECK-LABEL: func @rank_reducing_tensor_of_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases: 1. tensor.insert_slice operations where the size of the inserted slice is known to be 0. 2. memref.copy operations where either the source or target memrefs are known to be emtpy.
5ed96f6
to
85243eb
Compare
@matthias-springer @nicolasvasilache Any chance you could take a look when you have some time? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! In general tensors with 0 dims to me indicate a problem somewhere, but this folding itself makes sense.
In our case this was occurring after some lowering patterns which generate a lot of |
Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases: