Skip to content

Commit

Permalink
Add method to_axiom to class Rule
Browse files Browse the repository at this point in the history
  • Loading branch information
tothtamas28 committed Dec 4, 2024
1 parent b419a9c commit b36da46
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 47 deletions.
202 changes: 155 additions & 47 deletions pyk/src/pyk/kore/rule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import logging
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar, final
from functools import reduce
from typing import TYPE_CHECKING, Generic, TypeVar, cast, final

from .prelude import inj
from .prelude import BOOL, SORT_GENERATED_TOP_CELL, TRUE, inj
from .syntax import (
DV,
And,
Expand All @@ -28,7 +29,7 @@
if TYPE_CHECKING:
from typing import Final

from .syntax import Definition
from .syntax import Definition, Sort

Attrs = dict[str, tuple[Pattern, ...]]

Expand Down Expand Up @@ -68,8 +69,12 @@ class Rule(ABC):
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@abstractmethod
def to_axiom(self) -> Axiom: ...

@staticmethod
def from_axiom(axiom: Axiom) -> Rule:
if isinstance(axiom.pattern, Rewrites):
Expand All @@ -89,22 +94,25 @@ def from_axiom(axiom: Axiom) -> Rule:
raise ValueError(f'Cannot parse simplification rule: {axiom.text}')

@staticmethod
def extract_all(defn: Definition) -> list[Rule]:
def is_rule(axiom: Axiom) -> bool:
if axiom == _INJ_AXIOM:
return False
def is_rule(axiom: Axiom) -> bool:
if axiom == _INJ_AXIOM:
return False

if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS):
return False
if any(attr in axiom.attrs_by_key for attr in _SKIPPED_ATTRS):
return False

return True
return True

return [Rule.from_axiom(axiom) for axiom in defn.axioms if is_rule(axiom)]
@staticmethod
def extract_all(defn: Definition) -> list[Rule]:
return [Rule.from_axiom(axiom) for axiom in defn.axioms if Rule.is_rule(axiom)]


@final
@dataclass(frozen=True)
class RewriteRule(Rule):
sort = SORT_GENERATED_TOP_CELL

lhs: App
rhs: App
req: Pattern | None
Expand All @@ -114,6 +122,19 @@ class RewriteRule(Rule):
uid: str
label: str | None

def to_axiom(self) -> Axiom:
lhs = self.lhs if self.ctx is None else And(self.sort, (self.lhs, self.ctx))
req = _to_ml_pred(self.req, self.sort)
ens = _to_ml_pred(self.ens, self.sort)
return Axiom(
(),
Rewrites(
self.sort,
And(self.sort, (lhs, req)),
And(self.sort, (self.rhs, ens)),
),
)

@staticmethod
def from_axiom(axiom: Axiom) -> RewriteRule:
lhs, rhs, req, ens, ctx = RewriteRule._extract(axiom)
Expand Down Expand Up @@ -166,60 +187,125 @@ class FunctionRule(Rule):
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
arg_sorts: tuple[Sort, ...]
anti_left: Pattern | None
priority: int

def to_axiom(self) -> Axiom:
R = SortVar('R') # noqa N806

def arg_list(rest: Pattern, arg_pair: tuple[EVar, Pattern]) -> Pattern:
var, arg = arg_pair
return And(R, (In(var.sort, R, var, arg), rest))

vars = tuple(EVar(f'X{i}', sort) for i, sort in enumerate(self.arg_sorts))

# \and{R}(\in{S1, R}(X1 : S1, Arg1), \and{R}(\in{S2, R}(X2 : S2, Arg2), \top{R}())) etc.
_args = reduce(
arg_list,
reversed(tuple(zip(vars, self.lhs.args, strict=True))),
cast('Pattern', Top(R)),
)

_req = _to_ml_pred(self.req, R)
req = And(R, (_req, _args))
if self.anti_left:
req = And(R, (Not(R, self.anti_left), req))

app = self.lhs.let(args=vars)
ens = _to_ml_pred(self.ens, self.sort)

return Axiom(
(R,),
Implies(
R,
req,
Equals(self.sort, R, app, And(self.sort, (self.rhs, ens))),
),
)

@staticmethod
def from_axiom(axiom: Axiom) -> FunctionRule:
lhs, rhs, req, ens = FunctionRule._extract(axiom)
anti_left: Pattern | None = None
match axiom.pattern:
case Implies(
left=And(ops=(Not(pattern=anti_left), And(ops=(_req, _args)))),
right=Equals(op_sort=sort, left=App() as app, right=_rhs),
):
pass
case Implies(
left=And(ops=(_req, _args)),
right=Equals(op_sort=sort, left=App() as app, right=_rhs),
):
pass
case _:
raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}')

arg_sorts, args = FunctionRule._extract_args(_args)
lhs = app.let(args=args)
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)

priority = _extract_priority(axiom)
return FunctionRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
arg_sorts=arg_sorts,
anti_left=anti_left,
priority=priority,
)

@staticmethod
def _extract(axiom: Axiom) -> tuple[App, Pattern, Pattern | None, Pattern | None]:
match axiom.pattern:
case Implies(
left=And(
ops=(Not(), And(ops=(_req, _args))) | (_req, _args),
),
right=Equals(left=App() as app, right=_rhs),
):
args = FunctionRule._extract_args(_args)
lhs = app.let(args=args)
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
return lhs, rhs, req, ens
case _:
raise ValueError(f'Cannot extract function rule from axiom: {axiom.text}')

@staticmethod
def _extract_args(pattern: Pattern) -> tuple[Pattern, ...]:
def _extract_args(pattern: Pattern) -> tuple[tuple[Sort, ...], tuple[Pattern, ...]]:
match pattern:
case Top():
return ()
case And(ops=(In(left=EVar(), right=arg), rest)):
return (arg,) + FunctionRule._extract_args(rest)
return (), ()
case And(ops=(In(left=EVar(sort=sort), right=arg), rest)):
sorts, args = FunctionRule._extract_args(rest)
return (sort,) + sorts, (arg,) + args
case _:
raise ValueError(f'Cannot extract argument list from pattern: {pattern.text}')


class SimpliRule(Rule, Generic[P], ABC):
lhs: P
sort: Sort

def to_axiom(self) -> Axiom:
R = SortVar('R') # noqa N806

vars = (R, self.sort) if isinstance(self.sort, SortVar) else (R,)
req = _to_ml_pred(self.req, R)
ens = _to_ml_pred(self.ens, self.sort)

return Axiom(
vars,
Implies(
R,
req,
Equals(self.sort, R, self.lhs, And(self.sort, (self.rhs, ens))),
),
attrs=(
App(
'simplification',
args=() if self.priority == 50 else (String(str(self.priority)),),
),
),
)

@staticmethod
def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None]:
def _extract(axiom: Axiom, lhs_type: type[P]) -> tuple[P, Pattern, Pattern | None, Pattern | None, Sort]:
match axiom.pattern:
case Implies(left=_req, right=Equals(left=lhs, right=_rhs)):
case Implies(left=_req, right=Equals(op_sort=sort, left=lhs, right=_rhs)):
req = _extract_condition(_req)
rhs, ens = _extract_rhs(_rhs)
if not isinstance(lhs, lhs_type):
raise ValueError(f'Invalid LHS type from simplification axiom: {axiom.text}')
return lhs, rhs, req, ens
return lhs, rhs, req, ens, sort
case _:
raise ValueError(f'Cannot extract simplification rule from axiom: {axiom.text}')

Expand All @@ -231,63 +317,67 @@ class AppRule(SimpliRule[App]):
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> AppRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, App)
priority = _extract_priority(axiom)
lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, App)
priority = _extract_simpl_priority(axiom)
return AppRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
priority=priority,
)


@final
@dataclass(frozen=True)
class CeilRule(SimpliRule):
class CeilRule(SimpliRule[Ceil]):
lhs: Ceil
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> CeilRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, Ceil)
priority = _extract_priority(axiom)
lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, Ceil)
priority = _extract_simpl_priority(axiom)
return CeilRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
priority=priority,
)


@final
@dataclass(frozen=True)
class EqualsRule(SimpliRule):
class EqualsRule(SimpliRule[Equals]):
lhs: Equals
rhs: Pattern
req: Pattern | None
ens: Pattern | None
sort: Sort
priority: int

@staticmethod
def from_axiom(axiom: Axiom) -> EqualsRule:
lhs, rhs, req, ens = SimpliRule._extract(axiom, Equals)
if not isinstance(lhs, Equals):
raise ValueError(f'Cannot extract LHS as Equals from axiom: {axiom.text}')
priority = _extract_priority(axiom)
lhs, rhs, req, ens, sort = SimpliRule._extract(axiom, Equals)
priority = _extract_simpl_priority(axiom)
return EqualsRule(
lhs=lhs,
rhs=rhs,
req=req,
ens=ens,
sort=sort,
priority=priority,
)

Expand Down Expand Up @@ -340,3 +430,21 @@ def _extract_priority(axiom: Axiom) -> int:
return 200 if 'owise' in attrs else 50
case _:
raise ValueError(f'Cannot extract priority from axiom: {axiom.text}')


def _extract_simpl_priority(axiom: Axiom) -> int:
attrs = axiom.attrs_by_key
match attrs['simplification']:
case App(args=() | (String(''),)):
return 50
case App(args=(String(p),)):
return int(p)
case _:
raise ValueError(f'Cannot extract simplification priority from axiom: {axiom.text}')


def _to_ml_pred(pattern: Pattern | None, sort: Sort) -> Pattern:
if pattern is None:
return Top(sort)

return Equals(BOOL, sort, pattern, TRUE)
16 changes: 16 additions & 0 deletions pyk/src/tests/integration/kore/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,19 @@ def test_extract_all(definition: Definition) -> None:
assert cnt['AppRule']
assert cnt['CeilRule']
assert cnt['EqualsRule']


def test_to_axiom(definition: Definition) -> None:
for axiom in definition.axioms:
if not Rule.is_rule(axiom):
continue

# Given
expected = axiom.let(attrs=tuple(attr for attr in axiom.attrs if attr.symbol == 'simplification'))

# When
rule = Rule.from_axiom(axiom)
actual = rule.to_axiom()

# Then
assert expected == actual

0 comments on commit b36da46

Please sign in to comment.