Skip to content

Commit

Permalink
Cleanups to handling ml predicates and substitutions (#4625)
Browse files Browse the repository at this point in the history
~Blocked on: #4631
~Blocked on: #4630
~Blocked on: #4633

While reviewing and going over
#4621 with @Stevengre , it
became somewhat clear that how we handle turning substitions into ML
predicates is a bit dirty. This attempts to clean this up a bit. Where
potentially breaking changes to API are introduced here, I've checked if
it affects the following repos when I mention "downstream" below:
`evm-semantics kontrol wasm-semantics riscv-semantics mir-semantics`.

In particular:

- The function `CTerm.anti_unify` has a simplification where it reuses a
function from `kast.manip` instead of reimplementing it.
- The functions `CSubst.from_pred` and `CSubst.pred` are added, as
replacements for `Subst.ml_pred`. This is because `Subst.ml_pred`
doesn't have a good way to produce correctly sorted predicates, because
it's in module `kast.inner`.
- `Subst.ml_pred` is removed, and tests are updated to use the new
`CSubst` variant. None of the downstream repositories use
`Subst.ml_pred` directly.
- The new `CSubst.pred` correctly sorts the generated `#Equals` clauses,
defaulting to `K` sort or if a `KDefinition` is supplied using it to do
sort inference. It also provides options for controlling whether we
include the substitution or the constraints in the generated predicate.
- A test is added for a `CSubst.pred` case which caused a bug in the
integration tests dealing with identity substitutions.
- The `CTermSymbolic.implies` function is updated to reuse
`CSubst.from_pred` instead of reimplementing it.
- On the case of duplicate entries, the first is kept and the latter are
made as predicates.
  • Loading branch information
ehildenb authored Sep 9, 2024
1 parent 15cbf88 commit 2f4d4f5
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 69 deletions.
39 changes: 25 additions & 14 deletions pyk/src/pyk/cterm/cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from typing import TYPE_CHECKING

from ..kast import KInner
from ..kast.inner import KApply, KRewrite, KToken, Subst, bottom_up
from ..kast.inner import KApply, KRewrite, KToken, KVariable, Subst, bottom_up
from ..kast.manip import (
abstract_term_safely,
build_claim,
build_rule,
extract_subst,
flatten_label,
free_vars,
ml_pred_to_bool,
Expand All @@ -20,9 +21,9 @@
split_config_and_constraints,
split_config_from,
)
from ..prelude.k import GENERATED_TOP_CELL
from ..prelude.k import GENERATED_TOP_CELL, K
from ..prelude.kbool import andBool, orBool
from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEqualsTrue, mlImplies, mlTop
from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEquals, mlEqualsTrue, mlImplies, mlTop
from ..utils import unique

if TYPE_CHECKING:
Expand Down Expand Up @@ -217,17 +218,7 @@ def anti_unify(
if KToken('true', 'Bool') not in [disjunct_lhs, disjunct_rhs]:
new_cterm = new_cterm.add_constraint(mlEqualsTrue(orBool([disjunct_lhs, disjunct_rhs])))

new_constraints = []
fvs = new_cterm.free_vars
len_fvs = 0
while len_fvs < len(fvs):
len_fvs = len(fvs)
for constraint in common_constraints:
if constraint not in new_constraints:
constraint_fvs = free_vars(constraint)
if any(fv in fvs for fv in constraint_fvs):
new_constraints.append(constraint)
fvs = fvs | constraint_fvs
new_constraints = remove_useless_constraints(common_constraints, new_cterm.free_vars)

for constraint in new_constraints:
new_cterm = new_cterm.add_constraint(constraint)
Expand Down Expand Up @@ -341,6 +332,26 @@ def from_dict(dct: dict[str, Any]) -> CSubst:
constraints = (KInner.from_dict(c) for c in dct['constraints'])
return CSubst(subst=subst, constraints=constraints)

@staticmethod
def from_pred(pred: KInner) -> CSubst:
"""Extract from a boolean predicate a CSubst."""
subst, pred = extract_subst(pred)
return CSubst(subst=subst, constraints=flatten_label('#And', pred))

def pred(self, sort_with: KDefinition | None = None, subst: bool = True, constraints: bool = True) -> KInner:
"""Return an ML predicate representing this substitution."""
_preds: list[KInner] = []
if subst:
for k, v in self.subst.minimize().items():
sort = K
if sort_with is not None:
_sort = sort_with.sort(v)
sort = _sort if _sort is not None else sort
_preds.append(mlEquals(KVariable(k, sort=sort), v, arg_sort=sort))
if constraints:
_preds.extend(self.constraints)
return mlAnd(_preds)

@property
def constraint(self) -> KInner:
"""Return the set of constraints as a single flattened constraint using `mlAnd`."""
Expand Down
17 changes: 3 additions & 14 deletions pyk/src/pyk/cterm/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
kore_server,
)
from ..prelude.k import GENERATED_TOP_CELL, K_ITEM
from ..prelude.ml import is_top, mlEquals
from ..prelude.ml import mlAnd

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -267,19 +267,8 @@ def implies(
raise ValueError('Received empty predicate for valid implication.')
ml_subst = self.kore_to_kast(result.substitution)
ml_pred = self.kore_to_kast(result.predicate)
ml_preds = flatten_label('#And', ml_pred)
if is_top(ml_subst):
csubst = CSubst(subst=Subst({}), constraints=ml_preds)
return CTermImplies(csubst, (), None, result.logs)
subst_pattern = mlEquals(KVariable('###VAR'), KVariable('###TERM'))
_subst: dict[str, KInner] = {}
for subst_pred in flatten_label('#And', ml_subst):
m = subst_pattern.match(subst_pred)
if m is not None and type(m['###VAR']) is KVariable:
_subst[m['###VAR'].name] = m['###TERM']
else:
raise AssertionError(f'Received a non-substitution from implies endpoint: {subst_pred}')
csubst = CSubst(subst=Subst(_subst), constraints=ml_preds)
ml_subst_pred = mlAnd(flatten_label('#And', ml_subst) + flatten_label('#And', ml_pred))
csubst = CSubst.from_pred(ml_subst_pred)
return CTermImplies(csubst, (), None, result.logs)

def assume_defined(self, cterm: CTerm, module_name: str | None = None) -> CTerm:
Expand Down
14 changes: 0 additions & 14 deletions pyk/src/pyk/kast/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,20 +749,6 @@ def from_pred(pred: KInner) -> Subst:
raise ValueError(f'Invalid substitution predicate: {conjunct}')
return Subst(subst)

@property
def ml_pred(self) -> KInner:
"""Turn this `Subst` into a matching logic predicate using `{_#Equals_}` operator."""
items = []
for k in self:
if KVariable(k) != self[k]:
items.append(KApply('#Equals', [KVariable(k), self[k]]))
if len(items) == 0:
return KApply('#Top')
ml_term = items[0]
for _i in items[1:]:
ml_term = KApply('#And', [ml_term, _i])
return ml_term

@property
def pred(self) -> KInner:
"""Turn this `Subst` into a boolean predicate using `_==K_` operator."""
Expand Down
4 changes: 3 additions & 1 deletion pyk/src/pyk/kcfg/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def dump(self, cfgid: str, cfg: KCFG, dump_dir: Path, dot: bool = False) -> None
cover_file = covers_dir / f'config_{cover.source.id}_{cover.target.id}.txt'
cover_constraint_file = covers_dir / f'constraint_{cover.source.id}_{cover.target.id}.txt'

subst_equalities = flatten_label('#And', cover.csubst.subst.ml_pred)
subst_equalities = flatten_label(
'#And', cover.csubst.pred(sort_with=self.kprint.definition, constraints=False)
)

if not cover_file.exists():
cover_file.write_text('\n'.join(self.kprint.pretty_print(se) for se in subst_equalities))
Expand Down
12 changes: 10 additions & 2 deletions pyk/src/pyk/kcfg/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,12 @@ def _cterm_text(cterm: CTerm) -> tuple[str, str]:
term_str, constraint_str = _cterm_text(crewrite)

elif type(self._element) is KCFG.Cover:
subst_equalities = map(_boolify, flatten_label('#And', self._element.csubst.subst.ml_pred))
subst_equalities = map(
_boolify,
flatten_label(
'#And', self._element.csubst.pred(sort_with=self._kprint.definition, constraints=False)
),
)
constraints = map(_boolify, flatten_label('#And', self._element.csubst.constraint))
term_str = '\n'.join(self._kprint.pretty_print(se) for se in subst_equalities)
constraint_str = '\n'.join(self._kprint.pretty_print(c) for c in constraints)
Expand All @@ -320,7 +325,10 @@ def _cterm_text(cterm: CTerm) -> tuple[str, str]:
term_strs.append('')
term_strs.append(f' - {shorten_hashes(target_id)}')
if len(csubst.subst) > 0:
subst_equalities = map(_boolify, flatten_label('#And', csubst.subst.ml_pred))
subst_equalities = map(
_boolify,
flatten_label('#And', csubst.pred(sort_with=self._kprint.definition, constraints=False)),
)
term_strs.extend(f' {self._kprint.pretty_print(cline)}' for cline in subst_equalities)
if len(csubst.constraints) > 0:
constraints = map(_boolify, flatten_label('#And', csubst.constraint))
Expand Down
6 changes: 4 additions & 2 deletions pyk/src/pyk/proof/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,16 @@ def from_spec_modules(

return res

def path_constraints(self, final_node_id: NodeIdLike) -> KInner:
def path_constraints(self, final_node_id: NodeIdLike, sort_with: KDefinition | None = None) -> KInner:
path = self.shortest_path_to(final_node_id)
curr_constraint: KInner = mlTop()
for edge in reversed(path):
if type(edge) is KCFG.Split:
assert len(edge.targets) == 1
csubst = edge.splits[edge.targets[0].id]
curr_constraint = mlAnd([csubst.subst.minimize().ml_pred, csubst.constraint, curr_constraint])
curr_constraint = mlAnd(
[csubst.pred(sort_with=sort_with, constraints=False), csubst.constraint, curr_constraint]
)
if type(edge) is KCFG.Cover:
curr_constraint = mlAnd([edge.csubst.constraint, edge.csubst.subst.apply(curr_constraint)])
return mlAnd(flatten_label('#And', curr_constraint))
Expand Down
20 changes: 0 additions & 20 deletions pyk/src/tests/unit/kast/test_subst.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from pyk.kast.inner import KApply, KLabel, KVariable, Subst
from pyk.kast.manip import extract_subst
from pyk.prelude.kbool import TRUE
from pyk.prelude.kint import INT, intToken
from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlOr, mlTop

Expand Down Expand Up @@ -108,25 +107,6 @@ def test_unapply(term: KInner, subst: dict[str, KInner], expected: KInner) -> No
assert actual == expected


ML_PRED_TEST_DATA: Final = (
('empty', Subst({}), KApply('#Top')),
('singleton', Subst({'X': TRUE}), KApply('#Equals', [KVariable('X'), TRUE])),
(
'double',
Subst({'X': TRUE, 'Y': intToken(4)}),
KApply(
'#And',
[KApply('#Equals', [KVariable('X'), TRUE]), KApply('#Equals', [KVariable('Y'), intToken(4)])],
),
),
)


@pytest.mark.parametrize('test_id,subst,pred', ML_PRED_TEST_DATA, ids=[test_id for test_id, *_ in ML_PRED_TEST_DATA])
def test_ml_pred(test_id: str, subst: Subst, pred: KInner) -> None:
assert subst.ml_pred == pred


_0 = intToken(0)
_EQ = KLabel('_==Int_')
EXTRACT_SUBST_TEST_DATA: Final[tuple[tuple[KInner, dict[str, KInner], KInner], ...]] = (
Expand Down
27 changes: 25 additions & 2 deletions pyk/src/tests/unit/test_cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from pyk.kast import Atts, KAtt
from pyk.kast.inner import KApply, KLabel, KRewrite, KSequence, KSort, KVariable, Subst
from pyk.kast.outer import KClaim
from pyk.prelude.k import GENERATED_TOP_CELL
from pyk.prelude.k import GENERATED_TOP_CELL, K
from pyk.prelude.kbool import TRUE
from pyk.prelude.kint import INT, intToken
from pyk.prelude.ml import mlAnd, mlEqualsTrue
from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlTop

from .utils import a, b, c, f, g, ge_ml, h, k, lt_ml, x, y, z

Expand Down Expand Up @@ -188,6 +189,28 @@ def test_from_kast(test_id: str, kast: KInner, expected: CTerm) -> None:
assert cterm == expected


ML_PRED_TEST_DATA: Final = (
('empty', CSubst(Subst({})), mlTop()),
('singleton', CSubst(Subst({'X': TRUE})), mlEquals(KVariable('X', sort=K), TRUE, arg_sort=K)),
('identity', CSubst(Subst({'X': KVariable('X')})), mlTop()),
(
'double',
CSubst(Subst({'X': TRUE, 'Y': intToken(4)})),
mlAnd(
[
mlEquals(KVariable('X', sort=K), TRUE, arg_sort=K),
mlEquals(KVariable('Y', sort=K), intToken(4), arg_sort=K),
]
),
),
)


@pytest.mark.parametrize('test_id,csubst,pred', ML_PRED_TEST_DATA, ids=[test_id for test_id, *_ in ML_PRED_TEST_DATA])
def test_ml_pred(test_id: str, csubst: CSubst, pred: KInner) -> None:
assert csubst.pred() == pred


APPLY_TEST_DATA: Final = (
(CTerm.top(), CSubst(), CTerm.top()),
(CTerm.bottom(), CSubst(), CTerm.bottom()),
Expand Down

0 comments on commit 2f4d4f5

Please sign in to comment.