Skip to content

Commit

Permalink
Merge pull request #2850 from emma58/logical-to-disjunctive-bugfix
Browse files Browse the repository at this point in the history
contrib.cp: Fixing a bug where the logical-to-disjunctive walker ignored fixedness of BooleanVars
  • Loading branch information
jsiirola authored May 29, 2023
2 parents 2cb3210 + 8adbb64 commit 4743218
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
80 changes: 80 additions & 0 deletions pyomo/contrib/cp/tests/test_logical_to_disjunctive.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,86 @@ def test_at_least(self):
self, m.cons[1].expr, m.disjuncts[0].binary_indicator_var >= 1
)

def test_boolean_fixed_true(self):
m = self.make_model()
e = m.a.implies(m.b)
m.a.fix(True)

visitor = LogicalToDisjunctiveVisitor()
m.cons = visitor.constraints
m.z = visitor.z_vars
m.disjuncts = visitor.disjuncts
m.disjunctions = visitor.disjunctions

visitor.walk_expression(e)
# we'll get !a v b
self.assertEqual(len(m.z), 3)
self.assertEqual(len(m.cons), 4)

self.assertIs(m.a.get_associated_binary(), m.z[1])
self.assertTrue(m.z[1].fixed)
self.assertEqual(value(m.z[1]), 1)
self.assertIs(m.b.get_associated_binary(), m.z[2])

assertExpressionsEqual(
self, m.cons[1].expr, (1 - m.z[3]) + (1 - m.z[1]) + m.z[2] >= 1
)
assertExpressionsEqual(self, m.cons[2].expr, 1 - (1 - m.z[1]) + m.z[3] >= 1)
assertExpressionsEqual(self, m.cons[3].expr, m.z[3] + (1 - m.z[2]) >= 1)
assertExpressionsEqual(self, m.cons[4].expr, m.z[3] >= 1)

def test_boolean_fixed_false(self):
m = self.make_model()
e = m.a & m.b
m.a.fix(False)

visitor = LogicalToDisjunctiveVisitor()
m.cons = visitor.constraints
m.z = visitor.z_vars
m.disjuncts = visitor.disjuncts
m.disjunctions = visitor.disjunctions

visitor.walk_expression(e)
# we'll get !a v b
self.assertEqual(len(m.z), 3)
self.assertEqual(len(m.cons), 3)

self.assertIs(m.a.get_associated_binary(), m.z[1])
self.assertTrue(m.z[1].fixed)
self.assertEqual(value(m.z[1]), 0)
self.assertIs(m.b.get_associated_binary(), m.z[2])

assertExpressionsEqual(self, m.cons[1].expr, m.z[1] >= m.z[3])
assertExpressionsEqual(self, m.cons[2].expr, m.z[2] >= m.z[3])
assertExpressionsEqual(self, m.cons[3].expr, m.z[3] >= 1)

def test_boolean_fixed_none(self):
m = self.make_model()
e = m.a & m.b
# I don't get what this means, but you can do it, so... I guess we need
# to handle it.
m.a.fix(None)

visitor = LogicalToDisjunctiveVisitor()
m.cons = visitor.constraints
m.z = visitor.z_vars
m.disjuncts = visitor.disjuncts
m.disjunctions = visitor.disjunctions

visitor.walk_expression(e)
# we'll get !a v b
self.assertEqual(len(m.z), 3)
self.assertEqual(len(m.cons), 3)

self.assertIs(m.a.get_associated_binary(), m.z[1])
self.assertTrue(m.z[1].fixed)
self.assertIsNone(m.z[1].value)
self.assertIs(m.b.get_associated_binary(), m.z[2])

assertExpressionsEqual(self, m.cons[1].expr, m.z[1] >= m.z[3])
assertExpressionsEqual(self, m.cons[2].expr, m.z[2] >= m.z[3])
assertExpressionsEqual(self, m.cons[3].expr, m.z[3] >= 1)

def test_no_need_to_walk(self):
m = self.make_model()
e = m.a
Expand Down
5 changes: 5 additions & 0 deletions pyomo/contrib/cp/transform/logical_to_disjunctive_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def _dispatch_boolean_var(visitor, node):
z = visitor.z_vars.add()
visitor.boolean_to_binary_map[node] = z
node.associate_binary_var(z)
if node.fixed:
visitor.boolean_to_binary_map[node].fixed = True
visitor.boolean_to_binary_map[node].set_value(
int(node.value) if node.value is not None else None, skip_validation=True
)
return False, visitor.boolean_to_binary_map[node]


Expand Down

0 comments on commit 4743218

Please sign in to comment.