Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Implement folding logic for size 0 tensor and memref ops #90814

Merged
merged 1 commit into from
May 20, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented May 2, 2024

Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases:

  1. tensor.insert_slice operations where the size of the inserted slice is known to be 0.
  2. memref.copy operations where either the source or target memrefs are known to be emtpy.

@llvmbot
Copy link
Member

llvmbot commented May 2, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Spenser Bauman (sabauma)

Changes

Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases:

  1. tensor.insert_slice operations where the size of the inserted slice is known to be 0.
  2. memref.copy operations where either the source or target memrefs are known to be emtpy.

Full diff: https://github.com/llvm/llvm-project/pull/90814.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+21-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+3)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b969d41d934d41..675aeacd8f0e23 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
     return success();
   }
 };
+
+struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+  static bool isEmptyMemRef(BaseMemRefType type) {
+    return type.hasRank() &&
+      llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
+  }
+
+  LogicalResult matchAndRewrite(CopyOp copyOp,
+                                PatternRewriter& rewriter) const override {
+    if (isEmptyMemRef(copyOp.getSource().getType()) ||
+        isEmptyMemRef(copyOp.getTarget().getType())) {
+      rewriter.eraseOp(copyOp);
+      return success();
+    }
+
+    return failure();
+  }
+};
 } // namespace
 
 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldCopyOfCast, FoldSelfCopy>(context);
+  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
 }
 
 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5f..ef8a078078c864 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2606,6 +2606,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
     return getResult();
   if (auto result = foldInsertAfterExtractSlice(*this))
     return result;
+  if (llvm::any_of(getMixedSizes(),
+                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+    return getDest();
   return OpFoldResult();
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f442a61dc31ed1..c4ff6480a4ce5e 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @empty_copy
+//  CHECK-NEXT:   return
+func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
+  memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
+  memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
+  return
+}
+
+// -----
+
 func.func @scopeMerge() {
   memref.alloca_scope {
     %cnt = "test.count"() : () -> index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c93..e8adb7653c3e23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
 
 // -----
 
+// CHECK-LABEL: func @empty_insert_slice
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
+//  CHECK-SAME:   %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
+//   CHECK-NOT:   tensor.extract_slice
+//       CHECK:   return %[[ARG1]] :  tensor<3x3xi8>
+func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
+  return %0 : tensor<3x3xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>

Copy link

github-actions bot commented May 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Implement folding and rewrite logic to eliminate no-op tensor and memref
operations. This handles two specific cases:

1. tensor.insert_slice operations where the size of the inserted slice
   is known to be 0.
2. memref.copy operations where either the source or target memrefs are
   known to be emtpy.
@sabauma sabauma force-pushed the insert-slice-folder branch from 5ed96f6 to 85243eb Compare May 2, 2024 02:51
@sabauma
Copy link
Contributor Author

sabauma commented May 17, 2024

@matthias-springer @nicolasvasilache Any chance you could take a look when you have some time?

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! In general tensors with 0 dims to me indicate a problem somewhere, but this folding itself makes sense.

@sabauma
Copy link
Contributor Author

sabauma commented May 20, 2024

In our case this was occurring after some lowering patterns which generate a lot of extract_slice and insert_slice patterns. We do some folding during the lowering, but the fact that the slices are size-0 is not always obvious until after canonicalizing.

@sabauma sabauma merged commit 1f07bfb into llvm:main May 20, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants