From c54f424904f6abd9af90b9b22e1f922dc2f2e631 Mon Sep 17 00:00:00 2001 From: Everett Hildenbrandt Date: Fri, 6 Sep 2024 02:43:57 -0600 Subject: [PATCH] Improvements/refactors to substitution extraction routine (#4631) Pulled out of: https://github.com/runtimeverification/k/pull/4625 This PR improves the substitution extraction machinery in `kast.manip`, and adds tests. This isn't used anywhere at the moment, but #4625 will start using it heavily. - The code for `extract_substs` is simplified. - The cases of circular substitutions are handled slightly more gracefully. --------- Co-authored-by: rv-jenkins --- pyk/src/pyk/kast/manip.py | 65 ++++++++++++--------------- pyk/src/tests/unit/kast/test_subst.py | 7 ++- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/pyk/src/pyk/kast/manip.py b/pyk/src/pyk/kast/manip.py index 932397bf80c..f92c8d69798 100644 --- a/pyk/src/pyk/kast/manip.py +++ b/pyk/src/pyk/kast/manip.py @@ -215,48 +215,39 @@ def extract_rhs(term: KInner) -> KInner: def extract_subst(term: KInner) -> tuple[Subst, KInner]: - def _subst_for_terms(term1: KInner, term2: KInner) -> Subst | None: - if type(term1) is KVariable and type(term2) not in {KToken, KVariable}: - return Subst({term1.name: term2}) - if type(term2) is KVariable and type(term1) not in {KToken, KVariable}: - return Subst({term2.name: term1}) - return None - - def _extract_subst(conjunct: KInner) -> Subst | None: - if type(conjunct) is KApply: - if conjunct.label.name == '#Equals': - subst = _subst_for_terms(conjunct.args[0], conjunct.args[1]) - - if subst is not None: - return subst - - if ( - conjunct.args[0] == TRUE - and type(conjunct.args[1]) is KApply - and conjunct.args[1].label.name in {'_==K_', '_==Int_'} - ): - subst = _subst_for_terms(conjunct.args[1].args[0], conjunct.args[1].args[1]) - - if subst is not None: - return subst + _subst = {} + rem_conjuncts: list[KInner] = [] + def _extract_subst(_term1: KInner, _term2: KInner) -> tuple[str, KInner] | None: + if ( + (type(_term1) is KVariable and _term1.name not in _subst) + and not (type(_term2) is KVariable and _term2.name in _subst) + and _term1.name not in free_vars(_term2) + ): + return (_term1.name, _term2) + if ( + (type(_term2) is KVariable and _term2.name not in _subst) + and not (type(_term1) is KVariable and _term1.name in _subst) + and _term2.name not in free_vars(_term1) + ): + return (_term2.name, _term1) + if _term1 == TRUE and type(_term2) is KApply and _term2.label.name in {'_==K_', '_==Int_'}: + return _extract_subst(_term2.args[0], _term2.args[1]) + if _term2 == TRUE and type(_term1) is KApply and _term1.label.name in {'_==K_', '_==Int_'}: + return _extract_subst(_term1.args[0], _term1.args[1]) return None - conjuncts = flatten_label('#And', term) - subst = Subst() - rem_conjuncts: list[KInner] = [] - - for conjunct in conjuncts: - new_subst = _extract_subst(conjunct) - if new_subst is None: - rem_conjuncts.append(conjunct) + for conjunct in flatten_label('#And', term): + if type(conjunct) is KApply and conjunct.label.name == '#Equals': + if _conjunct_subst := _extract_subst(conjunct.args[0], conjunct.args[1]): + name, value = _conjunct_subst + _subst[name] = value + else: + rem_conjuncts.append(conjunct) else: - new_subst = subst.union(new_subst) - if new_subst is None: - raise ValueError('Conflicting substitutions') # TODO handle this case - subst = new_subst + rem_conjuncts.append(conjunct) - return subst, mlAnd(rem_conjuncts) + return Subst(_subst), mlAnd(rem_conjuncts) def count_vars(term: KInner) -> Counter[str]: diff --git a/pyk/src/tests/unit/kast/test_subst.py b/pyk/src/tests/unit/kast/test_subst.py index 3a06f698f5f..1ce2b11d2d6 100644 --- a/pyk/src/tests/unit/kast/test_subst.py +++ b/pyk/src/tests/unit/kast/test_subst.py @@ -133,8 +133,11 @@ def test_ml_pred(test_id: str, subst: Subst, pred: KInner) -> None: (a, {}, a), (mlEquals(a, b), {}, mlEquals(a, b)), (mlEquals(x, a), {'x': a}, mlTop()), - (mlEquals(x, _0), {}, mlEquals(x, _0)), - (mlEquals(x, y), {}, mlEquals(x, y)), + (mlEquals(x, _0), {'x': _0}, mlTop()), + (mlEquals(x, y), {'x': y}, mlTop()), + (mlEquals(x, f(x)), {}, mlEquals(x, f(x))), + (mlAnd([mlEquals(x, y), mlEquals(x, b)]), {'x': y}, mlEquals(x, b)), + (mlAnd([mlEquals(x, b), mlEquals(x, y)]), {'x': b}, mlEquals(x, y)), (mlAnd([mlEquals(a, b), mlEquals(x, a)]), {'x': a}, mlEquals(a, b)), (mlEqualsTrue(_EQ(a, b)), {}, mlEqualsTrue(_EQ(a, b))), (mlEqualsTrue(_EQ(x, a)), {'x': a}, mlTop()),