diff --git a/pyomo/contrib/cp/tests/test_logical_to_disjunctive.py b/pyomo/contrib/cp/tests/test_logical_to_disjunctive.py index 9e6241942de..b9c5b20ea48 100755 --- a/pyomo/contrib/cp/tests/test_logical_to_disjunctive.py +++ b/pyomo/contrib/cp/tests/test_logical_to_disjunctive.py @@ -283,6 +283,86 @@ def test_at_least(self): self, m.cons[1].expr, m.disjuncts[0].binary_indicator_var >= 1 ) + def test_boolean_fixed_true(self): + m = self.make_model() + e = m.a.implies(m.b) + m.a.fix(True) + + visitor = LogicalToDisjunctiveVisitor() + m.cons = visitor.constraints + m.z = visitor.z_vars + m.disjuncts = visitor.disjuncts + m.disjunctions = visitor.disjunctions + + visitor.walk_expression(e) + # we'll get !a v b + self.assertEqual(len(m.z), 3) + self.assertEqual(len(m.cons), 4) + + self.assertIs(m.a.get_associated_binary(), m.z[1]) + self.assertTrue(m.z[1].fixed) + self.assertEqual(value(m.z[1]), 1) + self.assertIs(m.b.get_associated_binary(), m.z[2]) + + assertExpressionsEqual( + self, m.cons[1].expr, (1 - m.z[3]) + (1 - m.z[1]) + m.z[2] >= 1 + ) + assertExpressionsEqual(self, m.cons[2].expr, 1 - (1 - m.z[1]) + m.z[3] >= 1) + assertExpressionsEqual(self, m.cons[3].expr, m.z[3] + (1 - m.z[2]) >= 1) + assertExpressionsEqual(self, m.cons[4].expr, m.z[3] >= 1) + + def test_boolean_fixed_false(self): + m = self.make_model() + e = m.a & m.b + m.a.fix(False) + + visitor = LogicalToDisjunctiveVisitor() + m.cons = visitor.constraints + m.z = visitor.z_vars + m.disjuncts = visitor.disjuncts + m.disjunctions = visitor.disjunctions + + visitor.walk_expression(e) + # we'll get !a v b + self.assertEqual(len(m.z), 3) + self.assertEqual(len(m.cons), 3) + + self.assertIs(m.a.get_associated_binary(), m.z[1]) + self.assertTrue(m.z[1].fixed) + self.assertEqual(value(m.z[1]), 0) + self.assertIs(m.b.get_associated_binary(), m.z[2]) + + assertExpressionsEqual(self, m.cons[1].expr, m.z[1] >= m.z[3]) + assertExpressionsEqual(self, m.cons[2].expr, m.z[2] >= m.z[3]) + assertExpressionsEqual(self, m.cons[3].expr, m.z[3] >= 1) + + def test_boolean_fixed_none(self): + m = self.make_model() + e = m.a & m.b + # I don't get what this means, but you can do it, so... I guess we need + # to handle it. + m.a.fix(None) + + visitor = LogicalToDisjunctiveVisitor() + m.cons = visitor.constraints + m.z = visitor.z_vars + m.disjuncts = visitor.disjuncts + m.disjunctions = visitor.disjunctions + + visitor.walk_expression(e) + # we'll get !a v b + self.assertEqual(len(m.z), 3) + self.assertEqual(len(m.cons), 3) + + self.assertIs(m.a.get_associated_binary(), m.z[1]) + self.assertTrue(m.z[1].fixed) + self.assertIsNone(m.z[1].value) + self.assertIs(m.b.get_associated_binary(), m.z[2]) + + assertExpressionsEqual(self, m.cons[1].expr, m.z[1] >= m.z[3]) + assertExpressionsEqual(self, m.cons[2].expr, m.z[2] >= m.z[3]) + assertExpressionsEqual(self, m.cons[3].expr, m.z[3] >= 1) + def test_no_need_to_walk(self): m = self.make_model() e = m.a diff --git a/pyomo/contrib/cp/transform/logical_to_disjunctive_walker.py b/pyomo/contrib/cp/transform/logical_to_disjunctive_walker.py index 51c1f3a0e2b..00018e1f31e 100644 --- a/pyomo/contrib/cp/transform/logical_to_disjunctive_walker.py +++ b/pyomo/contrib/cp/transform/logical_to_disjunctive_walker.py @@ -43,6 +43,11 @@ def _dispatch_boolean_var(visitor, node): z = visitor.z_vars.add() visitor.boolean_to_binary_map[node] = z node.associate_binary_var(z) + if node.fixed: + visitor.boolean_to_binary_map[node].fixed = True + visitor.boolean_to_binary_map[node].set_value( + int(node.value) if node.value is not None else None, skip_validation=True + ) return False, visitor.boolean_to_binary_map[node]