diff --git a/pyk/src/pyk/cterm/cterm.py b/pyk/src/pyk/cterm/cterm.py index 6a8189aa75..0c04e0ff20 100644 --- a/pyk/src/pyk/cterm/cterm.py +++ b/pyk/src/pyk/cterm/cterm.py @@ -6,11 +6,12 @@ from typing import TYPE_CHECKING from ..kast import KInner -from ..kast.inner import KApply, KRewrite, KToken, Subst, bottom_up +from ..kast.inner import KApply, KRewrite, KToken, KVariable, Subst, bottom_up from ..kast.manip import ( abstract_term_safely, build_claim, build_rule, + extract_subst, flatten_label, free_vars, ml_pred_to_bool, @@ -20,9 +21,9 @@ split_config_and_constraints, split_config_from, ) -from ..prelude.k import GENERATED_TOP_CELL +from ..prelude.k import GENERATED_TOP_CELL, K from ..prelude.kbool import andBool, orBool -from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEqualsTrue, mlImplies, mlTop +from ..prelude.ml import is_bottom, is_top, mlAnd, mlBottom, mlEquals, mlEqualsTrue, mlImplies, mlTop from ..utils import unique if TYPE_CHECKING: @@ -217,17 +218,7 @@ def anti_unify( if KToken('true', 'Bool') not in [disjunct_lhs, disjunct_rhs]: new_cterm = new_cterm.add_constraint(mlEqualsTrue(orBool([disjunct_lhs, disjunct_rhs]))) - new_constraints = [] - fvs = new_cterm.free_vars - len_fvs = 0 - while len_fvs < len(fvs): - len_fvs = len(fvs) - for constraint in common_constraints: - if constraint not in new_constraints: - constraint_fvs = free_vars(constraint) - if any(fv in fvs for fv in constraint_fvs): - new_constraints.append(constraint) - fvs = fvs | constraint_fvs + new_constraints = remove_useless_constraints(common_constraints, new_cterm.free_vars) for constraint in new_constraints: new_cterm = new_cterm.add_constraint(constraint) @@ -341,6 +332,26 @@ def from_dict(dct: dict[str, Any]) -> CSubst: constraints = (KInner.from_dict(c) for c in dct['constraints']) return CSubst(subst=subst, constraints=constraints) + @staticmethod + def from_pred(pred: KInner) -> CSubst: + """Extract from a boolean predicate a CSubst.""" + subst, pred = extract_subst(pred) + return CSubst(subst=subst, constraints=flatten_label('#And', pred)) + + def pred(self, sort_with: KDefinition | None = None, subst: bool = True, constraints: bool = True) -> KInner: + """Return an ML predicate representing this substitution.""" + _preds: list[KInner] = [] + if subst: + for k, v in self.subst.minimize().items(): + sort = K + if sort_with is not None: + _sort = sort_with.sort(v) + sort = _sort if _sort is not None else sort + _preds.append(mlEquals(KVariable(k, sort=sort), v, arg_sort=sort)) + if constraints: + _preds.extend(self.constraints) + return mlAnd(_preds) + @property def constraint(self) -> KInner: """Return the set of constraints as a single flattened constraint using `mlAnd`.""" diff --git a/pyk/src/pyk/cterm/symbolic.py b/pyk/src/pyk/cterm/symbolic.py index 1ad00617d6..9cb52bbb4a 100644 --- a/pyk/src/pyk/cterm/symbolic.py +++ b/pyk/src/pyk/cterm/symbolic.py @@ -25,7 +25,7 @@ kore_server, ) from ..prelude.k import GENERATED_TOP_CELL, K_ITEM -from ..prelude.ml import is_top, mlEquals +from ..prelude.ml import mlAnd if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -267,19 +267,8 @@ def implies( raise ValueError('Received empty predicate for valid implication.') ml_subst = self.kore_to_kast(result.substitution) ml_pred = self.kore_to_kast(result.predicate) - ml_preds = flatten_label('#And', ml_pred) - if is_top(ml_subst): - csubst = CSubst(subst=Subst({}), constraints=ml_preds) - return CTermImplies(csubst, (), None, result.logs) - subst_pattern = mlEquals(KVariable('###VAR'), KVariable('###TERM')) - _subst: dict[str, KInner] = {} - for subst_pred in flatten_label('#And', ml_subst): - m = subst_pattern.match(subst_pred) - if m is not None and type(m['###VAR']) is KVariable: - _subst[m['###VAR'].name] = m['###TERM'] - else: - raise AssertionError(f'Received a non-substitution from implies endpoint: {subst_pred}') - csubst = CSubst(subst=Subst(_subst), constraints=ml_preds) + ml_subst_pred = mlAnd(flatten_label('#And', ml_subst) + flatten_label('#And', ml_pred)) + csubst = CSubst.from_pred(ml_subst_pred) return CTermImplies(csubst, (), None, result.logs) def assume_defined(self, cterm: CTerm, module_name: str | None = None) -> CTerm: diff --git a/pyk/src/pyk/kast/inner.py b/pyk/src/pyk/kast/inner.py index b27a0d3153..e034c048e7 100644 --- a/pyk/src/pyk/kast/inner.py +++ b/pyk/src/pyk/kast/inner.py @@ -749,20 +749,6 @@ def from_pred(pred: KInner) -> Subst: raise ValueError(f'Invalid substitution predicate: {conjunct}') return Subst(subst) - @property - def ml_pred(self) -> KInner: - """Turn this `Subst` into a matching logic predicate using `{_#Equals_}` operator.""" - items = [] - for k in self: - if KVariable(k) != self[k]: - items.append(KApply('#Equals', [KVariable(k), self[k]])) - if len(items) == 0: - return KApply('#Top') - ml_term = items[0] - for _i in items[1:]: - ml_term = KApply('#And', [ml_term, _i]) - return ml_term - @property def pred(self) -> KInner: """Turn this `Subst` into a boolean predicate using `_==K_` operator.""" diff --git a/pyk/src/pyk/kcfg/show.py b/pyk/src/pyk/kcfg/show.py index b4e2db7333..83eb51b4dd 100644 --- a/pyk/src/pyk/kcfg/show.py +++ b/pyk/src/pyk/kcfg/show.py @@ -469,7 +469,9 @@ def dump(self, cfgid: str, cfg: KCFG, dump_dir: Path, dot: bool = False) -> None cover_file = covers_dir / f'config_{cover.source.id}_{cover.target.id}.txt' cover_constraint_file = covers_dir / f'constraint_{cover.source.id}_{cover.target.id}.txt' - subst_equalities = flatten_label('#And', cover.csubst.subst.ml_pred) + subst_equalities = flatten_label( + '#And', cover.csubst.pred(sort_with=self.kprint.definition, constraints=False) + ) if not cover_file.exists(): cover_file.write_text('\n'.join(self.kprint.pretty_print(se) for se in subst_equalities)) diff --git a/pyk/src/pyk/kcfg/tui.py b/pyk/src/pyk/kcfg/tui.py index 128c377b53..f346d81113 100644 --- a/pyk/src/pyk/kcfg/tui.py +++ b/pyk/src/pyk/kcfg/tui.py @@ -309,7 +309,12 @@ def _cterm_text(cterm: CTerm) -> tuple[str, str]: term_str, constraint_str = _cterm_text(crewrite) elif type(self._element) is KCFG.Cover: - subst_equalities = map(_boolify, flatten_label('#And', self._element.csubst.subst.ml_pred)) + subst_equalities = map( + _boolify, + flatten_label( + '#And', self._element.csubst.pred(sort_with=self._kprint.definition, constraints=False) + ), + ) constraints = map(_boolify, flatten_label('#And', self._element.csubst.constraint)) term_str = '\n'.join(self._kprint.pretty_print(se) for se in subst_equalities) constraint_str = '\n'.join(self._kprint.pretty_print(c) for c in constraints) @@ -320,7 +325,10 @@ def _cterm_text(cterm: CTerm) -> tuple[str, str]: term_strs.append('') term_strs.append(f' - {shorten_hashes(target_id)}') if len(csubst.subst) > 0: - subst_equalities = map(_boolify, flatten_label('#And', csubst.subst.ml_pred)) + subst_equalities = map( + _boolify, + flatten_label('#And', csubst.pred(sort_with=self._kprint.definition, constraints=False)), + ) term_strs.extend(f' {self._kprint.pretty_print(cline)}' for cline in subst_equalities) if len(csubst.constraints) > 0: constraints = map(_boolify, flatten_label('#And', csubst.constraint)) diff --git a/pyk/src/pyk/proof/reachability.py b/pyk/src/pyk/proof/reachability.py index 9244a0c850..62f5956f91 100644 --- a/pyk/src/pyk/proof/reachability.py +++ b/pyk/src/pyk/proof/reachability.py @@ -487,14 +487,16 @@ def from_spec_modules( return res - def path_constraints(self, final_node_id: NodeIdLike) -> KInner: + def path_constraints(self, final_node_id: NodeIdLike, sort_with: KDefinition | None = None) -> KInner: path = self.shortest_path_to(final_node_id) curr_constraint: KInner = mlTop() for edge in reversed(path): if type(edge) is KCFG.Split: assert len(edge.targets) == 1 csubst = edge.splits[edge.targets[0].id] - curr_constraint = mlAnd([csubst.subst.minimize().ml_pred, csubst.constraint, curr_constraint]) + curr_constraint = mlAnd( + [csubst.pred(sort_with=sort_with, constraints=False), csubst.constraint, curr_constraint] + ) if type(edge) is KCFG.Cover: curr_constraint = mlAnd([edge.csubst.constraint, edge.csubst.subst.apply(curr_constraint)]) return mlAnd(flatten_label('#And', curr_constraint)) diff --git a/pyk/src/tests/unit/kast/test_subst.py b/pyk/src/tests/unit/kast/test_subst.py index fa5fd1c628..d8fe816605 100644 --- a/pyk/src/tests/unit/kast/test_subst.py +++ b/pyk/src/tests/unit/kast/test_subst.py @@ -7,7 +7,6 @@ from pyk.kast.inner import KApply, KLabel, KVariable, Subst from pyk.kast.manip import extract_subst -from pyk.prelude.kbool import TRUE from pyk.prelude.kint import INT, intToken from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlOr, mlTop @@ -108,25 +107,6 @@ def test_unapply(term: KInner, subst: dict[str, KInner], expected: KInner) -> No assert actual == expected -ML_PRED_TEST_DATA: Final = ( - ('empty', Subst({}), KApply('#Top')), - ('singleton', Subst({'X': TRUE}), KApply('#Equals', [KVariable('X'), TRUE])), - ( - 'double', - Subst({'X': TRUE, 'Y': intToken(4)}), - KApply( - '#And', - [KApply('#Equals', [KVariable('X'), TRUE]), KApply('#Equals', [KVariable('Y'), intToken(4)])], - ), - ), -) - - -@pytest.mark.parametrize('test_id,subst,pred', ML_PRED_TEST_DATA, ids=[test_id for test_id, *_ in ML_PRED_TEST_DATA]) -def test_ml_pred(test_id: str, subst: Subst, pred: KInner) -> None: - assert subst.ml_pred == pred - - _0 = intToken(0) _EQ = KLabel('_==Int_') EXTRACT_SUBST_TEST_DATA: Final[tuple[tuple[KInner, dict[str, KInner], KInner], ...]] = ( diff --git a/pyk/src/tests/unit/test_cterm.py b/pyk/src/tests/unit/test_cterm.py index ba2fce44a6..955894504f 100644 --- a/pyk/src/tests/unit/test_cterm.py +++ b/pyk/src/tests/unit/test_cterm.py @@ -9,9 +9,10 @@ from pyk.kast import Atts, KAtt from pyk.kast.inner import KApply, KLabel, KRewrite, KSequence, KSort, KVariable, Subst from pyk.kast.outer import KClaim -from pyk.prelude.k import GENERATED_TOP_CELL +from pyk.prelude.k import GENERATED_TOP_CELL, K +from pyk.prelude.kbool import TRUE from pyk.prelude.kint import INT, intToken -from pyk.prelude.ml import mlAnd, mlEqualsTrue +from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlTop from .utils import a, b, c, f, g, ge_ml, h, k, lt_ml, x, y, z @@ -188,6 +189,28 @@ def test_from_kast(test_id: str, kast: KInner, expected: CTerm) -> None: assert cterm == expected +ML_PRED_TEST_DATA: Final = ( + ('empty', CSubst(Subst({})), mlTop()), + ('singleton', CSubst(Subst({'X': TRUE})), mlEquals(KVariable('X', sort=K), TRUE, arg_sort=K)), + ('identity', CSubst(Subst({'X': KVariable('X')})), mlTop()), + ( + 'double', + CSubst(Subst({'X': TRUE, 'Y': intToken(4)})), + mlAnd( + [ + mlEquals(KVariable('X', sort=K), TRUE, arg_sort=K), + mlEquals(KVariable('Y', sort=K), intToken(4), arg_sort=K), + ] + ), + ), +) + + +@pytest.mark.parametrize('test_id,csubst,pred', ML_PRED_TEST_DATA, ids=[test_id for test_id, *_ in ML_PRED_TEST_DATA]) +def test_ml_pred(test_id: str, csubst: CSubst, pred: KInner) -> None: + assert csubst.pred() == pred + + APPLY_TEST_DATA: Final = ( (CTerm.top(), CSubst(), CTerm.top()), (CTerm.bottom(), CSubst(), CTerm.bottom()),