Skip to content

Commit

Permalink
Add missing check in inlining.
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal committed May 27, 2024
1 parent 96dc769 commit c98e78c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
44 changes: 21 additions & 23 deletions tests/filecheck/transforms/stencil-inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,12 @@ func.func @multiple_edges(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %a
// CHECK-NEXT: %3 = stencil.access %arg3[0, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %4 = stencil.access %arg3[-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %5 = stencil.access %arg3[1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %6 = stencil.access %arg3[-1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %7 = stencil.access %arg3[1, 0, 0] : !stencil.temp<[-1,65]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %8 = stencil.access %arg6[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %9 = arith.addf %3, %4 : f64
// CHECK-NEXT: %10 = arith.addf %7, %8 : f64
// CHECK-NEXT: %11 = arith.addf %9, %10 : f64
// CHECK-NEXT: %12 = stencil.store_result %11 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %12 : !stencil.result<f64>
// CHECK-NEXT: %6 = stencil.access %arg6[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %7 = arith.addf %3, %4 : f64
// CHECK-NEXT: %8 = arith.addf %5, %6 : f64
// CHECK-NEXT: %9 = arith.addf %7, %8 : f64
// CHECK-NEXT: %10 = stencil.store_result %9 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %10 : !stencil.result<f64>
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %2 to %arg2 ([0, 0, 0] : [64, 64, 60]) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>
// CHECK-NEXT: func.return
Expand Down Expand Up @@ -305,16 +303,16 @@ func.func @dyn_access(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %arg1:
// CHECK-NEXT: %3 = stencil.store_result %2 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %3 : !stencil.result<f64>
// CHECK-NEXT: }
// CHECK-NEXT: %4 = stencil.apply(%5 = %1 : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) {
// CHECK-NEXT: %6 = stencil.index 0 [-1, 0, 0]
// CHECK-NEXT: %7 = stencil.dyn_access %5[%6, %6, %6] in [-2, -1, -1] : [0, 1, 1] : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>
// CHECK-NEXT: %8 = stencil.index 0 [1, 0, 0]
// CHECK-NEXT: %9 = stencil.dyn_access %5[%8, %8, %8] in [0, -1, -1] : [2, 1, 1] : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>
// CHECK-NEXT: %10 = arith.addf %9, %7 : f64
// CHECK-NEXT: %11 = stencil.store_result %10 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %11 : !stencil.result<f64>
// CHECK-NEXT: %2 = stencil.apply(%3 = %1 : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) {
// CHECK-NEXT: %4 = stencil.index 0 [-1, 0, 0]
// CHECK-NEXT: %5 = stencil.dyn_access %3[%4, %4, %4] in [-2, -1, -1] : [0, 1, 1] : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>
// CHECK-NEXT: %6 = stencil.index 0 [1, 0, 0]
// CHECK-NEXT: %7 = stencil.dyn_access %3[%6, %6, %6] in [0, -1, -1] : [2, 1, 1] : !stencil.temp<[-2,66]x[-1,65]x[-1,61]xf64>
// CHECK-NEXT: %8 = arith.addf %7, %5 : f64
// CHECK-NEXT: %9 = stencil.store_result %8 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %9 : !stencil.result<f64>
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %4 to %arg1 ([0, 0, 0] : [64, 64, 60]) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>
// CHECK-NEXT: stencil.store %2 to %arg1 ([0, 0, 0] : [64, 64, 60]) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>
// CHECK-NEXT: func.return
// CHECK-NEXT: }

Expand Down Expand Up @@ -344,12 +342,12 @@ func.func @simple_buffer(%arg0: !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>, %ar
// CHECK-NEXT: %3 = stencil.store_result %2 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %3 : !stencil.result<f64>
// CHECK-NEXT: }
// CHECK-NEXT: %4 = stencil.buffer %1 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %5 = stencil.apply(%arg2_1 = %4 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) {
// CHECK-NEXT: %6 = stencil.access %arg2_1[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %7 = stencil.store_result %6 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %7 : !stencil.result<f64>
// CHECK-NEXT: %2 = stencil.buffer %1 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %3 = stencil.apply(%arg2 = %2 : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,60]xf64>) {
// CHECK-NEXT: %4 = stencil.access %arg2[0, 0, 0] : !stencil.temp<[0,64]x[0,64]x[0,60]xf64>
// CHECK-NEXT: %5 = stencil.store_result %4 : !stencil.result<f64>
// CHECK-NEXT: stencil.return %5 : !stencil.result<f64>
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %5 to %arg1 ([0, 0, 0] : [64, 64, 60]) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>
// CHECK-NEXT: stencil.store %3 to %arg1 ([0, 0, 0] : [64, 64, 60]) : !stencil.temp<[0,64]x[0,64]x[0,60]xf64> to !stencil.field<[-3,67]x[-3,67]x[-3,67]xf64>
// CHECK-NEXT: func.return
// CHECK-NEXT: }
18 changes: 16 additions & 2 deletions xdsl/transforms/stencil_inlining.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from xdsl.transforms.canonicalization_patterns.stencil import (
RedundantOperands,
UnusedOperands,
UnusedResults,
)
from xdsl.transforms.experimental.stencil_shape_inference import update_result_size
Expand Down Expand Up @@ -98,15 +99,23 @@ def is_rerouting_possible(producer: ApplyOp, consumer: ApplyOp):
Check if rerouting is possible.
"""
# Perform producer consumer inlining instead
return not has_single_consumer(producer, consumer)
if has_single_consumer(producer, consumer):
return False
for operand in consumer.operands:
if isinstance(operand.owner, Operation):
if (operand.owner is not producer) and is_before_in_block(
producer, operand.owner
):
return False
return True


def is_inlining_possible(producer: ApplyOp, consumer: ApplyOp):
"""
Check if inlining is possible.
"""
# Don't inline any producer with conditional writes.
return not any(
r = not any(
store_result.arg is None
for store_result in producer.walk()
if isinstance(store_result, StoreResultOp)
Expand All @@ -120,6 +129,8 @@ def is_inlining_possible(producer: ApplyOp, consumer: ApplyOp):
].uses
)

return r


class StencilReroutingPattern(RewritePattern):
"""
Expand Down Expand Up @@ -190,6 +201,7 @@ def redirect_store(
if use.operation is new_consumer:
continue
use.operation.operands[use.index] = rres

rewriter.replace_op(
consumer, new_consumer, new_consumer.res[: len(consumer.res)]
)
Expand Down Expand Up @@ -225,6 +237,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter):
if is_inlining_possible(producer, consumer) and is_rerouting_possible(
producer, consumer
):
print("Rerouting!")
return self.redirect_store(producer, consumer, rewriter)


Expand Down Expand Up @@ -347,6 +360,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
StencilReroutingPattern(),
StencilInliningPattern(),
UnusedResults(),
UnusedOperands(),
RedundantOperands(),
]
),
Expand Down

0 comments on commit c98e78c

Please sign in to comment.