Skip to content

Commit

Permalink
compiler: fix cse with different conditionals
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 15, 2024
1 parent 3024584 commit 99fb87e
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 17 deletions.
2 changes: 1 addition & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def detect_accesses(exprs):
other_dims = set()
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.implicit_dims)
other_dims.update(e.implicit_dims or {})
other_dims = filter_sorted(other_dims)
mapper[None] = Stencil([(i, 0) for i in other_dims])

Expand Down
40 changes: 25 additions & 15 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sympy.core.basic import ordering_of_classes

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass
from devito.ir import Cluster, Scope, cluster_pass, ClusterizedEq
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
Expand Down Expand Up @@ -90,12 +90,13 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):

while True:
# Detect redundancies
counted = count(processed).items()
targets = OrderedDict([(k, estimate_cost(k, True)) for k, v in counted if v > 1])
counted = count(processed, None).items()
targets = OrderedDict([(k, estimate_cost(k[0], True))
for k, v in counted if v > 1])

# Rule out Dimension-independent data dependencies
targets = OrderedDict([(k, v) for k, v in targets.items()
if not k.free_symbols & exclude])
if not k[0].free_symbols & exclude])

if not targets or max(targets.values()) < min_cost:
break
Expand All @@ -111,7 +112,10 @@ def _cse(maybe_exprs, make, min_cost=1, mode='default'):
updated = []
for e in processed:
pe = e
for k, v in chosen:
pe_c = e.conditionals
for (k, c), v in chosen:
if not c == pe_c:
continue
pe, changed = _uxreplace(pe, {k: v})
if changed and v not in scheduled:
updated.append(pe.func(v, k, operation=None))
Expand Down Expand Up @@ -156,53 +160,59 @@ def _compact_temporaries(exprs, exclude):


@singledispatch
def count(expr):
def count(expr, conds):
"""
Construct a mapper `expr -> #occurrences` for each sub-expression in `expr`.
"""
mapper = Counter()
for a in expr.args:
mapper.update(count(a))
mapper.update(count(a, None))
return mapper


@count.register(list)
@count.register(tuple)
def _(exprs):
def _(exprs, conds):
mapper = Counter()
for e in exprs:
mapper.update(count(e))
mapper.update(count(e, None))
return mapper


@count.register(ClusterizedEq)
def _(exprs, conds):
conditionals = exprs.conditionals
return count(exprs.rhs, conditionals)


@count.register(Indexed)
@count.register(Symbol)
def _(expr):
def _(expr, conds):
"""
Handler for objects preventing CSE to propagate through their arguments.
"""
return Counter()


@count.register(IndexDerivative)
def _(expr):
def _(expr, conds):
"""
Handler for symbol-binding objects. There can be many of them and therefore
they should be detected as common subexpressions, but it's either pointless
or forbidden to look inside them.
"""
return Counter([expr])
return Counter([(expr, conds)])


@count.register(Add)
@count.register(Mul)
@count.register(Pow)
@count.register(Function)
def _(expr):
def _(expr, conds):
mapper = Counter()
for a in expr.args:
mapper.update(count(a))
mapper.update(count(a, conds))

mapper[expr] += 1
mapper[(expr, conds)] += 1

return mapper
4 changes: 4 additions & 0 deletions devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def substitutions(self):
def implicit_dims(self):
return self._implicit_dims

@property
def conditionals(self):
return None

@cached_property
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, Ge)
sin, sqrt, Ge, Lt)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.finite_differences.differentiable import diffify
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
Expand Down Expand Up @@ -191,6 +191,55 @@ def test_cse_w_conditionals():
assert len(FindNodes(Conditional).visit(op)) == 1


def test_cse_w_multi_conditionals():
grid = Grid(shape=(10, 10, 10))
x, _, _ = grid.dimensions

cd = ConditionalDimension(name='cd', parent=x, condition=Ge(x, 4),
indirect=True)

cd2 = ConditionalDimension(name='cd2', parent=x, condition=Lt(x, 4),
indirect=True)

f = Function(name='f', grid=grid)
g = Function(name='g', grid=grid)
h = Function(name='h', grid=grid)
a0 = Function(name='a0', grid=grid)
a1 = Function(name='a1', grid=grid)
a2 = Function(name='a2', grid=grid)
a3 = Function(name='a3', grid=grid)

eq0 = Eq(h, a0, implicit_dims=cd)
eq1 = Eq(a0, a0 + f*g, implicit_dims=cd)
eq2 = Eq(a1, a1 + f*g, implicit_dims=cd)
eq3 = Eq(a2, a2 + f*g, implicit_dims=cd2)
eq4 = Eq(a3, a3 + f*g, implicit_dims=cd2)

op = Operator([eq0, eq1, eq3])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 0

op = Operator([eq0, eq1, eq3, eq4])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 1

op = Operator([eq0, eq1, eq2, eq3, eq4])

assert_structure(op, ['x,y,z'], 'xyz')
assert len(FindNodes(Conditional).visit(op)) == 2

tmps = [s for s in FindSymbols().visit(op) if s.name.startswith('r')]
assert len(tmps) == 2


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down

0 comments on commit 99fb87e

Please sign in to comment.