Skip to content

Commit

Permalink
Adhoc handling of custom entries for notations and improvements on pl…
Browse files Browse the repository at this point in the history
…ugins (VernacExtend) (#51)

* handle custom entries and (partially) gen args for plugins

* hotfix listarg

* Add test for the nth locate result in a step context

* Remove handling of ListArg for identref, no longer needed

* Fix equations and add test for it

* Add coq-equations to workflow

* bump ocaml version

* Add coq-released to workflow
  • Loading branch information
pcarrott authored Oct 16, 2024
1 parent 226c0ff commit bfa476a
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 44 deletions.
18 changes: 14 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ jobs:
strategy:
matrix:
ocaml-compiler:
- "4.11"
- "5.2.0"
coq-version:
- "8.17.1"
- "8.18.0"
- "8.19.1"
- "8.19.2"

steps:
- name: Checkout
Expand All @@ -40,8 +40,8 @@ jobs:
/home/runner/work/coqpyt/coqpyt/_opam/
key: ${{ matrix.ocaml-compiler }}-${{ matrix.coq-version }}-opam

- name: Set-up OCaml ${{ matrix.ocaml-compiler }}
uses: ocaml/setup-ocaml@v2
- name: Set-up OCaml
uses: ocaml/setup-ocaml@v3
with:
ocaml-compiler: ${{ matrix.ocaml-compiler }}

Expand All @@ -51,6 +51,16 @@ jobs:
opam pin add coq ${{ matrix.coq-version }}
opam install coq-lsp
- name: Add coq-released
if: steps.cache-opam-restore.outputs.cache-hit != 'true'
run: |
opam repo add coq-released https://coq.inria.fr/opam/released
- name: Install coq-equations
if: steps.cache-opam-restore.outputs.cache-hit != 'true'
run: |
opam install coq-equations
- name: Install coqpyt
run: |
pip install -e .
Expand Down
63 changes: 45 additions & 18 deletions coqpyt/coq/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ def __add_terms(self, step: Step, expr: List):
self.__anonymous_id += 1
self.__add_term(name, step, term_type)
elif term_type == TermType.DERIVE:
name = FileContext.get_ident(expr[2][0])
self.__add_term(name, step, term_type)
if self.__ext_entry(expr[1]) == "Derive":
prop = FileContext.get_ident(expr[2][2])
self.__add_term(prop, step, term_type)
elif term_type == TermType.OBLIGATION:
for arg in expr[2]:
name = FileContext.get_ident(arg)
if name is not None:
self.__add_term(name, step, term_type)
elif term_type in [TermType.OBLIGATION, TermType.EQUATION]:
# FIXME: For Equations, we are unable of getting terms from the AST
# but these commands do generate named terms
self.__last_terms[-1].append(
("", Term(step, term_type, self.__path, self.__segments.modules[:]))
)
Expand Down Expand Up @@ -241,12 +242,15 @@ def __term_type(self, expr: List) -> TermType:
return TermType.FIXPOINT
if expr[0] == "VernacScheme":
return TermType.SCHEME
# FIXME: These are plugins and should probably be handled differently
if self.__is_extend(expr, "Obligations"):
return TermType.OBLIGATION
if self.__is_extend(expr, "VernacDeclareTacticDefinition"):
return TermType.TACTIC
if self.__is_extend(expr, "Function"):
return TermType.FUNCTION
if self.__is_extend(expr, "Define_equations", exact=False):
return TermType.EQUATION
if self.__is_extend(expr, "Derive", exact=False):
return TermType.DERIVE
if self.__is_extend(expr, "AddSetoid", exact=False):
Expand Down Expand Up @@ -293,6 +297,16 @@ def __get_names(expr: List) -> List[str]:
stack.append(v)
return res

@staticmethod
def is_id(el) -> bool:
return isinstance(el, list) and (len(el) == 3 and el[0] == "Ser_Qualid")

@staticmethod
def is_notation(el) -> bool:
return isinstance(el, list) and (
len(el) == 4 and el[0] == "CNotation" and el[2][1] != ""
)

@staticmethod
def get_id(id: List) -> Optional[str]:
# FIXME: This should be made private once [__step_context] is extracted
Expand All @@ -305,18 +319,18 @@ def get_id(id: List) -> Optional[str]:

@staticmethod
def get_ident(el: List) -> Optional[str]:
# FIXME: This should be made private once [__get_program_context] is extracted
# from ProofFile to here.
if (
len(el) == 3
and el[0] == "GenArg"
and el[1][0] == "Rawwit"
and el[1][1][0] == "ExtraArg"
):
if el[1][1][1] == "identref":
return el[2][0][1][1]
elif el[1][1][1] == "ident":
return el[2][1]
# FIXME: This method should be made private once [__get_program_context]
# is extracted from ProofFile to here.
def handle_arg_type(args, ids):
if args[0] == "ExtraArg":
if args[1] == "identref":
return ids[0][1][1]
elif args[1] == "ident":
return ids[1]
return None

if len(el) == 3 and el[0] == "GenArg" and el[1][0] == "Rawwit":
return handle_arg_type(el[1][1], el[2])
return None

@staticmethod
Expand Down Expand Up @@ -546,6 +560,19 @@ def get_term(self, name: str) -> Optional[Term]:
return self.__terms[curr_name][-1]
return None

@staticmethod
def get_notation_scope(notation: str) -> str:
"""Get the scope of a notation.
Args:
notation (str): Possibly scoped notation pattern. E.g. "_ + _ : nat_scope".
Returns:
str: The scope of the notation. E.g. "nat_scope".
"""
if notation.split(":")[-1].endswith("_scope"):
return notation.split(":")[-1].strip()
return ""

def get_notation(self, notation: str, scope: str) -> Term:
"""Get a notation from the context.
Expand Down
46 changes: 24 additions & 22 deletions coqpyt/coq/proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
Args:
file_path (str): Path of the Coq file.
library (Optional[str], optional): The library of the file. Defaults to None.
timeout (int, optional): Timeout used in coq-lsp. Defaults to 2.
timeout (int, optional): Timeout used in coq-lsp. Defaults to 30.
workspace (Optional[str], optional): Absolute path for the workspace.
If the workspace is not defined, the workspace is equal to the
path of the file.
Expand Down Expand Up @@ -372,42 +372,44 @@ def _handle_exception(self, e):
raise e

def __locate(self, search, line):
nots = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line).split(
"\n"
)
fun = lambda x: x.endswith("(default interpretation)")
return nots[0][:-25] if fun(nots[0]) else nots[0]
located = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line)
trim = lambda x: x[:-25] if x.endswith("(default interpretation)") else x
return list(map(trim, located.split("\n")))

def __step_context(self, step: Step) -> List[Term]:
stack, res = self.context.expr(step)[:0:-1], []
while len(stack) > 0:
el = stack.pop()
if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid":
if FileContext.is_id(el):
term = self.context.get_term(FileContext.get_id(el))
if term is not None and term not in res:
res.append(term)
elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation":
elif FileContext.is_notation(el):
stack.append(el[1:])

notation_name = el[2][1]
line = len(self.__aux_file.read().split("\n"))
self.__aux_file.append(f'\nLocate "{el[2][1]}".')
self.__aux_file.append(f'\nLocate "{notation_name}".')
self.__aux_file.didChange()
notations = self.__locate(notation_name, line)
if len(notations) == 1 and notations[0] == "Unknown notation":
continue

notation_name, scope = el[2][1], ""
notation = self.__locate(notation_name, line)
if notation.split(":")[-1].endswith("_scope"):
scope = notation.split(":")[-1].strip()

if notation != "Unknown notation":
for notation in notations:
scope = FileContext.get_notation_scope(notation)
try:
term = self.context.get_notation(notation_name, scope)
if term not in res:
res.append(term)
except NotationNotFoundException as e:
if self.__error_mode == "strict":
raise e
else:
logging.warning(str(e))

stack.append(el[1:])
break
except NotationNotFoundException:
continue
else:
e = NotationNotFoundException(notation_name)
if self.__error_mode == "strict":
raise e
else:
logging.warning(str(e))
elif isinstance(el, list):
for v in reversed(el):
if isinstance(v, (dict, list)):
Expand Down
1 change: 1 addition & 0 deletions coqpyt/coq/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TermType(Enum):
SETOID = 22
FUNCTION = 23
DERIVE = 24
EQUATION = 25
OTHER = 100


Expand Down
22 changes: 22 additions & 0 deletions coqpyt/tests/proof_file/test_proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,28 @@ def test_unknown_notation(self):
assert self.proof_file.context.get_notation("{ _ }", "")


class TestProofNthLocate(SetupProofFile):
def setup_method(self, method):
self.setup("test_nth_locate.v")

def test_nth_locate(self):
"""Checks if it is able to handle notations that are not the first result
returned by the Locate command.
"""
proof_file = self.proof_file
assert len(proof_file.proofs) == 1
proof = proof_file.proofs[0]

theorem = "Lemma test : <> = <>."
assert proof.text == theorem

statement_context = [
('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []),
('Notation "<>" := BAnon : binder_scope.', TermType.NOTATION, []),
]
compare_context(statement_context, proof.context)


class TestProofNestedProofs(SetupProofFile):
def setup_method(self, method):
self.setup("test_nested_proofs.v")
Expand Down
6 changes: 6 additions & 0 deletions coqpyt/tests/resources/test_equations.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
From Equations Require Import Equations.

Equations? f (n : nat) : nat :=
f 0 := 42 ;
f (S m) with f m := { f (S m) IH := _ }.
Proof. intros. exact IH. Defined.
7 changes: 7 additions & 0 deletions coqpyt/tests/resources/test_nth_locate.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Inductive binder := BAnon | BNum :> nat -> binder.
Declare Scope binder_scope.
Notation "<>" := BAnon : binder_scope.

Open Scope binder_scope.
Lemma test : <> = <>.
Proof. reflexivity. Qed.
11 changes: 11 additions & 0 deletions coqpyt/tests/test_coq_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,17 @@ def test_derive(setup, teardown):
)


@pytest.mark.parametrize("setup", ["test_equations.v"], indirect=True)
def test_derive(setup, teardown):
coq_file.run()
assert len(coq_file.context.terms) == 0
assert coq_file.context.last_term is not None
assert (
coq_file.context.last_term.text
== "Equations? f (n : nat) : nat := f 0 := 42 ; f (S m) with f m := { f (S m) IH := _ }."
)


def test_space_in_path():
# This test exists because coq-lsp encodes spaces in paths as %20
# This causes the diagnostics to be saved in a different path than the one
Expand Down

0 comments on commit bfa476a

Please sign in to comment.