Skip to content

Commit

Permalink
[mlir][tensor] Implement folding logic for size 0 tensor and memref o…
Browse files Browse the repository at this point in the history
…ps (#90814)

Implement folding and rewrite logic to eliminate no-op tensor and memref
operations. This handles two specific cases:

1. tensor.insert_slice operations where the size of the inserted slice
is known to be 0.
2. memref.copy operations where either the source or target memrefs are
known to be emtpy.

Co-authored-by: Spenser Bauman <sabauma@fastmail>
  • Loading branch information
sabauma and Spenser Bauman authored May 20, 2024
1 parent 250c39c commit 1f07bfb
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
22 changes: 21 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
return success();
}
};

struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;

static bool isEmptyMemRef(BaseMemRefType type) {
return type.hasRank() &&
llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
}

LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
if (isEmptyMemRef(copyOp.getSource().getType()) ||
isEmptyMemRef(copyOp.getTarget().getType())) {
rewriter.eraseOp(copyOp);
return success();
}

return failure();
}
};
} // namespace

void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldCopyOfCast, FoldSelfCopy>(context);
results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
}

LogicalResult CopyOp::fold(FoldAdaptor adaptor,
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2609,6 +2609,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
if (llvm::any_of(getMixedSizes(),
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
return getDest();
return OpFoldResult();
}

Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {

// -----

// CHECK-LABEL: func @empty_copy
// CHECK-NEXT: return
func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
return
}

// -----

func.func @scopeMerge() {
memref.alloca_scope {
%cnt = "test.count"() : () -> index
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6

// -----

// CHECK-LABEL: func @empty_insert_slice
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
// CHECK-NOT: tensor.extract_slice
// CHECK: return %[[ARG1]] : tensor<3x3xi8>
func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
%0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
return %0 : tensor<3x3xi8>
}

// -----

// CHECK-LABEL: func @rank_reducing_tensor_of_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
Expand Down

0 comments on commit 1f07bfb

Please sign in to comment.