From 10154e269879f5ca28e3b44bae6f3c206410ab4b Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Wed, 7 Aug 2024 14:17:03 +0100 Subject: [PATCH 01/12] Finer analysis in stencil bufferization and related side-effect updates on stencil ops. --- .../transforms/stencil-bufferize.mlir | 4 +- xdsl/dialects/stencil.py | 26 +++++++--- xdsl/transforms/stencil_bufferize.py | 47 ++++++++++++------- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/tests/filecheck/transforms/stencil-bufferize.mlir b/tests/filecheck/transforms/stencil-bufferize.mlir index 9c193299af..4ab32f9967 100644 --- a/tests/filecheck/transforms/stencil-bufferize.mlir +++ b/tests/filecheck/transforms/stencil-bufferize.mlir @@ -366,7 +366,7 @@ func.func @store_result_lowering(%arg0 : f64) { func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { - %true = "test.op"() : () -> i1 + %true = "test.pureop"() : () -> i1 %2, %3 = "scf.if"(%true) ({ %4 = stencil.store_result %arg1 : !stencil.result scf.yield %4, %arg1 : !stencil.result, f64 @@ -384,7 +384,7 @@ func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, // CHECK: func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ // CHECK-NEXT: stencil.apply(%arg1 = %arg0 : f64) outs (%b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) { -// CHECK-NEXT: %true = "test.op"() : () -> i1 +// CHECK-NEXT: %true = "test.pureop"() : () -> i1 // CHECK-NEXT: %0, %1 = "scf.if"(%true) ({ // CHECK-NEXT: %2 = stencil.store_result %arg1 : !stencil.result // CHECK-NEXT: scf.yield %2, %arg1 : !stencil.result, f64 diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 9b30d46813..ddaa172be5 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -66,7 +66,6 @@ IsTerminator, MemoryEffect, MemoryEffectKind, - MemoryReadEffect, NoMemoryEffect, Pure, RecursiveMemoryEffect, @@ -427,10 +426,11 @@ class ApplyMemoryEffect(RecursiveMemoryEffect): def get_effects(cls, op: Operation): effects = super().get_effects(op) if effects is not None: - if len(cast(ApplyOp, op).dest) > 0: - effects.add(EffectInstance(MemoryEffectKind.WRITE)) - if any(isinstance(o.type, FieldType) for o in op.operands): - effects.add(EffectInstance(MemoryEffectKind.READ)) + for d in cast(ApplyOp, op).dest: + effects.add(EffectInstance(MemoryEffectKind.WRITE, d)) + for o in cast(ApplyOp, op).args: + if isinstance(o.type, FieldType): + effects.add(EffectInstance(MemoryEffectKind.READ, o)) return effects @@ -1133,6 +1133,12 @@ def get_apply(self): return cast(ApplyOp, ancestor) +class LoadOpMemoryEffect(MemoryEffect): + @classmethod + def get_effects(cls, op: Operation): + return {EffectInstance(MemoryEffectKind.READ, cast(LoadOp, op).field)} + + @irdl_op_definition class LoadOp(IRDLOperation): """ @@ -1171,7 +1177,7 @@ class LoadOp(IRDLOperation): assembly_format = "$field attr-dict-with-keyword `:` type($field) `->` type($res)" - traits = frozenset([MemoryReadEffect()]) + traits = frozenset([LoadOpMemoryEffect()]) @staticmethod def get( @@ -1296,6 +1302,12 @@ def verify( super().verify(attr, constraint_context) +class StoreOpMemoryEffect(MemoryEffect): + @classmethod + def get_effects(cls, op: Operation): + return {EffectInstance(MemoryEffectKind.WRITE, cast(StoreOp, op).field)} + + @irdl_op_definition class StoreOp(IRDLOperation): """ @@ -1348,6 +1360,8 @@ class StoreOp(IRDLOperation): assembly_format = "$temp `to` $field `` `(` $bounds `)` attr-dict-with-keyword `:` type($temp) `to` type($field)" + traits = frozenset([StoreOpMemoryEffect()]) + @staticmethod def get( temp: SSAValue | Operation, diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 34f85dfc77..acdded1b8a 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -33,7 +33,7 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint -from xdsl.traits import is_side_effect_free +from xdsl.traits import MemoryEffectKind, get_effects from xdsl.transforms.dead_code_elimination import RemoveUnusedOperations from xdsl.utils.hints import isa @@ -44,6 +44,23 @@ def field_from_temp(temp: TempType[_TypeElement]) -> FieldType[_TypeElement]: return FieldType[_TypeElement].new(temp.parameters) +def only_has_effect(op: Operation, effect: MemoryEffectKind) -> bool: + """ + Returns if the operation has the given side effects and no others. + """ + effects = get_effects(op) + return effects is not None and all(e.kind == effect for e in effects) + + +def might_effect( + operation: Operation, effects: set[MemoryEffectKind], value: SSAValue +) -> bool: + op_effects = get_effects(operation) + return op_effects is None or any( + e.kind in effects and e.value in (None, value) for e in op_effects + ) + + class ApplyBufferizePattern(RewritePattern): """ Naive partial `stencil.apply` bufferization. @@ -184,9 +201,7 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter): effecting = [ o for o in walk_from_to(load, user) - if underlying in o.operands - and (not is_side_effect_free(o)) - and (o not in (load, op, user)) + if might_effect(o, {MemoryEffectKind.WRITE}, underlying) ] if effecting: return @@ -205,9 +220,9 @@ class ApplyLoadStoreFoldPattern(RewritePattern): stencil.apply() outs (%temp : !stencil.field<[0,32]>) { // [...] } - // [... %temp, %dest not affected] + // [... %dest not read] %loaded = stencil.load %temp : !stencil.field<[0,32]> -> !stencil.temp<[0,32]> - // [... %dest not affected] + // [... %dest not read] stencil.store %loaded to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]> ``` yields: @@ -238,8 +253,9 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): if len(other_uses) != 1: return + # we restrict to the case where the apply and load are the only users of %temp + # for now other_use = other_uses.pop() - if not isinstance( apply := other_use.operation, ApplyOp ) or other_use.index < len(apply.args): @@ -247,26 +263,21 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): print() return - # Get first occurence of the field, to walk from it - start = op.field.owner + # Get first occurence of the destination field, to walk from it + dest = op.field + start = dest.owner if isinstance(start, Block): - if start is not op.parent: - return start = cast(Operation, start.first_op) effecting = [ o for o in walk_from_to(start, op) - if infield in o.operands - and (not is_side_effect_free(o)) - and (o not in (load, apply)) + if might_effect(o, {MemoryEffectKind.READ}, dest) ] if effecting: - print("effecting: ", effecting) - print(load) return new_operands = list(apply.operands) - new_operands[other_use.index] = op.field + new_operands[other_use.index] = dest new_apply = ApplyOp.create( operands=new_operands, @@ -285,7 +296,7 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): ) new_load = LoadOp.create( - operands=[op.field], + operands=[dest], result_types=[r.type for r in load.results], attributes=load.attributes.copy(), properties=load.properties.copy(), From f0406e87a79e511e34cf2454b93704bda679217b Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Wed, 7 Aug 2024 16:33:26 +0100 Subject: [PATCH 02/12] Try out more progressive bufferization on apply. This significantly lift analysis difficulty on the whole apply/load/store; might be slightly less generic? For now, it's a step forward. --- .../stencil/oec-kernels/fvtp2d_qi.mlir | 115 +++++++++++++++++ .../transforms/stencil-bufferize.mlir | 71 ++++------- xdsl/dialects/stencil.py | 9 +- .../canonicalization_patterns/stencil.py | 10 +- xdsl/transforms/stencil_bufferize.py | 116 ++++++++++-------- 5 files changed, 218 insertions(+), 103 deletions(-) diff --git a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir index 9abed5e24e..383ce68368 100644 --- a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir +++ b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir @@ -1,6 +1,7 @@ // RUN: XDSL_ROUNDTRIP // RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference | filecheck %s --check-prefix SHAPE // RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference,convert-stencil-to-ll-mlir | filecheck %s --check-prefix MLIR +// RUN: xdsl-opt %s -p stencil-storage-materialization,stencil-shape-inference,stencil-bufferize | filecheck %s --check-prefix BUFF func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field, %arg2: !stencil.field, %arg3: !stencil.field, %arg4: !stencil.field, %arg5: !stencil.field, %arg6: !stencil.field) attributes {stencil.program} { %0 = stencil.cast %arg0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> @@ -547,3 +548,117 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field> // MLIR-NEXT: func.return // MLIR-NEXT: } + +// BUFF: func.func @fvtp2d_qi(%arg0 : !stencil.field, %arg1 : !stencil.field, %arg2 : !stencil.field, %arg3 : !stencil.field, %arg4 : !stencil.field, %arg5 : !stencil.field, %arg6 : !stencil.field) attributes {"stencil.program"}{ +// BUFF-NEXT: %0 = stencil.alloc : !stencil.field<[0,64]x[0,65]x[0,64]xf64> +// BUFF-NEXT: %1 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %2 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %3 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %4 = stencil.alloc : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %5 = stencil.alloc : !stencil.field<[0,64]x[-1,66]x[0,64]xf64> +// BUFF-NEXT: %6 = stencil.cast %arg0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %7 = stencil.cast %arg1 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %8 = stencil.cast %arg2 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %9 = stencil.cast %arg3 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %10 = stencil.cast %arg4 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %11 = stencil.cast %arg5 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %12 = stencil.cast %arg6 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%5 : !stencil.field<[0,64]x[-1,66]x[0,64]xf64>) { +// BUFF-NEXT: %cst = arith.constant 1.000000e+00 : f64 +// BUFF-NEXT: %cst_1 = arith.constant 7.000000e+00 : f64 +// BUFF-NEXT: %cst_2 = arith.constant 1.200000e+01 : f64 +// BUFF-NEXT: %13 = arith.divf %cst_1, %cst_2 : f64 +// BUFF-NEXT: %14 = arith.divf %cst, %cst_2 : f64 +// BUFF-NEXT: %15 = stencil.access %arg7[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %16 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %17 = arith.addf %15, %16 : f64 +// BUFF-NEXT: %18 = stencil.access %arg7[0, -2, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %19 = stencil.access %arg7[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %20 = arith.addf %18, %19 : f64 +// BUFF-NEXT: %21 = arith.mulf %13, %17 : f64 +// BUFF-NEXT: %22 = arith.mulf %14, %20 : f64 +// BUFF-NEXT: %23 = arith.addf %21, %22 : f64 +// BUFF-NEXT: %24 = stencil.store_result %23 : !stencil.result +// BUFF-NEXT: stencil.return %24 : !stencil.result +// BUFF-NEXT: } to <[0, -1, 0], [64, 66, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %5 : !stencil.field<[0,64]x[-1,66]x[0,64]xf64>) outs (%4 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %3 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %2 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %1 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>) { +// BUFF-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// BUFF-NEXT: %cst_1 = arith.constant 1.000000e+00 : f64 +// BUFF-NEXT: %13 = stencil.access %arg8[0, 0, 0] : !stencil.field<[0,64]x[-1,66]x[0,64]xf64> +// BUFF-NEXT: %14 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %15 = arith.subf %13, %14 : f64 +// BUFF-NEXT: %16 = stencil.access %arg8[0, 1, 0] : !stencil.field<[0,64]x[-1,66]x[0,64]xf64> +// BUFF-NEXT: %17 = arith.subf %16, %14 : f64 +// BUFF-NEXT: %18 = arith.addf %15, %17 : f64 +// BUFF-NEXT: %19 = arith.mulf %15, %17 : f64 +// BUFF-NEXT: %20 = arith.cmpf olt, %19, %cst : f64 +// BUFF-NEXT: %21 = arith.select %20, %cst_1, %cst : f64 +// BUFF-NEXT: %22 = stencil.store_result %15 : !stencil.result +// BUFF-NEXT: %23 = stencil.store_result %17 : !stencil.result +// BUFF-NEXT: %24 = stencil.store_result %18 : !stencil.result +// BUFF-NEXT: %25 = stencil.store_result %21 : !stencil.result +// BUFF-NEXT: stencil.return %22, %23, %24, %25 : !stencil.result, !stencil.result, !stencil.result, !stencil.result +// BUFF-NEXT: } to <[0, -1, 0], [64, 65, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %7 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg9 = %4 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %arg10 = %3 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %arg11 = %2 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>, %arg12 = %1 : !stencil.field<[0,64]x[-1,65]x[0,64]xf64>) outs (%12 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// BUFF-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// BUFF-NEXT: %cst_1 = arith.constant 1.000000e+00 : f64 +// BUFF-NEXT: %13 = stencil.access %arg12[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %14 = arith.cmpf oeq, %13, %cst : f64 +// BUFF-NEXT: %15 = arith.select %14, %cst_1, %cst : f64 +// BUFF-NEXT: %16 = stencil.access %arg12[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %17 = arith.mulf %16, %15 : f64 +// BUFF-NEXT: %18 = arith.addf %13, %17 : f64 +// BUFF-NEXT: %19 = stencil.access %arg8[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %20 = arith.cmpf ogt, %19, %cst : f64 +// BUFF-NEXT: %21 = "scf.if"(%20) ({ +// BUFF-NEXT: %22 = stencil.access %arg10[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %23 = stencil.access %arg11[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %24 = arith.mulf %19, %23 : f64 +// BUFF-NEXT: %25 = arith.subf %22, %24 : f64 +// BUFF-NEXT: %26 = arith.subf %cst_1, %19 : f64 +// BUFF-NEXT: %27 = arith.mulf %26, %25 : f64 +// BUFF-NEXT: scf.yield %27 : f64 +// BUFF-NEXT: }, { +// BUFF-NEXT: %28 = stencil.access %arg9[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %29 = stencil.access %arg11[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> +// BUFF-NEXT: %30 = arith.mulf %19, %29 : f64 +// BUFF-NEXT: %31 = arith.addf %28, %30 : f64 +// BUFF-NEXT: %32 = arith.addf %cst_1, %19 : f64 +// BUFF-NEXT: %33 = arith.mulf %32, %31 : f64 +// BUFF-NEXT: scf.yield %33 : f64 +// BUFF-NEXT: }) : (i1) -> f64 +// BUFF-NEXT: %34 = arith.mulf %21, %18 : f64 +// BUFF-NEXT: %35 = "scf.if"(%20) ({ +// BUFF-NEXT: %36 = stencil.access %arg7[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %37 = arith.addf %36, %34 : f64 +// BUFF-NEXT: scf.yield %37 : f64 +// BUFF-NEXT: }, { +// BUFF-NEXT: %38 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %39 = arith.addf %38, %34 : f64 +// BUFF-NEXT: scf.yield %39 : f64 +// BUFF-NEXT: }) : (i1) -> f64 +// BUFF-NEXT: %40 = stencil.store_result %35 : !stencil.result +// BUFF-NEXT: stencil.return %40 : !stencil.result +// BUFF-NEXT: } to <[0, 0, 0], [64, 65, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %9 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %12 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%0 : !stencil.field<[0,64]x[0,65]x[0,64]xf64>) { +// BUFF-NEXT: %13 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %14 = stencil.access %arg8[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %15 = arith.mulf %13, %14 : f64 +// BUFF-NEXT: %16 = stencil.store_result %15 : !stencil.result +// BUFF-NEXT: stencil.return %16 : !stencil.result +// BUFF-NEXT: } to <[0, 0, 0], [64, 65, 64]> +// BUFF-NEXT: stencil.apply(%arg7 = %6 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg8 = %10 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %arg9 = %0 : !stencil.field<[0,64]x[0,65]x[0,64]xf64>, %arg10 = %8 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%11 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// BUFF-NEXT: %13 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %14 = stencil.access %arg8[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %15 = arith.mulf %13, %14 : f64 +// BUFF-NEXT: %16 = stencil.access %arg9[0, 0, 0] : !stencil.field<[0,64]x[0,65]x[0,64]xf64> +// BUFF-NEXT: %17 = stencil.access %arg9[0, 1, 0] : !stencil.field<[0,64]x[0,65]x[0,64]xf64> +// BUFF-NEXT: %18 = arith.subf %16, %17 : f64 +// BUFF-NEXT: %19 = arith.addf %15, %18 : f64 +// BUFF-NEXT: %20 = stencil.access %arg10[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// BUFF-NEXT: %21 = arith.divf %19, %20 : f64 +// BUFF-NEXT: %22 = stencil.store_result %21 : !stencil.result +// BUFF-NEXT: stencil.return %22 : !stencil.result +// BUFF-NEXT: } to <[0, 0, 0], [64, 64, 64]> +// BUFF-NEXT: func.return +// BUFF-NEXT: } diff --git a/tests/filecheck/transforms/stencil-bufferize.mlir b/tests/filecheck/transforms/stencil-bufferize.mlir index 4ab32f9967..4a7287ba03 100644 --- a/tests/filecheck/transforms/stencil-bufferize.mlir +++ b/tests/filecheck/transforms/stencil-bufferize.mlir @@ -73,43 +73,43 @@ func.func @copy_1d(%0 : !stencil.field, %out : !stencil.field) { // CHECK-NEXT: func.return // CHECK-NEXT: } -func.func @copy_2d(%0 : !stencil.field) { +func.func @copy_2d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]xf64>) { %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]xf64> %2 = stencil.load %1 : !stencil.field<[-4,68]x[-4,68]xf64> -> !stencil.temp<[-1,64]x[0,68]xf64> %3 = stencil.apply(%4 = %2 : !stencil.temp<[-1,64]x[0,68]xf64>) -> (!stencil.temp<[0,64]x[0,68]xf64>) { %5 = stencil.access %4[-1, 0] : !stencil.temp<[-1,64]x[0,68]xf64> stencil.return %5 : f64 } + stencil.store %3 to %out (<[0, 0], [64, 68]>) : !stencil.temp<[0,64]x[0,68]xf64> to !stencil.field<[-4,68]x[-4,68]xf64> func.return } -// CHECK: func.func @copy_2d(%0 : !stencil.field) { +// CHECK: func.func @copy_2d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]xf64>) { // CHECK-NEXT: %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %2 = stencil.alloc : !stencil.field<[0,64]x[0,68]xf64> -// CHECK-NEXT: stencil.apply(%3 = %1 : !stencil.field<[-4,68]x[-4,68]xf64>) outs (%2 : !stencil.field<[0,64]x[0,68]xf64>) { -// CHECK-NEXT: %4 = stencil.access %3[-1, 0] : !stencil.field<[-4,68]x[-4,68]xf64> -// CHECK-NEXT: stencil.return %4 : f64 +// CHECK-NEXT: stencil.apply(%2 = %1 : !stencil.field<[-4,68]x[-4,68]xf64>) outs (%out : !stencil.field<[-4,68]x[-4,68]xf64>) { +// CHECK-NEXT: %3 = stencil.access %2[-1, 0] : !stencil.field<[-4,68]x[-4,68]xf64> +// CHECK-NEXT: stencil.return %3 : f64 // CHECK-NEXT: } to <[0, 0], [64, 68]> // CHECK-NEXT: func.return // CHECK-NEXT: } -func.func @copy_3d(%0 : !stencil.field) { +func.func @copy_3d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> %2 = stencil.load %1 : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> -> !stencil.temp<[-1,64]x[0,64]x[0,69]xf64> %3 = stencil.apply(%4 = %2 : !stencil.temp<[-1,64]x[0,64]x[0,69]xf64>) -> (!stencil.temp<[0,64]x[0,64]x[0,68]xf64>) { %5 = stencil.access %4[-1, 0, 1] : !stencil.temp<[-1,64]x[0,64]x[0,69]xf64> stencil.return %5 : f64 } + stencil.store %3 to %out (<[0, 0, 0], [64, 64, 68]>) : !stencil.temp<[0,64]x[0,64]x[0,68]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> func.return } -// CHECK: func.func @copy_3d(%0 : !stencil.field) { +// CHECK: func.func @copy_3d(%0 : !stencil.field, %out : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { // CHECK-NEXT: %1 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> -// CHECK-NEXT: %2 = stencil.alloc : !stencil.field<[0,64]x[0,64]x[0,68]xf64> -// CHECK-NEXT: stencil.apply(%3 = %1 : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64>) outs (%2 : !stencil.field<[0,64]x[0,64]x[0,68]xf64>) { -// CHECK-NEXT: %4 = stencil.access %3[-1, 0, 1] : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> -// CHECK-NEXT: stencil.return %4 : f64 +// CHECK-NEXT: stencil.apply(%2 = %1 : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64>) outs (%out : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// CHECK-NEXT: %3 = stencil.access %2[-1, 0, 1] : !stencil.field<[-4,68]x[-4,70]x[-4,72]xf64> +// CHECK-NEXT: stencil.return %3 : f64 // CHECK-NEXT: } to <[0, 0, 0], [64, 64, 68]> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -140,20 +140,19 @@ func.func @offsets(%0 : !stencil.field, %1 : !stencil.field, %1 : !stencil.field, %2 : !stencil.field) { // CHECK-NEXT: %3 = stencil.cast %0 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // CHECK-NEXT: %4 = stencil.cast %1 : !stencil.field -> !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %5 = stencil.alloc : !stencil.field<[0,64]x[0,64]x[0,64]xf64> -// CHECK-NEXT: stencil.apply(%6 = %3 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%4 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %5 : !stencil.field<[0,64]x[0,64]x[0,64]xf64>) { -// CHECK-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %9 = stencil.access %6[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %10 = stencil.access %6[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %11 = stencil.access %6[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -// CHECK-NEXT: %12 = arith.addf %7, %8 : f64 -// CHECK-NEXT: %13 = arith.addf %9, %10 : f64 -// CHECK-NEXT: %14 = arith.addf %12, %13 : f64 +// CHECK-NEXT: stencil.apply(%5 = %3 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%4 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { +// CHECK-NEXT: %6 = stencil.access %5[-1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %7 = stencil.access %5[1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %8 = stencil.access %5[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %9 = stencil.access %5[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %10 = stencil.access %5[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> +// CHECK-NEXT: %11 = arith.addf %6, %7 : f64 +// CHECK-NEXT: %12 = arith.addf %8, %9 : f64 +// CHECK-NEXT: %13 = arith.addf %11, %12 : f64 // CHECK-NEXT: %cst = arith.constant -4.000000e+00 : f64 -// CHECK-NEXT: %15 = arith.mulf %11, %cst : f64 -// CHECK-NEXT: %16 = arith.addf %15, %14 : f64 -// CHECK-NEXT: stencil.return %16, %15 : f64, f64 +// CHECK-NEXT: %14 = arith.mulf %10, %cst : f64 +// CHECK-NEXT: %15 = arith.addf %14, %13 : f64 +// CHECK-NEXT: stencil.return %15 : f64 // CHECK-NEXT: } to <[0, 0, 0], [64, 64, 64]> // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -342,28 +341,6 @@ func.func @stencil_init_index_offset(%0 : !stencil.field<[0,64]x[0,64]x[0,64]xin // CHECK-NEXT: func.return // CHECK-NEXT: } -func.func @store_result_lowering(%arg0 : f64) { - %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { - %2 = stencil.store_result %arg1 : !stencil.result - %3 = stencil.store_result %arg1 : !stencil.result - stencil.return %2, %3 : !stencil.result, !stencil.result - } - %2 = stencil.buffer %1 : !stencil.temp<[0,7]x[0,7]x[0,7]xf64> -> !stencil.temp<[0,7]x[0,7]x[0,7]xf64> - %3 = stencil.buffer %0 : !stencil.temp<[0,7]x[0,7]x[0,7]xf64> -> !stencil.temp<[0,7]x[0,7]x[0,7]xf64> - func.return -} - -// CHECK: func.func @store_result_lowering(%arg0 : f64) { -// CHECK-NEXT: %0 = stencil.alloc : !stencil.field<[0,7]x[0,7]x[0,7]xf64> -// CHECK-NEXT: %1 = stencil.alloc : !stencil.field<[0,7]x[0,7]x[0,7]xf64> -// CHECK-NEXT: stencil.apply(%arg1 = %arg0 : f64) outs (%0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) { -// CHECK-NEXT: %2 = stencil.store_result %arg1 : !stencil.result -// CHECK-NEXT: %3 = stencil.store_result %arg1 : !stencil.result -// CHECK-NEXT: stencil.return %2, %3 : !stencil.result, !stencil.result -// CHECK-NEXT: } to <[0, 0, 0], [7, 7, 7]> -// CHECK-NEXT: func.return -// CHECK-NEXT: } - func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { %true = "test.pureop"() : () -> i1 diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index ddaa172be5..a1b873ad9b 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -564,7 +564,6 @@ def get( body: Block | Region, result_types: Sequence[TempType[Attribute]], ): - assert len(result_types) > 0 if isinstance(body, Block): body = Region(body) @@ -647,6 +646,14 @@ def get_accesses(self) -> Iterable[AccessPattern]: accesses.append(offsets) yield AccessPattern(tuple(accesses)) + def get_bounds(self): + if self.bounds is not None: + return self.bounds + else: + assert len(self.res) > 0 + res_type = cast(TempType[Attribute], self.res[0].type) + return res_type.bounds + class AllocOpEffect(MemoryEffect): @classmethod diff --git a/xdsl/transforms/canonicalization_patterns/stencil.py b/xdsl/transforms/canonicalization_patterns/stencil.py index b3bed37913..8cee81a291 100644 --- a/xdsl/transforms/canonicalization_patterns/stencil.py +++ b/xdsl/transforms/canonicalization_patterns/stencil.py @@ -1,7 +1,7 @@ from typing import cast from xdsl.dialects import stencil -from xdsl.ir import Attribute, Block, SSAValue +from xdsl.ir import Attribute, Block, Region, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, @@ -95,8 +95,12 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> N results.pop(i) return_args.pop(i) - new = stencil.ApplyOp.get( - op.args, block, [cast(stencil.TempType[Attribute], r.type) for r in results] + new = stencil.ApplyOp.build( + operands=[op.args, op.dest], + regions=[Region(block)], + result_types=[[cast(stencil.TempType[Attribute], r.type) for r in results]], + properties=op.properties.copy(), + attributes=op.attributes.copy(), ) replace_results: list[SSAValue | None] = list(new.res) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index acdded1b8a..090d5f735a 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -13,6 +13,7 @@ FieldType, IndexAttr, LoadOp, + ReturnOp, StencilBoundsAttr, StoreOp, TempType, @@ -34,6 +35,7 @@ ) from xdsl.rewriter import InsertPoint from xdsl.traits import MemoryEffectKind, get_effects +from xdsl.transforms.canonicalization_patterns.stencil import ApplyUnusedResults from xdsl.transforms.dead_code_elimination import RemoveUnusedOperations from xdsl.utils.hints import isa @@ -90,16 +92,16 @@ class ApplyBufferizePattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): - if not op.res: + if all(not isinstance(o.type, TempType) for o in op.args): return - bounds = cast(TempType[Attribute], op.res[0].type).bounds + bounds = op.get_bounds() - dests = [ - AllocOp(result_types=[field_from_temp(cast(TempType[Attribute], r.type))]) - for r in op.res - ] - operands = [ + # dests = [ + # AllocOp(result_types=[field_from_temp(cast(TempType[Attribute], r.type))]) + # for r in op.res + # ] + args = [ ( BufferOp.create( operands=[o], @@ -108,17 +110,17 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): if isa(o.type, TempType[Attribute]) else o ) - for o in op.operands + for o in op.args ] - loads = [ - LoadOp(operands=[d], result_types=[r.type]) for d, r in zip(dests, op.res) - ] + # loads = [ + # LoadOp(operands=[d], result_types=[r.type]) for d, r in zip(dests, op.res) + # ] new = ApplyOp( - operands=[operands, dests], - regions=[Region(Block(arg_types=[SSAValue.get(a).type for a in operands]))], - result_types=[[]], + operands=[args, op.dest], + regions=[Region(Block(arg_types=[SSAValue.get(a).type for a in op.args]))], + result_types=[[r.type for r in op.res]], properties={"bounds": bounds}, ) rewriter.inline_block( @@ -128,8 +130,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): ) rewriter.replace_matched_op( - [*(o for o in operands if isinstance(o, Operation)), *dests, new, *loads], - [SSAValue.get(l) for l in loads], + [*(o for o in args if isinstance(o, Operation)), new] ) @@ -240,49 +241,43 @@ class ApplyLoadStoreFoldPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): - temp = op.temp + stored = op.temp # We are looking for a loaded destination of an apply - if not isinstance(load := temp.owner, LoadOp): + if not isinstance(apply := stored.owner, ApplyOp): return - infield = load.field - - other_uses = [u for u in infield.uses if u.operation is not load] - - if len(other_uses) != 1: - return - - # we restrict to the case where the apply and load are the only users of %temp - # for now - other_use = other_uses.pop() - if not isinstance( - apply := other_use.operation, ApplyOp - ) or other_use.index < len(apply.args): - print(other_use) - print() - return - - # Get first occurence of the destination field, to walk from it + # Check that the destination is not used between the apply and store. dest = op.field start = dest.owner if isinstance(start, Block): start = cast(Operation, start.first_op) effecting = [ o - for o in walk_from_to(start, op) - if might_effect(o, {MemoryEffectKind.READ}, dest) + for o in walk_from_to(apply, op) + if might_effect(o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest) ] if effecting: return - new_operands = list(apply.operands) - new_operands[other_use.index] = dest + temp_index = apply.results.index(stored) - new_apply = ApplyOp.create( - operands=new_operands, - result_types=[], - properties=apply.properties.copy(), + bounds = apply.get_bounds() + if not isinstance(bounds, StencilBoundsAttr): + raise ValueError( + "Stencil shape inference must be ran before bufferization." + ) + + new_apply = ApplyOp.build( + operands=[apply.args, [*apply.dest, dest]], + result_types=[ + [ + r.type + for r in apply.results[:temp_index] + + apply.results[temp_index + 1 :] + ] + ], + properties=apply.properties.copy() | {"bounds": bounds}, attributes=apply.attributes.copy(), regions=[ Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), @@ -295,16 +290,32 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): new_apply.region.block.args, ) - new_load = LoadOp.create( - operands=[dest], - result_types=[r.type for r in load.results], - attributes=load.attributes.copy(), - properties=load.properties.copy(), + old_return = new_apply.region.block.last_op + assert isinstance(old_return, ReturnOp) + uf = old_return.unroll_factor + new_return_args = list( + old_return.arg[: uf * temp_index] + + old_return.arg[uf * (temp_index + 1) :] + + old_return.arg[uf * temp_index : uf * (temp_index + 1)] + ) + new_return = ReturnOp.create( + operands=new_return_args, + properties=old_return.properties.copy(), + attributes=old_return.attributes.copy(), ) - rewriter.replace_op(apply, new_apply) - rewriter.replace_op(load, new_load) - rewriter.erase_op(op) + rewriter.replace_op(old_return, new_return) + + load = LoadOp.get(dest, bounds.lb, bounds.ub) + + rewriter.replace_op( + apply, + [new_apply, load], + new_apply.results[:temp_index] + + (load.res,) + + new_apply.results[temp_index:], + ) + rewriter.erase_matched_op() @dataclass(frozen=True) @@ -528,6 +539,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: LoadBufferFoldPattern(), ApplyLoadStoreFoldPattern(), RemoveUnusedOperations(), + ApplyUnusedResults(), ] ) ) From 65ae4b6cbde42bef6c6c2d00075c6b522b9725b2 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Thu, 8 Aug 2024 16:48:00 +0100 Subject: [PATCH 03/12] Fix up constructor. --- xdsl/dialects/stencil.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index a1b873ad9b..8c29b83f9a 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -562,16 +562,20 @@ def parse_operand(): def get( args: Sequence[SSAValue] | Sequence[Operation], body: Block | Region, - result_types: Sequence[TempType[Attribute]], + result_types: Sequence[TempType[Attribute]] | None = None, + bounds: StencilBoundsAttr | None = None, ): - + assert result_types or bounds if isinstance(body, Block): body = Region(body) + properties = {"bounds": bounds} if bounds else {} + return ApplyOp.build( operands=[list(args), []], regions=[body], result_types=[result_types], + properties=properties, ) def verify_(self) -> None: From dbf0fa11d744b84a86b1ec04deb8319072a9af19 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Aug 2024 13:41:12 +0100 Subject: [PATCH 04/12] Unused copypasta cleanup. --- xdsl/transforms/stencil_bufferize.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 090d5f735a..2f421d586b 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -46,14 +46,6 @@ def field_from_temp(temp: TempType[_TypeElement]) -> FieldType[_TypeElement]: return FieldType[_TypeElement].new(temp.parameters) -def only_has_effect(op: Operation, effect: MemoryEffectKind) -> bool: - """ - Returns if the operation has the given side effects and no others. - """ - effects = get_effects(op) - return effects is not None and all(e.kind == effect for e in effects) - - def might_effect( operation: Operation, effects: set[MemoryEffectKind], value: SSAValue ) -> bool: From b07b52b6e693838cedc6e3e0dc505f6299af0903 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Aug 2024 13:41:54 +0100 Subject: [PATCH 05/12] Comments cleanup. --- xdsl/transforms/stencil_bufferize.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 2f421d586b..58e58814fe 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -89,10 +89,6 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): bounds = op.get_bounds() - # dests = [ - # AllocOp(result_types=[field_from_temp(cast(TempType[Attribute], r.type))]) - # for r in op.res - # ] args = [ ( BufferOp.create( @@ -105,10 +101,6 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): for o in op.args ] - # loads = [ - # LoadOp(operands=[d], result_types=[r.type]) for d, r in zip(dests, op.res) - # ] - new = ApplyOp( operands=[args, op.dest], regions=[Region(Block(arg_types=[SSAValue.get(a).type for a in op.args]))], From 14ab642d196b1cc6af353fb303bd29ac46fbe8a5 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Aug 2024 13:48:33 +0100 Subject: [PATCH 06/12] Update doc. --- xdsl/transforms/stencil_bufferize.py | 33 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 58e58814fe..22939339b0 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -194,32 +194,24 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter): rewriter.replace_matched_op(new_ops=[], new_results=[underlying]) -class ApplyLoadStoreFoldPattern(RewritePattern): +class ApplyStoreFoldPattern(RewritePattern): """ - If an allocated field is only used by an apply to write its output and loaded - to be stored in a destination field, make the apply work on the destination directly. + Fold stores of applys result Example: ```mlir - %temp = stencil.alloc : !stencil.field<[0,32]> - stencil.apply() outs (%temp : !stencil.field<[0,32]>) { + %temp = stencil.apply() -> (!stencil.temp<[0,32]>) { // [...] } - // [... %dest not read] - %loaded = stencil.load %temp : !stencil.field<[0,32]> -> !stencil.temp<[0,32]> // [... %dest not read] - stencil.store %loaded to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]> + stencil.store %temp to %dest (<[0], [32]>) : !stencil.temp<[0,32]> to !stencil.field<[-2,34]> ``` yields: ```mlir - // Will be simplified away by the canonicalizer - %temp = stencil.alloc : !stencil.field<[0,32]> - // Outputs on dest - stencil.apply() outs (%dest : !stencil.field<[0,32]>) { + // Outputs on dest directly + stencil.apply() outs (%dest : !stencil.field<[-2,34]>) { // [...] } - // Load same values from %dest instead for next operations - %loaded = stencil.load %dest : !stencil.field<[0,32]> -> !stencil.temp<[0,32]> ``` """ @@ -227,7 +219,7 @@ class ApplyLoadStoreFoldPattern(RewritePattern): def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): stored = op.temp - # We are looking for a loaded destination of an apply + # We are looking for a stored result of an apply if not isinstance(apply := stored.owner, ApplyOp): return @@ -244,6 +236,7 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): if effecting: return + # Get the result index to help build the new apply temp_index = apply.results.index(stored) bounds = apply.get_bounds() @@ -253,7 +246,9 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): ) new_apply = ApplyOp.build( + # We add a destination, corresponding to the removed result operands=[apply.args, [*apply.dest, dest]], + # We only remove the considered result result_types=[ [ r.type @@ -263,17 +258,21 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): ], properties=apply.properties.copy() | {"bounds": bounds}, attributes=apply.attributes.copy(), + # The block signature is the same regions=[ Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), ], ) + # The body is the same rewriter.inline_block( apply.region.block, InsertPoint.at_start(new_apply.region.block), new_apply.region.block.args, ) + # We swap the return's operand order, to make sure the order still matches destinations + # after bufferization old_return = new_apply.region.block.last_op assert isinstance(old_return, ReturnOp) uf = old_return.unroll_factor @@ -287,9 +286,9 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): properties=old_return.properties.copy(), attributes=old_return.attributes.copy(), ) - rewriter.replace_op(old_return, new_return) + # Create a load of the destination, for any other user of the result load = LoadOp.get(dest, bounds.lb, bounds.ub) rewriter.replace_op( @@ -521,7 +520,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: BufferAlloc(), CombineStoreFold(), LoadBufferFoldPattern(), - ApplyLoadStoreFoldPattern(), + ApplyStoreFoldPattern(), RemoveUnusedOperations(), ApplyUnusedResults(), ] From aeee21abd2b8b5b08103a9feaeeadd12ede6fe35 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Aug 2024 16:01:40 +0100 Subject: [PATCH 07/12] Rewrite AppltStoreFoldPattern in terms of apply, not store. --- xdsl/transforms/stencil_bufferize.py | 155 ++++++++++++++------------- 1 file changed, 80 insertions(+), 75 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 22939339b0..a81d098a17 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -216,89 +216,94 @@ class ApplyStoreFoldPattern(RewritePattern): """ @op_type_rewrite_pattern - def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): - stored = op.temp - - # We are looking for a stored result of an apply - if not isinstance(apply := stored.owner, ApplyOp): - return + def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): + apply = op + for temp_index, stored in enumerate(op.res): + stores = [ + use.operation + for use in stored.uses + if isinstance(use.operation, StoreOp) + ] + if not stores: + continue - # Check that the destination is not used between the apply and store. - dest = op.field - start = dest.owner - if isinstance(start, Block): - start = cast(Operation, start.first_op) - effecting = [ - o - for o in walk_from_to(apply, op) - if might_effect(o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest) - ] - if effecting: - return + store = stores[0] + + # Check that the destination is not used between the apply and store. + dest = store.field + start = dest.owner + if isinstance(start, Block): + start = cast(Operation, start.first_op) + effecting = [ + o + for o in walk_from_to(apply, op) + if might_effect( + o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest + ) + ] + if effecting: + return - # Get the result index to help build the new apply - temp_index = apply.results.index(stored) + bounds = apply.get_bounds() + if not isinstance(bounds, StencilBoundsAttr): + raise ValueError( + "Stencil shape inference must be ran before bufferization." + ) - bounds = apply.get_bounds() - if not isinstance(bounds, StencilBoundsAttr): - raise ValueError( - "Stencil shape inference must be ran before bufferization." + new_apply = ApplyOp.build( + # We add a destination, corresponding to the removed result + operands=[apply.args, [*apply.dest, dest]], + # We only remove the considered result + result_types=[ + [ + r.type + for r in apply.results[:temp_index] + + apply.results[temp_index + 1 :] + ] + ], + properties=apply.properties.copy() | {"bounds": bounds}, + attributes=apply.attributes.copy(), + # The block signature is the same + regions=[ + Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), + ], ) - new_apply = ApplyOp.build( - # We add a destination, corresponding to the removed result - operands=[apply.args, [*apply.dest, dest]], - # We only remove the considered result - result_types=[ - [ - r.type - for r in apply.results[:temp_index] - + apply.results[temp_index + 1 :] - ] - ], - properties=apply.properties.copy() | {"bounds": bounds}, - attributes=apply.attributes.copy(), - # The block signature is the same - regions=[ - Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), - ], - ) - - # The body is the same - rewriter.inline_block( - apply.region.block, - InsertPoint.at_start(new_apply.region.block), - new_apply.region.block.args, - ) + # The body is the same + rewriter.inline_block( + apply.region.block, + InsertPoint.at_start(new_apply.region.block), + new_apply.region.block.args, + ) - # We swap the return's operand order, to make sure the order still matches destinations - # after bufferization - old_return = new_apply.region.block.last_op - assert isinstance(old_return, ReturnOp) - uf = old_return.unroll_factor - new_return_args = list( - old_return.arg[: uf * temp_index] - + old_return.arg[uf * (temp_index + 1) :] - + old_return.arg[uf * temp_index : uf * (temp_index + 1)] - ) - new_return = ReturnOp.create( - operands=new_return_args, - properties=old_return.properties.copy(), - attributes=old_return.attributes.copy(), - ) - rewriter.replace_op(old_return, new_return) + # We swap the return's operand order, to make sure the order still matches destinations + # after bufferization + old_return = new_apply.region.block.last_op + assert isinstance(old_return, ReturnOp) + uf = old_return.unroll_factor + new_return_args = list( + old_return.arg[: uf * temp_index] + + old_return.arg[uf * (temp_index + 1) :] + + old_return.arg[uf * temp_index : uf * (temp_index + 1)] + ) + new_return = ReturnOp.create( + operands=new_return_args, + properties=old_return.properties.copy(), + attributes=old_return.attributes.copy(), + ) + rewriter.replace_op(old_return, new_return) - # Create a load of the destination, for any other user of the result - load = LoadOp.get(dest, bounds.lb, bounds.ub) + # Create a load of the destination, for any other user of the result + load = LoadOp.get(dest, bounds.lb, bounds.ub) - rewriter.replace_op( - apply, - [new_apply, load], - new_apply.results[:temp_index] - + (load.res,) - + new_apply.results[temp_index:], - ) - rewriter.erase_matched_op() + rewriter.replace_matched_op( + [new_apply, load], + new_apply.results[:temp_index] + + (load.res,) + + new_apply.results[temp_index:], + ) + rewriter.erase_op(store) + return @dataclass(frozen=True) From 2a3df6fcdf92a5f676e4d21907420cc0f408bed1 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Aug 2024 16:08:51 +0100 Subject: [PATCH 08/12] Fold all stores if multiple. --- xdsl/transforms/stencil_bufferize.py | 47 +++++++++++++++------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index a81d098a17..37f97a2872 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -215,35 +215,34 @@ class ApplyStoreFoldPattern(RewritePattern): ``` """ + @staticmethod + def is_dest_safe(apply: ApplyOp, store: StoreOp) -> bool: + # Check that the destination is not used between the apply and store. + dest = store.field + start = dest.owner + if isinstance(start, Block): + start = cast(Operation, start.first_op) + effecting = [ + o + for o in walk_from_to(apply, store) + if might_effect(o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest) + ] + return not effecting + @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): apply = op + for temp_index, stored in enumerate(op.res): stores = [ use.operation for use in stored.uses if isinstance(use.operation, StoreOp) + and self.is_dest_safe(apply, use.operation) ] if not stores: continue - store = stores[0] - - # Check that the destination is not used between the apply and store. - dest = store.field - start = dest.owner - if isinstance(start, Block): - start = cast(Operation, start.first_op) - effecting = [ - o - for o in walk_from_to(apply, op) - if might_effect( - o, {MemoryEffectKind.READ, MemoryEffectKind.WRITE}, dest - ) - ] - if effecting: - return - bounds = apply.get_bounds() if not isinstance(bounds, StencilBoundsAttr): raise ValueError( @@ -252,7 +251,10 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): new_apply = ApplyOp.build( # We add a destination, corresponding to the removed result - operands=[apply.args, [*apply.dest, dest]], + operands=[ + apply.args, + (*apply.dest, *(store.field for store in stores)), + ], # We only remove the considered result result_types=[ [ @@ -284,7 +286,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): new_return_args = list( old_return.arg[: uf * temp_index] + old_return.arg[uf * (temp_index + 1) :] - + old_return.arg[uf * temp_index : uf * (temp_index + 1)] + + old_return.arg[uf * temp_index : uf * (temp_index + 1)] * len(stores) ) new_return = ReturnOp.create( operands=new_return_args, @@ -293,8 +295,8 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): ) rewriter.replace_op(old_return, new_return) - # Create a load of the destination, for any other user of the result - load = LoadOp.get(dest, bounds.lb, bounds.ub) + # Create a load of a destination, for any other user of the result + load = LoadOp.get(stores[0].field, bounds.lb, bounds.ub) rewriter.replace_matched_op( [new_apply, load], @@ -302,7 +304,8 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): + (load.res,) + new_apply.results[temp_index:], ) - rewriter.erase_op(store) + for store in stores: + rewriter.erase_op(store) return From 31c4b4b9f2b12391d5cbf089d20d92abecc51d6f Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 9 Aug 2024 16:14:35 +0100 Subject: [PATCH 09/12] Comments. --- xdsl/transforms/stencil_bufferize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 37f97a2872..c500c101b5 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -232,8 +232,8 @@ def is_dest_safe(apply: ApplyOp, store: StoreOp) -> bool: @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): apply = op - for temp_index, stored in enumerate(op.res): + # We are looking for a result that is stored and foldable stores = [ use.operation for use in stored.uses @@ -250,7 +250,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): ) new_apply = ApplyOp.build( - # We add a destination, corresponding to the removed result + # We add new destinations for each store of the removed result operands=[ apply.args, (*apply.dest, *(store.field for store in stores)), From f7ea30b7197f76d9e55da345445e32c2002c08ee Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 12 Aug 2024 11:05:38 +0100 Subject: [PATCH 10/12] Update docstrings. --- xdsl/transforms/stencil_bufferize.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index c500c101b5..17bffe1fe9 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -49,6 +49,9 @@ def field_from_temp(temp: TempType[_TypeElement]) -> FieldType[_TypeElement]: def might_effect( operation: Operation, effects: set[MemoryEffectKind], value: SSAValue ) -> bool: + """ + Return True if the operation might have any of the given effects on the given value. + """ op_effects = get_effects(operation) return op_effects is None or any( e.kind in effects and e.value in (None, value) for e in op_effects @@ -59,11 +62,8 @@ class ApplyBufferizePattern(RewritePattern): """ Naive partial `stencil.apply` bufferization. - Just replace all operands with the field result of a stencil.buffer on them, meaning - "The buffer those value are allocated to"; and allocate buffers for every result, - loading them back after the apply, to keep types fine with users. - - Point is to fold as much as possible all the allocations and loads. + Just replace all temp arguments with the field result of a stencil.buffer on them, meaning + "The buffer those value are allocated to". Example: ```mlir @@ -74,11 +74,9 @@ class ApplyBufferizePattern(RewritePattern): yields: ```mlir %in_buf = stencil.buffer %in : !stencil.temp<[0,32]xf64> -> !stencil.field<[0,32]xf64> - %out_buf = stencil.alloc : !stencil.field<[0,32]>xf64 stencil.apply(%0 = %in_buf : !stencil.field<[0,32]>xf64) outs (%out_buf : !stencil.field<[0,32]>xf64) { // [...] } - %out = stencil.load %out_buf : !stencil.field<[0,32]>xf64 -> !stencil.temp<[0,32]>xf64 ``` """ From d27e544ceed9d04bb22a86507a950e07212a2ba8 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 16 Aug 2024 14:04:42 +0100 Subject: [PATCH 11/12] Update xdsl/dialects/stencil.py Co-authored-by: Sasha Lopoukhine --- xdsl/dialects/stencil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 7dfcfb8761..52b30003af 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -653,7 +653,7 @@ def get_bounds(self): if self.bounds is not None: return self.bounds else: - assert len(self.res) > 0 + assert self.res res_type = cast(TempType[Attribute], self.res[0].type) return res_type.bounds From 01d874b58f1239bdcd894c6ad69378f0fda675b9 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 16 Aug 2024 16:43:55 +0100 Subject: [PATCH 12/12] Empty tuple --- xdsl/dialects/stencil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index d8fd634f1a..ce57a85c77 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -561,7 +561,7 @@ def parse_operand(): def get( args: Sequence[SSAValue] | Sequence[Operation], body: Block | Region, - result_types: Sequence[TempType[Attribute]] | None = None, + result_types: Sequence[TempType[Attribute]] = (), bounds: StencilBoundsAttr | None = None, ): assert result_types or bounds