From 563277791ff6b9267a25e9dd968e0530ec1bf2ec Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 15 Jul 2024 17:55:08 +0200 Subject: [PATCH] Simplify logic with `variadic_add` and `variadic_mul` helpers --- pytensor/tensor/blas.py | 8 ++---- pytensor/tensor/math.py | 36 +++++++++++++++++--------- pytensor/tensor/rewriting/basic.py | 13 +++------- pytensor/tensor/rewriting/blas.py | 15 +++++++---- pytensor/tensor/rewriting/math.py | 25 +++++------------- pytensor/tensor/rewriting/subtensor.py | 15 +++++------ 6 files changed, 53 insertions(+), 59 deletions(-) diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 22a08718ae..b3cf96cbd4 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -102,7 +102,7 @@ from pytensor.tensor.basic import expand_dims from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import add, mul, neg, sub +from pytensor.tensor.math import add, mul, neg, sub, variadic_add from pytensor.tensor.shape import shape_padright, specify_broadcastable from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor @@ -1399,11 +1399,7 @@ def item_to_var(t): item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) ] add_inputs.extend(gemm_of_sM_list) - if len(add_inputs) > 1: - rval = [add(*add_inputs)] - else: - rval = add_inputs - # print "RETURNING GEMM THING", rval + rval = [variadic_add(*add_inputs)] return rval, old_dot22 diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 1b5b94aa7f..d1aa438216 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1429,18 +1429,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) else: shp = cast(shp, "float64") - if axis is None: - axis = list(range(input.ndim)) - elif isinstance(axis, int | np.integer): - axis = [axis] - elif isinstance(axis, np.ndarray) and axis.ndim == 0: - axis = [int(axis)] - else: - axis = [int(a) for a in axis] - - # This sequential division will possibly be optimized by PyTensor: - for i in axis: - s = true_div(s, shp[i]) + reduced_dims = ( + shp + if axis is None + else [shp[i] for i in normalize_axis_tuple(axis, input.type.ndim)] + ) + s /= variadic_mul(*reduced_dims).astype(shp.dtype) # This can happen when axis is an empty list/tuple if s.dtype != shp.dtype and s.dtype in discrete_dtypes: @@ -1596,6 +1590,15 @@ def add(a, *other_terms): # see decorator for function body +def variadic_add(*args): + """Add that accepts arbitrary number of inputs, including zero or one.""" + if not args: + return constant(0) + if len(args) == 1: + return args[0] + return add(*args) + + @scalar_elemwise def sub(a, b): """elementwise subtraction""" @@ -1608,6 +1611,15 @@ def mul(a, *other_terms): # see decorator for function body +def variadic_mul(*args): + """Mul that accepts arbitrary number of inputs, including zero or one.""" + if not args: + return constant(1) + if len(args) == 1: + return args[0] + return mul(*args) + + @scalar_elemwise def true_div(a, b): """elementwise [true] division (inverse of multiplication)""" diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 3cdd5b7ad6..78d00790ac 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -68,7 +68,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays -from pytensor.tensor.math import Sum, add, eq +from pytensor.tensor.math import Sum, add, eq, variadic_add from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -939,14 +939,9 @@ def local_sum_make_vector(fgraph, node): if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64": return - if len(elements) == 0: - element_sum = zeros(dtype=out_dtype, shape=()) - elif len(elements) == 1: - element_sum = cast(elements[0], out_dtype) - else: - element_sum = cast( - add(*[cast(value, acc_dtype) for value in elements]), out_dtype - ) + element_sum = cast( + variadic_add(*[cast(value, acc_dtype) for value in elements]), out_dtype + ) return [element_sum] diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index cc8dd472e6..d52ee70e17 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -96,7 +96,15 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub +from pytensor.tensor.math import ( + Dot, + _matrix_matrix_matmul, + add, + mul, + neg, + sub, + variadic_add, +) from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.type import ( DenseTensorType, @@ -386,10 +394,7 @@ def item_to_var(t): item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) ] add_inputs.extend(gemm_of_sM_list) - if len(add_inputs) > 1: - rval = [add(*add_inputs)] - else: - rval = add_inputs + rval = [variadic_add(*add_inputs)] # print "RETURNING GEMM THING", rval return rval, old_dot22 diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 41d1783644..6568bcdf3e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -76,6 +76,8 @@ sub, tri_gamma, true_div, + variadic_add, + variadic_mul, ) from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import max as pt_max @@ -1270,17 +1272,13 @@ def local_sum_prod_of_mul_or_div(fgraph, node): if not outer_terms: return None - elif len(outer_terms) == 1: - [outer_term] = outer_terms else: - outer_term = mul(*outer_terms) + outer_term = variadic_mul(*outer_terms) if not inner_terms: inner_term = None - elif len(inner_terms) == 1: - [inner_term] = inner_terms else: - inner_term = mul(*inner_terms) + inner_term = variadic_mul(*inner_terms) else: # true_div # We only care about removing the denominator out of the reduction @@ -2143,10 +2141,7 @@ def local_add_remove_zeros(fgraph, node): assert cst.type.broadcastable == (True,) * ndim return [alloc_like(cst, node_output, fgraph)] - if len(new_inputs) == 1: - ret = [alloc_like(new_inputs[0], node_output, fgraph)] - else: - ret = [alloc_like(add(*new_inputs), node_output, fgraph)] + ret = [alloc_like(variadic_add(*new_inputs), node_output, fgraph)] # The dtype should not be changed. It can happen if the input # that was forcing upcasting was equal to 0. @@ -2257,10 +2252,7 @@ def local_log1p(fgraph, node): # scalar_inputs are potentially dimshuffled and fill'd scalars if scalars and np.allclose(np.sum(scalars), 1): if nonconsts: - if len(nonconsts) > 1: - ninp = add(*nonconsts) - else: - ninp = nonconsts[0] + ninp = variadic_add(*nonconsts) if ninp.dtype != log_arg.type.dtype: ninp = ninp.astype(node.outputs[0].dtype) return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] @@ -3084,10 +3076,7 @@ def local_exp_over_1_plus_exp(fgraph, node): return # put the new numerator together new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest - if len(new_num) == 1: - new_num = new_num[0] - else: - new_num = mul(*new_num) + new_num = variadic_mul(*new_num) if num_neg ^ denom_neg: new_num = -new_num diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index f234b46804..0e7f9cc3f1 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -48,6 +48,7 @@ maximum, minimum, or_, + variadic_add, ) from pytensor.tensor.math import all as pt_all from pytensor.tensor.rewriting.basic import ( @@ -1241,15 +1242,11 @@ def movable(i): new_inputs = [i for i in node.inputs if not movable(i)] + [ mi.owner.inputs[0] for mi in movable_inputs ] - if len(new_inputs) == 0: - new_add = new_inputs[0] - else: - new_add = add(*new_inputs) - - # Copy over stacktrace from original output, as an error - # (e.g. an index error) in this add operation should - # correspond to an error in the original add operation. - copy_stack_trace(node.outputs[0], new_add) + new_add = variadic_add(*new_inputs) + # Copy over stacktrace from original output, as an error + # (e.g. an index error) in this add operation should + # correspond to an error in the original add operation. + copy_stack_trace(node.outputs[0], new_add) # stack up the new incsubtensors tip = new_add