diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 45f39c80041c9a0..d70e6d0b79cd6fb 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern { return success(); } }; + +struct FoldEmptyCopy final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static bool isEmptyMemRef(BaseMemRefType type) { + return type.hasRank() && + llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; }); + } + + LogicalResult matchAndRewrite(CopyOp copyOp, + PatternRewriter &rewriter) const override { + if (isEmptyMemRef(copyOp.getSource().getType()) || + isEmptyMemRef(copyOp.getTarget().getType())) { + rewriter.eraseOp(copyOp); + return success(); + } + + return failure(); + } +}; } // namespace void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } LogicalResult CopyOp::fold(FoldAdaptor adaptor, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 8a6df82abb312aa..8545c7b9af8f73d 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2609,6 +2609,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) { return getResult(); if (auto result = foldInsertAfterExtractSlice(*this)) return result; + if (llvm::any_of(getMixedSizes(), + [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); })) + return getDest(); return OpFoldResult(); } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index f442a61dc31ed1c..c4ff6480a4ce5e1 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref) { // ----- +// CHECK-LABEL: func @empty_copy +// CHECK-NEXT: return +func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref) { + memref.copy %m1, %m2 : memref<0x10xf32> to memref + memref.copy %m2, %m1 : memref to memref<0x10xf32> + return +} + +// ----- + func.func @scopeMerge() { memref.alloca_scope { %cnt = "test.count"() : () -> index diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index b5a82eb3e9035de..914e5e8b8c4b861 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6 // ----- +// CHECK-LABEL: func @empty_insert_slice +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8> +// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8> +// CHECK-NOT: tensor.extract_slice +// CHECK: return %[[ARG1]] : tensor<3x3xi8> +func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> { + %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8> + return %0 : tensor<3x3xi8> +} + +// ----- + // CHECK-LABEL: func @rank_reducing_tensor_of_cast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>