Skip to content

Commit

Permalink
Fix unroll_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Jan 27, 2025
1 parent 9134b56 commit b59ba56
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 31 deletions.
28 changes: 3 additions & 25 deletions src/gt4py/next/iterator/transforms/unroll_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,6 @@ def _get_connectivity(
return connectivities[0]


def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall:
return im.shift(
offsets, iterator
) # TODO: location? # TODO test_unroll_reduce failing because of offsets


def _make_deref(iterator: itir.Expr) -> itir.FunCall:
return im.deref(iterator) # TODO: location


def _make_can_deref(iterator: itir.Expr) -> itir.FunCall:
return im.can_deref(iterator) # TODO: location? # TODO test_unroll_reduce failing


def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall:
return im.if_(cond, true_expr, false_expr) # TODO: location?


def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall:
return im.list_get(offset, expr) # TODO: location?


@dataclasses.dataclass(frozen=True)
class UnrollReduce(PreserveLocationVisitor, NodeTranslator):
# we use one UID generator per instance such that the generated ids are
Expand All @@ -131,13 +109,13 @@ def _visit_reduce(
assert isinstance(node.fun, itir.FunCall)
fun, init = node.fun.args

elems = [_make_list_get(offset, arg) for arg in node.args]
elems = [im.list_get(offset, arg) for arg in node.args]
step_fun: itir.Expr = itir.FunCall(fun=fun, args=[acc, *elems])
if has_skip_values:
check_arg = next(_get_neighbors_args(node.args))
offset_tag, it = check_arg.args
can_deref = _make_can_deref(_make_shift([offset_tag, offset], it))
step_fun = _make_if(can_deref, step_fun, acc)
can_deref = im.can_deref(im.shift(offset_tag, offset)(it))
step_fun = im.if_(can_deref, step_fun, acc)
step_fun = itir.Lambda(params=[itir.Sym(id=acc.id), itir.Sym(id=offset.id)], expr=step_fun)
expr = init
for i in range(max_neighbors):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,8 @@ def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0):
if has_skip_values:
neighbors_offset = red.args[shifted_arg].args[0]
neighbors_it = red.args[shifted_arg].args[1]
can_deref = im.can_deref(
ir.FunCall(
fun=im.shift(neighbors_offset, offset),
args=[neighbors_it],
)
)
can_deref = im.can_deref(im.shift(neighbors_offset, offset)(neighbors_it))

step_expr = im.if_(can_deref, step_expr, acc)
step_fun = ir.Lambda(params=[ir.Sym(id=acc.id), ir.Sym(id=offset.id)], expr=step_expr)

Expand Down

0 comments on commit b59ba56

Please sign in to comment.