Skip to content

Commit

Permalink
Pushing select and substitution code down in class hierarchy to expre…
Browse files Browse the repository at this point in the history
…ssions
  • Loading branch information
ckirsch committed Dec 5, 2024
1 parent 0e958de commit 7f38869
Showing 1 changed file with 101 additions and 144 deletions.
245 changes: 101 additions & 144 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,86 +238,6 @@ def get(nid):
assert Line.is_defined(nid), f"undefined nid {self.nid} @ {self.line_no}"
return Line.lines[nid]

def get_z3_lambda(self, line):
if self.z3_lambda_line is None:
if line.domain:
self.z3_lambda_line = z3.Lambda([state.get_z3() for state in line.domain], line.get_z3())
else:
self.z3_lambda_line = line.get_z3()
return self.z3_lambda_line

def get_z3_select(self, domain, step):
if step not in self.cache_z3_instance:
if domain:
self.cache_z3_instance[step] = z3.Select(self.get_z3_lambda(), *[state.get_z3_step(step) for state in domain])
else:
self.cache_z3_instance[step] = self.get_z3_lambda()
return self.cache_z3_instance[step]

def get_z3_substitute(self, domain, step):
assert step >= 0
if step not in self.cache_z3_instance:
self.z3 = self.get_z3()
if domain:
if step == 0:
current_states = [state.get_z3() for state in domain]
else:
# assuming that self.z3 is a term over states of step - 1
current_states = [state.get_z3_step(step - 1) for state in domain]
next_states = [state.get_z3_step(step) for state in domain]
renaming = list(zip(current_states, next_states))

self.z3 = z3.substitute(self.z3, renaming)
self.cache_z3_instance[step] = self.z3
return self.cache_z3_instance[step]

def get_z3_step(self, domain, step):
if Line.LAMBDAS:
return self.get_z3_select(domain, step)
else:
return self.get_z3_substitute(domain, step)

def get_bitwuzla_lambda(self, line, tm):
if self.bitwuzla_lambda_line is None:
if line.domain:
self.bitwuzla_lambda_line = tm.mk_term(bitwuzla.Kind.LAMBDA,
[*[state.get_bitwuzla(tm) for state in line.domain], line.get_bitwuzla(tm)])
else:
self.bitwuzla_lambda_line = line.get_bitwuzla(tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, domain, step, tm):
if step not in self.cache_bitwuzla_instance:
if domain:
self.cache_bitwuzla_instance[step] = tm.mk_term(bitwuzla.Kind.APPLY,
[self.get_bitwuzla_lambda(tm), *[state.get_bitwuzla_step(step, tm) for state in domain]])
else:
self.cache_bitwuzla_instance[step] = self.get_bitwuzla_lambda(tm)
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_substitute(self, domain, step, tm):
assert step >= 0
if step not in self.cache_bitwuzla_instance:
self.bitwuzla = self.get_bitwuzla(tm)
if domain:
if step == 0:
current_states = [state.get_bitwuzla(tm) for state in domain]
else:
# assuming that self.bitwuzla is a term over states of step - 1
current_states = [state.get_bitwuzla_step(step - 1, tm) for state in domain]
next_states = [state.get_bitwuzla_step(step, tm) for state in domain]
renaming = dict(zip(current_states, next_states))

self.bitwuzla = tm.substitute_term(self.bitwuzla, renaming)
self.cache_bitwuzla_instance[step] = self.bitwuzla
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_step(self, domain, step, tm):
if Line.LAMBDAS:
return self.get_bitwuzla_select(domain, step, tm)
else:
return self.get_bitwuzla_substitute(domain, step, tm)

class Sort(Line):
keyword = OP_SORT

Expand Down Expand Up @@ -440,6 +360,91 @@ def __init__(self, nid, sid_line, domain, comment, line_no):
if not isinstance(sid_line, Sort):
raise model_error("sort", line_no)

def get_z3_lambda(self, seq_line):
if seq_line.z3_lambda is None:
if self.domain:
seq_line.z3_lambda = z3.Lambda([state.get_z3() for state in self.domain], self.get_z3())
else:
seq_line.z3_lambda = self.get_z3()
return seq_line.z3_lambda

def get_z3_select(self, seq_line, step):
if step not in seq_line.cache_z3_instance:
if self.domain:
seq_line.cache_z3_instance[step] = z3.Select(self.get_z3_lambda(seq_line),
*[state.get_z3_step(step) for state in self.domain])
else:
seq_line.cache_z3_instance[step] = self.get_z3_lambda(seq_line)
return seq_line.cache_z3_instance[step]

def get_z3_substitute(self, seq_line, step):
assert step >= 0
if step not in seq_line.cache_z3_instance:
if step == 0:
seq_line.cache_z3_instance[step] = self.get_z3()
else:
seq_line.cache_z3_instance[step] = seq_line.cache_z3_instance[step - 1]
if self.domain:
if step == 0:
current_states = [state.get_z3() for state in self.domain]
else:
# assuming that self.z3 is a term over states of step - 1
current_states = [state.get_z3_step(step - 1) for state in self.domain]
next_states = [state.get_z3_step(step) for state in self.domain]
renaming = list(zip(current_states, next_states))

seq_line.cache_z3_instance[step] = z3.substitute(seq_line.cache_z3_instance[step], renaming)
return seq_line.cache_z3_instance[step]

def get_z3_instance(self, seq_line, step):
if Line.LAMBDAS:
return self.get_z3_select(seq_line, step)
else:
return self.get_z3_substitute(seq_line, step)

def get_bitwuzla_lambda(self, seq_line, tm):
if seq_line.bitwuzla_lambda is None:
if self.domain:
seq_line.bitwuzla_lambda = tm.mk_term(bitwuzla.Kind.LAMBDA,
[*[state.get_bitwuzla(tm) for state in self.domain], self.get_bitwuzla(tm)])
else:
seq_line.bitwuzla_lambda = self.get_bitwuzla(tm)
return seq_line.bitwuzla_lambda

def get_bitwuzla_select(self, seq_line, step, tm):
if step not in seq_line.cache_bitwuzla_instance:
if self.domain:
seq_line.cache_bitwuzla_instance[step] = tm.mk_term(bitwuzla.Kind.APPLY,
[self.get_bitwuzla_lambda(seq_line, tm),
*[state.get_bitwuzla_step(step, tm) for state in self.domain]])
else:
seq_line.cache_bitwuzla_instance[step] = self.get_bitwuzla_lambda(seq_line, tm)
return seq_line.cache_bitwuzla_instance[step]

def get_bitwuzla_substitute(self, seq_line, step, tm):
assert step >= 0
if step not in seq_line.cache_bitwuzla_instance:
if step == 0:
seq_line.cache_bitwuzla_instance[step] = self.get_bitwuzla(tm)
else:
seq_line.cache_bitwuzla_instance[step] = seq_line.cache_bitwuzla_instance[step - 1]
if self.domain:
if step == 0:
current_states = [state.get_bitwuzla(tm) for state in self.domain]
else:
current_states = [state.get_bitwuzla_step(step - 1, tm) for state in self.domain]
next_states = [state.get_bitwuzla_step(step, tm) for state in self.domain]
renaming = dict(zip(current_states, next_states))

seq_line.cache_bitwuzla_instance[step] = tm.substitute_term(seq_line.cache_bitwuzla_instance[step], renaming)
return seq_line.cache_bitwuzla_instance[step]

def get_bitwuzla_instance(self, seq_line, step, tm):
if Line.LAMBDAS:
return self.get_bitwuzla_select(seq_line, step, tm)
else:
return self.get_bitwuzla_substitute(seq_line, step, tm)

class Constant(Expression):
def __init__(self, nid, sid_line, value, comment, line_no):
super().__init__(nid, sid_line, {}, comment, line_no)
Expand Down Expand Up @@ -1203,15 +1208,15 @@ def __init__(self, nid, sid_line, arg1_line, arg2_line, arg3_line, comment, line
self.ite_cache = {}
if comment == "; branch true condition":
Ite.branching_conditions = self
self.z3_lambda_line = None
self.z3_lambda = None
self.cache_z3_instance = {}
self.bitwuzla_lambda_line = None
self.bitwuzla_lambda = None
self.cache_bitwuzla_instance = {}
elif comment == "; branch false condition":
Ite.non_branching_conditions = self
self.z3_lambda_line = None
self.z3_lambda = None
self.cache_z3_instance = {}
self.bitwuzla_lambda_line = None
self.bitwuzla_lambda = None
self.cache_bitwuzla_instance = {}

def copy(self, arg1_line, arg2_line, arg3_line):
Expand All @@ -1234,13 +1239,9 @@ def get_z3(self):
self.arg2_line.get_z3(), self.arg3_line.get_z3())
return self.z3

def get_z3_lambda(self):
# only needed for branching
return super().get_z3_lambda(self)

def get_z3_step(self, step):
# only needed for branching
return super().get_z3_step(self.domain, step)
return self.get_z3_instance(self, step)

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
Expand All @@ -1250,13 +1251,9 @@ def get_bitwuzla(self, tm):
self.arg3_line.get_bitwuzla(tm)])
return self.bitwuzla

def get_bitwuzla_lambda(self, tm):
# only needed for branching
return super().get_bitwuzla_lambda(self, tm)

def get_bitwuzla_step(self, step, tm):
# only needed for branching
return super().get_bitwuzla_step(self.domain, step, tm)
return self.get_bitwuzla_instance(self, step, tm)

class Write(Ternary):
keyword = OP_WRITE
Expand Down Expand Up @@ -1327,21 +1324,11 @@ class Sequential(Line, Cache):
def __init__(self, nid, comment, line_no):
Line.__init__(self, nid, comment, line_no)
Cache.__init__(self)
self.z3_lambda_line = None
self.z3_lambda = None
self.cache_z3_instance = {}
self.bitwuzla_lambda_line = None
self.bitwuzla_lambda = None
self.cache_bitwuzla_instance = {}

def get_z3(self, line):
if self.z3 is None:
self.z3 = line.get_z3()
return self.z3

def get_bitwuzla(self, line, tm):
if self.bitwuzla is None:
self.bitwuzla = line.get_bitwuzla(tm)
return self.bitwuzla

class Transitional(Sequential):
def __init__(self, nid, sid_line, state_line, exp_line, comment, line_no, array_line, index):
super().__init__(nid, comment, line_no)
Expand Down Expand Up @@ -1390,24 +1377,6 @@ def new_transition(self, transitions, index):
assert self.nid not in transitions, f"transition nid {self.nid} already defined @ {self.line_no}"
transitions[self.nid] = self

def get_z3(self):
return super().get_z3(self.exp_line)

def get_z3_lambda(self):
return super().get_z3_lambda(self.exp_line)

def get_z3_step(self, step):
return super().get_z3_step(self.exp_line.domain, step)

def get_bitwuzla(self, tm):
return super().get_bitwuzla(self.exp_line, tm)

def get_bitwuzla_lambda(self, tm):
return super().get_bitwuzla_lambda(self.exp_line, tm)

def get_bitwuzla_step(self, step, tm):
return super().get_bitwuzla_step(self.exp_line.domain, step, tm)

def set_value(self):
self.state_line.set_value(self.exp_line.get_value())

Expand Down Expand Up @@ -1441,7 +1410,7 @@ def get_z3_step(self, step):
else:
if isinstance(self.exp_line, Constant):
self.set_value()
return self.state_line.get_z3_step(0) == super().get_z3_step(0)
return self.state_line.get_z3_step(0) == self.exp_line.get_z3_instance(self, 0)

def get_bitwuzla_step(self, step, tm):
assert step == 0, f"bitwuzla init with {step} != 0"
Expand All @@ -1456,7 +1425,7 @@ def get_bitwuzla_step(self, step, tm):
self.set_value()
return tm.mk_term(bitwuzla.Kind.EQUAL,
[self.state_line.get_bitwuzla_step(0, tm),
super().get_bitwuzla_step(0, tm)])
self.exp_line.get_bitwuzla_instance(self, 0, tm)])

class Next(Transitional):
keyword = OP_NEXT
Expand All @@ -1480,12 +1449,12 @@ def __str__(self):

def get_z3_step(self, step):
if step not in self.cache_z3:
self.cache_z3[step] = self.state_line.get_z3_step(step + 1) == super().get_z3_step(step)
self.cache_z3[step] = self.state_line.get_z3_step(step + 1) == self.exp_line.get_z3_instance(self, step)
return self.cache_z3[step]

def get_z3_change(self, step):
if step not in self.cache_z3_change:
self.cache_z3_change[step] = self.state_line.get_z3_step(step) != super().get_z3_step(step)
self.cache_z3_change[step] = self.state_line.get_z3_step(step) != self.exp_line.get_z3_instance(self, step)
return self.cache_z3_change[step]

def get_z3_no_change(self, step):
Expand All @@ -1497,14 +1466,14 @@ def get_bitwuzla_step(self, step, tm):
if step not in self.cache_bitwuzla:
self.cache_bitwuzla[step] = tm.mk_term(bitwuzla.Kind.EQUAL,
[self.state_line.get_bitwuzla_step(step + 1, tm),
super().get_bitwuzla_step(step, tm)])
self.exp_line.get_bitwuzla_instance(self, step, tm)])
return self.cache_bitwuzla[step]

def get_bitwuzla_change(self, step, tm):
if step not in self.cache_bitwuzla_change:
self.cache_bitwuzla_change[step] = tm.mk_term(bitwuzla.Kind.DISTINCT,
[self.state_line.get_bitwuzla_step(step, tm),
super().get_bitwuzla_step(step, tm)])
self.exp_line.get_bitwuzla_instance(self, step, tm)])
return self.cache_bitwuzla_change[step]

def get_bitwuzla_no_change(self, step, tm):
Expand All @@ -1529,23 +1498,11 @@ def __init__(self, nid, property_line, symbol, comment, line_no):
def set_mapped_array_expression(self):
self.property_line = self.property_line.get_mapped_array_expression_for(None)

def get_z3(self):
return super().get_z3(self.property_line)

def get_z3_lambda(self):
return super().get_z3_lambda(self.property_line)

def get_z3_step(self, step):
return super().get_z3_step(self.property_line.domain, step)

def get_bitwuzla(self, tm):
return super().get_bitwuzla(self.property_line, tm)

def get_bitwuzla_lambda(self, tm):
return super().get_bitwuzla_lambda(self.property_line, tm)
return self.property_line.get_z3_instance(self, step)

def get_bitwuzla_step(self, step, tm):
return super().get_bitwuzla_step(self.property_line.domain, step, tm)
return self.property_line.get_bitwuzla_instance(self, step, tm)

class Constraint(Property):
keyword = OP_CONSTRAINT
Expand Down

0 comments on commit 7f38869

Please sign in to comment.