Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements/refactors to substitution extraction routine #4631

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 28 additions & 37 deletions pyk/src/pyk/kast/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,48 +215,39 @@ def extract_rhs(term: KInner) -> KInner:


def extract_subst(term: KInner) -> tuple[Subst, KInner]:
def _subst_for_terms(term1: KInner, term2: KInner) -> Subst | None:
if type(term1) is KVariable and type(term2) not in {KToken, KVariable}:
return Subst({term1.name: term2})
if type(term2) is KVariable and type(term1) not in {KToken, KVariable}:
return Subst({term2.name: term1})
return None

def _extract_subst(conjunct: KInner) -> Subst | None:
if type(conjunct) is KApply:
if conjunct.label.name == '#Equals':
subst = _subst_for_terms(conjunct.args[0], conjunct.args[1])

if subst is not None:
return subst

if (
conjunct.args[0] == TRUE
and type(conjunct.args[1]) is KApply
and conjunct.args[1].label.name in {'_==K_', '_==Int_'}
):
subst = _subst_for_terms(conjunct.args[1].args[0], conjunct.args[1].args[1])

if subst is not None:
return subst
_subst = {}
rem_conjuncts: list[KInner] = []

def _extract_subst(_term1: KInner, _term2: KInner) -> tuple[str, KInner] | None:
ehildenb marked this conversation as resolved.
Show resolved Hide resolved
if (
(type(_term1) is KVariable and _term1.name not in _subst)
ehildenb marked this conversation as resolved.
Show resolved Hide resolved
and not (type(_term2) is KVariable and _term2.name in _subst)
and _term1.name not in free_vars(_term2)
):
return (_term1.name, _term2)
if (
(type(_term2) is KVariable and _term2.name not in _subst)
and not (type(_term1) is KVariable and _term1.name in _subst)
and _term2.name not in free_vars(_term1)
):
return (_term2.name, _term1)
if _term1 == TRUE and type(_term2) is KApply and _term2.label.name in {'_==K_', '_==Int_'}:
return _extract_subst(_term2.args[0], _term2.args[1])
if _term2 == TRUE and type(_term1) is KApply and _term1.label.name in {'_==K_', '_==Int_'}:
return _extract_subst(_term1.args[0], _term1.args[1])
return None

conjuncts = flatten_label('#And', term)
subst = Subst()
rem_conjuncts: list[KInner] = []

for conjunct in conjuncts:
new_subst = _extract_subst(conjunct)
if new_subst is None:
rem_conjuncts.append(conjunct)
for conjunct in flatten_label('#And', term):
if type(conjunct) is KApply and conjunct.label.name == '#Equals':
if _conjunct_subst := _extract_subst(conjunct.args[0], conjunct.args[1]):
name, value = _conjunct_subst
_subst[name] = value
else:
rem_conjuncts.append(conjunct)
else:
new_subst = subst.union(new_subst)
if new_subst is None:
raise ValueError('Conflicting substitutions') # TODO handle this case
subst = new_subst
rem_conjuncts.append(conjunct)

return subst, mlAnd(rem_conjuncts)
return Subst(_subst), mlAnd(rem_conjuncts)


def count_vars(term: KInner) -> Counter[str]:
Expand Down
7 changes: 5 additions & 2 deletions pyk/src/tests/unit/kast/test_subst.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def test_ml_pred(test_id: str, subst: Subst, pred: KInner) -> None:
(a, {}, a),
(mlEquals(a, b), {}, mlEquals(a, b)),
(mlEquals(x, a), {'x': a}, mlTop()),
(mlEquals(x, _0), {}, mlEquals(x, _0)),
(mlEquals(x, y), {}, mlEquals(x, y)),
(mlEquals(x, _0), {'x': _0}, mlTop()),
(mlEquals(x, y), {'x': y}, mlTop()),
(mlEquals(x, f(x)), {}, mlEquals(x, f(x))),
(mlAnd([mlEquals(x, y), mlEquals(x, b)]), {'x': y}, mlEquals(x, b)),
(mlAnd([mlEquals(x, b), mlEquals(x, y)]), {'x': b}, mlEquals(x, y)),
(mlAnd([mlEquals(a, b), mlEquals(x, a)]), {'x': a}, mlEquals(a, b)),
(mlEqualsTrue(_EQ(a, b)), {}, mlEqualsTrue(_EQ(a, b))),
(mlEqualsTrue(_EQ(x, a)), {'x': a}, mlTop()),
Expand Down
Loading