Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Enable dmp.swap stencil bufferization. #3066

Merged
merged 3 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/filecheck/transforms/distribute-stencil.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: xdsl-opt %s -p "distribute-stencil{strategy=3d-grid slices=2,2,2}" | filecheck %s
// RUN: xdsl-opt %s -p "distribute-stencil{strategy=3d-grid slices=2,2,2},shape-inference" | filecheck %s --check-prefix SHAPE
// RUN: xdsl-opt %s -p "distribute-stencil{strategy=3d-grid slices=2,2,2},shape-inference,stencil-bufferize" | filecheck %s --check-prefix BUFF

func.func @offsets(%27 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %28 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %29 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
%33 = stencil.load %27 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<?x?x?xf64>
Expand Down Expand Up @@ -66,6 +67,25 @@
// SHAPE-NEXT: func.return
// SHAPE-NEXT: }

// BUFF: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// BUFF-NEXT: "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange<at [32, 0, 0] size [1, 32, 32] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 32, 32] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 32, 0] size [32, 1, 32] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [32, 1, 32] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) -> ()
// BUFF-NEXT: stencil.apply(%3 = %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) outs (%1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// BUFF-NEXT: %4 = stencil.access %3[-1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %5 = stencil.access %3[1, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %6 = stencil.access %3[0, 1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %7 = stencil.access %3[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %8 = stencil.access %3[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// BUFF-NEXT: %9 = arith.addf %4, %5 : f64
// BUFF-NEXT: %10 = arith.addf %6, %7 : f64
// BUFF-NEXT: %11 = arith.addf %9, %10 : f64
// BUFF-NEXT: %cst = arith.constant -4.000000e+00 : f64
// BUFF-NEXT: %12 = arith.mulf %8, %cst : f64
// BUFF-NEXT: %13 = arith.addf %12, %11 : f64
// BUFF-NEXT: stencil.return %13, %12 : f64, f64
// BUFF-NEXT: } to <[0, 0, 0], [32, 32, 32]>
// BUFF-NEXT: func.return
// BUFF-NEXT: }

func.func @trivial_externals(%dyn_mem : memref<?x?x?xf64>, %sta_mem : memref<64x64x64xf64>, %dyn_field : !stencil.field<?x?x?xf64>, %sta_field : !stencil.field<[-2,62]x[0,64]x[2,66]xf64>) {
stencil.external_store %dyn_field to %dyn_mem : !stencil.field<?x?x?xf64> to memref<?x?x?xf64>
stencil.external_store %sta_field to %sta_mem : !stencil.field<[-2,62]x[0,64]x[2,66]xf64> to memref<64x64x64xf64>
Expand Down
32 changes: 29 additions & 3 deletions xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import ABC
from collections.abc import Iterable, Sequence
from math import prod
from typing import Literal
from typing import Literal, cast

from xdsl.dialects import builtin, stencil
from xdsl.ir import Attribute, Dialect, Operation, ParametrizedAttribute, SSAValue
Expand All @@ -29,7 +29,12 @@
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.traits import HasShapeInferencePatternsTrait
from xdsl.traits import (
EffectInstance,
HasShapeInferencePatternsTrait,
MemoryEffect,
MemoryEffectKind,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

Expand Down Expand Up @@ -594,6 +599,27 @@ def get_shape_inference_patterns(cls):
return (DmpSwapShapeInference(), DmpSwapSwapsInference())


class SwapOpMemoryEffect(MemoryEffect):
"""
Side effect implementation of dmp.swap.
"""

@classmethod
def get_effects(cls, op: Operation) -> set[EffectInstance]:
op = cast(SwapOp, op)
# If it's operating in value-semantic mode, it has no side effects.
if op.swapped_values:
return set()
# If it's operating in reference-semantic mode, it reads and writes to its field.
# TODO: consider the empty swaps case at some point.
# Right now, it relies on it before inferring them, so not very safe.
# But it could be an elegant way to generically simplify those.
return {
EffectInstance(MemoryEffectKind.WRITE, op.input_stencil),
EffectInstance(MemoryEffectKind.READ, op.input_stencil),
}


@irdl_op_definition
class SwapOp(IRDLOperation):
"""
Expand All @@ -609,7 +635,7 @@ class SwapOp(IRDLOperation):

strategy = attr_def(DomainDecompositionStrategy)

traits = frozenset([SwapOpHasShapeInferencePatterns()])
traits = frozenset([SwapOpHasShapeInferencePatterns(), SwapOpMemoryEffect()])

def verify_(self) -> None:
if self.swapped_values:
Expand Down
50 changes: 43 additions & 7 deletions xdsl/transforms/stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects.experimental.dmp import SwapOp
from xdsl.dialects.stencil import (
AllocOp,
ApplyOp,
Expand Down Expand Up @@ -166,17 +167,19 @@ def match_and_rewrite(self, op: BufferOp, rewriter: PatternRewriter):

underlying = load.field

# TODO: propery analysis of effects in between
# For illustration, only fold a single use of the handle
# (Requires more boilerplate to analyse the whole live range otherwise)
# TODO: further analysis
# For now, only handle usages in the same block
uses = op.res.uses.copy()
if len(uses) > 1:
block = op.parent
if not block or any(use.operation.parent is not block for use in uses):
return
user = uses.pop().operation
last_user = max(
uses, key=lambda u: block.get_operation_index(u.operation)
).operation

effecting = [
o
for o in walk_from_to(load, user)
for o in walk_from_to(load, last_user)
if might_effect(o, {MemoryEffectKind.WRITE}, underlying)
]
if effecting:
Expand Down Expand Up @@ -498,6 +501,37 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter):
return


class SwapBufferize(RewritePattern):
"""
Bufferize a dmp.swap operation.

NB: This should most likely consider a shared pass following canonicalize and
shape-inference.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: SwapOp, rewriter: PatternRewriter):
temp = op.input_stencil

if not isa(temp_t := temp.type, TempType[Attribute]):
return

load = temp.owner
if not isinstance(load, LoadOp):
return

buffer = BufferOp.create(
operands=[temp], result_types=[field_from_temp(temp_t)]
)
new_swap = SwapOp.get(buffer.res, op.strategy)
new_swap.swaps = op.swaps
load = LoadOp(operands=[buffer.res], result_types=[temp_t])

rewriter.replace_matched_op(
new_ops=[buffer, new_swap, load],
)


@dataclass(frozen=True)
class StencilBufferize(ModulePass):
"""
Expand All @@ -520,7 +554,9 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
ApplyStoreFoldPattern(),
RemoveUnusedOperations(),
ApplyUnusedResults(),
SwapBufferize(),
]
)
),
apply_recursively=True,
)
walker.rewrite_module(op)
Loading