From a26e46b32b621e8cda1d4b1fb61ba14065aa3a76 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 4 Jul 2023 17:02:40 +0200 Subject: [PATCH] Forbid runtime broadcasting in Elemwise --- pytensor/link/jax/dispatch/elemwise.py | 12 +- .../link/numba/dispatch/elemwise_codegen.py | 13 +- pytensor/tensor/__init__.py | 1 + pytensor/tensor/elemwise.py | 52 ++-- pytensor/tensor/elemwise_cgen.py | 101 ++++---- pytensor/tensor/extra_ops.py | 2 +- tests/link/jax/test_elemwise.py | 6 + tests/link/numba/test_elemwise.py | 6 + tests/tensor/rewriting/test_basic.py | 7 +- tests/tensor/rewriting/test_math.py | 7 +- tests/tensor/rewriting/test_shape.py | 13 +- tests/tensor/test_elemwise.py | 231 +++++++++--------- 12 files changed, 233 insertions(+), 218 deletions(-) diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py index 39ef836b6a..7750607dcb 100644 --- a/pytensor/link/jax/dispatch/elemwise.py +++ b/pytensor/link/jax/dispatch/elemwise.py @@ -7,9 +7,17 @@ @jax_funcify.register(Elemwise) -def jax_funcify_Elemwise(op, **kwargs): +def jax_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op - return jax_funcify(scalar_op, **kwargs) + base_fn = jax_funcify(scalar_op, node=node, **kwargs) + + def elemwise_fn(*inputs): + # ScalarVariables in JAX are passed as int/float. + # We wrap them in arrays just for the broadcast check + Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs))) + return base_fn(*inputs) + + return elemwise_fn @jax_funcify.register(CAReduce) diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index 3138110046..a57292aff2 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -35,15 +35,20 @@ def compute_itershape( with builder.if_then( builder.icmp_unsigned("!=", length, shape[i]), likely=False ): - with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( + with builder.if_else( + builder.or_( + builder.icmp_unsigned("==", length, one), + builder.icmp_unsigned("==", shape[i], one), + ) + ) as ( then, otherwise, ): with then: msg = ( - f"Incompatible shapes for input {j} and axis {i} of " - f"elemwise. Input {j} has shape 1, but is not statically " - "known to have shape 1, and thus not broadcastable." + "Runtime broadcasting not allowed. " + "One input had a distinct dimension length of 1, but was not marked as broadcastable.\n" + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) with otherwise: diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 2e84db13e9..dfe74b5b2f 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -132,6 +132,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: shape_padaxis, shape_padleft, shape_padright, + specify_broadcastable, specify_shape, ) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 0687d30f10..6ec6952e14 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -6,7 +6,7 @@ import pytensor.tensor.basic from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, Constant from pytensor.graph.null_type import NullType from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code @@ -19,9 +19,9 @@ from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import transfer_type, upcast -from pytensor.tensor import _get_vector_length, as_tensor_variable from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import _get_vector_length, as_tensor_variable from pytensor.tensor.type import ( TensorType, continuous_dtypes, @@ -740,9 +740,7 @@ def perform(self, node, inputs, output_storage): # FIXME: This no longer calls the C implementation! super().perform(node, inputs, output_storage) - for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): - if len(set(dim_shapes) - {1}) > 1: - raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}") + self._check_runtime_broadcast(node, inputs) ufunc_args = inputs ufunc_kwargs = {} @@ -818,18 +816,40 @@ def perform(self, node, inputs, output_storage): else: storage[0] = variable - def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: - if len(node.outputs) > 1: - from pytensor.tensor.exceptions import ShapeError - - raise ShapeError( - "Multiple outputs are not supported by the default `Elemwise.infer_shape`" - ) + @staticmethod + def _check_runtime_broadcast(node, inputs): + for dims_and_bcast in zip( + *[ + zip(input.shape, sinput.type.broadcastable) + for input, sinput in zip(inputs, node.inputs) + ] + ): + if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: + raise ValueError( + "Runtime broadcasting not allowed. " + "At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n" + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) - out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) + def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: + one = pytensor.tensor.basic.constant(1, dtype="int64") + output_shape = [] + for dim, broadcastable in enumerate(node.outputs[0].type.broadcastable): + out_dim_length = one + if not broadcastable: + # There must be some input that is not broadcastable in this dim + for inp_shape, inp_var in zip(i_shapes, node.inputs): + if not inp_var.type.broadcastable[dim]: + # Give preference to constant dims + if isinstance(inp_shape[dim], Constant): + out_dim_length = inp_shape[dim] + break + # If we haven't yet seen a non-broadcastable dim, use this one + if out_dim_length is one: + out_dim_length = inp_shape[dim] + output_shape.append(as_tensor_variable(out_dim_length, dtype="int64")) - # The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s - return [tuple(as_tensor_variable(s) for s in out_shape)] + return [tuple(output_shape)] * len(node.outputs) def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` @@ -1193,7 +1213,7 @@ def c_support_code_apply(self, node, nodename): return support_code def c_code_cache_version_apply(self, node): - version = [14] # the version corresponding to the c code in this Op + version = [15] # the version corresponding to the c code in this Op # now we insert versions for the ops on which we depend... scalar_node = Apply( diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py index 18106f082d..85d4a93c76 100644 --- a/pytensor/tensor/elemwise_cgen.py +++ b/pytensor/tensor/elemwise_cgen.py @@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub): if index != "x": # Initialize the variables associated to the jth loop # jump = stride - adjust - # If the variable has size 1 in that dim, we set the stride to zero to - # emulate broadcasting jump = f"({var}_stride{index}) - ({adjust})" init += f""" {var}_n{index} = PyArray_DIMS({var})[{index}]; - {var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype}); + {var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype}); {var}_jump{index}_{j} = {jump}; """ adjust = f"{var}_n{index}*{var}_stride{index}" @@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub): # This loop builds multiple if conditions to verify that the # dimensions of the inputs match, and the first one that is true # raises an informative error message + + runtime_broadcast_error_msg = ( + "Runtime broadcasting not allowed. " + "One input had a distinct dimension length of 1, but was not marked as broadcastable: " + "(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). " + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + for matches in zip(*loop_orders): to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] @@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub): if len(to_compare) < 2: continue - # Find first dimension size that is != 1 - jl, xl = to_compare[-1] - non1size_dim_check = f""" - npy_intp non1size_dim{xl}; - non1size_dim{xl} = """ - for j, x in to_compare[:-1]: - non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : " - non1size_dim_check += f"%(lv{jl})s_n{xl};" - check += non1size_dim_check - - # Check the nonsize1 dims match - # TODO: This is a bit inefficient because we are comparing one dimension against itself - check += f""" - if (non1size_dim{xl} != 1) - {{ - """ - for j, x in to_compare: + j0, x0 = to_compare[0] + for j, x in to_compare[1:]: check += f""" - if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1)) + if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x}) + {{ + if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1) {{ - PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.", - {x}, - (long long int) non1size_dim{x}, + PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}", + {j0}, + {x0}, + (long long int) %(lv{j0})s_n{x0}, + {j}, + {x}, + (long long int) %(lv{j})s_n{x} + ); + }} else {{ + PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)", + {j0}, + {x0}, + (long long int) %(lv{j0})s_n{x0}, {j}, {x}, (long long int) %(lv{j})s_n{x} ); - %(fail)s }} - """ - check += """ - } + %(fail)s + }} """ return init % sub + check % sub -def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str: - """Create c_code to compute broadcasted dimensions of multiple arrays, arising from - Elemwise operations. +def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str: + """Create c_code to compute the output dimensions of an Elemwise operation. The code returned by this function populates the array `array_name`, but does not initialize it. - TODO: We can decide to either specialize C code even further given the input types - or make it general, regardless of whether static broadcastable information is given + Note: We could specialize C code even further with the known static output shapes """ dims_c_code = "" for i, candidates in enumerate(zip(*loop_orders)): - # TODO: Are candidates always either "x" or "i"? If that's the case we can - # simplify some logic here (e.g., we don't need to track the `idx`). - nonx_candidates = tuple( - (idx, c) for idx, c in enumerate(candidates) if c != "x" - ) - - # All inputs are known to be broadcastable - if not nonx_candidates: + # Borrow the length of the first non-broadcastable input dimension + for j, candidate in enumerate(candidates): + if candidate != "x": + var = sub[f"lv{int(j)}"] + dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n" + break + # If none is non-broadcastable, the output dimension has a length of 1 + else: # no-break dims_c_code += f"{array_name}[{i}] = 1;\n" - continue - - # There is only one informative source of size - if len(nonx_candidates) == 1: - idx, candidate = nonx_candidates[0] - var = sub[f"lv{int(idx)}"] - dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n" - continue - # In this case any non-size 1 variable will define the right size - dims_c_code += f"{array_name}[{i}] = " - for idx, candidate in nonx_candidates[:-1]: - var = sub[f"lv{int(idx)}"] - dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: " - idx, candidate = nonx_candidates[-1] - var = sub[f"lv{idx}"] - dims_c_code += f"{var}_n{candidate};\n" return dims_c_code @@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): if type.startswith("PYTENSOR_COMPLEX"): type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX") nd = len(loop_orders[0]) - init_dims = compute_broadcast_dimensions("dims", loop_orders, sub) + init_dims = compute_output_dims_lengths("dims", loop_orders, sub) # TODO: it would be interesting to allocate the output in such a # way that its contiguous dimensions match one of the input's @@ -359,7 +342,7 @@ def make_reordered_loop( # Get the (sorted) total number of iterations of each loop declare_totals = f"int init_totals[{nnested}];\n" - declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub) + declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub) # Sort totals to match the new order that was computed by sorting # the loop vector. One integer variable per loop is declared. diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 078a18a6a9..26f1faa3e7 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): _broadcast_assert = Assert( "Could not broadcast dimensions. Broadcasting is only allowed along " - "axes that have a statically known length 1. Use `specify_shape` to " + "axes that have a statically known length 1. Use `specify_broadcastable` to " "inform PyTensor of a known shape." ) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 0f903a33b2..dab0a750b9 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -4,6 +4,7 @@ import pytensor import pytensor.tensor as at +from pytensor.compile import get_mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value @@ -14,6 +15,11 @@ from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, vector from tests.link.jax.test_basic import compare_jax_and_py +from tests.tensor.test_elemwise import TestElemwise + + +def test_elemwise_runtime_shape_error(): + TestElemwise.check_runtime_shapes_error(get_mode("JAX")) def test_jax_Dimshuffle(): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 2e5bb5c1e4..01e61b6df5 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -3,12 +3,14 @@ import numpy as np import pytest import scipy.special +from tensor.test_elemwise import TestElemwise import pytensor import pytensor.tensor as at import pytensor.tensor.inplace as ati import pytensor.tensor.math as aem from pytensor import config, function +from pytensor.compile import get_mode from pytensor.compile.ops import deep_copy_op from pytensor.compile.sharedvalue import SharedVariable from pytensor.gradient import grad @@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): compare_numba_and_py(out_fg, input_vals) +def test_elemwise_runtime_shape_error(): + TestElemwise.check_runtime_shapes_error(get_mode("NUMBA")) + + def test_elemwise_speed(benchmark): x = at.dmatrix("y") y = at.dvector("z") diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index fe2b795907..570d7f50dc 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1671,12 +1671,7 @@ def verify_op_count(f, count, cls): (), (), ), - pytest.param( - lambda x, y: at.mul(y, at.alloc(1, x)), - (), - (), - marks=pytest.mark.xfail(reason="Not implemented"), - ), + (lambda x, y: at.mul(y, at.alloc(1, x)), (), ()), (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), ( diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 185641e4fb..1ef5335ce3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -607,11 +607,10 @@ def test_mul_div_cases(self): ((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"), ((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"), - # must broadcast as their is a dimshuffle in the computation - # The broadcast leads to an extra elemwise to check compatibility - ((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"), + # must broadcast as there is a dimshuffle in the computation + ((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"), # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(), Alloc] - ((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"), + ((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"), # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(), Alloc] ] ): diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 8c8c8a7baa..75be70f130 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -415,25 +415,20 @@ def test_vector(self): fgraph.attach_feature(shape_feature) assert shape_feature.same_shape(x, o) - def test_no_static_shapes(self): + def test_vector2(self): x = vector() y = vector() o = x + y fgraph = FunctionGraph([x, y], [o], clone=False) shape_feature = ShapeFeature() fgraph.attach_feature(shape_feature) - # We no longer assume that `x` has the same shape as `y` simply because - # neither has static shape information. Instead, when there is no - # static shape information is available, we assume that `x` and/or `y` - # could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any - # combination of the two. - assert not shape_feature.same_shape(x, o) + assert shape_feature.same_shape(x, o) # The following case isn't implemented assert not shape_feature.same_shape(y, o) @pytest.mark.parametrize( "y_dim_0", - [2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))], + [2, None], ) def test_vector_dim(self, y_dim_0): x = at.tensor(dtype="floatX", shape=(2, None)) @@ -443,7 +438,7 @@ def test_vector_dim(self, y_dim_0): shape_feature = ShapeFeature() fgraph.attach_feature(shape_feature) assert shape_feature.same_shape(x, o, 0, 0) - assert not shape_feature.same_shape(x, o, 1, 1) + assert shape_feature.same_shape(x, o, 1, 1) def test_vector_dim_err(self): x = vector() diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index fa820f062a..52f34c7be0 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -18,7 +18,6 @@ from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.exceptions import ShapeError from pytensor.tensor.math import all as at_all from pytensor.tensor.math import any as at_any from pytensor.tensor.math import exp @@ -216,117 +215,93 @@ def rand_cval(self, shp): return np.asarray(np.random.random(shp), dtype=pytensor.config.floatX) def with_linker(self, linker, op, type, rand_val): - for shape_info in ("complete", "only_broadcastable", "none"): - for xsh, ysh in [ - ((3, 5), (3, 5)), - ((3, 5), (1, 5)), - ((3, 5), (3, 1)), - ((1, 5), (5, 1)), - ((1, 1), (1, 1)), - ((self.openmp_minsize,), (self.openmp_minsize,)), - ( - (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), - (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), - ), - ((2, 3, 4, 5), (2, 3, 4, 5)), - ((2, 3, 4, 5), (1, 3, 1, 5)), - ((2, 3, 4, 5), (1, 1, 1, 1)), - ((), ()), - ]: - if shape_info == "complete": - x_type = type(pytensor.config.floatX, shape=xsh) - y_type = type(pytensor.config.floatX, shape=ysh) - elif shape_info == "only_broadcastable": - # This condition is here for backwards compatibility, when the only - # type shape provided by PyTensor was broadcastable/non-broadcastable - x_type = type( - pytensor.config.floatX, - shape=tuple(s if s == 1 else None for s in xsh), - ) - y_type = type( - pytensor.config.floatX, - shape=tuple(s if s == 1 else None for s in ysh), - ) - else: - x_type = type(pytensor.config.floatX, shape=[None for _ in xsh]) - y_type = type(pytensor.config.floatX, shape=[None for _ in ysh]) + for xsh, ysh in [ + ((3, 5), (3, 5)), + ((3, 5), (1, 5)), + ((3, 5), (3, 1)), + ((1, 5), (5, 1)), + ((1, 1), (1, 1)), + ((self.openmp_minsize,), (self.openmp_minsize,)), + ( + (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), + (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), + ), + ((2, 3, 4, 5), (2, 3, 4, 5)), + ((2, 3, 4, 5), (1, 3, 1, 5)), + ((2, 3, 4, 5), (1, 1, 1, 1)), + ((), ()), + ]: + x_type = type( + pytensor.config.floatX, + shape=tuple(s if s == 1 else None for s in xsh), + ) + y_type = type( + pytensor.config.floatX, + shape=tuple(s if s == 1 else None for s in ysh), + ) + + x = x_type("x") + y = y_type("y") + e = op(aes.add)(x, y) + f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) + xv = rand_val(xsh) + yv = rand_val(ysh) + zv = xv + yv + + unittest_tools.assert_allclose(f(xv, yv), zv) + # test Elemwise.infer_shape + # the Shape op don't implement c_code! + if isinstance(linker, PerformLinker): x = x_type("x") y = y_type("y") e = op(aes.add)(x, y) - f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) - xv = rand_val(xsh) - yv = rand_val(ysh) - zv = xv + yv + f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape]))) + assert tuple(f(xv, yv)) == tuple(zv.shape) - unittest_tools.assert_allclose(f(xv, yv), zv) + def with_linker_inplace(self, linker, op, type, rand_val): + for xsh, ysh in [ + ((5, 5), (5, 5)), + ((5, 5), (1, 5)), + ((5, 5), (5, 1)), + ((1, 1), (1, 1)), + ((2, 3, 4, 5), (2, 3, 4, 5)), + ((2, 3, 4, 5), (1, 3, 1, 5)), + ((2, 3, 4, 5), (1, 1, 1, 1)), + ((), ()), + ]: + x_type = type( + pytensor.config.floatX, + shape=tuple(s if s == 1 else None for s in xsh), + ) + y_type = type( + pytensor.config.floatX, + shape=tuple(s if s == 1 else None for s in ysh), + ) - # test Elemwise.infer_shape - # the Shape op don't implement c_code! - if isinstance(linker, PerformLinker): - x = x_type("x") - y = y_type("y") - e = op(aes.add)(x, y) - f = make_function( - copy(linker).accept(FunctionGraph([x, y], [e.shape])) - ) - assert tuple(f(xv, yv)) == tuple(zv.shape) + x = x_type("x") + y = y_type("y") + e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y) + f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) + xv = rand_val(xsh) + yv = rand_val(ysh) + zv = xv + yv - def with_linker_inplace(self, linker, op, type, rand_val): - for shape_info in ("complete", "only_broadcastable", "none"): - for xsh, ysh in [ - ((5, 5), (5, 5)), - ((5, 5), (1, 5)), - ((5, 5), (5, 1)), - ((1, 1), (1, 1)), - ((2, 3, 4, 5), (2, 3, 4, 5)), - ((2, 3, 4, 5), (1, 3, 1, 5)), - ((2, 3, 4, 5), (1, 1, 1, 1)), - ((), ()), - ]: - if shape_info == "complete": - x_type = type(pytensor.config.floatX, shape=xsh) - y_type = type(pytensor.config.floatX, shape=ysh) - elif shape_info == "only_broadcastable": - # This condition is here for backwards compatibility, when the only - # type shape provided by PyTensor was broadcastable/non-broadcastable - x_type = type( - pytensor.config.floatX, - shape=tuple(s if s == 1 else None for s in xsh), - ) - y_type = type( - pytensor.config.floatX, - shape=tuple(s if s == 1 else None for s in ysh), - ) - else: - x_type = type(pytensor.config.floatX, shape=[None for _ in xsh]) - y_type = type(pytensor.config.floatX, shape=[None for _ in ysh]) + f(xv, yv) + assert (xv == zv).all() + # test Elemwise.infer_shape + # the Shape op don't implement c_code! + if isinstance(linker, PerformLinker): x = x_type("x") y = y_type("y") e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y) - f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) + f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape]))) xv = rand_val(xsh) yv = rand_val(ysh) zv = xv + yv - - f(xv, yv) - - assert (xv == zv).all() - # test Elemwise.infer_shape - # the Shape op don't implement c_code! - if isinstance(linker, PerformLinker): - x = x_type("x") - y = y_type("y") - e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y) - f = make_function( - copy(linker).accept(FunctionGraph([x, y], [e.shape])) - ) - xv = rand_val(xsh) - yv = rand_val(ysh) - zv = xv + yv - assert xv.shape == zv.shape - assert tuple(f(xv, yv)) == zv.shape + assert xv.shape == zv.shape + assert tuple(f(xv, yv)) == zv.shape def test_perform(self): self.with_linker(PerformLinker(), self.op, self.type, self.rand_val) @@ -775,32 +750,42 @@ def test_input_dimensions_overflow(self): g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py")) g(*[np.zeros(2**11, config.floatX) for i in range(6)]) - def check_input_dimensions_match(self, mode): - """Make sure that our input validation works correctly and doesn't - throw erroneous broadcast-based errors. - """ + @staticmethod + def check_runtime_shapes_error(mode): + """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" x_v = matrix("x") m_v = vector("m") - x = np.array([[-1.32720483], [0.23442016]]).astype(config.floatX) - m = np.array([0.0, 0.0]).astype(config.floatX) - z_v = x_v - m_v f = pytensor.function([x_v, m_v], z_v, mode=mode) - res = f(x, m) + # Test invalid broadcasting by either x or m + for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]: + x = np.ones(x_sh).astype(config.floatX) + m = np.zeros(m_sh).astype(config.floatX) + + # This error is introduced by PyTensor, so it's the same across different backends + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + f(x, m) + + x = np.ones((2, 3)).astype(config.floatX) + m = np.zeros((1,)).astype(config.floatX) - assert np.array_equal(res, x - m) + x = np.ones((2, 4)).astype(config.floatX) + m = np.zeros((3,)).astype(config.floatX) + # This error is backend specific, and may have different types + with pytest.raises((ValueError, TypeError)): + f(x, m) - def test_input_dimensions_match_python(self): - self.check_input_dimensions_match(Mode(linker="py")) + def test_runtime_shapes_error_python(self): + self.check_runtime_shapes_error(Mode(linker="py")) @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test.", ) - def test_input_dimensions_match_c(self): - self.check_input_dimensions_match(Mode(linker="c")) + def test_runtime_shapes_error_c(self): + self.check_runtime_shapes_error(Mode(linker="c")) def test_str(self): op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) @@ -825,7 +810,7 @@ def test_partial_static_shape_info(self): assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 - def test_multi_output(self): + def test_infer_shape_multi_output(self): class CustomElemwise(Elemwise): def make_node(self, *args): res = super().make_node(*args) @@ -839,14 +824,26 @@ def make_node(self, *args): ], ) - z_1, z_2 = CustomElemwise(aes.add)( - as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1)) - ) + custom_elemwise = CustomElemwise(aes.add) + z_1, z_2 = custom_elemwise( + as_tensor_variable(np.eye(1)), + as_tensor_variable(np.eye(1)), + ) in_1_shape = (aes.constant(1), aes.constant(1)) + outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + for out in outs: + assert out[0].eval() == 1 + assert out[1].eval() == 1 - with pytest.raises(ShapeError): - z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + z_1, z_2 = custom_elemwise( + as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3)) + ) + in_2_shape = (aes.constant(3), aes.constant(3)) + outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape]) + for out in outs: + assert out[0].eval() == 3 + assert out[1].eval() == 3 def test_shape_types(self): x = tensor(dtype=np.float64, shape=(None, 1))