Skip to content

Commit

Permalink
transformations: port MLIR's empty-tensor-to-alloc-tensor pass
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 14, 2024
1 parent 82e7072 commit af5dc1d
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 2 deletions.
12 changes: 10 additions & 2 deletions docs/marimo/onnx/onnx_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import marimo

__generated_with = "0.8.20"
__generated_with = "0.9.17"
app = marimo.App()


Expand Down Expand Up @@ -168,6 +168,7 @@ def __(mo):
@app.cell
def __(
ConvertOnnxToLinalgPass,
EmptyTensorToAllocTensorPass,
MLIROptPass,
init_module,
mo,
Expand All @@ -194,13 +195,18 @@ def __(
arguments=["--linalg-generalize-named-ops"]
)
),
(
mo.md(
"""We prepare the result tensors for bufferization:"""
),
EmptyTensorToAllocTensorPass()
),
(
mo.md(
"""We then use MLIR to bufferize our function:"""
),
MLIROptPass(
arguments=[
"--empty-tensor-to-alloc-tensor",
"--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map",
]
)
Expand Down Expand Up @@ -232,10 +238,12 @@ def __():
from xdsl.passes import PipelinePass
from xdsl.tools.command_line_tool import get_all_dialects
from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass
from xdsl.transforms.empty_tensor_to_alloc_tensor import EmptyTensorToAllocTensorPass
from xdsl.transforms.mlir_opt import MLIROptPass
return (
Attribute,
ConvertOnnxToLinalgPass,
EmptyTensorToAllocTensorPass,
MLContext,
MLIROptPass,
PipelinePass,
Expand Down
10 changes: 10 additions & 0 deletions tests/filecheck/transforms/empty-tensor-to-alloc-tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: xdsl-opt -p empty-tensor-to-alloc-tensor %s | filecheck %s

// CHECK: %val = "test.op"() : () -> index
%val = "test.op"() : () -> index

// CHECK-NEXT: %static = bufferization.alloc_tensor() : tensor<1024xi32>
%static = tensor.empty() : tensor<1024xi32>

// CHECK-NEXT: %dynamic = bufferization.alloc_tensor(%val) : tensor<1024x?xi32>
%dynamic = tensor.empty(%val) : tensor<1024x?xi32>
6 changes: 6 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def get_dmp_to_mpi():

return stencil_global_to_local.DmpToMpiPass

def get_empty_tensor_to_alloc_tensor():
from xdsl.transforms import empty_tensor_to_alloc_tensor

return empty_tensor_to_alloc_tensor.EmptyTensorToAllocTensorPass

def get_eqsat_add_costs():
from xdsl.transforms import eqsat_add_costs

Expand Down Expand Up @@ -523,6 +528,7 @@ def get_varith_fuse_repeated_operands():
"dce": get_dce,
"distribute-stencil": get_distribute_stencil,
"dmp-to-mpi": get_dmp_to_mpi,
"empty-tensor-to-alloc-tensor": get_empty_tensor_to_alloc_tensor,
"eqsat-add-costs": get_eqsat_add_costs,
"eqsat-create-eclasses": get_eqsat_create_eclasses,
"eqsat-extract": get_eqsat_extract,
Expand Down
36 changes: 36 additions & 0 deletions xdsl/transforms/empty_tensor_to_alloc_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from xdsl.context import MLContext
from xdsl.dialects import bufferization, builtin, tensor
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)


class EmptyTensorLoweringPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: tensor.EmptyOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(
bufferization.AllocTensorOp(
op.tensor.type,
op.dynamic_sizes,
)
)


class EmptyTensorToAllocTensorPass(ModulePass):
"""
tensor.empty ops return a tensor of unspecified contents who's only purpose
is to carry the tensor shape. This pass converts such ops to
bufferization.alloc_tensor ops, which bufferize to buffer allocations.
"""

name = "empty-tensor-to-alloc-tensor"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
EmptyTensorLoweringPattern(),
apply_recursively=False,
).rewrite_module(op)

0 comments on commit af5dc1d

Please sign in to comment.