Skip to content

Commit

Permalink
[MLIR] one-shot-bufferize: Add bufferize-bodiless-function-results op…
Browse files Browse the repository at this point in the history
…tion

When bufferizing a bodiless function
```
func.func private @foo() -> tensor<?xf32>
```
we currently fail with `cannot bufferize bodiless function that returns a tensor`.

This PR adds the option `bufferize-bodiless-function-results` (off by default),
to allow bufferizing this into
```
func.func private @foo() -> memref<?xf32, strided<[?], offset: ?>>
```
  • Loading branch information
mgehre-amd committed Jun 21, 2024
1 parent f951d24 commit 151f763
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ struct BufferizationOptions {
/// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
bool inferFunctionResultLayout = true;

/// If true, bufferize results of bodiless functions using the
/// `functionArgTypeConverterFn`.
/// Otherwise, bufferization fails when encountering bodiless functions that
/// have tensor results.
///
/// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
bool bufferizeBodilessFunctionResults = false;

/// Type converter from tensors to memrefs. This type converter is used if no
/// memref type could be inferred during bufferization. By default, a type
/// converter that returns a memref type with a fully dynamic layout map is
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
Option<"bufferizeFunctionBoundaries", "bufferize-function-boundaries",
"bool", /*default=*/"0",
"Bufferize function boundaries (experimental).">,
Option<"bufferizeBodilessFunctionResults", "bufferize-bodiless-function-results",
"bool", /*default=*/"0",
"Bufferize results of bodiless functions.">,
Option<"copyBeforeWrite", "copy-before-write", "bool", /*default=*/"false",
"Skip the analysis. Make a buffer copy on every write.">,
ListOption<"dialectFilter", "dialect-filter", "std::string",
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ struct OneShotBufferizePass
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
opt.bufferizeBodilessFunctionResults = bufferizeBodilessFunctionResults;
opt.noAnalysisFuncFilter = noAnalysisFuncFilter;

// Configure type converter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,17 @@ struct FuncOpInterface
if (funcOp.isExternal()) {
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (isa<TensorType>(resultType))
return funcOp->emitError() << "cannot bufferize bodiless function "
<< "that returns a tensor";
retTypes.push_back(resultType);
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
if (!options.bufferizeBodilessFunctionResults) {
return funcOp->emitError() << "cannot bufferize bodiless function "
<< "that returns a tensor";
}
retTypes.push_back(options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options));
} else {
retTypes.push_back(resultType);
}
}
funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 bufferize-bodiless-function-results=1" -split-input-file | FileCheck %s

func.func private @foo() -> tensor<?xf32>
// CHECK: func.func private @foo() -> memref<?xf32, strided<[?], offset: ?>>

// -----

func.func private @foo(tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)

func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
// CHECK: call @foo(%{{.*}}) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
return
}

0 comments on commit 151f763

Please sign in to comment.