Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional fixes for LinearRepn #2865

Merged
merged 2 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyomo/repn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
RangedExpression,
)
from pyomo.core.expr.visitor import StreamBasedExpressionVisitor, _EvaluationVisitor
from pyomo.core.expr import is_fixed
from pyomo.core.expr import is_fixed, value
from pyomo.core.base.expression import ScalarExpression, _GeneralExpressionData
from pyomo.core.base.objective import ScalarObjective, _GeneralObjectiveData
import pyomo.core.kernel as kernel
Expand Down Expand Up @@ -756,7 +756,7 @@ def _register_new_before_child_dispatcher(visitor, child):
dispatcher = _before_child_dispatcher
child_type = child.__class__
if child_type in native_numeric_types:
if isinstance(child_type, complex):
if issubclass(child_type, complex):
_complex_types.add(child_type)
dispatcher[child_type] = _before_complex
else:
Expand All @@ -775,7 +775,7 @@ def _register_new_before_child_dispatcher(visitor, child):
if pv_base_type not in dispatcher:
try:
child.__class__ = pv_base_type
_register_new_before_child_dispatcher(self, child)
_register_new_before_child_dispatcher(visitor, child)
finally:
child.__class__ = child_type
if pv_base_type in visitor.exit_node_handlers:
Expand Down
103 changes: 102 additions & 1 deletion pyomo/repn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from pyomo.core.expr.compare import assertExpressionsEqual
from pyomo.core.expr.numeric_expr import LinearExpression, MonomialTermExpression
from pyomo.core.expr.current import Expr_if, inequality
from pyomo.core.expr.current import (
Expr_if,
inequality,
LinearExpression,
NPV_SumExpression,
)
from pyomo.repn.linear import LinearRepn, LinearRepnVisitor
from pyomo.repn.util import InvalidNumber

Expand Down Expand Up @@ -1311,3 +1316,99 @@ def test_external(self):
self.assertEqual(repn.constant, 0)
self.assertEqual(repn.linear, {})
self.assertIs(repn.nonlinear, e)

def test_type_registrations(self):
m = ConcreteModel()

cfg = VisitorConfig()
visitor = LinearRepnVisitor(*cfg)

import pyomo.repn.linear as linear

_orig_dispatcher = linear._before_child_dispatcher
linear._before_child_dispatcher = bcd = {}
try:
# native type
self.assertEqual(
linear._register_new_before_child_dispatcher(visitor, 5),
(False, (linear._CONSTANT, 5)),
)
self.assertEqual(len(bcd), 1)
self.assertIs(bcd[int], linear._before_native)
# complex type
self.assertEqual(
linear._register_new_before_child_dispatcher(visitor, 5j),
(False, (linear._CONSTANT, 5j)),
)
self.assertEqual(len(bcd), 2)
self.assertIs(bcd[complex], linear._before_complex)
# ScalarParam
m.p = Param(initialize=5)
self.assertEqual(
linear._register_new_before_child_dispatcher(visitor, m.p),
(False, (linear._CONSTANT, 5)),
)
self.assertEqual(len(bcd), 3)
self.assertIs(bcd[m.p.__class__], linear._before_param)
# ParamData
m.q = Param([0], initialize=6, mutable=True)
self.assertEqual(
linear._register_new_before_child_dispatcher(visitor, m.q[0]),
(False, (linear._CONSTANT, 6)),
)
self.assertEqual(len(bcd), 4)
self.assertIs(bcd[m.q[0].__class__], linear._before_param)
# NPV_SumExpression
self.assertEqual(
linear._register_new_before_child_dispatcher(visitor, m.p + m.q[0]),
(False, (linear._CONSTANT, 11)),
)
self.assertEqual(len(bcd), 6)
self.assertIs(bcd[NPV_SumExpression], linear._before_npv)
self.assertIs(bcd[LinearExpression], linear._before_general_expression)
# Named expression
m.e = Expression(expr=m.p + m.q[0])
self.assertEqual(
linear._register_new_before_child_dispatcher(visitor, m.e), (True, None)
)
self.assertEqual(len(bcd), 7)
self.assertIs(bcd[m.e.__class__], linear._before_named_expression)

finally:
linear._before_child_dispatcher = _orig_dispatcher

def test_to_expression(self):
m = ConcreteModel()
m.x = Var()
m.y = Var()

cfg = VisitorConfig()
visitor = LinearRepnVisitor(*cfg)
# prepopulate the visitor's var_map
visitor.walk_expression(m.x + m.y)

expr = LinearRepn()
self.assertEqual(expr.to_expression(visitor), 0)

expr.linear[id(m.x)] = 0
self.assertEqual(expr.to_expression(visitor), 0)

expr.linear[id(m.x)] = 1
assertExpressionsEqual(self, expr.to_expression(visitor), m.x)

expr.linear[id(m.x)] = 2
assertExpressionsEqual(self, expr.to_expression(visitor), 2 * m.x)

expr.linear[id(m.y)] = 3
assertExpressionsEqual(self, expr.to_expression(visitor), 2 * m.x + 3 * m.y)

expr.multiplier = 10
assertExpressionsEqual(
self, expr.to_expression(visitor), (2 * m.x + 3 * m.y) * 10
)
expr.multiplier = 1

expr.constant = 0
expr.linear[id(m.x)] = 0
expr.linear[id(m.y)] = 0
assertExpressionsEqual(self, expr.to_expression(visitor), LinearExpression())