Skip to content

Commit

Permalink
Add Boolean overloads: &, |, ^, ~ (#103)
Browse files Browse the repository at this point in the history
Also: add a dummy `Model` pseudo-object and fix how `evaluate` handles contexts.
  • Loading branch information
alex-ozdemir authored Jan 6, 2025
1 parent be54c23 commit a2a3746
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
with:
options: "--check --verbose"
src: "cvc5_pythonic_api"
version: "23.7.0"
version: "24.10.0"

- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -56,7 +56,7 @@ jobs:
- name: Build cvc5
run: |
cd cvc5/
./configure.sh production --auto-download --python-bindings --cocoa
./configure.sh production --auto-download --python-bindings --cocoa --gpl
cd build/
make -j${{ env.num_proc }}
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ check:
pyright ./cvc5_pythonic_api

fmt:
black --required-version 23.7.0 ./cvc5_pythonic_api
black --required-version 24 ./cvc5_pythonic_api

check-fmt:
black --check --verbose --required-version 23.7.0 ./cvc5_pythonic_api
black --check --verbose --required-version 24 ./cvc5_pythonic_api

coverage:
coverage run test_unit.py && coverage report && coverage html
121 changes: 90 additions & 31 deletions cvc5_pythonic_api/cvc5_pythonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
* Missing features:
* Patterns
* Models for uninterpreted sorts
* The `Model` function
* In our API, this function returns an object whose only method is `evaluate`.
* Pseudo-boolean counting constraints
* AtMost, AtLeast, PbLe, PbGe, PbEq
* HTML integration
Expand Down Expand Up @@ -558,9 +560,6 @@ def _ctx_from_ast_arg_list(args, default_ctx=None):
if is_ast(a):
if ctx is None:
ctx = a.ctx
else:
if debugging():
_assert(ctx == a.ctx, "Context mismatch")
if ctx is None:
ctx = default_ctx
return ctx
Expand Down Expand Up @@ -1245,8 +1244,6 @@ def If(a, b, c, ctx=None):
s = BoolSort(ctx)
a = s.cast(a)
b, c = _coerce_exprs(b, c, ctx)
if debugging():
_assert(a.ctx == b.ctx, "Context mismatch")
return _to_expr_ref(ctx.solver.mkTerm(Kind.ITE, a.ast, b.ast, c.ast), ctx)


Expand Down Expand Up @@ -1429,6 +1426,38 @@ def __mul__(self, other):
return 0
return If(self, other, 0)

def __and__(self, other):
"""Create the SMT and expression `self & other`.
>>> solve(Bool("x") & Bool("y"))
[x = True, y = True]
"""
return And(self, other)

def __or__(self, other):
"""Create the SMT or expression `self | other`.
>>> solve(Bool("x") | Bool("y"), Not(Bool("x")))
[x = False, y = True]
"""
return Or(self, other)

def __xor__(self, other):
"""Create the SMT xor expression `self ^ other`.
>>> solve(Bool("x") ^ Bool("y"), Not(Bool("x")))
[x = False, y = True]
"""
return Xor(self, other)

def __invert__(self):
"""Create the SMT not expression `~self`.
>>> solve(~Bool("x"))
[x = False]
"""
return Not(self)


def is_bool(a):
"""Return `True` if `a` is an SMT Boolean expression.
Expand Down Expand Up @@ -1875,8 +1904,6 @@ def cast(self, val):
String
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
val_s = val.sort()
if self.eq(val_s):
return val
Expand Down Expand Up @@ -2617,8 +2644,6 @@ def cast(self, val):
failed
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
val_s = val.sort()
if self.eq(val_s):
return val
Expand Down Expand Up @@ -4067,8 +4092,6 @@ def cast(self, val):
'#b00000000000000000000000000001010'
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
# Idea: use sign_extend if sort of val is a bitvector of smaller size
return val
else:
Expand Down Expand Up @@ -5494,7 +5517,6 @@ def ArraySort(*sig):
if debugging():
for s in sig:
_assert(is_sort(s), "SMT sort expected")
_assert(s.ctx == r.ctx, "Context mismatch")
ctx = d.ctx
if len(sig) == 2:
return ArraySortRef(ctx.solver.mkArraySort(d.ast, r.ast), ctx)
Expand Down Expand Up @@ -6238,12 +6260,22 @@ def proof(self):
[a + 2 == 0, a == 0],
(EQ_RESOLVE: False,
(ASSUME: a == 0, [a == 0]),
(MACRO_SR_EQ_INTRO: (a == 0) == False,
[a == 0, 7, 12],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
[a + 2 == 0, 7, 12]))))))
(TRANS: (a == 0) == False,
(CONG: (a == 0) == (-2 == 0),
[5],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(TRANS: (a + 2 == 0) == (a == -2),
(CONG: (a + 2 == 0) == (2 + a == 0),
[5],
(TRUST_THEORY_REWRITE: a + 2 == 2 + a,
[a + 2 == 2 + a, 3, 7]),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (2 + a == 0) == (a == -2),
[(2 + a == 0) == (a == -2), 3, 7]))),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (-2 == 0) == False,
[(-2 == 0) == False, 3, 7])))))
"""
p = self.solver.getProof()[0]
return ProofRef(self, p)
Expand Down Expand Up @@ -6789,13 +6821,36 @@ def decls(self):


def evaluate(t):
"""Evaluates the given term (assuming it is constant!)"""
"""Evaluates the given term (assuming it is constant!)
>>> evaluate(evaluate(BitVecVal(1, 8) + BitVecVal(2, 8)) + BitVecVal(3, 8))
6
"""
if not isinstance(t, ExprRef):
raise TypeError("Can only evaluate `ExprRef`s")
s = Solver()
s.check()
m = s.model()
return m[t]


class EmptyModel:
def evaluate(self, t):
return evaluate(t)


def Model(ctx=None):
"""Return an object for evaluating terms.
We recommend using the standalone `evaluate` function for this instead,
but we also provide this function and its return object for z3 compatibility.
>>> Model().evaluate(BitVecVal(1, 8) + BitVecVal(2, 8))
3
"""
return EmptyModel()


class ProofRef:
"""A proof tree where every proof reference corresponds to the
root step of a proof. The branches of the root step are the
Expand Down Expand Up @@ -6857,12 +6912,22 @@ def getChildren(self):
>>> p
(EQ_RESOLVE: False,
(ASSUME: a == 0, [a == 0]),
(MACRO_SR_EQ_INTRO: (a == 0) == False,
[a == 0, 7, 12],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(MACRO_SR_EQ_INTRO: (a + 2 == 0) == (a == -2),
[a + 2 == 0, 7, 12]))))
(TRANS: (a == 0) == False,
(CONG: (a == 0) == (-2 == 0),
[5],
(EQ_RESOLVE: a == -2,
(ASSUME: a + 2 == 0, [a + 2 == 0]),
(TRANS: (a + 2 == 0) == (a == -2),
(CONG: (a + 2 == 0) == (2 + a == 0),
[5],
(TRUST_THEORY_REWRITE: a + 2 == 2 + a,
[a + 2 == 2 + a, 3, 7]),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (2 + a == 0) == (a == -2),
[(2 + a == 0) == (a == -2), 3, 7]))),
(REFL: 0 == 0, [0])),
(TRUST_THEORY_REWRITE: (-2 == 0) == False,
[(-2 == 0) == False, 3, 7])))
"""
children = self.proof.getChildren()
return [ProofRef(self.solver, cp) for cp in children]
Expand Down Expand Up @@ -6965,8 +7030,6 @@ def cast(self, val):
'(fp #b0 #b01111111 #b00000000000000000000000)'
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
return val
else:
return FPVal(val, None, self, self.ctx)
Expand Down Expand Up @@ -8633,7 +8696,6 @@ def CreateDatatypes(*ds):
_assert(
all([isinstance(d, Datatype) for d in ds]), "Arguments must be Datatypes"
)
_assert(all([d.ctx == ds[0].ctx for d in ds]), "Context mismatch")
_assert(all([d.constructors != [] for d in ds]), "Non-empty Datatypes expected")
ctx = ds[0].ctx
s = ctx.solver
Expand Down Expand Up @@ -9240,9 +9302,6 @@ def cast(self, val):
'#f10m31'
"""
if is_expr(val):
if debugging():
_assert(self.ctx == val.ctx, "Context mismatch")
# Idea: use sign_extend if sort of val is a bitvector of smaller size
return val
else:
return FiniteFieldVal(val, self)
Expand Down

0 comments on commit a2a3746

Please sign in to comment.