From 151f763c5ae2cbfca47c46ad3dc45021bf2b92af Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 21 Jun 2024 10:15:22 +0200 Subject: [PATCH] [MLIR] one-shot-bufferize: Add bufferize-bodiless-function-results option When bufferizing a bodiless function ``` func.func private @foo() -> tensor ``` we currently fail with `cannot bufferize bodiless function that returns a tensor`. This PR adds the option `bufferize-bodiless-function-results` (off by default), to allow bufferizing this into ``` func.func private @foo() -> memref> ``` --- .../Bufferization/IR/BufferizableOpInterface.h | 8 ++++++++ .../Dialect/Bufferization/Transforms/Passes.td | 3 +++ .../Bufferization/Transforms/Bufferize.cpp | 1 + .../FuncBufferizableOpInterfaceImpl.cpp | 15 +++++++++++---- ...dule-bufferize-bodiless-functions-results.mlir | 15 +++++++++++++++ 5 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-bodiless-functions-results.mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index d8cfeee2466360a..36c2e691247318a 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -343,6 +343,14 @@ struct BufferizationOptions { /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. bool inferFunctionResultLayout = true; + /// If true, bufferize results of bodiless functions using the + /// `functionArgTypeConverterFn`. + /// Otherwise, bufferization fails when encountering bodiless functions that + /// have tensor results. + /// + /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. + bool bufferizeBodilessFunctionResults = false; + /// Type converter from tensors to memrefs. This type converter is used if no /// memref type could be inferred during bufferization. By default, a type /// converter that returns a memref type with a fully dynamic layout map is diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index 1c3cdec81a39e07..c3c943c3533f5b1 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -476,6 +476,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> { Option<"bufferizeFunctionBoundaries", "bufferize-function-boundaries", "bool", /*default=*/"0", "Bufferize function boundaries (experimental).">, + Option<"bufferizeBodilessFunctionResults", "bufferize-bodiless-function-results", + "bool", /*default=*/"0", + "Bufferize results of bodiless functions.">, Option<"copyBeforeWrite", "copy-before-write", "bool", /*default=*/"false", "Skip the analysis. Make a buffer copy on every write.">, ListOption<"dialectFilter", "dialect-filter", "std::string", diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index de1f0d79e12ee81..f0b1c9d0c1630a7 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -219,6 +219,7 @@ struct OneShotBufferizePass opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; + opt.bufferizeBodilessFunctionResults = bufferizeBodilessFunctionResults; opt.noAnalysisFuncFilter = noAnalysisFuncFilter; // Configure type converter. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 4cdbbf35dc876bc..21a657277d9a622 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -407,10 +407,17 @@ struct FuncOpInterface if (funcOp.isExternal()) { SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (isa(resultType)) - return funcOp->emitError() << "cannot bufferize bodiless function " - << "that returns a tensor"; - retTypes.push_back(resultType); + if (auto tensorType = dyn_cast(resultType)) { + if (!options.bufferizeBodilessFunctionResults) { + return funcOp->emitError() << "cannot bufferize bodiless function " + << "that returns a tensor"; + } + retTypes.push_back(options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + options)); + } else { + retTypes.push_back(resultType); + } } funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); return success(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-bodiless-functions-results.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-bodiless-functions-results.mlir new file mode 100644 index 000000000000000..58ecd395001ad31 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-bodiless-functions-results.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 bufferize-bodiless-function-results=1" -split-input-file | FileCheck %s + +func.func private @foo() -> tensor +// CHECK: func.func private @foo() -> memref> + +// ----- + +func.func private @foo(tensor) -> (f32, tensor, f32) +// CHECK: func.func private @foo(memref>) -> (f32, memref>, f32) + +func.func @call_to_unknown_tensor_returning_func(%t : tensor) { + call @foo(%t) : (tensor) -> (f32, tensor, f32) + // CHECK: call @foo(%{{.*}}) : (memref>) -> (f32, memref>, f32) + return +}