Skip to content

Commit

Permalink
Merge pull request #2722 from jsiirola/expr-multiple-dispatch
Browse files Browse the repository at this point in the history
Rework expression generation to leverage multiple dispatch
  • Loading branch information
blnicho authored Apr 25, 2023
2 parents 9f57f07 + 78cf674 commit 5417276
Show file tree
Hide file tree
Showing 64 changed files with 20,014 additions and 2,875 deletions.
2 changes: 1 addition & 1 deletion examples/pyomobook/scripts-ch/value_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# unexpected expression instead of value
a = model.u - 1
print(a) # "u - 1"
print(type(a)) # <class 'pyomo.core.expr.numeric_expr.SumExpression'>
print(type(a)) # <class 'pyomo.core.expr.numeric_expr.LinearExpression'>

# correct way to access the value
b = pyo.value(model.u) - 1
Expand Down
2 changes: 1 addition & 1 deletion examples/pyomobook/scripts-ch/value_expression.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
u - 1
<class 'pyomo.core.expr.numeric_expr.SumExpression'>
<class 'pyomo.core.expr.numeric_expr.LinearExpression'>
1.0
<class 'float'>
6 changes: 5 additions & 1 deletion examples/pyomobook/test_book_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def gurobi_fully_licensed():
if not param_available:
raise unittest.SkipTest('Parameterized is not available.')

# Needed for testing (switches the matplotlib backend):
from pyomo.common.dependencies import matplotlib_available

bool(matplotlib_available)

# Find all *.txt files, and use them to define baseline tests
currdir = this_file_dir()
datadir = currdir
Expand Down Expand Up @@ -135,7 +140,6 @@ def gurobi_fully_licensed():
'test_performance_ch_wl': ['numpy', 'matplotlib'],
}


#
# Initialize the availability data
#
Expand Down
14 changes: 10 additions & 4 deletions pyomo/common/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def _finalize_matplotlib(module, available):
def _finalize_numpy(np, available):
if not available:
return
numeric_types.RegisterBooleanType(np.bool_)
numeric_types.RegisterLogicalType(np.bool_)
for t in (
np.int_,
np.intc,
Expand All @@ -708,10 +708,16 @@ def _finalize_numpy(np, available):
np.uint64,
):
numeric_types.RegisterIntegerType(t)
numeric_types.RegisterBooleanType(t)
for t in (np.float_, np.float16, np.float32, np.float64, np.ndarray):
# We have deprecated RegisterBooleanType, so we will mock up the
# registration here (to bypass the deprecation warning) until we
# finally remove all support for it
numeric_types._native_boolean_types.add(t)
for t in (np.float_, np.float16, np.float32, np.float64):
numeric_types.RegisterNumericType(t)
numeric_types.RegisterBooleanType(t)
# We have deprecated RegisterBooleanType, so we will mock up the
# registration here (to bypass the deprecation warning) until we
# finally remove all support for it
numeric_types._native_boolean_types.add(t)


dill, dill_available = attempt_import('dill')
Expand Down
44 changes: 38 additions & 6 deletions pyomo/common/numeric_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from pyomo.common.deprecation import deprecated, relocated_module_attribute

#: Python set used to identify numeric constants, boolean values, strings
#: and instances of
#: :class:`NonNumericValue <pyomo.core.expr.numvalue.NonNumericValue>`,
#: which is commonly used in code that walks Pyomo expression trees.
#:
#: :data:`nonpyomo_leaf_types` = :data:`native_types <pyomo.core.expr.numvalue.native_types>` + { :data:`NonNumericValue <pyomo.core.expr.numvalue.NonNumericValue>` }
nonpyomo_leaf_types = set([])
nonpyomo_leaf_types = set()

# It is *significantly* faster to build the list of types we want to
# test against as a "static" set, and not to regenerate it locally for
Expand All @@ -34,12 +35,23 @@
#: Python set used to identify numeric constants. This set includes
#: native Python types as well as numeric types from Python packages
#: like numpy, which may be registered by users.
native_numeric_types = set([int, float, bool, complex])
native_integer_types = set([int, bool])
native_boolean_types = set([int, bool, str, bytes])
native_numeric_types = {int, float, complex}
native_integer_types = {int}
native_logical_types = {bool}
pyomo_constant_types = set() # includes NumericConstant

_native_boolean_types = {int, bool, str, bytes}
relocated_module_attribute(
'native_boolean_types',
'pyomo.common.numeric_types._native_boolean_types',
version='6.5.1.dev0',
msg="The native_boolean_types set will be removed in the future: the set "
"contains types that were convertible to bool, and not types that should "
"be treated as if they were bool (as was the case for the other "
"native_*_types sets). Users likely should use native_logical_types.",
)


#: Python set used to identify numeric constants and related native
#: types. This set includes
#: native Python types as well as numeric types from Python packages
Expand All @@ -49,7 +61,8 @@
native_types = set([bool, str, type(None), slice, bytes])
native_types.update(native_numeric_types)
native_types.update(native_integer_types)
native_types.update(native_boolean_types)
native_types.update(_native_boolean_types)
native_types.update(native_logical_types)

nonpyomo_leaf_types.update(native_types)

Expand Down Expand Up @@ -80,6 +93,11 @@ def RegisterIntegerType(new_type):
nonpyomo_leaf_types.add(new_type)


@deprecated(
"The native_boolean_types set (and hence RegisterBooleanType) "
"is deprecated. Users likely should use RegisterLogicalType.",
version='6.5.1.dev0',
)
def RegisterBooleanType(new_type):
"""
A utility function for updating the set of types that are
Expand All @@ -88,6 +106,20 @@ def RegisterBooleanType(new_type):
The argument should be a class (e.g., numpy.bool_).
"""
native_boolean_types.add(new_type)
_native_boolean_types.add(new_type)
native_types.add(new_type)
nonpyomo_leaf_types.add(new_type)


def RegisterLogicalType(new_type):
"""
A utility function for updating the set of types that are
recognized as handling boolean values. This function does not
register the type of integer or numeric.
The argument should be a class (e.g., numpy.bool_).
"""
_native_boolean_types.add(new_type)
native_logical_types.add(new_type)
native_types.add(new_type)
nonpyomo_leaf_types.add(new_type)
2 changes: 1 addition & 1 deletion pyomo/contrib/cp/interval_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyomo.core.base.global_set import UnindexedComponent_index
from pyomo.core.base.indexed_component import IndexedComponent, UnindexedComponent_set
from pyomo.core.base.initializer import BoundInitializer, Initializer
from pyomo.core.expr.current import GetAttrExpression, GetItemExpression
from pyomo.core.expr.current import GetItemExpression


class IntervalVarTimePoint(ScalarVar):
Expand Down
13 changes: 13 additions & 0 deletions pyomo/contrib/cp/repn/docplex_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,12 @@ def _get_bool_valued_expr(arg):


def _handle_monomial_expr(visitor, node, arg1, arg2):
# Monomial terms show up a lot. This handles some common
# simplifications (necessary in part for the unit tests)
if arg2[1].__class__ in EXPR.native_types:
return _GENERAL, arg1[1] * arg2[1]
elif arg1[1] == 1:
return arg2
return (_GENERAL, cp.times(_get_int_valued_expr(arg1), _get_int_valued_expr(arg2)))


Expand Down Expand Up @@ -913,7 +919,13 @@ def _handle_always_in_node(visitor, node, cumul_func, lb, ub, start, end):
class LogicalToDoCplex(StreamBasedExpressionVisitor):
_operator_handles = {
EXPR.GetItemExpression: _handle_getitem,
EXPR.Structural_GetItemExpression: _handle_getitem,
EXPR.Numeric_GetItemExpression: _handle_getitem,
EXPR.Boolean_GetItemExpression: _handle_getitem,
EXPR.GetAttrExpression: _handle_getattr,
EXPR.Structural_GetAttrExpression: _handle_getattr,
EXPR.Numeric_GetAttrExpression: _handle_getattr,
EXPR.Boolean_GetAttrExpression: _handle_getattr,
CallExpression: _handle_call,
EXPR.NegationExpression: _handle_negation_node,
EXPR.ProductExpression: _handle_product_node,
Expand All @@ -922,6 +934,7 @@ class LogicalToDoCplex(StreamBasedExpressionVisitor):
EXPR.AbsExpression: _handle_abs_node,
EXPR.MonomialTermExpression: _handle_monomial_expr,
EXPR.SumExpression: _handle_sum_node,
EXPR.LinearExpression: _handle_sum_node,
MinExpression: _handle_min_node,
MaxExpression: _handle_max_node,
NotExpression: _handle_not_node,
Expand Down
Loading

0 comments on commit 5417276

Please sign in to comment.