Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SCF] Retire SCF-specific to_memref/to_tensor canonicalization patterns #74551

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SCF/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ add_mlir_dialect_library(MLIRSCFDialect

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRBufferizationDialect
MLIRControlFlowDialect
MLIRDialectUtils
MLIRFunctionInterfaces
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRValueBoundsOpInterface
)

132 changes: 2 additions & 130 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
Expand Down Expand Up @@ -1082,139 +1081,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};

/// Canonicalize the iter_args of an scf::ForOp that involve a
/// `bufferization.to_tensor` and for which only the last loop iteration is
/// actually visible outside of the loop. The canonicalization looks for a
/// pattern such as:
/// ```
/// %t0 = ... : tensor_type
/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
/// ...
/// // %m is either buffer_cast(%bb00) or defined above the loop
/// %m... : memref_type
/// ... // uses of %m with potential inplace updates
/// %new_tensor = bufferization.to_tensor %m : memref_type
/// ...
/// scf.yield %new_tensor : tensor_type
/// }
/// ```
///
/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
/// `%m = buffer_cast %bb0` op that feeds into the yielded
/// `bufferization.to_tensor` op.
///
/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
/// occurs between `bufferization.to_tensor and yield then the value %0
/// visible outside of the loop is the last `bufferization.to_tensor`
/// produced in the loop.
///
/// For now, we approximate the absence of aliasing by only supporting the case
/// when the bufferization.to_tensor is the operation immediately preceding
/// the yield.
//
/// The canonicalization rewrites the pattern as:
/// ```
/// // %m is either a buffer_cast or defined above
/// %m... : memref_type
/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
/// ... // uses of %m with potential inplace updates
/// scf.yield %bb0: tensor_type
/// }
/// %0 = bufferization.to_tensor %m : memref_type
/// ```
///
/// A later bbArg canonicalization will further rewrite as:
/// ```
/// // %m is either a buffer_cast or defined above
/// %m... : memref_type
/// scf.for ... { // no iter_args
/// ... // uses of %m with potential inplace updates
/// }
/// %0 = bufferization.to_tensor %m : memref_type
/// ```
struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
"unexpected multiple blocks");

Location loc = forOp.getLoc();
DenseMap<Value, Value> replacements;
for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
auto yieldOp =
cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
Value yieldVal = yieldOp->getOperand(idx);
auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
bool isTensor = llvm::isa<TensorType>(bbArg.getType());

bufferization::ToMemrefOp tensorToMemref;
// Either bbArg has no use or it has a single buffer_cast use.
if (bbArg.hasOneUse())
tensorToMemref =
dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
continue;
// If tensorToMemref is present, it must feed into the `ToTensorOp`.
if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
continue;
// TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
// must be before `ToTensorOp` in the block so that the lastWrite
// property is not subject to additional side-effects.
// For now, we only support the case when ToTensorOp appears
// immediately before the terminator.
if (tensorLoadOp->getNextNode() != yieldOp)
continue;

// Clone the optional tensorToMemref before forOp.
if (tensorToMemref) {
rewriter.setInsertionPoint(forOp);
rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
tensorToMemref, tensorToMemref.getMemref().getType(),
tensorToMemref.getTensor());
}

// Clone the tensorLoad after forOp.
rewriter.setInsertionPointAfter(forOp);
Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
loc, tensorLoadOp.getMemref());
Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
replacements.insert(std::make_pair(forOpResult, newTensorLoad));

// Make the terminator just yield the bbArg, the old tensorLoadOp + the
// old bbArg (that is now directly yielded) will canonicalize away.
rewriter.startRootUpdate(yieldOp);
yieldOp.setOperand(idx, bbArg);
rewriter.finalizeRootUpdate(yieldOp);
}
if (replacements.empty())
return failure();

// We want to replace a subset of the results of `forOp`. rewriter.replaceOp
// replaces the whole op and erase it unconditionally. This is wrong for
// `forOp` as it generally contains ops with side effects.
// Instead, use `rewriter.replaceOpWithIf`.
SmallVector<Value> newResults;
newResults.reserve(forOp.getNumResults());
for (Value v : forOp.getResults()) {
auto it = replacements.find(v);
newResults.push_back((it != replacements.end()) ? it->second : v);
}
unsigned idx = 0;
rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
return op.get() != newResults[idx++];
});
return success();
}
};
} // namespace

void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
context);
}

std::optional<APInt> ForOp::getConstantStep() {
Expand Down
50 changes: 0 additions & 50 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -773,56 +773,6 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {

// -----

func.func private @process(%0 : memref<128x128xf32>)
func.func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>

// CHECK-LABEL: last_value
// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<128x128xf32>
// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<128x128xf32>
// CHECK-SAME: %[[T2:[0-9a-z]*]]: tensor<128x128xf32>
// CHECK-SAME: %[[M0:[0-9a-z]*]]: memref<128x128xf32>
func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
%t2: tensor<128x128xf32>, %m0: memref<128x128xf32>,
%lb : index, %ub : index, %step : index)
-> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
{
// CHECK-NEXT: %[[M1:.*]] = bufferization.to_memref %[[T1]] : memref<128x128xf32>
// CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) {
%0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2)
-> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
{
%m1 = bufferization.to_memref %arg2 : memref<128x128xf32>

// CHECK-NEXT: call @process(%[[M0]]) : (memref<128x128xf32>) -> ()
func.call @process(%m0) : (memref<128x128xf32>) -> ()

// CHECK-NEXT: call @process(%[[M1]]) : (memref<128x128xf32>) -> ()
func.call @process(%m1) : (memref<128x128xf32>) -> ()

// This does not hoist (fails the bbArg has at most a single check).
// CHECK-NEXT: %[[T:.*]] = func.call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32>
// CHECK-NEXT: %[[YIELD_T:.*]] = bufferization.to_tensor %[[T:.*]]
%m2 = func.call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32>
%3 = bufferization.to_tensor %m2 : memref<128x128xf32>

// All this stuff goes away, incrementally
%1 = bufferization.to_tensor %m0 : memref<128x128xf32>
%2 = bufferization.to_tensor %m1 : memref<128x128xf32>

// CHECK-NEXT: scf.yield %[[YIELD_T]] : tensor<128x128xf32>
scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>

// CHECK-NEXT: }
}

// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
// CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
// CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}

// -----

// CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32
func.func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
Expand Down
1 change: 0 additions & 1 deletion utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3994,7 +3994,6 @@ cc_library(
deps = [
":ArithDialect",
":ArithUtils",
":BufferizationDialect",
":ControlFlowDialect",
":ControlFlowInterfaces",
":DestinationStyleOpInterface",
Expand Down