Skip to content

Commit

Permalink
wip: save work
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Mar 6, 2024
1 parent 662b796 commit 75a657c
Showing 1 changed file with 113 additions and 43 deletions.
156 changes: 113 additions & 43 deletions grudge/pytato_transforms/pytato_indirection_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ def _is_materialized(expr: Array) -> bool:


def _can_index_lambda_propagate_indirections_without_changing_axes(
expr: IndexLambda) -> bool:

expr: IndexLambda, iel_axis: Optional[int], idof_axis: Optional[int]
) -> bool:
"""
Returns *True* only if the axes being reindexed appear at the same
positions in the bindings' indexing locations.
"""
from pytato.utils import are_shapes_equal
from pytato.raising import (index_lambda_to_high_level_op,
BinaryOp)
Expand Down Expand Up @@ -219,8 +223,8 @@ def _fuse_from_element_indices(from_element_indices: Tuple[Array, ...]):
return result


def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], from_element_indices:
Tuple[Array, ...]):
def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...],
from_element_indices: Tuple[Array, ...]):
assert all(from_el_idx.ndim == 2 for from_el_idx in from_element_indices)
assert all(dof_pick_list.ndim == 2 for dof_pick_list in dof_pick_lists)
assert all(from_el_idx.shape[1] == 1 for from_el_idx in from_element_indices)
Expand All @@ -239,7 +243,10 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array,
from_element_indices: Tuple[Array, ...],
dof_pick_lists: Tuple[Array, ...]
) -> Array:

raise NotImplementedError("We still need to port this from"
" the previous version, where only"
" indirections only along the element"
" axes.")
if iel_axis is not None:
assert idof_axis is not None
assert len(from_element_indices) != 0
Expand All @@ -263,6 +270,56 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array,
return rec_expr


def _is_iel_idof_picking(expr: AdvancedIndexInContiguousAxes,
iel_axis: Optional[int],
idof_axis: Optional[int],
) -> bool:
if expr.ndim != 2:
return False

if expr.array.ndim != 2:
return False

if not ((iel_axis is None and idof_axis is None)
or (iel_axis == 0 and idof_axis == 1)):
return False

if (isinstance(expr.indices[0], Array)
and isinstance(expr.indices[1], Array)):
from pytato.utils import are_shape_components_equal
from_el_indices, dof_pick_lists = expr.indices
assert isinstance(from_el_indices, Array)
assert isinstance(dof_pick_lists, Array)

if dof_pick_lists.ndim != 1:
return False
if from_el_indices.ndim != 2:
return False
if are_shape_components_equal(from_el_indices.shape[1], 1):
return False

return True
else:
return False


def _is_iel_only_picking(expr: AdvancedIndexInContiguousAxes,
iel_axis: Optional[int]) -> bool:
if expr.ndim != 1:
return False

if expr.array.ndim != 1:
return False

if not isinstance(expr.indices[0], Array):
return False

if iel_axis not in [0, None]:
return False

return True


class PickListFusers(Mapper):
def __init__(self) -> None:
self.can_pick_indirections_be_propagated = _CanPickIndirectionsBePropagated()
Expand All @@ -283,18 +340,22 @@ def rec(self, # type: ignore[override]
" is illegal for PickListFusers. Pass arrays"
" instead.")

if iel_axis is not None:
assert idof_axis is not None
if idof_axis is not None:
assert iel_axis is not None
assert 0 <= iel_axis < expr.ndim
assert 0 <= idof_axis < expr.ndim
# the condition below ensures that we are only dealing with indirections
# appearing at contiguous locations.
assert abs(iel_axis-idof_axis) == 1
else:
assert len(dof_pick_lists) == len(from_element_indices)
elif iel_axis is not None:
assert idof_axis is None
assert len(dof_pick_lists) == 0
assert len(from_element_indices) > 0
else:
assert iel_axis is None
assert len(from_element_indices) == 0

assert len(dof_pick_lists) == len(from_element_indices)
assert len(dof_pick_lists) == 0

key = (expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists)
try:
Expand All @@ -318,8 +379,8 @@ def __call__(self, # type: ignore[override]

def _map_input_base(self,
expr: InputArgumentBase,
iel_axis: int,
idof_axis: int,
iel_axis: Optional[int],
idof_axis: Optional[int],
from_element_indices: Tuple[Array, ...],
dof_pick_lists: Tuple[Array, ...]) -> Array:
return _pick_list_fusers_map_materialized_node(
Expand Down Expand Up @@ -351,30 +412,36 @@ def map_index_lambda(self,
rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists)

if iel_axis is not None:
assert idof_axis is not None
assert _can_index_lambda_propagate_indirections_without_changing_axes(
expr)
from pytato.utils import are_shapes_equal
new_el_dim, new_dofs_dim = dof_pick_lists[0].shape
assert are_shapes_equal(from_element_indices[0].shape, (new_el_dim, 1))

new_shape = tuple(
new_el_dim if idim == iel_axis else (
new_dofs_dim if idim == idof_axis else dim)
for idim, dim in enumerate(expr.shape))

return IndexLambda(
expr.expr,
new_shape,
expr.dtype,
Map({name: self.rec(bnd, iel_axis, idof_axis,
from_element_indices,
dof_pick_lists)
for name, bnd in expr.bindings.items()}),
var_to_reduction_descr=expr.var_to_reduction_descr,
tags=expr.tags,
axes=expr.axes
)
expr, iel_axis, idof_axis)
if idof_axis is None:
# TODO: Not encountered any practical DAGs that take this code path.
# Implement this branch only if seen in any practical applications.
raise NotImplementedError
else:
assert idof_axis is not None
from pytato.utils import are_shapes_equal
new_el_dim, new_dofs_dim = dof_pick_lists[0].shape
assert are_shapes_equal(from_element_indices[0].shape,
(new_el_dim, 1))

new_shape = tuple(
new_el_dim if idim == iel_axis else (
new_dofs_dim if idim == idof_axis else dim)
for idim, dim in enumerate(expr.shape))

return IndexLambda(
expr.expr,
new_shape,
expr.dtype,
Map({name: self.rec(bnd, iel_axis, idof_axis,
from_element_indices,
dof_pick_lists)
for name, bnd in expr.bindings.items()}),
var_to_reduction_descr=expr.var_to_reduction_descr,
tags=expr.tags,
axes=expr.axes
)
else:
return IndexLambda(
expr.expr,
Expand Down Expand Up @@ -405,14 +472,17 @@ def map_contiguous_advanced_index(self,
return _pick_list_fusers_map_materialized_node(
rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists)

if self.can_pick_indirections_be_propagated(expr,
iel_axis or 0,
idof_axis or 1):
idx1, idx2 = expr.indices
assert isinstance(idx1, Array) and isinstance(idx2, Array)
return self.rec(expr.array, 0, 1,
from_element_indices + (idx1,),
dof_pick_lists + (idx2,))
if (_is_iel_idof_picking(expr, iel_axis, idof_axis)
and self.can_pick_indirections_be_propagated(expr,
iel_axis or 0,
idof_axis or 1)):
raise NotImplementedError
elif (_is_iel_only_picking(expr, iel_axis)
and self.can_pick_indirections_be_propagated(expr,
iel_axis or 0,
None)):
assert idof_axis is None
raise NotImplementedError
else:
assert iel_axis is None and idof_axis is None
return AdvancedIndexInContiguousAxes(
Expand Down

0 comments on commit 75a657c

Please sign in to comment.