Skip to content

Commit

Permalink
Fixing a couple bugs in gdp.hull with nested disjunctions and variabl…
Browse files Browse the repository at this point in the history
…es not used in every Disjunct
  • Loading branch information
emma58 committed Jul 14, 2023
1 parent 1288032 commit 455a726
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
38 changes: 23 additions & 15 deletions pyomo/gdp/plugins/hull.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _apply_to_impl(self, instance, **kwds):
# Preprocess in order to find what disjunctive components need
# transformation
gdp_tree = self._get_gdp_tree_from_targets(instance, targets)
# Hull transforms from root to leaf
preprocessed_targets = gdp_tree.topological_sort()
self._targets_set = set(preprocessed_targets)

Expand Down Expand Up @@ -446,7 +447,7 @@ def _transform_disjunctionData(
# mark this as local because we won't re-disaggregate if this is
# a nested disjunction
if local_var_set is not None:
local_var_set.append(disaggregatedVar)
local_var_set.append(disaggregated_var)
var_free = 1 - sum(
disj.indicator_var.get_associated_binary()
for disj in disjunctsVarAppearsIn[var]
Expand Down Expand Up @@ -499,14 +500,19 @@ def _transform_disjunctionData(

# We equate the sum of the disaggregated vars to var (the original)
# if parent_disjunct is None, else it needs to be the disaggregated
# var corresponding to var on the parent disjunct. This is the
# reason we transform from root to leaf: This constraint is now
# correct regardless of how nested something may have been.
parent_var = (
var
if parent_disjunct is None
else self.get_disaggregated_var(var, parent_disjunct)
)
# var corresponding to var on the parent disjunct, assuming the Var
# appears on the parent disjunct. If it does not, then we can again
# use the original. This is the reason we transform from root to
# leaf: This constraint is now correct regardless of how nested
# something may have been and accounting for the Disjunctions above
# it.
parent_var = var
if parent_disjunct is not None:
parent_var = self.get_disaggregated_var(
var, parent_disjunct, raise_exception=False
)
if parent_var is None:
parent_var = var
cons_idx = len(disaggregationConstraint)
disaggregationConstraint.add(cons_idx, parent_var == disaggregatedExpr)
# and update the map so that we can find this later. We index by
Expand Down Expand Up @@ -885,7 +891,7 @@ def _add_local_var_suffix(self, disjunct):
% (disjunct.getname(fully_qualified=True), localSuffix.ctype)
)

def get_disaggregated_var(self, v, disjunct):
def get_disaggregated_var(self, v, disjunct, raise_exception=True):
"""
Returns the disaggregated variable corresponding to the Var v and the
Disjunct disjunct.
Expand All @@ -903,11 +909,13 @@ def get_disaggregated_var(self, v, disjunct):
try:
return transBlock._disaggregatedVarMap['disaggregatedVar'][disjunct][v]
except:
logger.error(
"It does not appear '%s' is a "
"variable that appears in disjunct '%s'" % (v.name, disjunct.name)
)
raise
if raise_exception:
logger.error(
"It does not appear '%s' is a "
"variable that appears in disjunct '%s'" % (v.name, disjunct.name)
)
raise
return None

def get_src_var(self, disaggregated_var):
"""
Expand Down
39 changes: 39 additions & 0 deletions pyomo/gdp/tests/test_hull.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,45 @@ def test_disaggregated_vars_are_set_to_0_correctly(self):
self.assertEqual(value(hull.get_disaggregated_var(m.x, m.d3)), 1.2)
self.assertEqual(value(hull.get_disaggregated_var(m.x, m.d4)), 0)

def test_nested_with_var_that_does_not_appear_in_every_disjunct(self):
m = ConcreteModel()
m.x = Var(bounds=(0, 10))
m.y = Var(bounds=(-4, 5))
m.parent1 = Disjunct()
m.parent2 = Disjunct()
m.parent2.c = Constraint(expr=m.x == 0)
m.parent_disjunction = Disjunction(expr=[m.parent1, m.parent2])
m.child1 = Disjunct()
m.child1.c = Constraint(expr=m.x <= 8)
m.child2 = Disjunct()
m.child2.c = Constraint(expr=m.x + m.y <= 3)
m.child3 = Disjunct()
m.child3.c = Constraint(expr=m.x <= 7)
m.parent1.disjunction = Disjunction(expr=[m.child1, m.child2, m.child3])

hull = TransformationFactory('gdp.hull')
hull.apply_to(m)

y_c2 = hull.get_disaggregated_var(m.y, m.child2)
self.assertEqual(y_c2.bounds, (-4, 5))
other_y = hull.get_disaggregated_var(m.y, m.child1)
self.assertEqual(other_y.bounds, (-4, 5))
other_other_y = hull.get_disaggregated_var(m.y, m.child3)
self.assertIs(other_y, other_other_y)
y_cons = hull.get_disaggregation_constraint(m.y, m.parent1.disjunction)
# check that the disaggregated ys in the nested just sum to the original
assertExpressionsEqual(self, y_cons.expr, m.y == other_y + y_c2)

x_c1 = hull.get_disaggregated_var(m.x, m.child1)
x_c2 = hull.get_disaggregated_var(m.x, m.child2)
x_c3 = hull.get_disaggregated_var(m.x, m.child3)
x_p1 = hull.get_disaggregated_var(m.x, m.parent1)
x_p2 = hull.get_disaggregated_var(m.x, m.parent2)
x_cons_parent = hull.get_disaggregation_constraint(m.x, m.parent_disjunction)
assertExpressionsEqual(self, x_cons_parent.expr, m.x == x_p1 + x_p2)
x_cons_child = hull.get_disaggregation_constraint(m.x, m.parent1.disjunction)
assertExpressionsEqual(self, x_cons_child.expr, x_p1 == x_c1 + x_c2 + x_c3)


class TestSpecialCases(unittest.TestCase):
def test_local_vars(self):
Expand Down

0 comments on commit 455a726

Please sign in to comment.