Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanups to handling ml predicates and substitutions #4625

Merged
merged 15 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions pyk/src/pyk/cterm/cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from typing import TYPE_CHECKING

from ..kast import KInner
from ..kast.inner import KApply, KRewrite, KToken, Subst, bottom_up
from ..kast.inner import KApply, KRewrite, KToken, KVariable, Subst, bottom_up
from ..kast.manip import (
abstract_term_safely,
build_claim,
build_rule,
extract_subst,
flatten_label,
free_vars,
ml_pred_to_bool,
Expand All @@ -20,10 +21,10 @@
split_config_and_constraints,
split_config_from,
)
from ..prelude.k import GENERATED_TOP_CELL
from ..prelude.k import GENERATED_TOP_CELL, K
from ..prelude.kbool import andBool, orBool
from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEqualsTrue, mlImplies, mlTop
from ..utils import unique
from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEquals, mlEqualsTrue, mlImplies, mlTop
from ..utils import not_none, unique

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -217,17 +218,7 @@ def anti_unify(
if KToken('true', 'Bool') not in [disjunct_lhs, disjunct_rhs]:
new_cterm = new_cterm.add_constraint(mlEqualsTrue(orBool([disjunct_lhs, disjunct_rhs])))

new_constraints = []
fvs = new_cterm.free_vars
len_fvs = 0
while len_fvs < len(fvs):
len_fvs = len(fvs)
for constraint in common_constraints:
if constraint not in new_constraints:
constraint_fvs = free_vars(constraint)
if any(fv in fvs for fv in constraint_fvs):
new_constraints.append(constraint)
fvs = fvs | constraint_fvs
new_constraints = remove_useless_constraints(common_constraints, new_cterm.free_vars)

for constraint in new_constraints:
new_cterm = new_cterm.add_constraint(constraint)
Expand Down Expand Up @@ -322,6 +313,23 @@ def from_dict(dct: dict[str, Any]) -> CSubst:
constraints = (KInner.from_dict(c) for c in dct['constraints'])
return CSubst(subst=subst, constraints=constraints)

@staticmethod
def from_pred(pred: KInner) -> CSubst:
"""Extract from a boolean predicate a CSubst."""
subst, pred = extract_subst(pred)
return CSubst(subst=subst, constraints=flatten_label('#And', pred))

def pred(self, sort_with: KDefinition | None = None, subst: bool = True, constraints: bool = True) -> KInner:
"""Return an ML predicate representing this substitution."""
_preds: list[KInner] = []
if subst:
for k, v in self.subst.minimize().items():
sort = K if not (sort_with and sort_with.sort(v)) else not_none(sort_with.sort(v))
ehildenb marked this conversation as resolved.
Show resolved Hide resolved
_preds.append(mlEquals(KVariable(k, sort=sort), v, arg_sort=sort))
if constraints:
_preds.extend(self.constraints)
return mlAnd(_preds)

@property
def constraint(self) -> KInner:
"""Return the set of constraints as a single flattened constraint using `mlAnd`."""
Expand Down
17 changes: 3 additions & 14 deletions pyk/src/pyk/cterm/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
kore_server,
)
from ..prelude.k import GENERATED_TOP_CELL, K_ITEM
from ..prelude.ml import is_top, mlEquals
from ..prelude.ml import mlAnd

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -267,19 +267,8 @@ def implies(
raise ValueError('Received empty predicate for valid implication.')
ml_subst = self.kore_to_kast(result.substitution)
ml_pred = self.kore_to_kast(result.predicate)
ml_preds = flatten_label('#And', ml_pred)
if is_top(ml_subst):
csubst = CSubst(subst=Subst({}), constraints=ml_preds)
return CTermImplies(csubst, (), None, result.logs)
subst_pattern = mlEquals(KVariable('###VAR'), KVariable('###TERM'))
_subst: dict[str, KInner] = {}
for subst_pred in flatten_label('#And', ml_subst):
m = subst_pattern.match(subst_pred)
if m is not None and type(m['###VAR']) is KVariable:
_subst[m['###VAR'].name] = m['###TERM']
else:
raise AssertionError(f'Received a non-substitution from implies endpoint: {subst_pred}')
csubst = CSubst(subst=Subst(_subst), constraints=ml_preds)
ml_subst_pred = mlAnd(flatten_label('#And', ml_subst) + flatten_label('#And', ml_pred))
csubst = CSubst.from_pred(ml_subst_pred)
return CTermImplies(csubst, (), None, result.logs)

def assume_defined(self, cterm: CTerm, module_name: str | None = None) -> CTerm:
Expand Down
14 changes: 0 additions & 14 deletions pyk/src/pyk/kast/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,20 +749,6 @@ def from_pred(pred: KInner) -> Subst:
raise ValueError(f'Invalid substitution predicate: {conjunct}')
return Subst(subst)

@property
def ml_pred(self) -> KInner:
"""Turn this `Subst` into a matching logic predicate using `{_#Equals_}` operator."""
items = []
for k in self:
if KVariable(k) != self[k]:
items.append(KApply('#Equals', [KVariable(k), self[k]]))
if len(items) == 0:
return KApply('#Top')
ml_term = items[0]
for _i in items[1:]:
ml_term = KApply('#And', [ml_term, _i])
return ml_term

@property
def pred(self) -> KInner:
"""Turn this `Subst` into a boolean predicate using `_==K_` operator."""
Expand Down
4 changes: 3 additions & 1 deletion pyk/src/pyk/kcfg/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def dump(self, cfgid: str, cfg: KCFG, dump_dir: Path, dot: bool = False) -> None
cover_file = covers_dir / f'config_{cover.source.id}_{cover.target.id}.txt'
cover_constraint_file = covers_dir / f'constraint_{cover.source.id}_{cover.target.id}.txt'

subst_equalities = flatten_label('#And', cover.csubst.subst.ml_pred)
subst_equalities = flatten_label(
'#And', cover.csubst.pred(sort_with=self.kprint.definition, constraints=False)
)

if not cover_file.exists():
cover_file.write_text('\n'.join(self.kprint.pretty_print(se) for se in subst_equalities))
Expand Down
12 changes: 10 additions & 2 deletions pyk/src/pyk/kcfg/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,12 @@ def _cterm_text(cterm: CTerm) -> tuple[str, str]:
term_str, constraint_str = _cterm_text(crewrite)

elif type(self._element) is KCFG.Cover:
subst_equalities = map(_boolify, flatten_label('#And', self._element.csubst.subst.ml_pred))
subst_equalities = map(
_boolify,
flatten_label(
'#And', self._element.csubst.pred(sort_with=self._kprint.definition, constraints=False)
),
)
constraints = map(_boolify, flatten_label('#And', self._element.csubst.constraint))
term_str = '\n'.join(self._kprint.pretty_print(se) for se in subst_equalities)
constraint_str = '\n'.join(self._kprint.pretty_print(c) for c in constraints)
Expand All @@ -320,7 +325,10 @@ def _cterm_text(cterm: CTerm) -> tuple[str, str]:
term_strs.append('')
term_strs.append(f' - {shorten_hashes(target_id)}')
if len(csubst.subst) > 0:
subst_equalities = map(_boolify, flatten_label('#And', csubst.subst.ml_pred))
subst_equalities = map(
_boolify,
flatten_label('#And', csubst.pred(sort_with=self._kprint.definition, constraints=False)),
)
term_strs.extend(f' {self._kprint.pretty_print(cline)}' for cline in subst_equalities)
if len(csubst.constraints) > 0:
constraints = map(_boolify, flatten_label('#And', csubst.constraint))
Expand Down
6 changes: 4 additions & 2 deletions pyk/src/pyk/proof/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,16 @@ def from_spec_modules(

return res

def path_constraints(self, final_node_id: NodeIdLike) -> KInner:
def path_constraints(self, final_node_id: NodeIdLike, sort_with: KDefinition | None = None) -> KInner:
path = self.shortest_path_to(final_node_id)
curr_constraint: KInner = mlTop()
for edge in reversed(path):
if type(edge) is KCFG.Split:
assert len(edge.targets) == 1
csubst = edge.splits[edge.targets[0].id]
curr_constraint = mlAnd([csubst.subst.minimize().ml_pred, csubst.constraint, curr_constraint])
curr_constraint = mlAnd(
[csubst.pred(sort_with=sort_with, constraints=False), csubst.constraint, curr_constraint]
)
if type(edge) is KCFG.Cover:
curr_constraint = mlAnd([edge.csubst.constraint, edge.csubst.subst.apply(curr_constraint)])
return mlAnd(flatten_label('#And', curr_constraint))
Expand Down
20 changes: 0 additions & 20 deletions pyk/src/tests/unit/kast/test_subst.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from pyk.kast.inner import KApply, KLabel, KVariable, Subst
from pyk.kast.manip import extract_subst
from pyk.prelude.kbool import TRUE
from pyk.prelude.kint import INT, intToken
from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlOr, mlTop

Expand Down Expand Up @@ -108,25 +107,6 @@ def test_unapply(term: KInner, subst: dict[str, KInner], expected: KInner) -> No
assert actual == expected


ML_PRED_TEST_DATA: Final = (
('empty', Subst({}), KApply('#Top')),
('singleton', Subst({'X': TRUE}), KApply('#Equals', [KVariable('X'), TRUE])),
(
'double',
Subst({'X': TRUE, 'Y': intToken(4)}),
KApply(
'#And',
[KApply('#Equals', [KVariable('X'), TRUE]), KApply('#Equals', [KVariable('Y'), intToken(4)])],
),
),
)


@pytest.mark.parametrize('test_id,subst,pred', ML_PRED_TEST_DATA, ids=[test_id for test_id, *_ in ML_PRED_TEST_DATA])
def test_ml_pred(test_id: str, subst: Subst, pred: KInner) -> None:
assert subst.ml_pred == pred


_0 = intToken(0)
_EQ = KLabel('_==Int_')
EXTRACT_SUBST_TEST_DATA: Final[tuple[tuple[KInner, dict[str, KInner], KInner], ...]] = (
Expand Down
31 changes: 27 additions & 4 deletions pyk/src/tests/unit/test_cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

import pytest

from pyk.cterm import CTerm, cterm_build_claim, cterm_build_rule
from pyk.cterm import CSubst, CTerm, cterm_build_claim, cterm_build_rule
from pyk.kast import Atts, KAtt
from pyk.kast.inner import KApply, KLabel, KRewrite, KSequence, KSort, KVariable
from pyk.kast.inner import KApply, KLabel, KRewrite, KSequence, KSort, KVariable, Subst
from pyk.kast.outer import KClaim
from pyk.prelude.k import GENERATED_TOP_CELL
from pyk.prelude.k import GENERATED_TOP_CELL, K
from pyk.prelude.kbool import TRUE
from pyk.prelude.kint import INT, intToken
from pyk.prelude.ml import mlAnd, mlEqualsTrue
from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlTop

from .utils import a, b, c, f, g, h, k, x, y, z

Expand Down Expand Up @@ -186,3 +187,25 @@ def test_from_kast(test_id: str, kast: KInner, expected: CTerm) -> None:

# Then
assert cterm == expected


ML_PRED_TEST_DATA: Final = (
('empty', CSubst(Subst({})), mlTop()),
('singleton', CSubst(Subst({'X': TRUE})), mlEquals(KVariable('X', sort=K), TRUE, arg_sort=K)),
('identity', CSubst(Subst({'X': KVariable('X')})), mlTop()),
(
'double',
CSubst(Subst({'X': TRUE, 'Y': intToken(4)})),
mlAnd(
[
mlEquals(KVariable('X', sort=K), TRUE, arg_sort=K),
mlEquals(KVariable('Y', sort=K), intToken(4), arg_sort=K),
]
),
),
)


@pytest.mark.parametrize('test_id,csubst,pred', ML_PRED_TEST_DATA, ids=[test_id for test_id, *_ in ML_PRED_TEST_DATA])
def test_ml_pred(test_id: str, csubst: CSubst, pred: KInner) -> None:
assert csubst.pred() == pred
Loading