diff --git a/pyomo/core/base/constraint.py b/pyomo/core/base/constraint.py index c67236656be..8cf3c48ad0a 100644 --- a/pyomo/core/base/constraint.py +++ b/pyomo/core/base/constraint.py @@ -27,6 +27,7 @@ as_numeric, is_fixed, native_numeric_types, + native_logical_types, native_types, ) from pyomo.core.expr import ( @@ -84,14 +85,15 @@ def C_rule(model, i, j): model.c = Constraint(rule=simple_constraint_rule(...)) """ - return rule_wrapper( - rule, - { - None: Constraint.Skip, - True: Constraint.Feasible, - False: Constraint.Infeasible, - }, - ) + map_types = set([type(None)]) | native_logical_types + result_map = {None: Constraint.Skip} + for l_type in native_logical_types: + result_map[l_type(True)] = Constraint.Feasible + result_map[l_type(False)] = Constraint.Infeasible + # Note: some logical types hash the same as bool (e.g., np.bool_), so + # we will pass the set of all logical types in addition to the + # result_map + return rule_wrapper(rule, result_map, map_types=map_types) def simple_constraintlist_rule(rule): @@ -109,14 +111,15 @@ def C_rule(model, i, j): model.c = ConstraintList(expr=simple_constraintlist_rule(...)) """ - return rule_wrapper( - rule, - { - None: ConstraintList.End, - True: Constraint.Feasible, - False: Constraint.Infeasible, - }, - ) + map_types = set([type(None)]) | native_logical_types + result_map = {None: ConstraintList.End} + for l_type in native_logical_types: + result_map[l_type(True)] = Constraint.Feasible + result_map[l_type(False)] = Constraint.Infeasible + # Note: some logical types hash the same as bool (e.g., np.bool_), so + # we will pass the set of all logical types in addition to the + # result_map + return rule_wrapper(rule, result_map, map_types=map_types) # diff --git a/pyomo/core/base/indexed_component.py b/pyomo/core/base/indexed_component.py index abb29580960..0d498da091d 100644 --- a/pyomo/core/base/indexed_component.py +++ b/pyomo/core/base/indexed_component.py @@ -160,9 +160,12 @@ def _get_indexed_component_data_name(component, index): """ -def rule_result_substituter(result_map): +def rule_result_substituter(result_map, map_types): _map = result_map - _map_types = set(type(key) for key in result_map) + if map_types is None: + _map_types = set(type(key) for key in result_map) + else: + _map_types = map_types def rule_result_substituter_impl(rule, *args, **kwargs): if rule.__class__ in _map_types: @@ -203,7 +206,7 @@ def rule_result_substituter_impl(rule, *args, **kwargs): """ -def rule_wrapper(rule, wrapping_fcn, positional_arg_map=None): +def rule_wrapper(rule, wrapping_fcn, positional_arg_map=None, map_types=None): """Wrap a rule with another function This utility method provides a way to wrap a function (rule) with @@ -230,7 +233,7 @@ def rule_wrapper(rule, wrapping_fcn, positional_arg_map=None): """ if isinstance(wrapping_fcn, dict): - wrapping_fcn = rule_result_substituter(wrapping_fcn) + wrapping_fcn = rule_result_substituter(wrapping_fcn, map_types) if not inspect.isfunction(rule): return wrapping_fcn(rule) # Because some of our processing of initializer functions relies on