Skip to content

Commit

Permalink
Fixing a typo in the hull transformation, and adding a test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
emma58 committed Jul 19, 2023
1 parent 4f2d5fc commit 848ece7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyomo/gdp/plugins/hull.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _transform_disjunctionData(
# mark this as local because we won't re-disaggregate if this is
# a nested disjunction
if local_var_set is not None:
local_var_set.append(disaggregatedVar)
local_var_set.append(disaggregated_var)
var_free = 1 - sum(
disj.indicator_var.get_associated_binary()
for disj in disjunctsVarAppearsIn[var]
Expand Down
77 changes: 77 additions & 0 deletions pyomo/gdp/tests/test_hull.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,83 @@ def test_disaggregated_vars_are_set_to_0_correctly(self):
self.assertEqual(value(hull.get_disaggregated_var(m.x, m.d3)), 1.2)
self.assertEqual(value(hull.get_disaggregated_var(m.x, m.d4)), 0)

def test_nested_with_local_vars(self):
m = ConcreteModel()

m.x = Var(bounds=(0, 10))
m.S = RangeSet(2)

@m.Disjunct()
def d_l(d):
d.lambdas = Var(m.S, bounds=(0, 1))
d.LocalVars = Suffix(direction=Suffix.LOCAL)
d.LocalVars[d] = list(d.lambdas.values())
d.c1 = Constraint(expr=d.lambdas[1] + d.lambdas[2] == 1)
d.c2 = Constraint(expr=m.x == 2 * d.lambdas[1] + 3 * d.lambdas[2])

@m.Disjunct()
def d_r(d):
@d.Disjunct()
def d_l(e):
e.lambdas = Var(m.S, bounds=(0, 1))
e.LocalVars = Suffix(direction=Suffix.LOCAL)
e.LocalVars[e] = list(e.lambdas.values())
e.c1 = Constraint(expr=e.lambdas[1] + e.lambdas[2] == 1)
e.c2 = Constraint(expr=m.x == 2 * e.lambdas[1] + 3 * e.lambdas[2])

@d.Disjunct()
def d_r(e):
e.lambdas = Var(m.S, bounds=(0, 1))
e.LocalVars = Suffix(direction=Suffix.LOCAL)
e.LocalVars[e] = list(e.lambdas.values())
e.c1 = Constraint(expr=e.lambdas[1] + e.lambdas[2] == 1)
e.c2 = Constraint(expr=m.x == 2 * e.lambdas[1] + 3 * e.lambdas[2])

d.inner_disj = Disjunction(expr=[d.d_l, d.d_r])

m.disj = Disjunction(expr=[m.d_l, m.d_r])
m.obj = Objective(expr=m.x)

hull = TransformationFactory('gdp.hull')
hull.apply_to(m)

x1 = hull.get_disaggregated_var(m.x, m.d_l)
x2 = hull.get_disaggregated_var(m.x, m.d_r)
x3 = hull.get_disaggregated_var(m.x, m.d_r.d_l)
x4 = hull.get_disaggregated_var(m.x, m.d_r.d_r)

for d, x in [(m.d_l, x1), (m.d_r.d_l, x3), (m.d_r.d_r, x4)]:
lambda1 = hull.get_disaggregated_var(d.lambdas[1], d)
self.assertIs(lambda1, d.lambdas[1])
lambda2 = hull.get_disaggregated_var(d.lambdas[2], d)
self.assertIs(lambda2, d.lambdas[2])

cons = hull.get_transformed_constraints(d.c1)
self.assertEqual(len(cons), 1)
convex_combo = cons[0]
assertExpressionsEqual(
self,
convex_combo.expr,
lambda1 + lambda2 - (1 - d.indicator_var.get_associated_binary()) * 0.0
== d.indicator_var.get_associated_binary(),
)
cons = hull.get_transformed_constraints(d.c2)
self.assertEqual(len(cons), 1)
get_x = cons[0]
assertExpressionsEqual(
self,
get_x.expr,
x
- (2 * lambda1 + 3 * lambda2)
- (1 - d.indicator_var.get_associated_binary()) * 0.0
== 0.0 * d.indicator_var.get_associated_binary(),
)

cons = hull.get_disaggregation_constraint(m.x, m.disj)
assertExpressionsEqual(self, cons.expr, m.x == x1 + x2)
cons = hull.get_disaggregation_constraint(m.x, m.d_r.inner_disj)
assertExpressionsEqual(self, cons.expr, x2 == x3 + x4)


class TestSpecialCases(unittest.TestCase):
def test_local_vars(self):
Expand Down

0 comments on commit 848ece7

Please sign in to comment.