diff --git a/pyk/src/pyk/cterm/cterm.py b/pyk/src/pyk/cterm/cterm.py index 10f9f1cb52..1b0cbe3f58 100644 --- a/pyk/src/pyk/cterm/cterm.py +++ b/pyk/src/pyk/cterm/cterm.py @@ -102,6 +102,11 @@ def is_bottom(self) -> bool: """Check if a given `CTerm` is trivially empty.""" return is_bottom(self.config, weak=True) or any(is_bottom(cterm, weak=True) for cterm in self.constraints) + @property + def constraint(self) -> KInner: + """Return the set of constraints as a single flattened constraint using `mlAnd`.""" + return mlAnd(self.constraints) + @staticmethod def _constraint_sort_key(term: KInner) -> tuple[int, str]: term_str = str(term) @@ -253,6 +258,46 @@ def remove_useless_constraints(self, keep_vars: Iterable[str] = ()) -> CTerm: return CTerm(self.config, new_constraints) +def merge_cterms(t1: CTerm, t2: CTerm) -> CTerm | None: + """Return a `CTerm` which is the merge of the two input `CTerm` instances. + + Args: + t1: First `CTerm` to merge. + t2: Second `CTerm` to merge. + + Returns: + A `CTerm` which is the merge of the two input `CTerm` instances. + """ + # check all cells in t1 and t1, if they are the same, keep them, otherwise, create a new free variable for them + t1_config, t1_subst = split_config_from(t1.config) + t2_config, t2_subst = split_config_from(t2.config) + + if t1_config != t2_config: + # cannot merge two configurations with different structure + return None + + new_subst = Subst({}) + new_t1_subst = Subst({}) + new_t2_subst = Subst({}) + + for cell in t1_subst: + if t1_subst[cell] == t2_subst[cell]: + # keep the cell if it is the same + new_subst = new_subst * Subst({cell: t1_subst[cell]}) + else: + # create a new free variable for the cell + new_t1_subst = new_t1_subst * Subst({cell: t1_subst[cell]}) + new_t2_subst = new_t2_subst * Subst({cell: t2_subst[cell]}) + new_config = new_subst(t1_config) + + new_constraints: list[KInner] = [] + for new_subst, t in [(new_t1_subst, t1), (new_t2_subst, t2)]: + if new_subst: + new_constraints.append(mlImplies(new_subst.ml_pred, t.constraint)) + + return CTerm(new_config, new_constraints) + + def anti_unify(state1: KInner, state2: KInner, kdef: KDefinition | None = None) -> tuple[KInner, Subst, Subst]: """Return a generalized state over the two input states. diff --git a/pyk/src/tests/unit/test_cterm.py b/pyk/src/tests/unit/test_cterm.py index c9991f9fcd..8268dc93b8 100644 --- a/pyk/src/tests/unit/test_cterm.py +++ b/pyk/src/tests/unit/test_cterm.py @@ -6,12 +6,13 @@ import pytest from pyk.cterm import CTerm, cterm_build_claim, cterm_build_rule +from pyk.cterm.cterm import merge_cterms from pyk.kast import Atts, KAtt from pyk.kast.inner import KApply, KLabel, KRewrite, KSequence, KSort, KVariable from pyk.kast.outer import KClaim from pyk.prelude.k import GENERATED_TOP_CELL from pyk.prelude.kint import INT, intToken -from pyk.prelude.ml import mlAnd, mlEqualsTrue +from pyk.prelude.ml import mlAnd, mlEqualsTrue, mlImplies, mlEquals, mlTop from .utils import a, b, c, f, g, h, k, x, y, z @@ -186,3 +187,104 @@ def test_from_kast(test_id: str, kast: KInner, expected: CTerm) -> None: # Then assert cterm == expected + + +MERGE_TEST_DATA: Final = ( + (CTerm.top(), CTerm.top(), CTerm.top()), + (CTerm.bottom(), CTerm.top(), None), + (CTerm(k(intToken(1))), CTerm(k(intToken(1))), CTerm(k(intToken(1)))), + ( + CTerm(k(KVariable('X'))), + CTerm(k(KVariable('X'))), + CTerm(k(KVariable('X'))), + ), + ( + CTerm(k(KVariable('X'))), + CTerm(k(KVariable('Y'))), + CTerm( + k(KVariable('K_CELL')), + [ + mlImplies(mlEquals(KVariable('K_CELL'), KVariable('X')), mlTop()), + mlImplies(mlEquals(KVariable('K_CELL'), KVariable('Y')), mlTop()), + ], + ), + ), + # ( + # CTerm(k(intToken(1))), + # CTerm(k(intToken(2))), + # CTerm( + # k(KVariable('K_CELL')), + # [ + # mlImplies(mlEquals(KVariable('TOP_CELL'), intToken(1)), mlTop()), + # mlImplies(mlEquals(KVariable('TOP_CELL'), intToken(2)), mlTop()), + # ], + # ), + # ), + # ( + # CTerm( + # k(KVariable('X')), + # [ge_ml('X', 0)], + # ), + # CTerm( + # k(KVariable('X')), + # [ge_ml('X', 0)], + # ), + # CTerm( + # k(KVariable('K_CELL')), + # [ + # mlImplies(mlEquals(KVariable('TOP_CELL'), KVariable('X')), ge_ml('X', 0)), + # ], + # ), + # ), + # ( + # CTerm( + # k(KVariable('X')), + # [ge_ml('X', 0), lt_ml('X', 3)], + # ), + # CTerm( + # k(KVariable('X')), + # [ge_ml('X', 0), lt_ml('X', 5)], + # ), + # CTerm( + # k(KVariable('K_CELL')), + # [ + # mlImplies(mlEquals(KVariable('TOP_CELL'), KVariable('X')), mlAnd([lt_ml('X', 3), ge_ml('X', 0)])), + # mlImplies( + # mlEquals(KVariable('TOP_CELL'), KVariable('X')), + # mlAnd( + # [ + # lt_ml('X', 5), + # ge_ml('X', 0), + # ] + # ), + # ), + # ], + # ), + # ), + # ( + # CTerm( + # k(KVariable('X')), + # [ge_ml('X', 0), lt_ml('X', 3)], + # ), + # CTerm( + # k(KVariable('Y')), + # [ge_ml('Y', 0), lt_ml('Y', 5)], + # ), + # CTerm( + # k(KVariable('K_CELL')), + # [ + # mlImplies(mlEquals(KVariable('TOP_CELL'), KVariable('X')), mlAnd([lt_ml('X', 3), ge_ml('X', 0)])), + # mlImplies(mlEquals(KVariable('TOP_CELL'), KVariable('Y')), mlAnd([lt_ml('Y', 5), ge_ml('Y', 0)])), + # ], + # ), + # ), +) + + +@pytest.mark.parametrize('t1,t2,expected', MERGE_TEST_DATA, ids=count()) +def test_cterm_merge(t1: CTerm, t2: CTerm, expected: CTerm) -> None: + # When + merged = merge_cterms(t1, t2) + + # Then + assert merged == expected diff --git a/pyk/src/tests/unit/utils.py b/pyk/src/tests/unit/utils.py index 73ed9ff877..d31d2e0e9d 100644 --- a/pyk/src/tests/unit/utils.py +++ b/pyk/src/tests/unit/utils.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING from pyk.kast.inner import KApply, KLabel, KVariable +from pyk.prelude.kint import geInt, intToken, ltInt +from pyk.prelude.ml import mlEqualsTrue if TYPE_CHECKING: from typing import Final @@ -16,3 +18,11 @@ f, g, h = map(KLabel, ('f', 'g', 'h')) k = KLabel('') + + +def lt_ml(var: str, n: int) -> KApply: + return mlEqualsTrue(ltInt(KVariable(var), intToken(n))) + + +def ge_ml(var: str, n: int) -> KApply: + return mlEqualsTrue(geInt(KVariable(var), intToken(n)))