Skip to content

Commit

Permalink
Forbid runtime broadcasting in Elemwise
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 10, 2023
1 parent 5c87d74 commit a26e46b
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 218 deletions.
12 changes: 10 additions & 2 deletions pytensor/link/jax/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@


@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op, **kwargs):
def jax_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
return jax_funcify(scalar_op, **kwargs)
base_fn = jax_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
# ScalarVariables in JAX are passed as int/float.
# We wrap them in arrays just for the broadcast check
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
return base_fn(*inputs)

return elemwise_fn


@jax_funcify.register(CAReduce)
Expand Down
13 changes: 9 additions & 4 deletions pytensor/link/numba/dispatch/elemwise_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@ def compute_itershape(
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
with builder.if_else(
builder.or_(
builder.icmp_unsigned("==", length, one),
builder.icmp_unsigned("==", shape[i], one),
)
) as (
then,
otherwise,
):
with then:
msg = (
f"Incompatible shapes for input {j} and axis {i} of "
f"elemwise. Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise:
Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
shape_padaxis,
shape_padleft,
shape_padright,
specify_broadcastable,
specify_shape,
)

Expand Down
52 changes: 36 additions & 16 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytensor.tensor.basic
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.basic import failure_code
Expand All @@ -19,9 +19,9 @@
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
Expand Down Expand Up @@ -740,9 +740,7 @@ def perform(self, node, inputs, output_storage):
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)

for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
self._check_runtime_broadcast(node, inputs)

ufunc_args = inputs
ufunc_kwargs = {}
Expand Down Expand Up @@ -818,18 +816,40 @@ def perform(self, node, inputs, output_storage):
else:
storage[0] = variable

def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1:
from pytensor.tensor.exceptions import ShapeError

raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
)
@staticmethod
def _check_runtime_broadcast(node, inputs):
for dims_and_bcast in zip(
*[
zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
one = pytensor.tensor.basic.constant(1, dtype="int64")
output_shape = []
for dim, broadcastable in enumerate(node.outputs[0].type.broadcastable):
out_dim_length = one
if not broadcastable:
# There must be some input that is not broadcastable in this dim
for inp_shape, inp_var in zip(i_shapes, node.inputs):
if not inp_var.type.broadcastable[dim]:
# Give preference to constant dims
if isinstance(inp_shape[dim], Constant):
out_dim_length = inp_shape[dim]
break
# If we haven't yet seen a non-broadcastable dim, use this one
if out_dim_length is one:
out_dim_length = inp_shape[dim]
output_shape.append(as_tensor_variable(out_dim_length, dtype="int64"))

# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
return [tuple(as_tensor_variable(s) for s in out_shape)]
return [tuple(output_shape)] * len(node.outputs)

def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
Expand Down Expand Up @@ -1193,7 +1213,7 @@ def c_support_code_apply(self, node, nodename):
return support_code

def c_code_cache_version_apply(self, node):
version = [14] # the version corresponding to the c code in this Op
version = [15] # the version corresponding to the c code in this Op

# now we insert versions for the ops on which we depend...
scalar_node = Apply(
Expand Down
101 changes: 42 additions & 59 deletions pytensor/tensor/elemwise_cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
if index != "x":
# Initialize the variables associated to the jth loop
# jump = stride - adjust
# If the variable has size 1 in that dim, we set the stride to zero to
# emulate broadcasting
jump = f"({var}_stride{index}) - ({adjust})"
init += f"""
{var}_n{index} = PyArray_DIMS({var})[{index}];
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_jump{index}_{j} = {jump};
"""
adjust = f"{var}_n{index}*{var}_stride{index}"
Expand All @@ -86,88 +84,73 @@ def make_checks(loop_orders, dtypes, sub):
# This loop builds multiple if conditions to verify that the
# dimensions of the inputs match, and the first one that is true
# raises an informative error message

runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

for matches in zip(*loop_orders):
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]

# elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
if len(to_compare) < 2:
continue

# Find first dimension size that is != 1
jl, xl = to_compare[-1]
non1size_dim_check = f"""
npy_intp non1size_dim{xl};
non1size_dim{xl} = """
for j, x in to_compare[:-1]:
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
non1size_dim_check += f"%(lv{jl})s_n{xl};"
check += non1size_dim_check

# Check the nonsize1 dims match
# TODO: This is a bit inefficient because we are comparing one dimension against itself
check += f"""
if (non1size_dim{xl} != 1)
{{
"""
for j, x in to_compare:
j0, x0 = to_compare[0]
for j, x in to_compare[1:]:
check += f"""
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1))
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
{{
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
{{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
{x},
(long long int) non1size_dim{x},
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
}} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
%(fail)s
}}
"""
check += """
}
%(fail)s
}}
"""

return init % sub + check % sub


def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from
Elemwise operations.
def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute the output dimensions of an Elemwise operation.
The code returned by this function populates the array `array_name`, but does not
initialize it.
TODO: We can decide to either specialize C code even further given the input types
or make it general, regardless of whether static broadcastable information is given
Note: We could specialize C code even further with the known static output shapes
"""
dims_c_code = ""
for i, candidates in enumerate(zip(*loop_orders)):
# TODO: Are candidates always either "x" or "i"? If that's the case we can
# simplify some logic here (e.g., we don't need to track the `idx`).
nonx_candidates = tuple(
(idx, c) for idx, c in enumerate(candidates) if c != "x"
)

# All inputs are known to be broadcastable
if not nonx_candidates:
# Borrow the length of the first non-broadcastable input dimension
for j, candidate in enumerate(candidates):
if candidate != "x":
var = sub[f"lv{int(j)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
break
# If none is non-broadcastable, the output dimension has a length of 1
else: # no-break
dims_c_code += f"{array_name}[{i}] = 1;\n"
continue

# There is only one informative source of size
if len(nonx_candidates) == 1:
idx, candidate = nonx_candidates[0]
var = sub[f"lv{int(idx)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
continue

# In this case any non-size 1 variable will define the right size
dims_c_code += f"{array_name}[{i}] = "
for idx, candidate in nonx_candidates[:-1]:
var = sub[f"lv{int(idx)}"]
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
idx, candidate = nonx_candidates[-1]
var = sub[f"lv{idx}"]
dims_c_code += f"{var}_n{candidate};\n"
return dims_c_code


Expand All @@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
if type.startswith("PYTENSOR_COMPLEX"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0])
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub)
init_dims = compute_output_dims_lengths("dims", loop_orders, sub)

# TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's
Expand Down Expand Up @@ -359,7 +342,7 @@ def make_reordered_loop(

# Get the (sorted) total number of iterations of each loop
declare_totals = f"int init_totals[{nnested}];\n"
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub)
declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub)

# Sort totals to match the new order that was computed by sorting
# the loop vector. One integer variable per loop is declared.
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):

_broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to "
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
)

Expand Down
6 changes: 6 additions & 0 deletions tests/link/jax/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytensor
import pytensor.tensor as at
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
Expand All @@ -14,6 +15,11 @@
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise


def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))


def test_jax_Dimshuffle():
Expand Down
6 changes: 6 additions & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import numpy as np
import pytest
import scipy.special
from tensor.test_elemwise import TestElemwise

import pytensor
import pytensor.tensor as at
import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
Expand Down Expand Up @@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)


def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))


def test_elemwise_speed(benchmark):
x = at.dmatrix("y")
y = at.dvector("z")
Expand Down
7 changes: 1 addition & 6 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,12 +1671,7 @@ def verify_op_count(f, count, cls):
(),
(),
),
pytest.param(
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(
Expand Down
7 changes: 3 additions & 4 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,10 @@ def test_mul_div_cases(self):
((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"),
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
# The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# must broadcast as there is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
]
):
Expand Down
Loading

0 comments on commit a26e46b

Please sign in to comment.