Skip to content

Commit

Permalink
Term substitution with Z3, needs more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Dec 4, 2024
1 parent 1acea48 commit a4b8d7c
Showing 1 changed file with 70 additions and 21 deletions.
91 changes: 70 additions & 21 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,21 @@ def get_z3_select(line, domain, step):
else:
return line.get_z3_lambda()

def get_z3_substitute(line, domain, step):
assert step >= 0
line.z3 = line.get_z3()
if domain:
if step == 0:
current_states = [state.get_z3() for state in domain]
else:
# assuming that line.z3 is a term over states of step - 1
current_states = [state.get_z3_step(step - 1) for state in domain]
next_states = [state.get_z3_step(step) for state in domain]
renaming = list(zip(current_states, next_states))

line.z3 = z3.substitute(line.z3, renaming)
return line.z3

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
self.bitwuzla = tm.mk_var(self.sid_line.get_bitwuzla(tm), self.name)
Expand Down Expand Up @@ -620,7 +635,7 @@ def get_bitwuzla_substitute(line, domain, step, tm):
# assuming that line.bitwuzla is a term over states of step - 1
current_states = [state.get_bitwuzla_step(step - 1, tm) for state in domain]
next_states = [state.get_bitwuzla_step(step, tm) for state in domain]
renaming = dict(current_next for current_next in zip(current_states, next_states))
renaming = dict(zip(current_states, next_states))

line.bitwuzla = tm.substitute_term(line.bitwuzla, renaming)
return line.bitwuzla
Expand Down Expand Up @@ -1165,11 +1180,13 @@ def __init__(self, nid, sid_line, arg1_line, arg2_line, arg3_line, comment, line
if comment == "; branch true condition":
Ite.branching_conditions = self
self.z3_lambda_line = None
self.cache_z3_instance = {}
self.bitwuzla_lambda_line = None
self.cache_bitwuzla_instance = {}
elif comment == "; branch false condition":
Ite.non_branching_conditions = self
self.z3_lambda_line = None
self.cache_z3_instance = {}
self.bitwuzla_lambda_line = None
self.cache_bitwuzla_instance = {}

Expand Down Expand Up @@ -1201,11 +1218,22 @@ def get_z3_lambda(self):

def get_z3_select(self, domain, step):
# only needed for branching
return State.get_z3_select(self, domain, step)
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_select(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_substitute(self, domain, step):
# only needed for branching
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_substitute(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_step(self, step):
# only needed for branching
return self.get_z3_select(self.domain, step)
if Line.LAMBDAS:
return self.get_z3_select(self.domain, step)
else:
return self.get_z3_substitute(self.domain, step)

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
Expand Down Expand Up @@ -1310,12 +1338,35 @@ def __init__(self, nid, comment, line_no):
Line.__init__(self, nid, comment, line_no)
Cache.__init__(self)
self.z3_lambda_line = None
self.cache_z3_select = {}
self.cache_z3_instance = {}
self.bitwuzla_lambda_line = None
self.cache_bitwuzla_instance = {}

def get_z3(self, line):
if self.z3 is None:
self.z3 = line.get_z3()
return self.z3

def get_z3_lambda(self, line):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(line)
return self.z3_lambda_line

def get_z3_select(self, domain, step):
return State.get_z3_select(self, domain, step)
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_select(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_substitute(self, domain, step):
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_substitute(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_step(self, domain, step):
if Line.LAMBDAS:
return self.get_z3_select(domain, step)
else:
return self.get_z3_substitute(domain, step)

def get_bitwuzla(self, line, tm):
if self.bitwuzla is None:
Expand Down Expand Up @@ -1391,15 +1442,14 @@ def new_transition(self, transitions, index):
assert self.nid not in transitions, f"transition nid {self.nid} already defined @ {self.line_no}"
transitions[self.nid] = self

def get_z3(self):
return super().get_z3(self.exp_line)

def get_z3_lambda(self):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(self.exp_line)
return self.z3_lambda_line
return super().get_z3_lambda(self.exp_line)

def get_z3_select(self, step):
if step not in self.cache_z3_select:
self.cache_z3_select[step] = super().get_z3_select(self.exp_line.domain, step)
return self.cache_z3_select[step]
def get_z3_step(self, step):
return super().get_z3_step(self.exp_line.domain, step)

def get_bitwuzla(self, tm):
return super().get_bitwuzla(self.exp_line, tm)
Expand Down Expand Up @@ -1443,7 +1493,7 @@ def get_z3_step(self, step):
else:
if isinstance(self.exp_line, Constant):
self.set_value()
return self.state_line.get_z3_step(0) == self.get_z3_select(0)
return self.state_line.get_z3_step(0) == super().get_z3_step(0)

def get_bitwuzla_step(self, step, tm):
assert step == 0, f"bitwuzla init with {step} != 0"
Expand Down Expand Up @@ -1482,12 +1532,12 @@ def __str__(self):

def get_z3_step(self, step):
if step not in self.cache_z3:
self.cache_z3[step] = self.state_line.get_z3_step(step + 1) == self.get_z3_select(step)
self.cache_z3[step] = self.state_line.get_z3_step(step + 1) == super().get_z3_step(step)
return self.cache_z3[step]

def get_z3_change(self, step):
if step not in self.cache_z3_change:
self.cache_z3_change[step] = self.state_line.get_z3_step(step) != self.get_z3_select(step)
self.cache_z3_change[step] = self.state_line.get_z3_step(step) != super().get_z3_step(step)
return self.cache_z3_change[step]

def get_z3_no_change(self, step):
Expand Down Expand Up @@ -1531,15 +1581,14 @@ def __init__(self, nid, property_line, symbol, comment, line_no):
def set_mapped_array_expression(self):
self.property_line = self.property_line.get_mapped_array_expression_for(None)

def get_z3(self):
return super().get_z3(self.property_line)

def get_z3_lambda(self):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(self.property_line)
return self.z3_lambda_line
return super().get_z3_lambda(self.property_line)

def get_z3_step(self, step):
if step not in self.cache_z3:
self.cache_z3[step] = super().get_z3_select(self.property_line.domain, step)
return self.cache_z3[step]
return super().get_z3_step(self.property_line.domain, step)

def get_bitwuzla(self, tm):
return super().get_bitwuzla(self.property_line, tm)
Expand Down

0 comments on commit a4b8d7c

Please sign in to comment.