Skip to content

Commit

Permalink
Implement vectorize_node for CheckAndRaise Op
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 15, 2023
1 parent 31a4df6 commit 8ae14c2
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
19 changes: 19 additions & 0 deletions pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic
Expand Down Expand Up @@ -198,3 +199,21 @@ def __str__(self):


assert_op = Assert()


@_vectorize_node.register(CheckAndRaise)
def vectorize_check_and_raise(op, node, batch_x, batch_cond):
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.shape import shape_padright

batch_cond_dims = batch_cond.type.ndim

if batch_cond_dims:
out = op(batch_x, batch_cond.all())
# Condition may broadcast batch dims of x
# We broadcast after the Check Op, so it can be removed more easily if not needed
x_core_ndim = node.inputs[0].type.ndim
batch_out, _ = broadcast_arrays(out, shape_padright(batch_cond, x_core_ndim))
return batch_out.owner
else:
return op.make_node(batch_x, batch_cond)
68 changes: 68 additions & 0 deletions tests/test_raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import OPT_FAST_RUN, Mode
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Constant, equal_computations
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import ScalarType, float64
from pytensor.sparse import as_sparse_variable
from pytensor.tensor.basic import second
from pytensor.tensor.elemwise import DimShuffle
from tests import unittest_tools as utt


Expand Down Expand Up @@ -184,3 +187,68 @@ def test_CheckAndRaise_sparse_variable():
a2 = check_and_raise(aspe1, aspe2.sum() > 2)
with pytest.raises(ValueError, match="sparse_check"):
a2.sum().eval()


@pytensor.config.change_flags(cxx="") # For speed-up
def test_vectorize():
floatX = pytensor.config.floatX
x = pt.vector("x")
y = pt.vector("y")
cond = pt.all(y >= 0)
out = assert_op(x, cond)

batch_x = pt.matrix("batch_x", shape=(2, None))
batch_y = pt.matrix("batch_y", shape=(2, None))

test_x = np.arange(3).astype(floatX)
test_y = np.arange(4).astype(floatX)
test_batch_x = np.arange(6).reshape(2, 3).astype(floatX)
test_batch_y = np.arange(8).reshape(2, 4).astype(floatX)

# Only x is batched
vect_out = vectorize_graph(out, {x: batch_x, y: y})
assert vect_out.type.shape == (2, None)
assert isinstance(vect_out.owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({batch_x: test_batch_x, y: test_y}),
test_batch_x,
)
with pytest.raises(AssertionError):
vect_out.eval({batch_x: test_batch_x, y: -test_y})

# Only y is batched
vect_out = vectorize_graph(out, {x: x, y: batch_y})
assert vect_out.type.shape == (2, None)
assert vect_out.owner.op == second # broadcast
assert isinstance(vect_out.owner.inputs[1].owner.op, DimShuffle)
assert isinstance(vect_out.owner.inputs[1].owner.inputs[0].owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({x: test_x, batch_y: test_batch_y}),
np.broadcast_to(test_x, (2, *test_x.shape)),
)
with pytest.raises(AssertionError):
vect_out.eval({x: test_x, batch_y: -test_batch_y})

# Both x, and y are batched
vect_out = vectorize_graph(out, {x: batch_x, y: batch_y})
assert vect_out.type.shape == (2, None)
assert vect_out.owner.op == second
assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
test_batch_x,
)
with pytest.raises(AssertionError):
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})

# Both x, and y are batched and broadcast each other
vect_out = vectorize_graph(out, {x: batch_x[:, None, :], y: batch_y[None, :, :]})
assert vect_out.type.shape == (2, 2, None)
assert vect_out.owner.op == second
assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise)
np.testing.assert_array_equal(
vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}),
np.broadcast_to(test_batch_x[:, None, :], (2, *test_batch_x.shape)),
)
with pytest.raises(AssertionError):
vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})

0 comments on commit 8ae14c2

Please sign in to comment.