From 8ae14c20a3f1e57c3e7ba6b0f8bcb7ce4227e55f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 13 Dec 2023 12:43:35 +0100 Subject: [PATCH] Implement vectorize_node for CheckAndRaise Op --- pytensor/raise_op.py | 19 ++++++++++++ tests/test_raise_op.py | 68 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index 25e1aebf52..52d674b801 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -6,6 +6,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Variable +from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import Generic @@ -198,3 +199,21 @@ def __str__(self): assert_op = Assert() + + +@_vectorize_node.register(CheckAndRaise) +def vectorize_check_and_raise(op, node, batch_x, batch_cond): + from pytensor.tensor.extra_ops import broadcast_arrays + from pytensor.tensor.shape import shape_padright + + batch_cond_dims = batch_cond.type.ndim + + if batch_cond_dims: + out = op(batch_x, batch_cond.all()) + # Condition may broadcast batch dims of x + # We broadcast after the Check Op, so it can be removed more easily if not needed + x_core_ndim = node.inputs[0].type.ndim + batch_out, _ = broadcast_arrays(out, shape_padright(batch_cond, x_core_ndim)) + return batch_out.owner + else: + return op.make_node(batch_x, batch_cond) diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index 2cd1cc830f..e41fec28aa 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -5,10 +5,13 @@ import pytensor import pytensor.tensor as pt from pytensor.compile.mode import OPT_FAST_RUN, Mode +from pytensor.graph import vectorize_graph from pytensor.graph.basic import Constant, equal_computations from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.scalar.basic import ScalarType, float64 from pytensor.sparse import as_sparse_variable +from pytensor.tensor.basic import second +from pytensor.tensor.elemwise import DimShuffle from tests import unittest_tools as utt @@ -184,3 +187,68 @@ def test_CheckAndRaise_sparse_variable(): a2 = check_and_raise(aspe1, aspe2.sum() > 2) with pytest.raises(ValueError, match="sparse_check"): a2.sum().eval() + + +@pytensor.config.change_flags(cxx="") # For speed-up +def test_vectorize(): + floatX = pytensor.config.floatX + x = pt.vector("x") + y = pt.vector("y") + cond = pt.all(y >= 0) + out = assert_op(x, cond) + + batch_x = pt.matrix("batch_x", shape=(2, None)) + batch_y = pt.matrix("batch_y", shape=(2, None)) + + test_x = np.arange(3).astype(floatX) + test_y = np.arange(4).astype(floatX) + test_batch_x = np.arange(6).reshape(2, 3).astype(floatX) + test_batch_y = np.arange(8).reshape(2, 4).astype(floatX) + + # Only x is batched + vect_out = vectorize_graph(out, {x: batch_x, y: y}) + assert vect_out.type.shape == (2, None) + assert isinstance(vect_out.owner.op, CheckAndRaise) + np.testing.assert_array_equal( + vect_out.eval({batch_x: test_batch_x, y: test_y}), + test_batch_x, + ) + with pytest.raises(AssertionError): + vect_out.eval({batch_x: test_batch_x, y: -test_y}) + + # Only y is batched + vect_out = vectorize_graph(out, {x: x, y: batch_y}) + assert vect_out.type.shape == (2, None) + assert vect_out.owner.op == second # broadcast + assert isinstance(vect_out.owner.inputs[1].owner.op, DimShuffle) + assert isinstance(vect_out.owner.inputs[1].owner.inputs[0].owner.op, CheckAndRaise) + np.testing.assert_array_equal( + vect_out.eval({x: test_x, batch_y: test_batch_y}), + np.broadcast_to(test_x, (2, *test_x.shape)), + ) + with pytest.raises(AssertionError): + vect_out.eval({x: test_x, batch_y: -test_batch_y}) + + # Both x, and y are batched + vect_out = vectorize_graph(out, {x: batch_x, y: batch_y}) + assert vect_out.type.shape == (2, None) + assert vect_out.owner.op == second + assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise) + np.testing.assert_array_equal( + vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}), + test_batch_x, + ) + with pytest.raises(AssertionError): + vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y}) + + # Both x, and y are batched and broadcast each other + vect_out = vectorize_graph(out, {x: batch_x[:, None, :], y: batch_y[None, :, :]}) + assert vect_out.type.shape == (2, 2, None) + assert vect_out.owner.op == second + assert isinstance(vect_out.owner.inputs[1].owner.op, CheckAndRaise) + np.testing.assert_array_equal( + vect_out.eval({batch_x: test_batch_x, batch_y: test_batch_y}), + np.broadcast_to(test_batch_x[:, None, :], (2, *test_batch_x.shape)), + ) + with pytest.raises(AssertionError): + vect_out.eval({batch_x: test_batch_x, batch_y: -test_batch_y})