Skip to content

Commit

Permalink
[MLIR] Make OneShotModuleBufferize use OpInterface (llvm#110322)
Browse files Browse the repository at this point in the history
**Description:** 
This PR replaces a part of `FuncOp` and `CallOp` with
`FunctionOpInterface` and `CallOpInterface` in `OneShotModuleBufferize`.
Also fix the error from an integration test in the a previous PR
attempt. (llvm#107295)

The below fixes skip `CallOpInterface` so that the assertions are not
triggered.


https://github.com/llvm/llvm-project/blob/8d780007625108a7f34e40efb8604b858e04c60c/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp#L254-L259


https://github.com/llvm/llvm-project/blob/8d780007625108a7f34e40efb8604b858e04c60c/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp#L311-L315

**Related Discord Discussion:**
[Link](https://discord.com/channels/636084430946959380/642426447167881246/1280556809911799900)

---------

Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
  • Loading branch information
tzunghanjuang and erick-xanadu authored Oct 1, 2024
1 parent 60b604a commit 2026501
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 281 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfoVariant.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -260,9 +261,9 @@ struct BufferizationOptions {
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
TensorType, Attribute memorySpace, FunctionOpInterface,
const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {

/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;

/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;

/// A set of all read BlockArguments of FuncOps.
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;

/// A set of all written-to BlockArguments of FuncOps.
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;

/// Keep track of which FuncOps are fully analyzed or currently being
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;

/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp);
void startFunctionAnalysis(FunctionOpInterface funcOp);
};

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -314,7 +315,7 @@ namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
BaseMemRefType
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
func::FuncOp funcOp,
FunctionOpInterface funcOp,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
}
Expand Down Expand Up @@ -361,7 +362,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
func::FuncOp funcOp,
FunctionOpInterface funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace mlir {
namespace bufferization {
namespace func_ext {

void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
auto createdAliasingResults =
Expand Down
112 changes: 56 additions & 56 deletions mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using namespace mlir::bufferization;
using namespace mlir::bufferization::func_ext;

/// A mapping of FuncOps to their callers.
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;

/// Get or create FuncAnalysisState.
static FuncAnalysisState &
Expand All @@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {

/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
func::ReturnOp returnOp;
for (Block &b : funcOp.getBody()) {
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
Operation *returnOp = nullptr;
for (Block &b : funcOp.getFunctionBody()) {
auto candidateOp = b.getTerminator();
if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
Expand Down Expand Up @@ -126,16 +127,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
/// Store function BlockArguments that are equivalent to/aliasing a returned
/// value in FuncAnalysisState.
static LogicalResult
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
if (funcOp.getBody().empty()) {
if (funcOp.getFunctionBody().empty()) {
// No function body available. Conservatively assume that every tensor
// return value may alias with any tensor bbArg.
FunctionType type = funcOp.getFunctionType();
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
if (!isa<TensorType>(inputIt.value()))
continue;
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
if (!isa<TensorType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
Expand All @@ -147,7 +148,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
}

// Support only single return-terminated block in the function.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");

for (OpOperand &returnVal : returnOp->getOpOperands())
Expand All @@ -168,8 +169,8 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}

static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
bool isWritten) {
static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
bool isRead, bool isWritten) {
OpBuilder b(funcOp.getContext());
Attribute accessType;
if (isRead && isWritten) {
Expand All @@ -189,12 +190,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
/// function with unknown ops, we conservatively assume that such ops bufferize
/// to a read + write.
static LogicalResult
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
++idx) {
for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
// Skip non-tensor arguments.
if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
continue;
bool isRead;
bool isWritten;
Expand All @@ -204,7 +205,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
StringRef str = accessAttr.getValue();
isRead = str == "read" || str == "read-write";
isWritten = str == "write" || str == "read-write";
} else if (funcOp.getBody().empty()) {
} else if (funcOp.getFunctionBody().empty()) {
// If the function has no body, conservatively assume that all args are
// read + written.
isRead = true;
Expand All @@ -230,33 +231,33 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,

/// Remove bufferization attributes on FuncOp arguments.
static void removeBufferizationAttributes(BlockArgument bbArg) {
auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kBufferLayoutAttrName);
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizationDialect::kWritableAttrName);
}

/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
return dyn_cast_or_null<FunctionOpInterface>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}

/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if the called function was already
/// analyzed.
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(func::FuncOp funcOp,
static void equivalenceAnalysis(FunctionOpInterface funcOp,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
funcOp->walk([&](func::CallOp callOp) {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
funcOp->walk([&](CallOpInterface callOp) {
FunctionOpInterface calledFunction = getCalledFunction(callOp);
if (!calledFunction)
return WalkResult::skip();

// No equivalence info available for the called function.
if (!funcState.equivalentFuncArgs.count(calledFunction))
Expand All @@ -267,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
int64_t bbargIdx = it.second;
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
continue;
Value returnVal = callOp.getResult(returnIdx);
Value returnVal = callOp->getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
state.unionEquivalenceClasses(returnVal, argVal);
}
Expand All @@ -277,11 +278,9 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
}

/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
return llvm::any_of(funcOp.getFunctionType().getInputs(),
llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getFunctionType().getResults(),
llvm::IsaPred<TensorType>);
static bool hasTensorSignature(FunctionOpInterface funcOp) {
return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred<TensorType>) ||
llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred<TensorType>);
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
Expand All @@ -291,16 +290,16 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
if (!funcOp.getBody().empty()) {
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
if (!funcOp.getFunctionBody().empty()) {
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
Expand All @@ -309,9 +308,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,

// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
FunctionOpInterface calledFunction = getCalledFunction(callOp);
if (!calledFunction)
return WalkResult::skip();
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
Expand Down Expand Up @@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// most generic layout map as function return types. After bufferizing the
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(func::FuncOp funcOp) {
if (funcOp.getBody().empty())
static void foldMemRefCasts(FunctionOpInterface funcOp) {
if (funcOp.getFunctionBody().empty())
return;

func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
SmallVector<Type> resultTypes;

for (OpOperand &operand : returnOp->getOpOperands()) {
Expand All @@ -365,8 +365,8 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
}

auto newFuncType = FunctionType::get(
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
auto newFuncType = FunctionType::get(funcOp.getContext(),
funcOp.getArgumentTypes(), resultTypes);
funcOp.setType(newFuncType);
}

Expand All @@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);

// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
SmallVector<FunctionOpInterface> orderedFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
Expand All @@ -388,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return failure();

// Analyze ops.
for (func::FuncOp funcOp : orderedFuncOps) {
for (FunctionOpInterface funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;

Expand Down Expand Up @@ -416,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,

void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
moduleOp.walk([&](func::FuncOp op) {
moduleOp.walk([&](FunctionOpInterface op) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
});
Expand All @@ -430,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());

// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
SmallVector<FunctionOpInterface> orderedFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
Expand All @@ -439,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
return failure();

// Bufferize functions.
for (func::FuncOp funcOp : orderedFuncOps) {
for (FunctionOpInterface funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.

if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
// This function was not analyzed and RaW conflicts were not resolved.
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
Expand All @@ -463,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op))
if (isa<FunctionOpInterface>(&op))
continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
Expand All @@ -490,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
auto func = dyn_cast<func::FuncOp>(op);
auto func = dyn_cast<FunctionOpInterface>(op);
if (!func)
func = op->getParentOfType<func::FuncOp>();
func = op->getParentOfType<FunctionOpInterface>();
if (func)
return llvm::is_contained(options.noAnalysisFuncFilter,
func.getSymName());
func.getName());
return false;
};
OneShotBufferizationOptions updatedOptions(options);
Expand Down
Loading

0 comments on commit 2026501

Please sign in to comment.