diff --git a/pyomo/common/dependencies.py b/pyomo/common/dependencies.py index 4ddbe1c9ee8..a0717dba883 100644 --- a/pyomo/common/dependencies.py +++ b/pyomo/common/dependencies.py @@ -778,6 +778,8 @@ def _finalize_matplotlib(module, available): def _finalize_numpy(np, available): if not available: return + # Register ndarray as a native type to prevent 1-element ndarrays + # from accidentally registering ndarray as a native_numeric_type. numeric_types.native_types.add(np.ndarray) numeric_types.RegisterLogicalType(np.bool_) for t in ( @@ -798,12 +800,30 @@ def _finalize_numpy(np, available): # registration here (to bypass the deprecation warning) until we # finally remove all support for it numeric_types._native_boolean_types.add(t) - for t in (np.float_, np.float16, np.float32, np.float64): + _floats = [np.float_, np.float16, np.float32, np.float64] + # float96 and float128 may or may not be defined in this particular + # numpy build (it depends on platform and version). + # Register them only if they are present + if hasattr(np, 'float96'): + _floats.append(np.float96) + if hasattr(np, 'float128'): + _floats.append(np.float128) + for t in _floats: numeric_types.RegisterNumericType(t) # We have deprecated RegisterBooleanType, so we will mock up the # registration here (to bypass the deprecation warning) until we # finally remove all support for it numeric_types._native_boolean_types.add(t) + _complex = [np.complex_, np.complex64, np.complex128] + # complex192 and complex256 may or may not be defined in this + # particular numpy build (it depends on platform and version). + # Register them only if they are present + if hasattr(np, 'complex192'): + _complex.append(np.complex192) + if hasattr(np, 'complex256'): + _complex.append(np.complex256) + for t in _complex: + numeric_types.RegisterComplexType(t) dill, dill_available = attempt_import('dill') diff --git a/pyomo/common/numeric_types.py b/pyomo/common/numeric_types.py index dbad3ef0853..af7eeded3cf 100644 --- a/pyomo/common/numeric_types.py +++ b/pyomo/common/numeric_types.py @@ -41,9 +41,14 @@ #: Python set used to identify numeric constants. This set includes #: native Python types as well as numeric types from Python packages #: like numpy, which may be registered by users. -native_numeric_types = {int, float, complex} +#: +#: Note that :data:`native_numeric_types` does NOT include +#: :py:`complex`, as that is not a valid constant in Pyomo numeric +#: expressions. +native_numeric_types = {int, float} native_integer_types = {int} native_logical_types = {bool} +native_complex_types = {complex} pyomo_constant_types = set() # includes NumericConstant _native_boolean_types = {int, bool, str, bytes} @@ -64,34 +69,53 @@ #: like numpy. #: #: :data:`native_types` = :data:`native_numeric_types ` + { str } -native_types = set([bool, str, type(None), slice, bytes]) +native_types = {bool, str, type(None), slice, bytes} native_types.update(native_numeric_types) native_types.update(native_integer_types) -native_types.update(_native_boolean_types) +native_types.update(native_complex_types) native_types.update(native_logical_types) +native_types.update(_native_boolean_types) nonpyomo_leaf_types.update(native_types) -def RegisterNumericType(new_type): - """ - A utility function for updating the set of types that are - recognized to handle numeric values. +def RegisterNumericType(new_type: type): + """Register the specified type as a "numeric type". + + A utility function for registering new types as "native numeric + types" that can be leaf nodes in Pyomo numeric expressions. The + type should be compatible with :py:class:`float` (that is, store a + scalar and be castable to a Python float). + + Parameters + ---------- + new_type: type + The new numeric type (e.g, numpy.float64) - The argument should be a class (e.g, numpy.float64). """ native_numeric_types.add(new_type) native_types.add(new_type) nonpyomo_leaf_types.add(new_type) -def RegisterIntegerType(new_type): - """ - A utility function for updating the set of types that are - recognized to handle integer values. This also registers the type - as numeric but does not register it as boolean. +def RegisterIntegerType(new_type: type): + """Register the specified type as an "integer type". + + A utility function for registering new types as "native integer + types". Integer types can be leaf nodes in Pyomo numeric + expressions. The type should be compatible with :py:class:`float` + (that is, store a scalar and be castable to a Python float). + + Registering a type as an integer type implies + :py:func:`RegisterNumericType`. + + Note that integer types are NOT registered as logical / Boolean types. + + Parameters + ---------- + new_type: type + The new integer type (e.g, numpy.int64) - The argument should be a class (e.g., numpy.int64). """ native_numeric_types.add(new_type) native_integer_types.add(new_type) @@ -104,26 +128,64 @@ def RegisterIntegerType(new_type): "is deprecated. Users likely should use RegisterLogicalType.", version='6.6.0', ) -def RegisterBooleanType(new_type): - """ - A utility function for updating the set of types that are - recognized as handling boolean values. This function does not - register the type of integer or numeric. +def RegisterBooleanType(new_type: type): + """Register the specified type as a "logical type". + + A utility function for registering new types as "native logical + types". Logical types can be leaf nodes in Pyomo logical + expressions. The type should be compatible with :py:class:`bool` + (that is, store a scalar and be castable to a Python bool). + + Note that logical types are NOT registered as numeric types. + + Parameters + ---------- + new_type: type + The new logical type (e.g, numpy.bool_) - The argument should be a class (e.g., numpy.bool_). """ _native_boolean_types.add(new_type) native_types.add(new_type) nonpyomo_leaf_types.add(new_type) -def RegisterLogicalType(new_type): +def RegisterComplexType(new_type: type): + """Register the specified type as an "complex type". + + A utility function for registering new types as "native complex + types". Complex types can NOT be leaf nodes in Pyomo numeric + expressions. The type should be compatible with :py:class:`complex` + (that is, store a scalar complex value and be castable to a Python + complex). + + Note that complex types are NOT registered as logical or numeric types. + + Parameters + ---------- + new_type: type + The new complex type (e.g, numpy.complex128) + """ - A utility function for updating the set of types that are - recognized as handling boolean values. This function does not - register the type of integer or numeric. + native_types.add(new_type) + native_complex_types.add(new_type) + nonpyomo_leaf_types.add(new_type) + + +def RegisterLogicalType(new_type: type): + """Register the specified type as a "logical type". + + A utility function for registering new types as "native logical + types". Logical types can be leaf nodes in Pyomo logical + expressions. The type should be compatible with :py:class:`bool` + (that is, store a scalar and be castable to a Python bool). + + Note that logical types are NOT registered as numeric types. + + Parameters + ---------- + new_type: type + The new logical type (e.g, numpy.bool_) - The argument should be a class (e.g., numpy.bool_). """ _native_boolean_types.add(new_type) native_logical_types.add(new_type) @@ -135,8 +197,9 @@ def check_if_numeric_type(obj): """Test if the argument behaves like a numeric type. We check for "numeric types" by checking if we can add zero to it - without changing the object's type. If that works, then we register - the type in native_numeric_types. + without changing the object's type, and that the object compares to + 0 in a meaningful way. If that works, then we register the type in + :py:attr:`native_numeric_types`. """ obj_class = obj.__class__ @@ -181,25 +244,25 @@ def check_if_numeric_type(obj): def value(obj, exception=True): """ - A utility function that returns the value of a Pyomo object or - expression. - - Args: - obj: The argument to evaluate. If it is None, a - string, or any other primitive numeric type, - then this function simply returns the argument. - Otherwise, if the argument is a NumericValue - then the __call__ method is executed. - exception (bool): If :const:`True`, then an exception should - be raised when instances of NumericValue fail to - s evaluate due to one or more objects not being - initialized to a numeric value (e.g, one or more - variables in an algebraic expression having the - value None). If :const:`False`, then the function - returns :const:`None` when an exception occurs. - Default is True. - - Returns: A numeric value or None. + A utility function that returns the value of a Pyomo object or + expression. + + Args: + obj: The argument to evaluate. If it is None, a + string, or any other primitive numeric type, + then this function simply returns the argument. + Otherwise, if the argument is a NumericValue + then the __call__ method is executed. + exception (bool): If :const:`True`, then an exception should + be raised when instances of NumericValue fail to + evaluate due to one or more objects not being + initialized to a numeric value (e.g, one or more + variables in an algebraic expression having the + value None). If :const:`False`, then the function + returns :const:`None` when an exception occurs. + Default is True. + + Returns: A numeric value or None. """ if obj.__class__ in native_types: return obj diff --git a/pyomo/core/kernel/register_numpy_types.py b/pyomo/core/kernel/register_numpy_types.py index b9205930512..5f7812354d9 100644 --- a/pyomo/core/kernel/register_numpy_types.py +++ b/pyomo/core/kernel/register_numpy_types.py @@ -17,10 +17,11 @@ version='6.1', ) -from pyomo.core.expr.numvalue import ( +from pyomo.common.numeric_types import ( RegisterNumericType, RegisterIntegerType, RegisterBooleanType, + native_complex_types, native_numeric_types, native_integer_types, native_boolean_types, @@ -37,6 +38,8 @@ numpy_float = [] numpy_bool_names = [] numpy_bool = [] +numpy_complex_names = [] +numpy_complex = [] if _has_numpy: # Historically, the lists included several numpy aliases @@ -44,6 +47,8 @@ numpy_int.extend((numpy.int_, numpy.intc, numpy.intp)) numpy_float_names.append('float_') numpy_float.append(numpy.float_) + numpy_complex_names.append('complex_') + numpy_complex.append(numpy.complex_) # Re-build the old numpy_* lists for t in native_boolean_types: @@ -63,13 +68,8 @@ # Complex -numpy_complex_names = [] -numpy_complex = [] -if _has_numpy: - numpy_complex_names.extend(('complex_', 'complex64', 'complex128')) - for _type_name in numpy_complex_names: - try: - _type = getattr(numpy, _type_name) - numpy_complex.append(_type) - except: # pragma:nocover - pass +for t in native_complex_types: + if t.__module__ == 'numpy': + if t.__name__ not in numpy_complex_names: + numpy_complex.append(t) + numpy_complex_names.append(t.__name__) diff --git a/pyomo/core/tests/unit/test_kernel_register_numpy_types.py b/pyomo/core/tests/unit/test_kernel_register_numpy_types.py index 0a9e3ab08f9..117de5c5f4c 100644 --- a/pyomo/core/tests/unit/test_kernel_register_numpy_types.py +++ b/pyomo/core/tests/unit/test_kernel_register_numpy_types.py @@ -10,7 +10,7 @@ # ___________________________________________________________________________ import pyomo.common.unittest as unittest -from pyomo.common.dependencies import numpy_available +from pyomo.common.dependencies import numpy, numpy_available from pyomo.common.log import LoggingIntercept # Boolean @@ -38,12 +38,22 @@ numpy_float_names.append('float16') numpy_float_names.append('float32') numpy_float_names.append('float64') + if hasattr(numpy, 'float96'): + numpy_float_names.append('float96') + if hasattr(numpy, 'float128'): + # On some numpy builds, the name of float128 is longdouble + numpy_float_names.append(numpy.float128.__name__) # Complex numpy_complex_names = [] if numpy_available: numpy_complex_names.append('complex_') numpy_complex_names.append('complex64') numpy_complex_names.append('complex128') + if hasattr(numpy, 'complex192'): + numpy_complex_names.append('complex192') + if hasattr(numpy, 'complex256'): + # On some numpy builds, the name of complex256 is clongdouble + numpy_complex_names.append(numpy.complex256.__name__) class TestNumpyRegistration(unittest.TestCase): diff --git a/pyomo/repn/linear.py b/pyomo/repn/linear.py index 8bffbf1d49b..a0060da6e9f 100644 --- a/pyomo/repn/linear.py +++ b/pyomo/repn/linear.py @@ -8,14 +8,18 @@ # rights in this software. # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ -import collections + import logging import sys from operator import itemgetter from itertools import filterfalse from pyomo.common.deprecation import deprecation_warning -from pyomo.common.numeric_types import native_types, native_numeric_types +from pyomo.common.numeric_types import ( + native_types, + native_numeric_types, + native_complex_types, +) from pyomo.core.expr.numeric_expr import ( NegationExpression, ProductExpression, @@ -37,10 +41,11 @@ ) from pyomo.core.expr.visitor import StreamBasedExpressionVisitor, _EvaluationVisitor from pyomo.core.expr import is_fixed, value -from pyomo.core.base.expression import ScalarExpression, _GeneralExpressionData -from pyomo.core.base.objective import ScalarObjective, _GeneralObjectiveData +from pyomo.core.base.expression import Expression import pyomo.core.kernel as kernel from pyomo.repn.util import ( + BeforeChildDispatcher, + ExitNodeDispatcher, ExprType, InvalidNumber, apply_node_operation, @@ -331,7 +336,7 @@ def _handle_division_nonlinear(visitor, node, arg1, arg2): def _handle_pow_constant_constant(visitor, node, *args): arg1, arg2 = args ans = apply_node_operation(node, (arg1[1], arg2[1])) - if ans.__class__ in _complex_types: + if ans.__class__ in native_complex_types: ans = complex_number_error(ans, visitor, node) return _CONSTANT, ans @@ -380,7 +385,7 @@ def _handle_pow_nonlinear(visitor, node, arg1, arg2): def _handle_unary_constant(visitor, node, arg): ans = apply_node_operation(node, (arg[1],)) # Unary includes sqrt() which can return complex numbers - if ans.__class__ in _complex_types: + if ans.__class__ in native_complex_types: ans = complex_number_error(ans, visitor, node) return _CONSTANT, ans @@ -416,23 +421,12 @@ def _handle_named_ANY(visitor, node, arg1): return _type, arg1.duplicate() -_exit_node_handlers[ScalarExpression] = { +_exit_node_handlers[Expression] = { (_CONSTANT,): _handle_named_constant, (_LINEAR,): _handle_named_ANY, (_GENERAL,): _handle_named_ANY, } -_named_subexpression_types = [ - ScalarExpression, - _GeneralExpressionData, - kernel.expression.expression, - kernel.expression.noclone, - # Note: objectives are special named expressions - _GeneralObjectiveData, - ScalarObjective, - kernel.objective.objective, -] - # # EXPR_IF handlers # @@ -578,246 +572,162 @@ def _handle_ranged_general(visitor, node, arg1, arg2, arg3): ] = _handle_ranged_const -def _before_native(visitor, child): - return False, (_CONSTANT, child) - - -def _before_invalid(visitor, child): - return False, ( - _CONSTANT, - InvalidNumber(child, "'{child}' is not a valid numeric type"), - ) - - -def _before_complex(visitor, child): - return False, (_CONSTANT, complex_number_error(child, visitor, child)) - - -def _before_var(visitor, child): - _id = id(child) - if _id not in visitor.var_map: - if child.fixed: - return False, (_CONSTANT, visitor._eval_fixed(child)) - visitor.var_map[_id] = child - visitor.var_order[_id] = len(visitor.var_order) - ans = visitor.Result() - ans.linear[_id] = 1 - return False, (_LINEAR, ans) - - -def _before_param(visitor, child): - return False, (_CONSTANT, visitor._eval_fixed(child)) - +class LinearBeforeChildDispatcher(BeforeChildDispatcher): + def __init__(self): + # Special handling for external functions: will be handled + # as terminal nodes from the point of view of the visitor + self[ExternalFunctionExpression] = self._before_external + # Special linear / summation expressions + self[MonomialTermExpression] = self._before_monomial + self[LinearExpression] = self._before_linear + self[SumExpression] = self._before_general_expression + + @staticmethod + def _before_var(visitor, child): + _id = id(child) + if _id not in visitor.var_map: + if child.fixed: + return False, (_CONSTANT, visitor.handle_constant(child.value, child)) + visitor.var_map[_id] = child + visitor.var_order[_id] = len(visitor.var_order) + ans = visitor.Result() + ans.linear[_id] = 1 + return False, (_LINEAR, ans) -def _before_npv(visitor, child): - try: - return False, (_CONSTANT, visitor._eval_expr(child)) - except (ValueError, ArithmeticError): - return True, None - - -def _before_monomial(visitor, child): - # - # The following are performance optimizations for common - # situations (Monomial terms and Linear expressions) - # - arg1, arg2 = child._args_ - if arg1.__class__ not in native_types: - try: - arg1 = visitor._eval_expr(arg1) - except (ValueError, ArithmeticError): - return True, None + @staticmethod + def _before_monomial(visitor, child): + # + # The following are performance optimizations for common + # situations (Monomial terms and Linear expressions) + # + arg1, arg2 = child._args_ + if arg1.__class__ not in native_types: + try: + arg1 = visitor.handle_constant(visitor.evaluate(arg1), arg1) + except (ValueError, ArithmeticError): + return True, None - # We want to check / update the var_map before processing "0" - # coefficients so that we are consistent with what gets added to the - # var_map (e.g., 0*x*y: y is processed by _before_var and will - # always be added, but x is processed here) - _id = id(arg2) - if _id not in visitor.var_map: - if arg2.fixed: - return False, (_CONSTANT, arg1 * visitor._eval_fixed(arg2)) - visitor.var_map[_id] = arg2 - visitor.var_order[_id] = len(visitor.var_order) - - # Trap multiplication by 0 and nan. - if not arg1: - if arg2.fixed: - arg2 = visitor._eval_fixed(arg2) - if arg2 != arg2: - deprecation_warning( - f"Encountered {arg1}*{str(arg2.value)} in expression " - "tree. Mapping the NaN result to 0 for compatibility " - "with the lp_v1 writer. In the future, this NaN " - "will be preserved/emitted to comply with IEEE-754.", - version='6.6.0', + # We want to check / update the var_map before processing "0" + # coefficients so that we are consistent with what gets added to the + # var_map (e.g., 0*x*y: y is processed by _before_var and will + # always be added, but x is processed here) + _id = id(arg2) + if _id not in visitor.var_map: + if arg2.fixed: + return False, ( + _CONSTANT, + arg1 * visitor.handle_constant(arg2.value, arg2), ) - return False, (_CONSTANT, arg1) + visitor.var_map[_id] = arg2 + visitor.var_order[_id] = len(visitor.var_order) + + # Trap multiplication by 0 and nan. + if not arg1: + if arg2.fixed: + arg2 = visitor.handle_constant(arg2.value, arg2) + if arg2 != arg2: + deprecation_warning( + f"Encountered {arg1}*{str(arg2.value)} in expression " + "tree. Mapping the NaN result to 0 for compatibility " + "with the lp_v1 writer. In the future, this NaN " + "will be preserved/emitted to comply with IEEE-754.", + version='6.6.0', + ) + return False, (_CONSTANT, arg1) - ans = visitor.Result() - ans.linear[_id] = arg1 - return False, (_LINEAR, ans) + ans = visitor.Result() + ans.linear[_id] = arg1 + return False, (_LINEAR, ans) + @staticmethod + def _before_linear(visitor, child): + var_map = visitor.var_map + var_order = visitor.var_order + next_i = len(var_order) + ans = visitor.Result() + const = 0 + linear = ans.linear + for arg in child.args: + if arg.__class__ is MonomialTermExpression: + arg1, arg2 = arg._args_ + if arg1.__class__ not in native_types: + try: + arg1 = visitor.handle_constant(visitor.evaluate(arg1), arg1) + except (ValueError, ArithmeticError): + return True, None + + # Trap multiplication by 0 and nan. + if not arg1: + if arg2.fixed: + arg2 = visitor.handle_constant(arg2.value, arg2) + if arg2 != arg2: + deprecation_warning( + f"Encountered {arg1}*{str(arg2.value)} in expression " + "tree. Mapping the NaN result to 0 for compatibility " + "with the lp_v1 writer. In the future, this NaN " + "will be preserved/emitted to comply with IEEE-754.", + version='6.6.0', + ) + continue -def _before_linear(visitor, child): - var_map = visitor.var_map - var_order = visitor.var_order - next_i = len(var_order) - ans = visitor.Result() - const = 0 - linear = ans.linear - for arg in child.args: - if arg.__class__ is MonomialTermExpression: - arg1, arg2 = arg._args_ - if arg1.__class__ not in native_types: + _id = id(arg2) + if _id not in var_map: + if arg2.fixed: + const += arg1 * visitor.handle_constant(arg2.value, arg2) + continue + var_map[_id] = arg2 + var_order[_id] = next_i + next_i += 1 + linear[_id] = arg1 + elif _id in linear: + linear[_id] += arg1 + else: + linear[_id] = arg1 + elif arg.__class__ in native_numeric_types: + const += arg + else: try: - arg1 = visitor._eval_expr(arg1) + const += visitor.handle_constant(visitor.evaluate(arg), arg) except (ValueError, ArithmeticError): return True, None - - # Trap multiplication by 0 and nan. - if not arg1: - if arg2.fixed: - arg2 = visitor._eval_fixed(arg2) - if arg2 != arg2: - deprecation_warning( - f"Encountered {arg1}*{str(arg2.value)} in expression " - "tree. Mapping the NaN result to 0 for compatibility " - "with the lp_v1 writer. In the future, this NaN " - "will be preserved/emitted to comply with IEEE-754.", - version='6.6.0', - ) - continue - - _id = id(arg2) - if _id not in var_map: - if arg2.fixed: - const += arg1 * visitor._eval_fixed(arg2) - continue - var_map[_id] = arg2 - var_order[_id] = next_i - next_i += 1 - linear[_id] = arg1 - elif _id in linear: - linear[_id] += arg1 - else: - linear[_id] = arg1 - elif arg.__class__ in native_numeric_types: - const += arg + if linear: + ans.constant = const + return False, (_LINEAR, ans) else: - try: - const += visitor._eval_expr(arg) - except (ValueError, ArithmeticError): - return True, None - if linear: - ans.constant = const - return False, (_LINEAR, ans) - else: - return False, (_CONSTANT, const) - - -def _before_named_expression(visitor, child): - _id = id(child) - if _id in visitor.subexpression_cache: - _type, expr = visitor.subexpression_cache[_id] - if _type is _CONSTANT: - return False, (_type, expr) + return False, (_CONSTANT, const) + + @staticmethod + def _before_named_expression(visitor, child): + _id = id(child) + if _id in visitor.subexpression_cache: + _type, expr = visitor.subexpression_cache[_id] + if _type is _CONSTANT: + return False, (_type, expr) + else: + return False, (_type, expr.duplicate()) else: - return False, (_type, expr.duplicate()) - else: - return True, None - + return True, None -def _before_external(visitor, child): - ans = visitor.Result() - if all(is_fixed(arg) for arg in child.args): - try: - ans.constant = visitor._eval_expr(child) - return False, (_CONSTANT, ans) - except: - pass - ans.nonlinear = child - return False, (_GENERAL, ans) - - -def _before_general_expression(visitor, child): - return True, None - - -def _register_new_before_child_dispatcher(visitor, child): - dispatcher = _before_child_dispatcher - child_type = child.__class__ - if child_type in native_numeric_types: - if issubclass(child_type, complex): - _complex_types.add(child_type) - dispatcher[child_type] = _before_complex - else: - dispatcher[child_type] = _before_native - elif child_type in native_types: - dispatcher[child_type] = _before_invalid - elif not child.is_expression_type(): - if child.is_potentially_variable(): - dispatcher[child_type] = _before_var - else: - dispatcher[child_type] = _before_param - elif not child.is_potentially_variable(): - dispatcher[child_type] = _before_npv - # If we descend into the named expression (because of an - # evaluation error), then on the way back out, we will use - # the potentially variable handler to process the result. - pv_base_type = child.potentially_variable_base_class() - if pv_base_type not in dispatcher: + @staticmethod + def _before_external(visitor, child): + ans = visitor.Result() + if all(is_fixed(arg) for arg in child.args): try: - child.__class__ = pv_base_type - _register_new_before_child_dispatcher(visitor, child) - finally: - child.__class__ = child_type - if pv_base_type in visitor.exit_node_handlers: - visitor.exit_node_handlers[child_type] = visitor.exit_node_handlers[ - pv_base_type - ] - for args, fcn in visitor.exit_node_handlers[child_type].items(): - visitor.exit_node_dispatcher[(child_type, *args)] = fcn - elif id(child) in visitor.subexpression_cache or issubclass( - child_type, _GeneralExpressionData - ): - dispatcher[child_type] = _before_named_expression - visitor.exit_node_handlers[child_type] = visitor.exit_node_handlers[ - ScalarExpression - ] - for args, fcn in visitor.exit_node_handlers[child_type].items(): - visitor.exit_node_dispatcher[(child_type, *args)] = fcn - else: - dispatcher[child_type] = _before_general_expression - return dispatcher[child_type](visitor, child) - + ans.constant = visitor.handle_constant(visitor.evaluate(child), child) + return False, (_CONSTANT, ans) + except: + pass + ans.nonlinear = child + return False, (_GENERAL, ans) -_before_child_dispatcher = collections.defaultdict( - lambda: _register_new_before_child_dispatcher -) -# For efficiency reasons, we will maintain a separate list of all -# complex number types -_complex_types = set((complex,)) - -# We do not support writing complex numbers out -_before_child_dispatcher[complex] = _before_complex -# Special handling for external functions: will be handled -# as terminal nodes from the point of view of the visitor -_before_child_dispatcher[ExternalFunctionExpression] = _before_external -# Special linear / summation expressions -_before_child_dispatcher[MonomialTermExpression] = _before_monomial -_before_child_dispatcher[LinearExpression] = _before_linear -_before_child_dispatcher[SumExpression] = _before_general_expression +_before_child_dispatcher = LinearBeforeChildDispatcher() # # Initialize the _exit_node_dispatcher # def _initialize_exit_node_dispatcher(exit_handlers): - # expand the knowns set of named expressiosn - for expr in _named_subexpression_types: - exit_handlers[expr] = exit_handlers[ScalarExpression] - exit_dispatcher = {} for cls, handlers in exit_handlers.items(): for args, fcn in handlers.items(): @@ -828,7 +738,9 @@ def _initialize_exit_node_dispatcher(exit_handlers): class LinearRepnVisitor(StreamBasedExpressionVisitor): Result = LinearRepn exit_node_handlers = _exit_node_handlers - exit_node_dispatcher = _initialize_exit_node_dispatcher(_exit_node_handlers) + exit_node_dispatcher = ExitNodeDispatcher( + _initialize_exit_node_dispatcher(_exit_node_handlers) + ) expand_nonlinear_products = False max_exponential_expansion = 1 @@ -838,17 +750,19 @@ def __init__(self, subexpression_cache, var_map, var_order): self.var_map = var_map self.var_order = var_order self._eval_expr_visitor = _EvaluationVisitor(True) + self.evaluate = self._eval_expr_visitor.dfs_postorder_stack - def _eval_fixed(self, obj): - ans = obj.value + def handle_constant(self, ans, obj): if ans.__class__ not in native_numeric_types: # None can be returned from uninitialized Var/Param objects if ans is None: return InvalidNumber( - None, f"'{obj}' contains a nonnumeric value '{ans}'" + None, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) if ans.__class__ is InvalidNumber: return ans + elif ans.__class__ in native_complex_types: + return complex_number_error(ans, self, obj) else: # It is possible to get other non-numeric types. Most # common are bool and 1-element numpy.array(). We will @@ -862,43 +776,12 @@ def _eval_fixed(self, obj): ans = float(ans) except: return InvalidNumber( - ans, f"'{obj}' contains a nonnumeric value '{ans}'" + ans, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) if ans != ans: - return InvalidNumber(nan, f"'{obj}' contains a nonnumeric value '{ans}'") - if ans.__class__ in _complex_types: - return complex_number_error(ans, self, obj) - return ans - - def _eval_expr(self, expr): - ans = self._eval_expr_visitor.dfs_postorder_stack(expr) - if ans.__class__ not in native_numeric_types: - # None can be returned from uninitialized Expression objects - if ans is None: - return InvalidNumber( - ans, f"'{expr}' evaluated to nonnumeric value '{ans}'" - ) - if ans.__class__ is InvalidNumber: - return ans - else: - # It is possible to get other non-numeric types. Most - # common are bool and 1-element numpy.array(). We will - # attempt to convert the value to a float before - # proceeding. - # - # TODO: we should check bool and warn/error (while bool is - # convertible to float in Python, they have very - # different semantic meanings in Pyomo). - try: - ans = float(ans) - except: - return InvalidNumber( - ans, f"'{expr}' evaluated to nonnumeric value '{ans}'" - ) - if ans != ans: - return InvalidNumber(ans, f"'{expr}' evaluated to nonnumeric value '{ans}'") - if ans.__class__ in _complex_types: - return complex_number_error(ans, self, expr) + return InvalidNumber( + nan, f"'{obj}' evaluated to a nonnumeric value '{ans}'" + ) return ans def initializeWalker(self, expr): diff --git a/pyomo/repn/plugins/lp_writer.py b/pyomo/repn/plugins/lp_writer.py index fab94d313d5..23f5c82280a 100644 --- a/pyomo/repn/plugins/lp_writer.py +++ b/pyomo/repn/plugins/lp_writer.py @@ -427,8 +427,6 @@ def write(self, model): # Pull out the constant: we will move it to the bounds offset = repn.constant - if offset.__class__ not in int_float: - offset = float(offset) repn.constant = 0 if repn.linear or getattr(repn, 'quadratic', None): @@ -584,8 +582,6 @@ def write_expression(self, ostream, expr, is_objective): for vid, coef in sorted( expr.linear.items(), key=lambda x: getVarOrder(x[0]) ): - if coef.__class__ not in int_float: - coef = float(coef) if coef < 0: ostream.write(f'{coef!r} {getSymbol(getVar(vid))}\n') else: @@ -607,8 +603,6 @@ def _normalize_constraint(data): else: col = c1, c2 sym = f' {getSymbol(getVar(vid1))} * {getSymbol(getVar(vid2))}\n' - if coef.__class__ not in int_float: - coef = float(coef) if coef < 0: return col, repr(coef) + sym else: diff --git a/pyomo/repn/plugins/nl_writer.py b/pyomo/repn/plugins/nl_writer.py index 296ea350648..6a282bdeab4 100644 --- a/pyomo/repn/plugins/nl_writer.py +++ b/pyomo/repn/plugins/nl_writer.py @@ -11,7 +11,7 @@ import logging import os -from collections import deque, defaultdict +from collections import deque from operator import itemgetter, attrgetter, setitem from contextlib import nullcontext @@ -24,6 +24,12 @@ from pyomo.common.deprecation import deprecation_warning from pyomo.common.errors import DeveloperError from pyomo.common.gc_manager import PauseGC +from pyomo.common.numeric_types import ( + native_complex_types, + native_numeric_types, + native_types, + value, +) from pyomo.common.timing import TicTocTimer from pyomo.core.expr import ( @@ -41,9 +47,6 @@ RangedExpression, Expr_ifExpression, ExternalFunctionExpression, - native_types, - native_numeric_types, - value, ) from pyomo.core.expr.visitor import StreamBasedExpressionVisitor, _EvaluationVisitor from pyomo.core.base import ( @@ -69,6 +72,8 @@ from pyomo.opt import WriterFactory from pyomo.repn.util import ( + BeforeChildDispatcher, + ExitNodeDispatcher, ExprType, FileDeterminism, FileDeterminism_to_SortComponents, @@ -574,8 +579,6 @@ def write(self, model): # Note: Constraint.lb/ub guarantee a return value that is # either a (finite) native_numeric_type, or None const = expr.const - if const.__class__ not in int_float: - const = float(const) lb = con.lb if lb is not None: lb = repr(lb - const) @@ -589,7 +592,6 @@ def write(self, model): n_ranges += 1 elif _type == 3: # and self.config.skip_trivial_constraints: continue - pass # FIXME: this is a HACK to be compatible with the NLv1 # writer. In the future, this writer should be expanded to # look for and process Complementarity components (assuming @@ -611,7 +613,8 @@ def write(self, model): linear_cons.append((con, expr, _type, lb, ub)) elif not self.config.skip_trivial_constraints: linear_cons.append((con, expr, _type, lb, ub)) - else: # constant constraint and skip_trivial_constraints + else: + # constant constraint and skip_trivial_constraints # # TODO: skip_trivial_constraints should be an # enum that also accepts "Exception" so that @@ -1321,15 +1324,12 @@ def write(self, model): for row_idx, info in enumerate(constraints): linear = info[1].linear # ASL will fail on "J 0", so if there are no coefficients - # (i.e., a constant objective), then skip this entry + # (e.g., a nonlinear-only constraint), then skip this entry if not linear: continue ostream.write(f'J{row_idx} {len(linear)}{row_comments[row_idx]}\n') for _id in sorted(linear.keys(), key=column_order.__getitem__): - val = linear[_id] - if val.__class__ not in int_float: - val = float(val) - ostream.write(f'{column_order[_id]} {val!r}\n') + ostream.write(f'{column_order[_id]} {linear[_id]!r}\n') # # "G" lines (non-empty terms in the Objective) @@ -1337,15 +1337,12 @@ def write(self, model): for obj_idx, info in enumerate(objectives): linear = info[1].linear # ASL will fail on "G 0", so if there are no coefficients - # (i.e., a constant objective), then skip this entry + # (e.g., a constant objective), then skip this entry if not linear: continue ostream.write(f'G{obj_idx} {len(linear)}{row_comments[obj_idx + n_cons]}\n') for _id in sorted(linear.keys(), key=column_order.__getitem__): - val = linear[_id] - if val.__class__ not in int_float: - val = float(val) - ostream.write(f'{column_order[_id]} {val!r}\n') + ostream.write(f'{column_order[_id]} {linear[_id]!r}\n') # Generate the return information info = NLWriterInfo( @@ -1497,33 +1494,18 @@ def _write_nl_expression(self, repn, include_const): # compiled before this point). Omitting the assertion for # efficiency. # assert repn.mult == 1 + # + # Note that repn.const should always be a int/float (it has + # already been compiled) if repn.nonlinear: nl, args = repn.nonlinear if include_const and repn.const: # Add the constant to the NL expression. AMPL adds the # constant as the second argument, so we will too. - nl = ( - self.template.binary_sum - + nl - + ( - self.template.const - % ( - repn.const - if repn.const.__class__ in int_float - else float(repn.const) - ) - ) - ) + nl = self.template.binary_sum + nl + self.template.const % repn.const self.ostream.write(nl % tuple(map(self.var_id_to_nl.__getitem__, args))) elif include_const: - self.ostream.write( - self.template.const - % ( - repn.const - if repn.const.__class__ in int_float - else float(repn.const) - ) - ) + self.ostream.write(self.template.const % repn.const) else: self.ostream.write(self.template.const % 0) @@ -1543,10 +1525,7 @@ def _write_v_line(self, expr_id, k): # ostream.write(f'V{self.next_V_line_id} {len(linear)} {k}{lbl}\n') for _id in sorted(linear, key=column_order.__getitem__): - val = linear[_id] - if val.__class__ not in int_float: - val = float(val) - ostream.write(f'{column_order[_id]} {val!r}\n') + ostream.write(f'{column_order[_id]} {linear[_id]!r}\n') self._write_nl_expression(info[1], True) self.next_V_line_id += 1 @@ -1671,9 +1650,7 @@ def compile_repn(self, visitor, prefix='', args=None, named_exprs=None): args.extend(self.nonlinear[1]) if self.const: nterms += 1 - nl_sum += template.const % ( - self.const if self.const.__class__ in int_float else float(self.const) - ) + nl_sum += template.const % self.const if nterms > 2: return (prefix + (template.nary_sum % nterms) + nl_sum, args, named_exprs) @@ -1983,7 +1960,7 @@ def handle_pow_node(visitor, node, arg1, arg2): if arg2[0] is _CONSTANT: if arg1[0] is _CONSTANT: ans = apply_node_operation(node, (arg1[1], arg2[1])) - if ans.__class__ in _complex_types: + if ans.__class__ in native_complex_types: ans = complex_number_error(ans, visitor, node) return _CONSTANT, ans elif not arg2[1]: @@ -2230,232 +2207,165 @@ def handle_external_function_node(visitor, node, *args): return (_GENERAL, AMPLRepn(0, None, nonlin)) -_operator_handles = { - NegationExpression: handle_negation_node, - ProductExpression: handle_product_node, - DivisionExpression: handle_division_node, - PowExpression: handle_pow_node, - AbsExpression: handle_abs_node, - UnaryFunctionExpression: handle_unary_node, - Expr_ifExpression: handle_exprif_node, - EqualityExpression: handle_equality_node, - InequalityExpression: handle_inequality_node, - RangedExpression: handle_ranged_inequality_node, - _GeneralExpressionData: handle_named_expression_node, - ScalarExpression: handle_named_expression_node, - kernel.expression.expression: handle_named_expression_node, - kernel.expression.noclone: handle_named_expression_node, - # Note: objectives are special named expressions - _GeneralObjectiveData: handle_named_expression_node, - ScalarObjective: handle_named_expression_node, - kernel.objective.objective: handle_named_expression_node, - ExternalFunctionExpression: handle_external_function_node, - # These are handled explicitly in beforeChild(): - # LinearExpression: handle_linear_expression, - # SumExpression: handle_sum_expression, - # - # Note: MonomialTermExpression is only hit when processing NPV - # subexpressions that raise errors (e.g., log(0) * m.x), so no - # special processing is needed [it is just a product expression] - MonomialTermExpression: handle_product_node, -} - - -def _before_native(visitor, child): - return False, (_CONSTANT, child) - - -def _before_complex(visitor, child): - return False, (_CONSTANT, complex_number_error(child, visitor, child)) - - -def _before_string(visitor, child): - visitor.encountered_string_arguments = True - ans = AMPLRepn(child, None, None) - ans.nl = (visitor.template.string % (len(child), child), ()) - return False, (_GENERAL, ans) - - -def _before_var(visitor, child): - _id = id(child) - if _id not in visitor.var_map: - if child.fixed: - return False, (_CONSTANT, visitor._eval_fixed(child)) - visitor.var_map[_id] = child - return False, (_MONOMIAL, _id, 1) - - -def _before_param(visitor, child): - return False, (_CONSTANT, visitor._eval_fixed(child)) - +_operator_handles = ExitNodeDispatcher( + { + NegationExpression: handle_negation_node, + ProductExpression: handle_product_node, + DivisionExpression: handle_division_node, + PowExpression: handle_pow_node, + AbsExpression: handle_abs_node, + UnaryFunctionExpression: handle_unary_node, + Expr_ifExpression: handle_exprif_node, + EqualityExpression: handle_equality_node, + InequalityExpression: handle_inequality_node, + RangedExpression: handle_ranged_inequality_node, + Expression: handle_named_expression_node, + ExternalFunctionExpression: handle_external_function_node, + # These are handled explicitly in beforeChild(): + # LinearExpression: handle_linear_expression, + # SumExpression: handle_sum_expression, + # + # Note: MonomialTermExpression is only hit when processing NPV + # subexpressions that raise errors (e.g., log(0) * m.x), so no + # special processing is needed [it is just a product expression] + MonomialTermExpression: handle_product_node, + } +) -def _before_npv(visitor, child): - try: - return False, (_CONSTANT, visitor._eval_expr(child)) - except (ValueError, ArithmeticError): - return True, None +class AMPLBeforeChildDispatcher(BeforeChildDispatcher): + __slots__ = () -def _before_monomial(visitor, child): - # - # The following are performance optimizations for common - # situations (Monomial terms and Linear expressions) - # - arg1, arg2 = child._args_ - if arg1.__class__ not in native_types: - try: - arg1 = visitor._eval_expr(arg1) - except (ValueError, ArithmeticError): - return True, None + def __init__(self): + # Special linear / summation expressions + self[MonomialTermExpression] = self._before_monomial + self[LinearExpression] = self._before_linear + self[SumExpression] = self._before_general_expression + + @staticmethod + def _before_string(visitor, child): + visitor.encountered_string_arguments = True + ans = AMPLRepn(child, None, None) + ans.nl = (visitor.template.string % (len(child), child), ()) + return False, (_GENERAL, ans) + + @staticmethod + def _before_var(visitor, child): + _id = id(child) + if _id not in visitor.var_map: + if child.fixed: + return False, (_CONSTANT, visitor.handle_constant(child.value, child)) + visitor.var_map[_id] = child + return False, (_MONOMIAL, _id, 1) + + @staticmethod + def _before_monomial(visitor, child): + # + # The following are performance optimizations for common + # situations (Monomial terms and Linear expressions) + # + arg1, arg2 = child._args_ + if arg1.__class__ not in native_types: + try: + arg1 = visitor.handle_constant(visitor.evaluate(arg1), arg1) + except (ValueError, ArithmeticError): + return True, None - # Trap multiplication by 0 and nan. - if not arg1: - if arg2.fixed: - arg2 = visitor._eval_fixed(arg2) - if arg2 != arg2: - deprecation_warning( - f"Encountered {arg1}*{arg2} in expression tree. " - "Mapping the NaN result to 0 for compatibility " - "with the nl_v1 writer. In the future, this NaN " - "will be preserved/emitted to comply with IEEE-754.", - version='6.4.3', + # Trap multiplication by 0 and nan. + if not arg1: + if arg2.fixed: + arg2 = visitor.handle_constant(arg2.value, arg2) + if arg2 != arg2: + deprecation_warning( + f"Encountered {arg1}*{arg2} in expression tree. " + "Mapping the NaN result to 0 for compatibility " + "with the nl_v1 writer. In the future, this NaN " + "will be preserved/emitted to comply with IEEE-754.", + version='6.4.3', + ) + return False, (_CONSTANT, arg1) + + _id = id(arg2) + if _id not in visitor.var_map: + if arg2.fixed: + return False, ( + _CONSTANT, + arg1 * visitor.handle_constant(arg2.value, arg2), ) - return False, (_CONSTANT, arg1) - - _id = id(arg2) - if _id not in visitor.var_map: - if arg2.fixed: - return False, (_CONSTANT, arg1 * visitor._eval_fixed(arg2)) - visitor.var_map[_id] = arg2 - return False, (_MONOMIAL, _id, arg1) - + visitor.var_map[_id] = arg2 + return False, (_MONOMIAL, _id, arg1) + + @staticmethod + def _before_linear(visitor, child): + # Because we are going to modify the LinearExpression in this + # walker, we need to make a copy of the arg list from the original + # expression tree. + var_map = visitor.var_map + const = 0 + linear = {} + for arg in child.args: + if arg.__class__ is MonomialTermExpression: + arg1, arg2 = arg._args_ + if arg1.__class__ not in native_types: + try: + arg1 = visitor.handle_constant(visitor.evaluate(arg1), arg1) + except (ValueError, ArithmeticError): + return True, None + + # Trap multiplication by 0 and nan. + if not arg1: + if arg2.fixed: + arg2 = visitor.handle_constant(arg2.value, arg2) + if arg2 != arg2: + deprecation_warning( + f"Encountered {arg1}*{str(arg2.value)} in expression " + "tree. Mapping the NaN result to 0 for compatibility " + "with the nl_v1 writer. In the future, this NaN " + "will be preserved/emitted to comply with IEEE-754.", + version='6.4.3', + ) + continue -def _before_linear(visitor, child): - # Because we are going to modify the LinearExpression in this - # walker, we need to make a copy of the arg list from the original - # expression tree. - var_map = visitor.var_map - const = 0 - linear = {} - for arg in child.args: - if arg.__class__ is MonomialTermExpression: - arg1, arg2 = arg._args_ - if arg1.__class__ not in native_types: + _id = id(arg2) + if _id not in var_map: + if arg2.fixed: + const += arg1 * visitor.handle_constant(arg2.value, arg2) + continue + var_map[_id] = arg2 + linear[_id] = arg1 + elif _id in linear: + linear[_id] += arg1 + else: + linear[_id] = arg1 + elif arg.__class__ in native_types: + const += arg + else: try: - arg1 = visitor._eval_expr(arg1) + const += visitor.handle_constant(visitor.evaluate(arg), arg) except (ValueError, ArithmeticError): return True, None - # Trap multiplication by 0 and nan. - if not arg1: - if arg2.fixed: - arg2 = visitor._eval_fixed(arg2) - if arg2 != arg2: - deprecation_warning( - f"Encountered {arg1}*{str(arg2.value)} in expression " - "tree. Mapping the NaN result to 0 for compatibility " - "with the nl_v1 writer. In the future, this NaN " - "will be preserved/emitted to comply with IEEE-754.", - version='6.4.3', - ) - continue - - _id = id(arg2) - if _id not in var_map: - if arg2.fixed: - const += arg1 * visitor._eval_fixed(arg2) - continue - var_map[_id] = arg2 - linear[_id] = arg1 - elif _id in linear: - linear[_id] += arg1 - else: - linear[_id] = arg1 - elif arg.__class__ in native_types: - const += arg + if linear: + return False, (_GENERAL, AMPLRepn(const, linear, None)) else: - try: - const += visitor._eval_expr(arg) - except (ValueError, ArithmeticError): - return True, None - - if linear: - return False, (_GENERAL, AMPLRepn(const, linear, None)) - else: - return False, (_CONSTANT, const) + return False, (_CONSTANT, const) - -def _before_named_expression(visitor, child): - _id = id(child) - if _id in visitor.subexpression_cache: - obj, repn, info = visitor.subexpression_cache[_id] - if info[2]: - if repn.linear: - return False, (_MONOMIAL, next(iter(repn.linear)), 1) - else: - return False, (_CONSTANT, repn.const) - return False, (_GENERAL, repn.duplicate()) - else: - return True, None - - -def _before_general_expression(visitor, child): - return True, None - - -def _register_new_before_child_handler(visitor, child): - handlers = _before_child_handlers - child_type = child.__class__ - if child_type in native_numeric_types: - if isinstance(child_type, complex): - _complex_types.add(child_type) - handlers[child_type] = _before_complex - else: - handlers[child_type] = _before_native - elif issubclass(child_type, str): - handlers[child_type] = _before_string - elif child_type in native_types: - handlers[child_type] = _before_native - elif not child.is_expression_type(): - if child.is_potentially_variable(): - handlers[child_type] = _before_var + @staticmethod + def _before_named_expression(visitor, child): + _id = id(child) + if _id in visitor.subexpression_cache: + obj, repn, info = visitor.subexpression_cache[_id] + if info[2]: + if repn.linear: + return False, (_MONOMIAL, next(iter(repn.linear)), 1) + else: + return False, (_CONSTANT, repn.const) + return False, (_GENERAL, repn.duplicate()) else: - handlers[child_type] = _before_param - elif not child.is_potentially_variable(): - handlers[child_type] = _before_npv - # If we descend into the named expression (because of an - # evaluation error), then on the way back out, we will use - # the potentially variable handler to process the result. - pv_base_type = child.potentially_variable_base_class() - if pv_base_type not in handlers: - try: - child.__class__ = pv_base_type - _register_new_before_child_handler(visitor, child) - finally: - child.__class__ = child_type - if pv_base_type in _operator_handles: - _operator_handles[child_type] = _operator_handles[pv_base_type] - elif id(child) in visitor.subexpression_cache or issubclass( - child_type, _GeneralExpressionData - ): - handlers[child_type] = _before_named_expression - _operator_handles[child_type] = handle_named_expression_node - else: - handlers[child_type] = _before_general_expression - return handlers[child_type](visitor, child) - + return True, None -_before_child_handlers = defaultdict(lambda: _register_new_before_child_handler) -_complex_types = set((complex,)) -_before_child_handlers[complex] = _before_complex -for _type in native_types: - if issubclass(_type, str): - _before_child_handlers[_type] = _before_string -# Special linear / summation expressions -_before_child_handlers[MonomialTermExpression] = _before_monomial -_before_child_handlers[LinearExpression] = _before_linear -_before_child_handlers[SumExpression] = _before_general_expression +_before_child_handlers = AMPLBeforeChildDispatcher() class AMPLRepnVisitor(StreamBasedExpressionVisitor): @@ -2482,48 +2392,19 @@ def __init__( self.use_named_exprs = use_named_exprs self.encountered_string_arguments = False self._eval_expr_visitor = _EvaluationVisitor(True) + self.evaluate = self._eval_expr_visitor.dfs_postorder_stack - def _eval_fixed(self, obj): - ans = obj.value + def handle_constant(self, ans, obj): if ans.__class__ not in native_numeric_types: # None can be returned from uninitialized Var/Param objects if ans is None: return InvalidNumber( - None, f"'{obj}' contains a nonnumeric value '{ans}'" - ) - if ans.__class__ is InvalidNumber: - return ans - else: - # It is possible to get other non-numeric types. Most - # common are bool and 1-element numpy.array(). We will - # attempt to convert the value to a float before - # proceeding. - # - # TODO: we should check bool and warn/error (while bool is - # convertible to float in Python, they have very - # different semantic meanings in Pyomo). - try: - ans = float(ans) - except: - return InvalidNumber( - ans, f"'{obj}' contains a nonnumeric value '{ans}'" - ) - if ans != ans: - return InvalidNumber(nan, f"'{obj}' contains a nonnumeric value '{ans}'") - if ans.__class__ in _complex_types: - return complex_number_error(ans, self, obj) - return ans - - def _eval_expr(self, expr): - ans = self._eval_expr_visitor.dfs_postorder_stack(expr) - if ans.__class__ not in native_numeric_types: - # None can be returned from uninitialized Expression objects - if ans is None: - return InvalidNumber( - ans, f"'{expr}' evaluated to nonnumeric value '{ans}'" + None, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) if ans.__class__ is InvalidNumber: return ans + elif ans.__class__ in native_complex_types: + return complex_number_error(ans, self, obj) else: # It is possible to get other non-numeric types. Most # common are bool and 1-element numpy.array(). We will @@ -2537,12 +2418,12 @@ def _eval_expr(self, expr): ans = float(ans) except: return InvalidNumber( - ans, f"'{expr}' evaluated to nonnumeric value '{ans}'" + ans, f"'{obj}' evaluated to a nonnumeric value '{ans}'" ) if ans != ans: - return InvalidNumber(ans, f"'{expr}' evaluated to nonnumeric value '{ans}'") - if ans.__class__ in _complex_types: - return complex_number_error(ans, self, expr) + return InvalidNumber( + nan, f"'{obj}' evaluated to a nonnumeric value '{ans}'" + ) return ans def initializeWalker(self, expr): @@ -2612,7 +2493,6 @@ def finalizeResult(self, result): # variables are not accidentally re-characterized as # nonlinear. pass - # ans.nonlinear = orig.nonlinear ans.nl = None if ans.nonlinear.__class__ is list: @@ -2620,8 +2500,8 @@ def finalizeResult(self, result): if not ans.linear: ans.linear = {} - linear = ans.linear if ans.mult != 1: + linear = ans.linear mult, ans.mult = ans.mult, 1 ans.const *= mult if linear: diff --git a/pyomo/repn/quadratic.py b/pyomo/repn/quadratic.py index fbe3860078a..2d11261de5d 100644 --- a/pyomo/repn/quadratic.py +++ b/pyomo/repn/quadratic.py @@ -28,7 +28,7 @@ InequalityExpression, RangedExpression, ) -from pyomo.core.base.expression import ScalarExpression +from pyomo.core.base.expression import Expression from . import linear from .linear import _merge_dict, to_expression @@ -341,7 +341,7 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2): # # NAMED EXPRESSION handlers # -_exit_node_handlers[ScalarExpression][(_QUADRATIC,)] = linear._handle_named_ANY +_exit_node_handlers[Expression][(_QUADRATIC,)] = linear._handle_named_ANY # # EXPR_IF handlers @@ -401,5 +401,7 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2): class QuadraticRepnVisitor(linear.LinearRepnVisitor): Result = QuadraticRepn exit_node_handlers = _exit_node_handlers - exit_node_dispatcher = linear._initialize_exit_node_dispatcher(_exit_node_handlers) + exit_node_dispatcher = linear.ExitNodeDispatcher( + linear._initialize_exit_node_dispatcher(_exit_node_handlers) + ) max_exponential_expansion = 2 diff --git a/pyomo/repn/tests/test_linear.py b/pyomo/repn/tests/test_linear.py index 1501ecfcc9d..faf12a7da09 100644 --- a/pyomo/repn/tests/test_linear.py +++ b/pyomo/repn/tests/test_linear.py @@ -1492,53 +1492,48 @@ def test_type_registrations(self): visitor = LinearRepnVisitor(*cfg) _orig_dispatcher = linear._before_child_dispatcher - linear._before_child_dispatcher = bcd = {} + linear._before_child_dispatcher = bcd = _orig_dispatcher.__class__() + bcd.clear() try: # native type self.assertEqual( - linear._register_new_before_child_dispatcher(visitor, 5), - (False, (linear._CONSTANT, 5)), + bcd.register_dispatcher(visitor, 5), (False, (linear._CONSTANT, 5)) ) self.assertEqual(len(bcd), 1) - self.assertIs(bcd[int], linear._before_native) + self.assertIs(bcd[int], bcd._before_native) # complex type self.assertEqual( - linear._register_new_before_child_dispatcher(visitor, 5j), - (False, (linear._CONSTANT, 5j)), + bcd.register_dispatcher(visitor, 5j), (False, (linear._CONSTANT, 5j)) ) self.assertEqual(len(bcd), 2) - self.assertIs(bcd[complex], linear._before_complex) + self.assertIs(bcd[complex], bcd._before_complex) # ScalarParam m.p = Param(initialize=5) self.assertEqual( - linear._register_new_before_child_dispatcher(visitor, m.p), - (False, (linear._CONSTANT, 5)), + bcd.register_dispatcher(visitor, m.p), (False, (linear._CONSTANT, 5)) ) self.assertEqual(len(bcd), 3) - self.assertIs(bcd[m.p.__class__], linear._before_param) + self.assertIs(bcd[m.p.__class__], bcd._before_param) # ParamData m.q = Param([0], initialize=6, mutable=True) self.assertEqual( - linear._register_new_before_child_dispatcher(visitor, m.q[0]), - (False, (linear._CONSTANT, 6)), + bcd.register_dispatcher(visitor, m.q[0]), (False, (linear._CONSTANT, 6)) ) self.assertEqual(len(bcd), 4) - self.assertIs(bcd[m.q[0].__class__], linear._before_param) + self.assertIs(bcd[m.q[0].__class__], bcd._before_param) # NPV_SumExpression self.assertEqual( - linear._register_new_before_child_dispatcher(visitor, m.p + m.q[0]), + bcd.register_dispatcher(visitor, m.p + m.q[0]), (False, (linear._CONSTANT, 11)), ) self.assertEqual(len(bcd), 6) - self.assertIs(bcd[NPV_SumExpression], linear._before_npv) - self.assertIs(bcd[LinearExpression], linear._before_general_expression) + self.assertIs(bcd[NPV_SumExpression], bcd._before_npv) + self.assertIs(bcd[LinearExpression], bcd._before_general_expression) # Named expression m.e = Expression(expr=m.p + m.q[0]) - self.assertEqual( - linear._register_new_before_child_dispatcher(visitor, m.e), (True, None) - ) + self.assertEqual(bcd.register_dispatcher(visitor, m.e), (True, None)) self.assertEqual(len(bcd), 7) - self.assertIs(bcd[m.e.__class__], linear._before_named_expression) + self.assertIs(bcd[m.e.__class__], bcd._before_named_expression) finally: linear._before_child_dispatcher = _orig_dispatcher diff --git a/pyomo/repn/tests/test_util.py b/pyomo/repn/tests/test_util.py index 58cbbe049cf..58ee09a1006 100644 --- a/pyomo/repn/tests/test_util.py +++ b/pyomo/repn/tests/test_util.py @@ -18,6 +18,13 @@ from pyomo.common.collections import ComponentMap from pyomo.common.errors import DeveloperError, InvalidValueError from pyomo.common.log import LoggingIntercept +from pyomo.core.expr import ( + ProductExpression, + NPV_ProductExpression, + SumExpression, + DivisionExpression, + NPV_DivisionExpression, +) from pyomo.environ import ( ConcreteModel, Block, @@ -32,6 +39,9 @@ ) import pyomo.repn.util from pyomo.repn.util import ( + _CONSTANT, + BeforeChildDispatcher, + ExitNodeDispatcher, FileDeterminism, FileDeterminism_to_SortComponents, InvalidNumber, @@ -637,6 +647,179 @@ class MockConfig(object): # verify no side effects self.assertEqual(MockConfig.row_order, ref) + def test_ExitNodeDispatcher_registration(self): + end = ExitNodeDispatcher( + { + ProductExpression: lambda v, n, d1, d2: d1 * d2, + Expression: lambda v, n, d: d, + } + ) + self.assertEqual(len(end), 2) + + node = ProductExpression((3, 4)) + self.assertEqual(end[node.__class__](None, node, *node.args), 12) + self.assertEqual(len(end), 2) + + node = Expression(initialize=5) + node.construct() + self.assertEqual(end[node.__class__](None, node, *node.args), 5) + self.assertEqual(len(end), 3) + self.assertIn(node.__class__, end) + + node = NPV_ProductExpression((6, 7)) + self.assertEqual(end[node.__class__](None, node, *node.args), 42) + self.assertEqual(len(end), 4) + self.assertIn(NPV_ProductExpression, end) + + class NewProductExpression(ProductExpression): + pass + + node = NewProductExpression((6, 7)) + with self.assertRaisesRegex( + DeveloperError, r".*Unexpected expression node type 'NewProductExpression'" + ): + end[node.__class__](None, node, *node.args) + self.assertEqual(len(end), 4) + + end[SumExpression, 2] = lambda v, n, *d: 2 * sum(d) + self.assertEqual(len(end), 5) + + node = SumExpression((1, 2, 3)) + self.assertEqual(end[node.__class__, 2](None, node, *node.args), 12) + self.assertEqual(len(end), 5) + + with self.assertRaisesRegex( + DeveloperError, + r"(?s)Base expression key '\(, 3\)' not found when.*" + r"inserting dispatcher for node 'SumExpression' while walking.*" + r"expression tree.", + ): + end[node.__class__, 3](None, node, *node.args) + self.assertEqual(len(end), 5) + + end[SumExpression] = lambda v, n, *d: sum(d) + self.assertEqual(len(end), 6) + self.assertIn(SumExpression, end) + + self.assertEqual(end[node.__class__, 1](None, node, *node.args), 6) + self.assertEqual(len(end), 7) + self.assertIn((SumExpression, 1), end) + + self.assertEqual(end[node.__class__, 3, 4, 5, 6](None, node, *node.args), 6) + self.assertEqual(len(end), 7) + self.assertNotIn((SumExpression, 3, 4, 5, 6), end) + + def test_BeforeChildDispatcher_registration(self): + class BeforeChildDispatcherTester(BeforeChildDispatcher): + @staticmethod + def _before_var(visitor, child): + return child + + @staticmethod + def _before_named_expression(visitor, child): + return child + + class VisitorTester(object): + def handle_constant(self, value, node): + return value + + def evaluate(self, node): + return node() + + visitor = VisitorTester() + + bcd = BeforeChildDispatcherTester() + self.assertEqual(len(bcd), 0) + + node = 5 + self.assertEqual(bcd[node.__class__](None, node), (False, (_CONSTANT, 5))) + self.assertIs(bcd[int], bcd._before_native) + self.assertEqual(len(bcd), 1) + + node = 'string' + ans = bcd[node.__class__](None, node) + self.assertEqual(ans, (False, (_CONSTANT, InvalidNumber(node)))) + self.assertEqual( + ''.join(ans[1][1].causes), + "'string' () is not a valid numeric type", + ) + self.assertIs(bcd[str], bcd._before_string) + self.assertEqual(len(bcd), 2) + + node = True + ans = bcd[node.__class__](None, node) + self.assertEqual(ans, (False, (_CONSTANT, InvalidNumber(node)))) + self.assertEqual( + ''.join(ans[1][1].causes), + "True () is not a valid numeric type", + ) + self.assertIs(bcd[bool], bcd._before_invalid) + self.assertEqual(len(bcd), 3) + + node = 1j + ans = bcd[node.__class__](None, node) + self.assertEqual(ans, (False, (_CONSTANT, InvalidNumber(node)))) + self.assertEqual( + ''.join(ans[1][1].causes), "Complex number returned from expression" + ) + self.assertIs(bcd[complex], bcd._before_complex) + self.assertEqual(len(bcd), 4) + + class new_int(int): + pass + + node = new_int(5) + self.assertEqual(bcd[node.__class__](None, node), (False, (_CONSTANT, 5))) + self.assertIs(bcd[new_int], bcd._before_native) + self.assertEqual(len(bcd), 5) + + node = [] + ans = bcd[node.__class__](None, node) + self.assertEqual(ans, (False, (_CONSTANT, InvalidNumber([])))) + self.assertEqual( + ''.join(ans[1][1].causes), "[] () is not a valid numeric type" + ) + self.assertIs(bcd[list], bcd._before_invalid) + self.assertEqual(len(bcd), 6) + + node = Var(initialize=7) + node.construct() + self.assertIs(bcd[node.__class__](None, node), node) + self.assertIs(bcd[node.__class__], bcd._before_var) + self.assertEqual(len(bcd), 7) + + node = Param(initialize=8) + node.construct() + self.assertEqual(bcd[node.__class__](visitor, node), (False, (_CONSTANT, 8))) + self.assertIs(bcd[node.__class__], bcd._before_param) + self.assertEqual(len(bcd), 8) + + node = Expression(initialize=9) + node.construct() + self.assertIs(bcd[node.__class__](None, node), node) + self.assertIs(bcd[node.__class__], bcd._before_named_expression) + self.assertEqual(len(bcd), 9) + + node = SumExpression((3, 5)) + self.assertEqual(bcd[node.__class__](None, node), (True, None)) + self.assertIs(bcd[node.__class__], bcd._before_general_expression) + self.assertEqual(len(bcd), 10) + + node = NPV_ProductExpression((3, 5)) + self.assertEqual(bcd[node.__class__](visitor, node), (False, (_CONSTANT, 15))) + self.assertEqual(len(bcd), 12) + self.assertIs(bcd[NPV_ProductExpression], bcd._before_npv) + self.assertIs(bcd[ProductExpression], bcd._before_general_expression) + self.assertEqual(len(bcd), 12) + + node = NPV_DivisionExpression((3, 0)) + self.assertEqual(bcd[node.__class__](visitor, node), (True, None)) + self.assertEqual(len(bcd), 14) + self.assertIs(bcd[NPV_DivisionExpression], bcd._before_npv) + self.assertIs(bcd[DivisionExpression], bcd._before_general_expression) + self.assertEqual(len(bcd), 14) + if __name__ == "__main__": unittest.main() diff --git a/pyomo/repn/util.py b/pyomo/repn/util.py index e60adbc0b33..8c850987253 100644 --- a/pyomo/repn/util.py +++ b/pyomo/repn/util.py @@ -9,7 +9,9 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +import collections import enum +import functools import itertools import logging import operator @@ -18,6 +20,12 @@ from pyomo.common.collections import Sequence, ComponentMap from pyomo.common.deprecation import deprecation_warning from pyomo.common.errors import DeveloperError, InvalidValueError +from pyomo.common.numeric_types import ( + check_if_numeric_type, + native_types, + native_numeric_types, + native_complex_types, +) from pyomo.core.pyomoobject import PyomoObject from pyomo.core.base import ( Var, @@ -26,11 +34,13 @@ Objective, Block, Constraint, + Expression, Suffix, SortComponents, ) from pyomo.core.base.component import ActiveComponent -from pyomo.core.expr.numvalue import native_numeric_types, is_fixed, value +from pyomo.core.base.expression import _ExpressionData +from pyomo.core.expr.numvalue import is_fixed, value import pyomo.core.expr as EXPR import pyomo.core.kernel as kernel @@ -43,6 +53,11 @@ EXPR.LinearExpression, EXPR.NPV_SumExpression, } +_named_subexpression_types = ( + _ExpressionData, + kernel.expression.expression, + kernel.objective.objective, +) HALT_ON_EVALUATION_ERROR = False nan = float('nan') @@ -221,6 +236,195 @@ def __rpow__(self, other): return self._op(operator.pow, other, self) +_CONSTANT = ExprType.CONSTANT + + +class BeforeChildDispatcher(collections.defaultdict): + """Dispatcher for handling the :py:class:`StreamBasedExpressionVisitor` + `beforeChild` callback + + This dispatcher implements a specialization of :py:`defaultdict` + that supports automatic type registration. Any missing types will + return the :py:meth:`register_dispatcher` method, which (when called + as a callback) will interrogate the type, identify the appropriate + callback, add the callback to the dict, and return the result of + calling the callback. As the callback is added to the dict, no type + will incur the overhead of `register_dispatcher` more than once. + + Note that all dispatchers are implemented as `staticmethod` + functions to avoid the (unnecessary) overhead of binding to the + dispatcher object. + + """ + + __slots__ = () + + def __missing__(self, key): + return self.register_dispatcher + + def register_dispatcher(self, visitor, child): + child_type = type(child) + if child_type in native_numeric_types: + self[child_type] = self._before_native + elif issubclass(child_type, str): + self[child_type] = self._before_string + elif child_type in native_types: + if issubclass(child_type, tuple(native_complex_types)): + self[child_type] = self._before_complex + else: + self[child_type] = self._before_invalid + elif not hasattr(child, 'is_expression_type'): + if check_if_numeric_type(child): + self[child_type] = self._before_native + else: + self[child_type] = self._before_invalid + elif not child.is_expression_type(): + if child.is_potentially_variable(): + self[child_type] = self._before_var + else: + self[child_type] = self._before_param + elif not child.is_potentially_variable(): + self[child_type] = self._before_npv + pv_base_type = child.potentially_variable_base_class() + if pv_base_type not in self: + try: + child.__class__ = pv_base_type + self.register_dispatcher(visitor, child) + finally: + child.__class__ = child_type + elif ( + issubclass(child_type, _named_subexpression_types) + or child_type is kernel.expression.noclone + ): + self[child_type] = self._before_named_expression + else: + self[child_type] = self._before_general_expression + return self[child_type](visitor, child) + + @staticmethod + def _before_general_expression(visitor, child): + return True, None + + @staticmethod + def _before_native(visitor, child): + return False, (_CONSTANT, child) + + @staticmethod + def _before_complex(visitor, child): + return False, (_CONSTANT, complex_number_error(child, visitor, child)) + + @staticmethod + def _before_invalid(visitor, child): + return False, ( + _CONSTANT, + InvalidNumber( + child, f"{child!r} ({type(child)}) is not a valid numeric type" + ), + ) + + @staticmethod + def _before_string(visitor, child): + return False, ( + _CONSTANT, + InvalidNumber( + child, f"{child!r} ({type(child)}) is not a valid numeric type" + ), + ) + + @staticmethod + def _before_npv(visitor, child): + try: + return False, ( + _CONSTANT, + visitor.handle_constant(visitor.evaluate(child), child), + ) + except (ValueError, ArithmeticError): + return True, None + + @staticmethod + def _before_param(visitor, child): + return False, (_CONSTANT, visitor.handle_constant(child.value, child)) + + # + # The following methods must be defined by derivative classes (along + # with any other special-case handling they want to implement; + # usually including handling for Monomial, Linear, and + # ExternalFunction + # + + # @staticmethod + # def _before_var(visitor, child): + # pass + + # @staticmethod + # def _before_named_expression(visitor, child): + # pass + + +class ExitNodeDispatcher(collections.defaultdict): + """Dispatcher for handling the :py:class:`StreamBasedExpressionVisitor` + `exitNode` callback + + This dispatcher implements a specialization of :py:`defaultdict` + that supports automatic type registration. Any missing types will + return the :py:meth:`register_dispatcher` method, which (when called + as a callback) will interrogate the type, identify the appropriate + callback, add the callback to the dict, and return the result of + calling the callback. As the callback is added to the dict, no type + will incur the overhead of `register_dispatcher` more than once. + + Note that in this case, the client is expected to register all + non-NPV expression types. The auto-registration is designed to only + handle two cases: + - Auto-detection of user-defined Named Expression types + - Automatic mappimg of NPV expressions to their equivalent non-NPV handlers + + """ + + __slots__ = () + + def __init__(self, *args, **kwargs): + super().__init__(None, *args, **kwargs) + + def __missing__(self, key): + return functools.partial(self.register_dispatcher, key=key) + + def register_dispatcher(self, visitor, node, *data, key=None): + if ( + isinstance(node, _named_subexpression_types) + or type(node) is kernel.expression.noclone + ): + base_type = Expression + elif not node.is_potentially_variable(): + base_type = node.potentially_variable_base_class() + else: + base_type = node.__class__ + if isinstance(key, tuple): + base_key = (base_type,) + key[1:] + # Only cache handlers for unary, binary and ternary operators + cache = len(key) <= 4 + else: + base_key = base_type + cache = True + if base_key in self: + fcn = self[base_key] + elif base_type in self: + fcn = self[base_type] + elif any((k[0] if k.__class__ is tuple else k) is base_type for k in self): + raise DeveloperError( + f"Base expression key '{base_key}' not found when inserting dispatcher" + f" for node '{type(node).__name__}' while walking expression tree." + ) + else: + raise DeveloperError( + f"Unexpected expression node type '{type(node).__name__}' " + "found while walking expression tree." + ) + if cache: + self[key] = fcn + return fcn(visitor, node, *data) + + def apply_node_operation(node, args): try: ans = node._apply_operation(args) diff --git a/pyomo/util/calc_var_value.py b/pyomo/util/calc_var_value.py index 81bbd285dd2..42d38f2f874 100644 --- a/pyomo/util/calc_var_value.py +++ b/pyomo/util/calc_var_value.py @@ -10,7 +10,7 @@ # ___________________________________________________________________________ from pyomo.common.errors import IterationLimitError -from pyomo.core.expr.numvalue import native_numeric_types, value, is_fixed +from pyomo.common.numeric_types import native_numeric_types, native_complex_types, value from pyomo.core.expr.calculus.derivatives import differentiate from pyomo.core.base.constraint import Constraint, _ConstraintData @@ -92,6 +92,9 @@ def calculate_variable_from_constraint( if lower != upper: raise ValueError(f"Constraint '{constraint}' must be an equality constraint") + _invalid_types = set(native_complex_types) + _invalid_types.add(type(None)) + if variable.value is None: # Note that we use "skip_validation=True" here as well, as the # variable domain may not admit the calculated initial guesses, @@ -151,7 +154,7 @@ def calculate_variable_from_constraint( # to using Newton's method. residual_2 = None - if residual_2 is not None and type(residual_2) is not complex: + if residual_2.__class__ not in _invalid_types: # if the variable appears linearly with a coefficient of 1, then we # are done if abs(residual_2 - upper) < eps: @@ -167,11 +170,7 @@ def calculate_variable_from_constraint( if slope: variable.set_value(-intercept / slope, skip_validation=True) body_val = value(body, exception=False) - if ( - body_val is not None - and body_val.__class__ is not complex - and abs(body_val - upper) < eps - ): + if body_val.__class__ not in _invalid_types and abs(body_val - upper) < eps: # Re-set the variable value to trigger any warnings WRT # the final variable state variable.set_value(variable.value) @@ -234,7 +233,7 @@ def calculate_variable_from_constraint( xk = value(variable) try: fk = value(expr) - if type(fk) is complex: + if fk.__class__ in _invalid_types and fk is not None: raise ValueError("Complex numbers are not allowed in Newton's method.") except: # We hit numerical problems with the last step (possible if @@ -275,7 +274,7 @@ def calculate_variable_from_constraint( # HACK for Python3 support, pending resolution of #879 # Issue #879 also pertains to other checks for "complex" # in this method. - if type(fkp1) is complex: + if fkp1.__class__ in _invalid_types: # We cannot perform computations on complex numbers fkp1 = None if fkp1 is not None and fkp1**2 < c1 * fk**2: @@ -289,7 +288,7 @@ def calculate_variable_from_constraint( if alpha <= alpha_min: residual = value(expr, exception=False) - if residual is None or type(residual) is complex: + if residual.__class__ in _invalid_types: residual = "{function evaluation error}" raise IterationLimitError( f"Linesearch iteration limit reached solving for "