Skip to content

Commit

Permalink
[mlir][bufferize] Improve resolveConflicts for ExtractSliceOp
Browse files Browse the repository at this point in the history
It is sometimes better to make a copy of the OpResult instead of making a copy of the OpOperand. E.g., when bufferizing tensor.extract_slice.

This implementation will eventually make parts of extract_slice's `bufferize` implementation obsolete (and simplify it). It will only need to handle in-place OpOperands.

Differential Revision: https://reviews.llvm.org/D126819
  • Loading branch information
matthias-springer committed Jun 9, 2022
1 parent 72a049d commit 87b4677
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
50 changes: 45 additions & 5 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ constexpr const ::llvm::StringLiteral

LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
RewriterBase &rewriter, const AnalysisState &state) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
SmallVector<OpResult> outOfPlaceOpResults;

// Find all out-of-place OpOperands.
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
if (!operandType.isa<TensorType>())
Expand All @@ -53,17 +58,52 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
continue;
if (operandType.isa<UnrankedTensorType>())
return op->emitError("copies of unranked tensors are not supported");
auto tensorType = operandType.dyn_cast<RankedTensorType>();
if (!tensorType)
continue;

SmallVector<OpResult> aliasingOpResults =
state.getAliasingOpResult(opOperand);
if (aliasingOpResults.size() == 1 &&
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
// The op itself does not write but may create exactly one alias. Instead
// of copying the OpOperand, copy the OpResult. The OpResult can sometimes
// be smaller than the OpOperand (e.g., in the case of an extract_slice,
// where the result is usually a smaller part of the source).
outOfPlaceOpResults.push_back(aliasingOpResults.front());
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
}
}

// Insert copies of OpOperands.
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
auto tensorType = opOperand->get().getType().cast<RankedTensorType>();
SmallVector<OpResult> aliasingOpResults =
state.getAliasingOpResult(*opOperand);
bool escape = llvm::any_of(
aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
Value copy = rewriter.create<AllocTensorOp>(
op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape);
rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
}

// Insert copies of OpResults.
rewriter.setInsertionPointAfter(op);
for (OpResult opResult : outOfPlaceOpResults) {
auto tensorType = opResult.getType().cast<RankedTensorType>();
bool escape = state.isTensorYielded(opResult);
Value copy = rewriter.create<AllocTensorOp>(op->getLoc(), tensorType,
ValueRange(), opResult, escape);
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
opResult.getUses(), [](OpOperand &use) { return &use; }));
for (OpOperand *use : uses) {
// Do not update the alloc_tensor op that we just created.
if (use->getOwner() != copy.getDefiningOp())
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
}
}

return success();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s
// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC

// CHECK-LABEL: func @extract_slice(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
// CHECK-FUNC-LABEL: func @extract_slice(
func.func @extract_slice(%t: tensor<?xf32>, %idx: index, %f: f32)
-> (tensor<5xf32>, tensor<?xf32>)
{
// CHECK: %[[extract_slice:.*]] = tensor.extract_slice %[[t]][10] [5] [1]
%0 = tensor.extract_slice %t[10][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[extract_slice]]) {escape = false} : tensor<5xf32>
// CHECK-FUNC: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<5xf32>
// CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[alloc]]
%1 = tensor.insert %f into %0[%idx] : tensor<5xf32>
// CHECK: return %[[insert]], %[[t]]
return %1, %t : tensor<5xf32>, tensor<?xf32>
}

0 comments on commit 87b4677

Please sign in to comment.