From bf3af90bf453d15da6946a48100571c4e6dbb48e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 25 Oct 2024 13:59:49 +0200 Subject: [PATCH] Simplify Elemwise perform method and issue informative warning when number of operands is too large. This also clears a hard to debug error when perform method attempted to falback to the C-implementation. --- pytensor/tensor/elemwise.py | 153 ++++++++++++++++-------------------- 1 file changed, 69 insertions(+), 84 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a51c2034af..a970ea31be 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Sequence from copy import copy from textwrap import dedent @@ -19,9 +20,9 @@ from pytensor.misc.frozendict import frozendict from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type +from pytensor.scalar.basic import Composite, transfer_type, upcast from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import identity as scalar_identity -from pytensor.scalar.basic import transfer_type, upcast from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length from pytensor.tensor.basic import _get_vector_length, as_tensor_variable @@ -364,6 +365,7 @@ def __init__( self.name = name self.scalar_op = scalar_op self.inplace_pattern = inplace_pattern + self.ufunc = None self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()} if nfunc_spec is None: @@ -375,14 +377,12 @@ def __init__( def __getstate__(self): d = copy(self.__dict__) d.pop("ufunc") - d.pop("nfunc") - d.pop("__epydoc_asRoutine", None) return d def __setstate__(self, d): + d.pop("nfunc", None) # This used to be stored in the Op, not anymore super().__setstate__(d) self.ufunc = None - self.nfunc = None self.inplace_pattern = frozendict(self.inplace_pattern) def get_output_info(self, *inputs): @@ -623,31 +623,47 @@ def transform(r): return ret - def prepare_node(self, node, storage_map, compute_map, impl): - # Postpone the ufunc building to the last minutes due to: - # - NumPy ufunc support only up to 32 operands (inputs and outputs) - # But our c code support more. - # - nfunc is reused for scipy and scipy is optional - if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py": - impl = "c" - - if getattr(self, "nfunc_spec", None) and impl != "c": - self.nfunc = import_func_from_string(self.nfunc_spec[0]) - + def _create_node_ufunc(self, node) -> None: if ( - (len(node.inputs) + len(node.outputs)) <= 32 - and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) - and self.ufunc is None - and impl == "py" + self.nfunc_spec is not None + # Some scalar Ops like `Add` allow for a variable number of inputs, + # whereas the numpy counterpart does not. + and len(node.inputs) == self.nfunc_spec[1] ): + ufunc = import_func_from_string(self.nfunc_spec[0]) + if ufunc is None: + raise ValueError( + f"Could not import ufunc {self.nfunc_spec[0]} for {self}" + ) + + elif self.ufunc is not None: + # Cached before + ufunc = self.ufunc + + else: + if (len(node.inputs) + len(node.outputs)) > 32: + if isinstance(self.scalar_op, Composite): + warnings.warn( + "Trying to create a Python Composite Elemwise function with more than 32 operands.\n" + "This operation should not have been introduced if the C-backend is not properly setup. " + 'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n' + "Alternatively, consider using an optional backend like NUMBA or JAX, by setting " + '`pytensor.config.mode = "NUMBA" (or "JAX").' + ) + else: + warnings.warn( + f"Trying to create a Python Elemwise function for the scalar Op {self.scalar_op} " + f"with more than 32 operands. This will likely fail." + ) + ufunc = np.frompyfunc( self.scalar_op.impl, len(node.inputs), self.scalar_op.nout ) - if self.scalar_op.nin > 0: - # We can reuse it for many nodes + if self.scalar_op.nin > 0: # Default in base class is -1 + # Op has constant signature, so we can reuse ufunc for many nodes. Cache it. self.ufunc = ufunc - else: - node.tag.ufunc = ufunc + + node.tag.ufunc = ufunc # Numpy ufuncs will sometimes perform operations in # float16, in particular when the input is int8. @@ -660,15 +676,23 @@ def prepare_node(self, node, storage_map, compute_map, impl): # NumPy 1.10.1 raise an error when giving the signature # when the input is complex. So add it only when inputs is int. - out_dtype = node.outputs[0].dtype + ufunc_kwargs = {} if ( - out_dtype in float_dtypes - and isinstance(self.nfunc, np.ufunc) + isinstance(ufunc, np.ufunc) + # TODO: Why check for the dtype of the first input only? and node.inputs[0].dtype in discrete_dtypes + and len(node.outputs) == 1 + and node.outputs[0].dtype in float_dtypes ): - char = np.sctype2char(out_dtype) - sig = char * node.nin + "->" + char * node.nout - node.tag.sig = sig + char = np.sctype2char(node.outputs[0].dtype) + ufunc_kwargs["sig"] = char * node.nin + "->" + char * node.nout + + node.tag.ufunc_kwargs = ufunc_kwargs + + def prepare_node(self, node, storage_map, compute_map, impl): + if impl == "py": + self._create_node_ufunc(node) + node.tag.fake_node = Apply( self.scalar_op, [ @@ -684,71 +708,32 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) def perform(self, node, inputs, output_storage): - if (len(node.inputs) + len(node.outputs)) > 32: - # Some versions of NumPy will segfault, other will raise a - # ValueError, if the number of operands in an ufunc is more than 32. - # In that case, the C version should be used, or Elemwise fusion - # should be disabled. - # FIXME: This no longer calls the C implementation! - super().perform(node, inputs, output_storage) + ufunc = getattr(node.tag, "ufunc", None) + if ufunc is None: + self._create_node_ufunc(node) + ufunc = node.tag.ufunc self._check_runtime_broadcast(node, inputs) - ufunc_args = inputs - ufunc_kwargs = {} - # We supported in the past calling manually op.perform. - # To keep that support we need to sometimes call self.prepare_node - if self.nfunc is None and self.ufunc is None: - self.prepare_node(node, None, None, "py") - if self.nfunc and len(inputs) == self.nfunc_spec[1]: - ufunc = self.nfunc - nout = self.nfunc_spec[2] - if hasattr(node.tag, "sig"): - ufunc_kwargs["sig"] = node.tag.sig - # Unfortunately, the else case does not allow us to - # directly feed the destination arguments to the nfunc - # since it sometimes requires resizing. Doing this - # optimization is probably not worth the effort, since we - # should normally run the C version of the Op. - else: - # the second calling form is used because in certain versions of - # numpy the first (faster) version leads to segfaults - if self.ufunc: - ufunc = self.ufunc - elif not hasattr(node.tag, "ufunc"): - # It happen that make_thunk isn't called, like in - # get_underlying_scalar_constant_value - self.prepare_node(node, None, None, "py") - # prepare_node will add ufunc to self or the tag - # depending if we can reuse it or not. So we need to - # test both again. - if self.ufunc: - ufunc = self.ufunc - else: - ufunc = node.tag.ufunc - else: - ufunc = node.tag.ufunc - - nout = ufunc.nout - - variables = ufunc(*ufunc_args, **ufunc_kwargs) + outputs = ufunc(*inputs, **node.tag.get("ufunc_kwargs", {})) - if nout == 1: - variables = [variables] + if not isinstance(outputs, tuple): + outputs = (outputs,) - for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs) + for i, (out, out_storage, node_out) in enumerate( + zip(outputs, output_storage, node.outputs) ): - storage[0] = variable = np.asarray(variable, dtype=nout.dtype) + # Numpy frompyfunc always returns object arrays + out_storage[0] = out = np.asarray(out, dtype=node_out.dtype) if i in self.inplace_pattern: - odat = inputs[self.inplace_pattern[i]] - odat[...] = variable - storage[0] = odat + inp = inputs[self.inplace_pattern[i]] + inp[...] = out + out_storage[0] = inp # numpy.real return a view! - if not variable.flags.owndata: - storage[0] = variable.copy() + if not out.flags.owndata: + out_storage[0] = out.copy() @staticmethod def _check_runtime_broadcast(node, inputs):