Skip to content

Commit

Permalink
Generalize and simplify local_reduce_join
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent b2c6258 commit 2086aeb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 71 deletions.
90 changes: 39 additions & 51 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
Expand Down Expand Up @@ -1628,68 +1629,55 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
@node_rewriter([CAReduce])
def local_reduce_join(fgraph, node):
"""
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
Notes
-----
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
all cases.
Currently we must reduce on axis 0. It is probably extensible to the case
where we join and reduce on the same set of axis.
When a, b have a dim length of 1 along the join axis
"""
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
join_node = node.inputs[0].owner
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
return None

if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum):
# Support only 2 inputs for now
if len(join_node.inputs) != 3:
return
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
return
elif len(join_node.inputs) <= 2:
# This is a useless join that should get removed by another rewrite?
return
[joined_out] = node.inputs
joined_node = joined_out.owner
join_axis_tensor, *joined_inputs = joined_node.inputs

new_inp = []
for inp in join_node.inputs[1:]:
inp = inp.owner
if not inp:
return
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != (
"x",
*range(inp.inputs[0].ndim),
):
return
new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp)
n_joined_inputs = len(joined_inputs)
if n_joined_inputs < 2:
# Let some other rewrite get rid of this useless Join
return None
if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
# We don't rewrite if a single Elemwise cannot take all inputs at once
return None

if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return
if not isinstance(join_axis_tensor, Constant):
return None
join_axis = join_axis_tensor.data

reduce_axis = node.op.axis
if reduce_axis is None:
reduce_axis = tuple(range(node.inputs[0].ndim))
# Check whether reduction happens on joined axis
reduce_op = node.op
reduce_axis = reduce_op.axis
if reduce_axis is None:
if joined_out.type.ndim > 1:
return None
elif reduce_axis != (join_axis,):
return None

if len(reduce_axis) != 1 or 0 not in reduce_axis:
return
# Check all inputs are broadcastable along the join axis and squeeze those dims away
new_inputs = []
for inp in joined_inputs:
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
new_inputs.append(new_input)

# We add the new check late to don't add extra warning.
try:
join_axis = get_underlying_scalar_constant_value(
join_node.inputs[0], only_process_constants=True
)
ret = Elemwise(node.op.scalar_op)(*new_inputs)

if join_axis != reduce_axis[0]:
return
except NotScalarConstantError:
return
if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return None

return [ret]
return [ret]


@register_infer_shape
Expand Down
72 changes: 52 additions & 20 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3231,7 +3231,7 @@ def test_local_prod_of_div(self):
class TestLocalReduce:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
"canonicalize", "specialize", "uncanonicalize"
)

def test_local_reduce_broadcast_all_0(self):
Expand Down Expand Up @@ -3304,62 +3304,94 @@ def test_local_reduce_broadcast_some_1(self):
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)

def test_local_reduce_join(self):

class TestReduceJoin:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize"
)

@pytest.mark.parametrize(
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
)
def test_local_reduce_join(self, op, nin):
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(pt_max((vx, vy), 0), np.max((x, y), 0)),
(pt_min((vx, vy), 0), np.min((x, y), 0)),
(pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out

inputs = (vx, vy, vz)[:nin]
test_values = (x, y, z)[:nin]

out = op(inputs, axis=0)
f = function(inputs, out, mode=self.mode)
np.testing.assert_allclose(
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
)
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2
assert isinstance(topo[-1].op, Elemwise)

def test_type(self):
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))

f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)

# Test a case that was bugged in a old PyTensor bug
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)

utt.assert_allclose(f(), [15, 15])
np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)

# This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)

A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)

def test_not_supported_axis_none(self):
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)

out = pt_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out)
f = function([vx, vy, vz], out, mode=self.mode)
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))

def test_not_supported_unequal_shapes(self):
# Not the same shape along the join axis
vx = matrix(shape=(1, 3))
vy = matrix(shape=(2, 3))
x = np.asarray([[1, 0, 1]], dtype=config.floatX)
y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX)
out = pt_sum(join(0, vx, vy), axis=0)

f = function([vx, vy], out, mode=self.mode)
np.testing.assert_allclose(
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
)


def test_local_useless_adds():
Expand Down

0 comments on commit 2086aeb

Please sign in to comment.