Skip to content

Commit

Permalink
Generalize and rename local_reduce_chain
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent 5b9c07e commit b2c6258
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 182 deletions.
74 changes: 42 additions & 32 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@
values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan,
)
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)


def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
Expand Down Expand Up @@ -1575,42 +1579,48 @@ def local_sum_prod_all_to_none(fgraph, node):


@register_canonicalize
@node_rewriter([Sum, Prod])
def local_op_of_op(fgraph, node):
@node_rewriter([CAReduce])
def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
"""
Prod(Prod()) -> single Prod()
or
Sum(Sum()) -> single Sum()
or any CAReduce(Careduce(x)) of the same type
"""
op_type = Sum if isinstance(node.op, Sum) else Prod
(node_inps,) = node.inputs
out_dtype = node.op.dtype
# This is done to make sure the rewrite doesn't affect other
# computations.
if len(fgraph.clients[node_inps]) == 1:
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None:
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]

# figure out which axes were in the original sum
newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)

assert len(newaxis) == len(
list(node_inps.owner.op.axis) + list(node.op.axis)
)
[inner_reduce] = node.inputs
if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)):
return None

# Don't apply rewrite if inner_reduce is used elsewhere
if len(fgraph.clients[inner_reduce]) > 1:
return None

# Check if CAReduces have the same scalar op
outer_op: CAReduce = node.op
inner_op = inner_reduce.owner.op

if outer_op.scalar_op != inner_op.scalar_op:
return None

combined = op_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
outer_axis = outer_op.axis
inner_axis = inner_op.axis
[x] = inner_reduce.owner.inputs
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if outer_axis is None or inner_axis is None:
return [outer_op.clone(axis=None)(x)]

# Merge axis
newaxis = list(inner_axis)
for i in outer_axis:
new_i = i
for ii in inner_axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)

assert len(newaxis) == len(inner_axis) + len(outer_axis)
return [outer_op.clone(axis=sorted(newaxis))(x)]


@register_canonicalize
Expand Down
Loading

0 comments on commit b2c6258

Please sign in to comment.