Skip to content

Commit

Permalink
Merge pull request #3152 from jsiirola/simple-rule-numpy-bool
Browse files Browse the repository at this point in the history
Generalize the simple_constraint_rule decorator
  • Loading branch information
mrmundt authored Feb 20, 2024
2 parents 306bdc8 + 7ecdcc9 commit 783872b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
35 changes: 19 additions & 16 deletions pyomo/core/base/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
as_numeric,
is_fixed,
native_numeric_types,
native_logical_types,
native_types,
)
from pyomo.core.expr import (
Expand Down Expand Up @@ -84,14 +85,15 @@ def C_rule(model, i, j):
model.c = Constraint(rule=simple_constraint_rule(...))
"""
return rule_wrapper(
rule,
{
None: Constraint.Skip,
True: Constraint.Feasible,
False: Constraint.Infeasible,
},
)
map_types = set([type(None)]) | native_logical_types
result_map = {None: Constraint.Skip}
for l_type in native_logical_types:
result_map[l_type(True)] = Constraint.Feasible
result_map[l_type(False)] = Constraint.Infeasible
# Note: some logical types hash the same as bool (e.g., np.bool_), so
# we will pass the set of all logical types in addition to the
# result_map
return rule_wrapper(rule, result_map, map_types=map_types)


def simple_constraintlist_rule(rule):
Expand All @@ -109,14 +111,15 @@ def C_rule(model, i, j):
model.c = ConstraintList(expr=simple_constraintlist_rule(...))
"""
return rule_wrapper(
rule,
{
None: ConstraintList.End,
True: Constraint.Feasible,
False: Constraint.Infeasible,
},
)
map_types = set([type(None)]) | native_logical_types
result_map = {None: ConstraintList.End}
for l_type in native_logical_types:
result_map[l_type(True)] = Constraint.Feasible
result_map[l_type(False)] = Constraint.Infeasible
# Note: some logical types hash the same as bool (e.g., np.bool_), so
# we will pass the set of all logical types in addition to the
# result_map
return rule_wrapper(rule, result_map, map_types=map_types)


#
Expand Down
11 changes: 7 additions & 4 deletions pyomo/core/base/indexed_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,12 @@ def _get_indexed_component_data_name(component, index):
"""


def rule_result_substituter(result_map):
def rule_result_substituter(result_map, map_types):
_map = result_map
_map_types = set(type(key) for key in result_map)
if map_types is None:
_map_types = set(type(key) for key in result_map)
else:
_map_types = map_types

def rule_result_substituter_impl(rule, *args, **kwargs):
if rule.__class__ in _map_types:
Expand Down Expand Up @@ -203,7 +206,7 @@ def rule_result_substituter_impl(rule, *args, **kwargs):
"""


def rule_wrapper(rule, wrapping_fcn, positional_arg_map=None):
def rule_wrapper(rule, wrapping_fcn, positional_arg_map=None, map_types=None):
"""Wrap a rule with another function
This utility method provides a way to wrap a function (rule) with
Expand All @@ -230,7 +233,7 @@ def rule_wrapper(rule, wrapping_fcn, positional_arg_map=None):
"""
if isinstance(wrapping_fcn, dict):
wrapping_fcn = rule_result_substituter(wrapping_fcn)
wrapping_fcn = rule_result_substituter(wrapping_fcn, map_types)
if not inspect.isfunction(rule):
return wrapping_fcn(rule)
# Because some of our processing of initializer functions relies on
Expand Down

0 comments on commit 783872b

Please sign in to comment.