Skip to content

Commit

Permalink
Merge pull request #3125 from jsiirola/expr-dispatcher-updates
Browse files Browse the repository at this point in the history
Update `ExitNodeDispatcher` to support subclasses, improved customization
  • Loading branch information
jsiirola authored Feb 8, 2024
2 parents 1e259f6 + ce7a6b5 commit 42fa578
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
36 changes: 26 additions & 10 deletions pyomo/repn/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyomo.common.errors import DeveloperError, InvalidValueError
from pyomo.common.log import LoggingIntercept
from pyomo.core.expr import (
NumericExpression,
ProductExpression,
NPV_ProductExpression,
SumExpression,
Expand Down Expand Up @@ -671,16 +672,6 @@ def test_ExitNodeDispatcher_registration(self):
self.assertEqual(len(end), 4)
self.assertIn(NPV_ProductExpression, end)

class NewProductExpression(ProductExpression):
pass

node = NewProductExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'NewProductExpression'"
):
end[node.__class__](None, node, *node.args)
self.assertEqual(len(end), 4)

end[SumExpression, 2] = lambda v, n, *d: 2 * sum(d)
self.assertEqual(len(end), 5)

Expand Down Expand Up @@ -710,6 +701,31 @@ class NewProductExpression(ProductExpression):
self.assertEqual(len(end), 7)
self.assertNotIn((SumExpression, 3, 4, 5, 6), end)

class NewProductExpression(ProductExpression):
pass

node = NewProductExpression((6, 7))
self.assertEqual(end[node.__class__](None, node, *node.args), 42)
self.assertEqual(len(end), 8)
self.assertIn(NewProductExpression, end)

class UnknownExpression(NumericExpression):
pass

node = UnknownExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__](None, node, *node.args)
self.assertEqual(len(end), 8)

node = UnknownExpression((6, 7))
with self.assertRaisesRegex(
DeveloperError, r".*Unexpected expression node type 'UnknownExpression'"
):
end[node.__class__, 6, 7](None, node, *node.args)
self.assertEqual(len(end), 8)

def test_BeforeChildDispatcher_registration(self):
class BeforeChildDispatcherTester(BeforeChildDispatcher):
@staticmethod
Expand Down
73 changes: 41 additions & 32 deletions pyomo/repn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,42 +387,51 @@ def __init__(self, *args, **kwargs):
super().__init__(None, *args, **kwargs)

def __missing__(self, key):
return functools.partial(self.register_dispatcher, key=key)

def register_dispatcher(self, visitor, node, *data, key=None):
if type(key) is tuple:
node_class = key[0]
else:
node_class = key
bases = node_class.__mro__
# Note: if we add an `etype`, then this special-case can be removed
if (
isinstance(node, _named_subexpression_types)
or type(node) is kernel.expression.noclone
issubclass(node_class, _named_subexpression_types)
or node_class is kernel.expression.noclone
):
base_type = Expression
elif not node.is_potentially_variable():
base_type = node.potentially_variable_base_class()
else:
base_type = node.__class__
if isinstance(key, tuple):
base_key = (base_type,) + key[1:]
# Only cache handlers for unary, binary and ternary operators
cache = len(key) <= 4
else:
base_key = base_type
cache = True
if base_key in self:
fcn = self[base_key]
elif base_type in self:
fcn = self[base_type]
elif any((k[0] if k.__class__ is tuple else k) is base_type for k in self):
raise DeveloperError(
f"Base expression key '{base_key}' not found when inserting dispatcher"
f" for node '{type(node).__name__}' while walking expression tree."
)
else:
raise DeveloperError(
f"Unexpected expression node type '{type(node).__name__}' "
"found while walking expression tree."
)
bases = [Expression]
fcn = None
for base_type in bases:
if isinstance(key, tuple):
base_key = (base_type,) + key[1:]
# Only cache handlers for unary, binary and ternary operators
cache = len(key) <= 4
else:
base_key = base_type
cache = True
if base_key in self:
fcn = self[base_key]
elif base_type in self:
fcn = self[base_type]
elif any((k[0] if type(k) is tuple else k) is base_type for k in self):
raise DeveloperError(
f"Base expression key '{base_key}' not found when inserting "
f"dispatcher for node '{node_class.__name__}' while walking "
"expression tree."
)
if fcn is None:
return self.unexpected_expression_type(key)
if cache:
self[key] = fcn
return fcn(visitor, node, *data)
return fcn

def unexpected_expression_type(self, key):
if type(key) is tuple:
node_class = key[0]
else:
node_class = key
raise DeveloperError(
f"Unexpected expression node type '{node_class.__name__}' "
f"found while walking expression tree."
)


def apply_node_operation(node, args):
Expand Down

0 comments on commit 42fa578

Please sign in to comment.