Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Implement stencil inlining. #2615

Merged
merged 18 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/filecheck/dialects/stencil/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ func.func @dup_operand(%f : !stencil.field<[0,64]xf64>, %of1 : !stencil.field<[0
// CHECK-NEXT: %t = stencil.load %f : !stencil.field<[0,64]xf64> -> !stencil.temp<?xf64>
// CHECK-NEXT: %o1, %o1_1 = stencil.apply(%one = %t : !stencil.temp<?xf64>) -> (!stencil.temp<?xf64>, !stencil.temp<?xf64>) {
// CHECK-NEXT: %0 = stencil.access %one[0] : !stencil.temp<?xf64>
// CHECK-NEXT: %1 = stencil.access %one[0] : !stencil.temp<?xf64>
// CHECK-NEXT: stencil.return %0, %1 : f64, f64
// CHECK-NEXT: stencil.return %0, %0 : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %o1 to %of1 ([0] : [64]) : !stencil.temp<?xf64> to !stencil.field<[0,64]xf64>
// CHECK-NEXT: stencil.store %o1_1 to %of2 ([0] : [64]) : !stencil.temp<?xf64> to !stencil.field<[0,64]xf64>
Expand Down
353 changes: 353 additions & 0 deletions tests/filecheck/transforms/stencil-inlining.mlir

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,15 @@ class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalisationPatternsTrait):
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.stencil import (
RedundantOperands,
UnusedOperands,
UnusedResults,
)

return (RedundantOperands(), UnusedResults())
return (
RedundantOperands(),
UnusedResults(),
UnusedOperands(),
)


@irdl_op_definition
Expand Down
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,11 @@ def get_replace_incompatible_fpga():

return replace_incompatible_fpga.ReplaceIncompatibleFPGA

def get_stencil_inlining():
from xdsl.transforms import stencil_inlining
georgebisbas marked this conversation as resolved.
Show resolved Hide resolved

return stencil_inlining.StencilInliningPass

def get_stencil_unroll():
from xdsl.transforms import stencil_unroll

Expand Down Expand Up @@ -639,6 +644,7 @@ def get_test_lower_snitch_stream_to_asm():
"snitch-allocate-registers": get_snitch_register_allocation,
"stencil-shape-inference": get_stencil_shape_inference,
"stencil-storage-materialization": get_stencil_storage_materialization,
"stencil-inlining": get_stencil_inlining,
"stencil-unroll": get_stencil_unroll,
"test-lower-snitch-stream-to-asm": get_test_lower_snitch_stream_to_asm,
}
Expand Down
32 changes: 29 additions & 3 deletions xdsl/transforms/canonicalization_patterns/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.transforms.common_subexpression_elimination import cse


class RedundantOperands(RewritePattern):
Expand All @@ -30,13 +31,38 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N
if not found_duplicate:
return

bbargs = op.region.block.args
for i, a in enumerate(bbargs):
if rbargs[i] == i:
continue
a.replace_by(bbargs[rbargs[i]])

cse(op.region.block)


class UnusedOperands(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
op_args = op.region.block.args
unused = {a for a in op_args if not a.uses}
if not unused:
return
bbargs = [a for a in op_args if a not in unused]
bbargs_type = [a.type for a in bbargs]
operands = [a for i, a in enumerate(op.args) if op_args[i] not in unused]

for arg in unused:
op.region.block.erase_arg(arg)

new = stencil.ApplyOp.get(
unique_operands,
block := Block(arg_types=[uo.type for uo in unique_operands]),
operands,
new_block := Block(arg_types=bbargs_type),
[cast(stencil.TempType[Attribute], r.type) for r in op.res],
)

rewriter.inline_block_at_start(
op.region.block, block, [block.args[i] for i in rbargs]
op.region.detach_block(0), new.region.block, new_block.args
)
rewriter.replace_matched_op(new)

Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RemoveUnusedOperations(RewritePattern):
"""

def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter):
if is_trivially_dead(op):
if is_trivially_dead(op) and op.parent is not None:
rewriter.erase_op(op)


Expand Down
Loading
Loading