From 9463524d678d171251279400bf8ebe7c6a89353f Mon Sep 17 00:00:00 2001 From: pcarrott Date: Tue, 5 Sep 2023 17:36:07 +0100 Subject: [PATCH 01/27] Handle Goal keyword --- coqlspclient/coq_file.py | 9 +++++++++ tests/resources/test_goal.v | 21 +++++++++++++++++++++ tests/test_proof_state.py | 17 +++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 tests/resources/test_goal.v diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 039d46c..9114ef5 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -72,6 +72,7 @@ def __init__( self.curr_section: List[str] = [] self.__segment_stack: List[SegmentType] = [] self.context = FileContext() + self.__anonymous_id: Optional[int] = None def __enter__(self): return self @@ -367,6 +368,14 @@ def traverse_ast(el, inductive=False): self.__add_term( name, self.ast[self.steps_taken], text, TermType.NOTATION ) + elif expr[0] == "VernacDefinition" and expr[2][0]["v"][0] == "Anonymous": + # We associate the anonymous term to the same name used internally by Coq + if self.__anonymous_id is None: + name, self.__anonymous_id = "Unnamed_thm", 0 + else: + name = f"Unnamed_thm{self.__anonymous_id}" + self.__anonymous_id += 1 + self.__add_term(name, self.ast[self.steps_taken], text, term_type) else: names = traverse_ast(expr) for name in names: diff --git a/tests/resources/test_goal.v b/tests/resources/test_goal.v new file mode 100644 index 0000000..9f3ade8 --- /dev/null +++ b/tests/resources/test_goal.v @@ -0,0 +1,21 @@ +(* http://d.hatena.ne.jp/hzkr/20100902 *) + +Definition ignored : forall P Q: Prop, (P -> Q) -> P -> Q. +Proof. + intros. (* 全部剥がす *) + apply H. + exact H0. (* apply H0 でも同じ *) +Save opaque. + + +Goal forall P Q: Prop, (P -> Q) -> P -> Q. +Proof. + intros p q f. (* 名前の数だけ剥がす *) + assumption. (* exact f でも同じ *) +Qed. + + +Goal forall P Q: Prop, (P -> Q) -> P -> Q. +Proof. + exact (fun p q f x => f x). +Defined transparent. \ No newline at end of file diff --git a/tests/test_proof_state.py b/tests/test_proof_state.py index a7104ca..c737fba 100644 --- a/tests/test_proof_state.py +++ b/tests/test_proof_state.py @@ -574,3 +574,20 @@ def test_type_class(setup, teardown): ("Inductive unit : Set := tt : unit.", TermType.INDUCTIVE, []), ] compare_context(context, state.proofs[1].context) + + +@pytest.mark.parametrize("setup", [("test_goal.v", None)], indirect=True) +def test_goal(setup, teardown): + assert len(state.proofs) == 3 + for proof in state.proofs: + assert proof.text == "Goal forall P Q: Prop, (P -> Q) -> P -> Q." + compare_context( + [ + ( + 'Notation "A -> B" := (forall (_ : A), B) : type_scope.', + TermType.NOTATION, + [], + ) + ], + proof.context, + ) From 40ab1af8eae2a976b343b0edf71287309e838708 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Tue, 5 Sep 2023 17:46:53 +0100 Subject: [PATCH 02/27] Minor fix: index check --- coqlspclient/coq_file.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 9114ef5..df2386c 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -206,6 +206,7 @@ def __get_term_type(expr: List) -> TermType: elif ( len(expr) > 1 and isinstance(expr[1], list) + and len(expr[1]) > 0 and expr[1][0] == "VernacDeclareTacticDefinition" ): return TermType.TACTIC @@ -347,9 +348,9 @@ def traverse_ast(el, inductive=False): elif len(self.curr_module_type) > 0: return elif ( - len(expr) >= 2 + len(expr) > 1 and isinstance(expr[1], list) - and len(expr[1]) == 2 + and len(expr[1]) > 0 and expr[1][0] == "VernacDeclareTacticDefinition" ): name = self.__get_tactic_name(expr) From 9e2620e41387730380baa5d15797b8fc83e311a5 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Tue, 5 Sep 2023 22:14:52 +0100 Subject: [PATCH 03/27] Fix scope of notations to capture list notations --- coqlspclient/coq_structs.py | 9 +++++++- tests/resources/test_list_notation.v | 5 +++++ tests/test_proof_state.py | 31 +++++++++++++++++++++------- 3 files changed, 37 insertions(+), 8 deletions(-) create mode 100644 tests/resources/test_list_notation.v diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 069c8e0..697ce6a 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -112,7 +112,9 @@ def get_notation(self, notation: str, scope: str) -> Term: regex = f"{re.escape(notation_id)}".split("\\ ") for i, sub in enumerate(regex): if sub == "_": - regex[i] = "(.+)" + # We match the wildcard with the description from here: + # https://coq.inria.fr/distrib/current/refman/language/core/basic.html#grammar-token-ident + regex[i] = "([a-zA-Z][a-zA-Z0-9_']*|_[a-zA-Z0-9_']+)" else: # Handle '_' regex[i] = f"({sub}|('{sub}'))" @@ -122,6 +124,11 @@ def get_notation(self, notation: str, scope: str) -> Term: for term in self.terms.keys(): if re.match(regex, term): return self.terms[term] + # We search again in case the stored id contains the scope but no scope is provided + for term in self.terms.keys(): + unscoped = term.split(":")[0].strip() + if re.match(regex, unscoped): + return self.terms[term] # Search Infix if re.match("^_ ([^ ]*) _$", notation): diff --git a/tests/resources/test_list_notation.v b/tests/resources/test_list_notation.v new file mode 100644 index 0000000..5730ba4 --- /dev/null +++ b/tests/resources/test_list_notation.v @@ -0,0 +1,5 @@ +Require Import List. +Import List.ListNotations. + +Goal [1] ++ [2] = [1; 2]. +Proof. reflexivity. Qed. \ No newline at end of file diff --git a/tests/test_proof_state.py b/tests/test_proof_state.py index c737fba..46721d2 100644 --- a/tests/test_proof_state.py +++ b/tests/test_proof_state.py @@ -94,7 +94,7 @@ def test_get_proofs(setup, teardown): ] statement_context = [ ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), ] @@ -182,7 +182,7 @@ def test_get_proofs(setup, teardown): TermType.NOTATION, [], ), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), ('Notation "n * m" := (mul n m) : nat_scope', TermType.NOTATION, []), ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), @@ -254,7 +254,7 @@ def test_get_proofs(setup, teardown): ] statement_context = [ ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), ] @@ -345,7 +345,7 @@ def test_get_proofs(setup, teardown): TermType.NOTATION, [], ), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n * m" := (mul n m) : nat_scope', TermType.NOTATION, []), ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), @@ -404,6 +404,18 @@ def test_exists_notation(setup, teardown): ) +@pytest.mark.parametrize("setup", [("test_list_notation.v", None)], indirect=True) +def test_list_notation(setup, teardown): + assert len(state.proofs) == 1 + context = [ + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), + ('Infix "++" := app (right associativity, at level 60) : list_scope.', TermType.NOTATION, []), + ('Notation "[ x ]" := (cons x nil) : list_scope.', TermType.NOTATION, ['ListNotations']), + ('Notation "[ x ; y ; .. ; z ]" := (cons x (cons y .. (cons z nil) ..)) (format "[ \'[\' x ; \'/\' y ; \'/\' .. ; \'/\' z \']\' ]") : list_scope.', TermType.NOTATION, ['ListNotations']), + ] + compare_context(context, state.proofs[0].context) + + @pytest.mark.parametrize("setup", [("test_unknown_notation.v", None)], indirect=True) def test_unknown_notation(setup, teardown): """Checks if it is able to handle the notation { _ } that is unknown for the @@ -494,7 +506,7 @@ def test_obligation(setup, teardown): TermType.NOTATION, [], ), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ] programs = [ ("id1", "S (pred n)"), @@ -579,8 +591,13 @@ def test_type_class(setup, teardown): @pytest.mark.parametrize("setup", [("test_goal.v", None)], indirect=True) def test_goal(setup, teardown): assert len(state.proofs) == 3 - for proof in state.proofs: - assert proof.text == "Goal forall P Q: Prop, (P -> Q) -> P -> Q." + goals = [ + "Definition ignored : forall P Q: Prop, (P -> Q) -> P -> Q.", + "Goal forall P Q: Prop, (P -> Q) -> P -> Q.", + "Goal forall P Q: Prop, (P -> Q) -> P -> Q.", + ] + for i, proof in enumerate(state.proofs): + assert proof.text == goals[i] compare_context( [ ( From af7db68befa4e7d4e446fca599ce1f8ba7fd2014 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Wed, 6 Sep 2023 16:18:44 +0100 Subject: [PATCH 04/27] Minor fix removing whitespace from notations --- coqlspclient/coq_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index df2386c..a36996d 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -356,7 +356,7 @@ def traverse_ast(el, inductive=False): name = self.__get_tactic_name(expr) self.__add_term(name, self.ast[self.steps_taken], text, TermType.TACTIC) elif expr[0] == "VernacNotation": - name = text.split('"')[1] + name = text.split('"')[1].strip() if text[:-1].split(":")[-1].endswith("_scope"): name += " : " + text[:-1].split(":")[-1].strip() self.__add_term( From 0457326ce3aae88a1f2a2b0f33fd99bf59c7c438 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Wed, 6 Sep 2023 19:00:12 +0100 Subject: [PATCH 05/27] Fix locate libraries --- coqlspclient/proof_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index 08930ba..77d675c 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -162,7 +162,7 @@ def get_context(file_path: str, timeout: int): for i, library in enumerate(libraries): v_file = aux_file.get_diagnostics( "Locate Library", library, last_line + i + 1 - ).split("\n")[-1][:-1] + ).split()[-1][:-1] coq_file = CoqFile(v_file, library=library, timeout=timeout) coq_file.run() From 23146487c7a68db2bfa898dea39635fb7dfde3aa Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 7 Sep 2023 01:53:57 +0100 Subject: [PATCH 06/27] Avoid RecursionError by making traverse_ast iterative --- coqlspclient/coq_file.py | 52 ++++++++++--------- coqlspclient/proof_state.py | 99 ++++++++++++++++++++----------------- tests/test_proof_state.py | 18 +++++-- 3 files changed, 96 insertions(+), 73 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index a36996d..caf378b 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -8,6 +8,7 @@ from coqlspclient.coq_structs import Step, FileContext, Term, TermType, SegmentType from coqlspclient.coq_lsp_client import CoqLspClient from typing import List, Optional +from collections import deque class CoqFile(object): @@ -276,30 +277,33 @@ def __get_tactic_name(self, expr): return None def __process_step(self, sign): - def traverse_ast(el, inductive=False): - if isinstance(el, dict): - if "v" in el and isinstance(el["v"], list) and len(el["v"]) == 2: - if el["v"][0] == "Id": - return [el["v"][1]] - if el["v"][0] == "Name": - return [el["v"][1][1]] - - return [x for v in el.values() for x in traverse_ast(v, inductive)] - elif isinstance(el, list): - if len(el) > 0: - if el[0] == "CLocalAssum": - return [] - if el[0] == "VernacInductive": - inductive = True - - res = [] - for v in el: - res.extend(traverse_ast(v, inductive)) - if not inductive and len(res) > 0: - return [res[0]] - return res - - return [] + def traverse_ast(ast): + inductive = True if ast[0] == "VernacInductive" else False + stack, res = deque(ast[1:]), [] + while len(stack) > 0: + el = stack.popleft() + if isinstance(el, dict): + if "v" in el and isinstance(el["v"], list) and len(el["v"]) == 2: + if el["v"][0] == "Id": + if not inductive: + return [el["v"][1]] + res.append(el["v"][1]) + elif el["v"][0] == "Name": + if not inductive: + return [el["v"][1][1]] + res.append(el["v"][1][1]) + + for v in reversed(el.values()): + if isinstance(v, (dict, list)): + stack.appendleft(v) + elif isinstance(el, list): + if len(el) > 0 and el[0] == "CLocalAssum": + continue + + for v in reversed(el): + if isinstance(v, (dict, list)): + stack.appendleft(v) + return res try: # TODO: A negative sign should handle things differently. For example: diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index 77d675c..6b56be5 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -27,6 +27,7 @@ from coqlspclient.coq_file import CoqFile from coqlspclient.coq_lsp_client import CoqLspClient from typing import List, Dict, Optional, Tuple +from collections import deque class _AuxFile(object): @@ -232,34 +233,42 @@ def __locate(self, search, line): return nots[0][:-25] if fun(nots[0]) else nots[0] def __step_context(self, step=None): - def traverse_ast(el): - if isinstance(el, dict): - return [x for v in el.values() for x in traverse_ast(v)] - elif isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": - id = ".".join([l[1] for l in el[1][1][::-1]] + [el[2][1]]) - term = self.__get_term(id) - return [] if term is None else [(lambda x: x, term)] - elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation": - line = len(self.__aux_file.read().split("\n")) - self.__aux_file.append(f'\nLocate "{el[2][1]}".') - - def __search_notation(call): - notation_name = call[0] - scope = "" - notation = call[1](*call[2:]) - if notation == "Unknown notation": - return None - if notation.split(":")[-1].endswith("_scope"): - scope = notation.split(":")[-1].strip() - return self.context.get_notation(notation_name, scope) - - return [ - (__search_notation, (el[2][1], self.__locate, el[2][1], line)) - ] + traverse_ast(el[1:]) - elif isinstance(el, list): - return [x for v in el for x in traverse_ast(v)] - - return [] + def traverse_ast(ast): + stack, res = deque([ast]), [] + while len(stack) > 0: + el = stack.popleft() + if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": + id = ".".join([l[1] for l in el[1][1][::-1]] + [el[2][1]]) + term = self.__get_term(id) + if term is not None: + res.append((lambda x: x, term)) + elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation": + line = len(self.__aux_file.read().split("\n")) + self.__aux_file.append(f'\nLocate "{el[2][1]}".') + + def __search_notation(call): + notation_name = call[0] + scope = "" + notation = call[1](*call[2:]) + if notation == "Unknown notation": + return None + if notation.split(":")[-1].endswith("_scope"): + scope = notation.split(":")[-1].strip() + return self.context.get_notation(notation_name, scope) + + res.append( + (__search_notation, (el[2][1], self.__locate, el[2][1], line)) + ) + stack.appendleft(el[1:]) + elif isinstance(el, list): + for v in reversed(el): + if isinstance(v, (dict, list)): + stack.appendleft(v) + elif isinstance(el, dict): + for v in reversed(el.values()): + if isinstance(v, (dict, list)): + stack.appendleft(v) + return res if step is None: step = self.__current_step.ast @@ -279,24 +288,22 @@ def __get_last_term(self): return last_term def __get_program_context(self): - def traverse_ast(el, keep_id=False): - if ( - isinstance(el, list) - and len(el) == 2 - and ( - (el[0] == "Id" and keep_id) - or (el[0] == "ExtraArg" and el[1] == "identref") - ) - ): - return el[1] - elif isinstance(el, list): - for x in el: - id = traverse_ast(x, keep_id=keep_id) - if id == "identref": - keep_id = True - elif id is not None: - return id - return "identref" if keep_id else None + def traverse_ast(ast): + stack = deque(ast) + while len(stack) > 0: + el = stack.popleft() + if ( + isinstance(el, list) + and len(el) == 3 + and el[0] == "GenArg" + and el[1][0] == "Rawwit" + and el[1][1][1] == "identref" + ): + return el[2][0][1][1] + elif isinstance(el, list): + for v in reversed(el): + if isinstance(v, list): + stack.appendleft(v) return None # Tags: diff --git a/tests/test_proof_state.py b/tests/test_proof_state.py index 46721d2..318c289 100644 --- a/tests/test_proof_state.py +++ b/tests/test_proof_state.py @@ -409,9 +409,21 @@ def test_list_notation(setup, teardown): assert len(state.proofs) == 1 context = [ ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), - ('Infix "++" := app (right associativity, at level 60) : list_scope.', TermType.NOTATION, []), - ('Notation "[ x ]" := (cons x nil) : list_scope.', TermType.NOTATION, ['ListNotations']), - ('Notation "[ x ; y ; .. ; z ]" := (cons x (cons y .. (cons z nil) ..)) (format "[ \'[\' x ; \'/\' y ; \'/\' .. ; \'/\' z \']\' ]") : list_scope.', TermType.NOTATION, ['ListNotations']), + ( + 'Infix "++" := app (right associativity, at level 60) : list_scope.', + TermType.NOTATION, + [], + ), + ( + 'Notation "[ x ]" := (cons x nil) : list_scope.', + TermType.NOTATION, + ["ListNotations"], + ), + ( + "Notation \"[ x ; y ; .. ; z ]\" := (cons x (cons y .. (cons z nil) ..)) (format \"[ '[' x ; '/' y ; '/' .. ; '/' z ']' ]\") : list_scope.", + TermType.NOTATION, + ["ListNotations"], + ), ] compare_context(context, state.proofs[0].context) From f39d7c134ec77b8d69a5e392eed4f78d8cb8f86b Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 7 Sep 2023 18:19:38 +0100 Subject: [PATCH 07/27] Rename traverse_ast as traverse_expr and replace deque with reversed list --- coqlspclient/coq_file.py | 15 +++++++-------- coqlspclient/proof_state.py | 29 +++++++++++++---------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index caf378b..94b5e52 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -8,7 +8,6 @@ from coqlspclient.coq_structs import Step, FileContext, Term, TermType, SegmentType from coqlspclient.coq_lsp_client import CoqLspClient from typing import List, Optional -from collections import deque class CoqFile(object): @@ -277,11 +276,11 @@ def __get_tactic_name(self, expr): return None def __process_step(self, sign): - def traverse_ast(ast): - inductive = True if ast[0] == "VernacInductive" else False - stack, res = deque(ast[1:]), [] + def traverse_expr(expr): + inductive = True if expr[0] == "VernacInductive" else False + stack, res = expr[:0:-1], [] while len(stack) > 0: - el = stack.popleft() + el = stack.pop() if isinstance(el, dict): if "v" in el and isinstance(el["v"], list) and len(el["v"]) == 2: if el["v"][0] == "Id": @@ -295,14 +294,14 @@ def traverse_ast(ast): for v in reversed(el.values()): if isinstance(v, (dict, list)): - stack.appendleft(v) + stack.append(v) elif isinstance(el, list): if len(el) > 0 and el[0] == "CLocalAssum": continue for v in reversed(el): if isinstance(v, (dict, list)): - stack.appendleft(v) + stack.append(v) return res try: @@ -382,7 +381,7 @@ def traverse_ast(ast): self.__anonymous_id += 1 self.__add_term(name, self.ast[self.steps_taken], text, term_type) else: - names = traverse_ast(expr) + names = traverse_expr(expr) for name in names: self.__add_term(name, self.ast[self.steps_taken], text, term_type) diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index 6b56be5..a59e68b 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -27,7 +27,6 @@ from coqlspclient.coq_file import CoqFile from coqlspclient.coq_lsp_client import CoqLspClient from typing import List, Dict, Optional, Tuple -from collections import deque class _AuxFile(object): @@ -232,11 +231,11 @@ def __locate(self, search, line): fun = lambda x: x.endswith("(default interpretation)") return nots[0][:-25] if fun(nots[0]) else nots[0] - def __step_context(self, step=None): - def traverse_ast(ast): - stack, res = deque([ast]), [] + def __step_context(self): + def traverse_expr(expr): + stack, res = expr[:0:-1], [] while len(stack) > 0: - el = stack.popleft() + el = stack.pop() if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": id = ".".join([l[1] for l in el[1][1][::-1]] + [el[2][1]]) term = self.__get_term(id) @@ -259,20 +258,18 @@ def __search_notation(call): res.append( (__search_notation, (el[2][1], self.__locate, el[2][1], line)) ) - stack.appendleft(el[1:]) + stack.append(el[1:]) elif isinstance(el, list): for v in reversed(el): if isinstance(v, (dict, list)): - stack.appendleft(v) + stack.append(v) elif isinstance(el, dict): for v in reversed(el.values()): if isinstance(v, (dict, list)): - stack.appendleft(v) + stack.append(v) return res - if step is None: - step = self.__current_step.ast - return traverse_ast(step.span) + return traverse_expr(CoqFile.expr(self.__current_step.ast)) def __get_last_term(self): terms = self.coq_file.terms @@ -288,10 +285,10 @@ def __get_last_term(self): return last_term def __get_program_context(self): - def traverse_ast(ast): - stack = deque(ast) + def traverse_expr(expr): + stack = expr[:0:-1] while len(stack) > 0: - el = stack.popleft() + el = stack.pop() if ( isinstance(el, list) and len(el) == 3 @@ -303,7 +300,7 @@ def traverse_ast(ast): elif isinstance(el, list): for v in reversed(el): if isinstance(v, list): - stack.appendleft(v) + stack.append(v) return None # Tags: @@ -315,7 +312,7 @@ def traverse_ast(ast): # 5 - Next Obligation tag = CoqFile.expr(self.__current_step.ast)[1][1] if tag in [0, 1, 4]: - id = traverse_ast(CoqFile.expr(self.__current_step.ast)) + id = traverse_expr(CoqFile.expr(self.__current_step.ast)) # This works because the obligation must be in the # same module as the program id = ".".join(self.coq_file.curr_module + [id]) From 1f23228e9fbc8245d94798fc75d1fb1b1885d18b Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 7 Sep 2023 22:05:12 +0100 Subject: [PATCH 08/27] Fix bullets and where notations --- coqlspclient/coq_file.py | 25 +++++++++++++------------ tests/resources/test_bullets.v | 2 +- tests/resources/test_where_notation.v | 6 +++++- tests/test_coq_file.py | 2 ++ tests/test_proof_state.py | 8 ++++---- 5 files changed, 25 insertions(+), 18 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 94b5e52..9a60a19 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -156,7 +156,7 @@ def __get_text(self, range: Range, trim: bool = False): start_character = 0 if prev_step is None else prev_step.range.end.character lines = self.__lines[start_line : end_line + 1] - lines[-1] = lines[-1][: end_character + 1] + lines[-1] = lines[-1][:end_character] lines[0] = lines[0][start_character:] text = "\n".join(lines) return " ".join(text.split()) if trim else text @@ -246,18 +246,19 @@ def __handle_where_notations(self, expr: List, term_type: TermType): # handles when multiple notations are defined for span in spans: - range = Range( - Position( - span["decl_ntn_string"]["loc"]["line_nb"] - 1, - span["decl_ntn_string"]["loc"]["bp"] - - span["decl_ntn_string"]["loc"]["bol_pos"], - ), - Position( - span["decl_ntn_interp"]["loc"]["line_nb_last"] - 1, - span["decl_ntn_interp"]["loc"]["ep"] - - span["decl_ntn_interp"]["loc"]["bol_pos"], - ), + start = Position( + span["decl_ntn_string"]["loc"]["line_nb"] - 1, + span["decl_ntn_string"]["loc"]["bp"] + - span["decl_ntn_string"]["loc"]["bol_pos"], ) + end = Position( + span["decl_ntn_interp"]["loc"]["line_nb_last"] - 1, + span["decl_ntn_interp"]["loc"]["ep"] + - span["decl_ntn_interp"]["loc"]["bol_pos"], + ) + if self.__lines[end.line][end.character] == ")": + end.character += 1 + range = Range(start, end) text = self.__get_text(range, trim=True) name = FileContext.get_notation_key( span["decl_ntn_string"]["v"], span["decl_ntn_scope"] diff --git a/tests/resources/test_bullets.v b/tests/resources/test_bullets.v index 1681031..929d850 100644 --- a/tests/resources/test_bullets.v +++ b/tests/resources/test_bullets.v @@ -1,6 +1,6 @@ Theorem bullets: forall x y: nat, x = x /\ y = y. Proof. intros x y. split. - - reflexivity. + -reflexivity. - reflexivity. Qed. diff --git a/tests/resources/test_where_notation.v b/tests/resources/test_where_notation.v index 141ea9a..0862366 100644 --- a/tests/resources/test_where_notation.v +++ b/tests/resources/test_where_notation.v @@ -1,4 +1,5 @@ Reserved Notation "A & B" (at level 80). +Reserved Notation "'ONE'" (at level 80). Fixpoint plus_test (n m : nat) {struct n} : nat := match n with @@ -8,4 +9,7 @@ end where "n + m" := (plus n m) : test_scope and "n - m" := (minus n m). Inductive and' (A B : Prop) : Prop := conj' : A -> B -> A & B -where "A & B" := (and' A B). \ No newline at end of file +where "A & B" := (and' A B). + +Fixpoint incr (n : nat) : nat := n + ONE +where "'ONE'" := 1. \ No newline at end of file diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index bd9b943..07ff2d7 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -36,6 +36,8 @@ def test_where_notation(setup, teardown): assert coq_file.context.terms["n - m"].text == 'Notation "n - m" := (minus n m)' assert "A & B" in coq_file.context.terms assert coq_file.context.terms["A & B"].text == 'Notation "A & B" := (and\' A B)' + assert "'ONE'" in coq_file.context.terms + assert coq_file.context.terms["'ONE'"].text == "Notation \"'ONE'\" := 1" @pytest.mark.parametrize("setup", ["test_get_notation.v"], indirect=True) diff --git a/tests/test_proof_state.py b/tests/test_proof_state.py index 318c289..edc8c42 100644 --- a/tests/test_proof_state.py +++ b/tests/test_proof_state.py @@ -491,11 +491,11 @@ def test_bullets(setup, teardown): proofs = state.proofs assert len(proofs) == 1 steps = [ - "\n intros x y. ", + "\n intros x y.", " split.", - "\n - ", - " reflexivity.", - "\n - ", + "\n -", + "reflexivity.", + "\n -", " reflexivity.", ] assert len(proofs[0].steps) == 6 From 0949e471ad111161bfe32588549b1bc610e607a7 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 8 Sep 2023 20:08:28 +0100 Subject: [PATCH 09/27] Handle Unicode characters in where notations --- coqlspclient/coq_file.py | 34 +++++++++++++++++---------- tests/resources/test_where_notation.v | 6 ++++- tests/test_coq_file.py | 2 ++ 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 9a60a19..24a3dd4 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -55,12 +55,9 @@ def __init__( self.__lines = f.read().split("\n") try: - self.coq_lsp_client.didOpen( - TextDocumentItem(uri, "coq", 1, "\n".join(self.__lines)) - ) - self.ast = self.coq_lsp_client.get_document( - TextDocumentIdentifier(uri) - ).spans + text, text_id = "\n".join(self.__lines), TextDocumentIdentifier(uri) + self.coq_lsp_client.didOpen(TextDocumentItem(uri, "coq", 1, text)) + self.ast = self.coq_lsp_client.get_document(text_id).spans except Exception as e: self.__handle_exception(e) raise e @@ -141,7 +138,20 @@ def __step_expr(self): curr_step = self.ast[self.steps_taken] return CoqFile.expr(curr_step) - def __get_text(self, range: Range, trim: bool = False): + def __get_text(self, range: Range, trim: bool = False, encode: bool = False): + def slice_line( + line: str, start: Optional[int] = None, stop: Optional[int] = None + ): + # The encode flag indicates if range.character is measured in bytes, + # rather than characters. If true, the string must be encoded before + # indexing. This special treatment is necessary for handling Unicode + # characters which take up more than 1 byte. + return ( + line.encode("utf-8")[start:stop].decode() + if encode + else line[start:stop] + ) + end_line = range.end.line end_character = range.end.character @@ -156,8 +166,8 @@ def __get_text(self, range: Range, trim: bool = False): start_character = 0 if prev_step is None else prev_step.range.end.character lines = self.__lines[start_line : end_line + 1] - lines[-1] = lines[-1][:end_character] - lines[0] = lines[0][start_character:] + lines[-1] = slice_line(lines[-1], stop=end_character) + lines[0] = slice_line(lines[0], start=start_character) text = "\n".join(lines) return " ".join(text.split()) if trim else text @@ -256,10 +266,10 @@ def __handle_where_notations(self, expr: List, term_type: TermType): span["decl_ntn_interp"]["loc"]["ep"] - span["decl_ntn_interp"]["loc"]["bol_pos"], ) - if self.__lines[end.line][end.character] == ")": + if chr(self.__lines[end.line].encode("utf-8")[end.character]) == ")": end.character += 1 range = Range(start, end) - text = self.__get_text(range, trim=True) + text = self.__get_text(range, trim=True, encode=True) name = FileContext.get_notation_key( span["decl_ntn_string"]["v"], span["decl_ntn_scope"] ) @@ -278,7 +288,7 @@ def __get_tactic_name(self, expr): def __process_step(self, sign): def traverse_expr(expr): - inductive = True if expr[0] == "VernacInductive" else False + inductive = expr[0] == "VernacInductive" stack, res = expr[:0:-1], [] while len(stack) > 0: el = stack.pop() diff --git a/tests/resources/test_where_notation.v b/tests/resources/test_where_notation.v index 0862366..6238dcd 100644 --- a/tests/resources/test_where_notation.v +++ b/tests/resources/test_where_notation.v @@ -1,5 +1,6 @@ Reserved Notation "A & B" (at level 80). Reserved Notation "'ONE'" (at level 80). +Reserved Notation "x 🀄 y" (at level 80). Fixpoint plus_test (n m : nat) {struct n} : nat := match n with @@ -12,4 +13,7 @@ Inductive and' (A B : Prop) : Prop := conj' : A -> B -> A & B where "A & B" := (and' A B). Fixpoint incr (n : nat) : nat := n + ONE -where "'ONE'" := 1. \ No newline at end of file +where "'ONE'" := 1. + +Fixpoint unicode x y := x 🀄 y +where "x 🀄 y" := (plus_test x y). \ No newline at end of file diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index 07ff2d7..0b3baec 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -38,6 +38,8 @@ def test_where_notation(setup, teardown): assert coq_file.context.terms["A & B"].text == 'Notation "A & B" := (and\' A B)' assert "'ONE'" in coq_file.context.terms assert coq_file.context.terms["'ONE'"].text == "Notation \"'ONE'\" := 1" + assert "x 🀄 y" in coq_file.context.terms + assert coq_file.context.terms["x 🀄 y"].text == 'Notation "x 🀄 y" := (plus_test x y)' @pytest.mark.parametrize("setup", ["test_get_notation.v"], indirect=True) From 417cc546ee28f4aef164799bb32e2ee3f9d5062f Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 8 Sep 2023 23:57:46 +0100 Subject: [PATCH 10/27] Add more characters in regex wildcard for notations --- coqlspclient/coq_structs.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 697ce6a..9e351ce 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -114,7 +114,10 @@ def get_notation(self, notation: str, scope: str) -> Term: if sub == "_": # We match the wildcard with the description from here: # https://coq.inria.fr/distrib/current/refman/language/core/basic.html#grammar-token-ident - regex[i] = "([a-zA-Z][a-zA-Z0-9_']*|_[a-zA-Z0-9_']+)" + # Coq accepts more characters, but no one should need more than these... + # chars = "A-Za-zÀ-ÖØ-öø-ˁˆ-ˑˠ-ˤˬˮͰ-ʹͶͷͺ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-Ֆՙա-և" + chars = "A-Za-z" + regex[i] = f"([{chars}][{chars}0-9_']*|_[{chars}0-9_']+)" else: # Handle '_' regex[i] = f"({sub}|('{sub}'))" From 25c02ec9d4357af5c9c1c59aa0e75b0de6facbdf Mon Sep 17 00:00:00 2001 From: pcarrott Date: Tue, 12 Sep 2023 19:30:16 +0100 Subject: [PATCH 11/27] Add keywords for Setoids and Relations --- coqlspclient/coq_file.py | 34 +++++++++++++++++++++++++++------- coqlspclient/coq_structs.py | 2 ++ coqlspclient/proof_state.py | 14 +++++--------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 24a3dd4..6aede99 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -121,6 +121,18 @@ def get_id(id: List) -> str: return id[1] return "" + @staticmethod + def get_identref(el: List) -> Optional[str]: + if ( + len(el) == 3 + and el[0] == "GenArg" + and el[1][0] == "Rawwit" + and el[1][1][0] == "ExtraArg" + and el[1][1][1] == "identref" + ): + return el[2][0][1][1] + return None + @staticmethod def expr(step: RangedSpan) -> Optional[List]: if ( @@ -213,11 +225,14 @@ def __get_term_type(expr: List) -> TermType: return TermType.FIXPOINT elif expr[0] == "VernacScheme": return TermType.SCHEME + elif expr[0] == "VernacExtend" and expr[1][0].startswith("AddSetoid"): + return TermType.SETOID + elif expr[0] == "VernacExtend" and expr[1][0].startswith( + ("AddRelation", "AddParametricRelation") + ): + return TermType.RELATION elif ( - len(expr) > 1 - and isinstance(expr[1], list) - and len(expr[1]) > 0 - and expr[1][0] == "VernacDeclareTacticDefinition" + expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition" ): return TermType.TACTIC else: @@ -289,6 +304,9 @@ def __get_tactic_name(self, expr): def __process_step(self, sign): def traverse_expr(expr): inductive = expr[0] == "VernacInductive" + add_cmd = expr[0] == "VernacExtend" and expr[1][0].startswith( + ("AddSetoid", "AddRelation", "AddParametricRelation") + ) stack, res = expr[:0:-1], [] while len(stack) > 0: el = stack.pop() @@ -310,6 +328,10 @@ def traverse_expr(expr): if len(el) > 0 and el[0] == "CLocalAssum": continue + identref = CoqFile.get_identref(el) + if identref is not None and add_cmd: + return [identref] + for v in reversed(el): if isinstance(v, (dict, list)): stack.append(v) @@ -362,9 +384,7 @@ def traverse_expr(expr): elif len(self.curr_module_type) > 0: return elif ( - len(expr) > 1 - and isinstance(expr[1], list) - and len(expr[1]) > 0 + expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition" ): name = self.__get_tactic_name(expr) diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 9e351ce..4286dac 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -35,6 +35,8 @@ class TermType(Enum): PROPOSITION = 16 PROPERTY = 17 OBLIGATION = 18 + RELATION = 19 + SETOID = 20 OTHER = 100 diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index a59e68b..a3cd67c 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -289,15 +289,11 @@ def traverse_expr(expr): stack = expr[:0:-1] while len(stack) > 0: el = stack.pop() - if ( - isinstance(el, list) - and len(el) == 3 - and el[0] == "GenArg" - and el[1][0] == "Rawwit" - and el[1][1][1] == "identref" - ): - return el[2][0][1][1] - elif isinstance(el, list): + if isinstance(el, list): + identref = CoqFile.get_identref(el) + if identref is not None: + return identref + for v in reversed(el): if isinstance(v, list): stack.append(v) From ceac087a5c491c2a0b083a270ef15e54b8755a15 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Tue, 12 Sep 2023 22:53:54 +0100 Subject: [PATCH 12/27] Minor refactor: get_id and reversed --- coqlspclient/coq_file.py | 16 ++++------------ coqlspclient/coq_structs.py | 3 +-- coqlspclient/proof_state.py | 3 +-- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 6aede99..19dc1d6 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -116,10 +116,10 @@ def __validate(self): @staticmethod def get_id(id: List) -> str: if id[0] == "Ser_Qualid": - return ".".join([l[1] for l in id[1][1][::-1]] + [id[2][1]]) + return ".".join([l[1] for l in reversed(id[1][1])] + [id[2][1]]) elif id[0] == "Id": return id[1] - return "" + return None @staticmethod def get_identref(el: List) -> Optional[str]: @@ -197,7 +197,7 @@ def __add_term(self, name: str, ast: RangedSpan, text: str, term_type: TermType) self.context.update(terms={full_name(name): term}) curr_file_module = "" - for module in self.file_module[::-1]: + for module in reversed(self.file_module): curr_file_module = module + "." + curr_file_module self.context.update(terms={curr_file_module + name: term}) @@ -293,14 +293,6 @@ def __handle_where_notations(self, expr: List, term_type: TermType): text = "Notation " + text self.__add_term(name, RangedSpan(range, span), text, TermType.NOTATION) - def __get_tactic_name(self, expr): - if len(expr[2][0][2][0][1][0]) == 2 and expr[2][0][2][0][1][0][0] == "v": - name = CoqFile.get_id(expr[2][0][2][0][1][0][1]) - if name != "": - return name - - return None - def __process_step(self, sign): def traverse_expr(expr): inductive = expr[0] == "VernacInductive" @@ -387,7 +379,7 @@ def traverse_expr(expr): expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition" ): - name = self.__get_tactic_name(expr) + name = CoqFile.get_id(expr[2][0][2][0][1][0][1]) self.__add_term(name, self.ast[self.steps_taken], text, TermType.TACTIC) elif expr[0] == "VernacNotation": name = text.split('"')[1].strip() diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 4286dac..b0fb706 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -117,8 +117,7 @@ def get_notation(self, notation: str, scope: str) -> Term: # We match the wildcard with the description from here: # https://coq.inria.fr/distrib/current/refman/language/core/basic.html#grammar-token-ident # Coq accepts more characters, but no one should need more than these... - # chars = "A-Za-zÀ-ÖØ-öø-ˁˆ-ˑˠ-ˤˬˮͰ-ʹͶͷͺ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-Ֆՙա-և" - chars = "A-Za-z" + chars = "A-Za-zÀ-ÖØ-öø-ˁˆ-ˑˠ-ˤˬˮͰ-ʹͶͷͺ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-Ֆՙա-և" regex[i] = f"([{chars}][{chars}0-9_']*|_[{chars}0-9_']+)" else: # Handle '_' diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index a3cd67c..85f69d8 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -237,8 +237,7 @@ def traverse_expr(expr): while len(stack) > 0: el = stack.pop() if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": - id = ".".join([l[1] for l in el[1][1][::-1]] + [el[2][1]]) - term = self.__get_term(id) + term = self.__get_term(CoqFile.get_id(el)) if term is not None: res.append((lambda x: x, term)) elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation": From 73521f44dfcce899631ea92c6a15759c719fa029 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Wed, 13 Sep 2023 17:26:07 +0100 Subject: [PATCH 13/27] Add keywords Function, CoFixpoint and CoInductive --- coqlspclient/coq_file.py | 53 +++++++++++++++++++++++-------------- coqlspclient/coq_structs.py | 33 ++++++++++++----------- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 19dc1d6..00d95ee 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -133,6 +133,14 @@ def get_identref(el: List) -> Optional[str]: return el[2][0][1][1] return None + @staticmethod + def get_v(el): + if isinstance(el, dict) and "v" in el: + return el["v"] + elif isinstance(el, list) and len(el) == 2 and el[0] == "v": + return el[1] + return None + @staticmethod def expr(step: RangedSpan) -> Optional[List]: if ( @@ -217,10 +225,14 @@ def __get_term_type(expr: List) -> TermType: return TermType.RECORD elif expr[0] == "VernacInductive" and expr[1][0] == "Variant": return TermType.VARIANT + elif expr[0] == "VernacInductive" and expr[1][0] == "CoInductive": + return TermType.COINDUCTIVE elif expr[0] == "VernacInductive": return TermType.INDUCTIVE elif expr[0] == "VernacInstance": return TermType.INSTANCE + elif expr[0] == "VernacCoFixpoint": + return TermType.COFIXPOINT elif expr[0] == "VernacFixpoint": return TermType.FIXPOINT elif expr[0] == "VernacScheme": @@ -235,6 +247,8 @@ def __get_term_type(expr: List) -> TermType: expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition" ): return TermType.TACTIC + elif expr[0] == "VernacExtend" and expr[1][0] == "Function": + return TermType.FUNCTION else: return TermType.OTHER @@ -296,23 +310,23 @@ def __handle_where_notations(self, expr: List, term_type: TermType): def __process_step(self, sign): def traverse_expr(expr): inductive = expr[0] == "VernacInductive" - add_cmd = expr[0] == "VernacExtend" and expr[1][0].startswith( - ("AddSetoid", "AddRelation", "AddParametricRelation") - ) + extend = expr[0] == "VernacExtend" stack, res = expr[:0:-1], [] while len(stack) > 0: el = stack.pop() - if isinstance(el, dict): - if "v" in el and isinstance(el["v"], list) and len(el["v"]) == 2: - if el["v"][0] == "Id": - if not inductive: - return [el["v"][1]] - res.append(el["v"][1]) - elif el["v"][0] == "Name": - if not inductive: - return [el["v"][1][1]] - res.append(el["v"][1][1]) + v = CoqFile.get_v(el) + if v is not None and isinstance(v, list) and len(v) == 2: + id = CoqFile.get_id(v) + if id is not None: + if not inductive: + return [id] + res.append(id) + elif v[0] == "Name": + if not inductive: + return [v[1][1]] + res.append(v[1][1]) + if isinstance(el, dict): for v in reversed(el.values()): if isinstance(v, (dict, list)): stack.append(v) @@ -321,7 +335,7 @@ def traverse_expr(expr): continue identref = CoqFile.get_identref(el) - if identref is not None and add_cmd: + if identref is not None and extend: return [identref] for v in reversed(el): @@ -349,6 +363,8 @@ def traverse_expr(expr): expr = self.__step_expr() if expr == [None]: return + if expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve": + return term_type = CoqFile.__get_term_type(expr) if expr[0] == "VernacEndSegment": @@ -375,12 +391,9 @@ def traverse_expr(expr): # and should be overriden. elif len(self.curr_module_type) > 0: return - elif ( - expr[0] == "VernacExtend" - and expr[1][0] == "VernacDeclareTacticDefinition" - ): - name = CoqFile.get_id(expr[2][0][2][0][1][0][1]) - self.__add_term(name, self.ast[self.steps_taken], text, TermType.TACTIC) + elif expr[0] == "VernacExtend" and expr[1][0] == "VernacTacticNotation": + # FIXME: Handle this case + return elif expr[0] == "VernacNotation": name = text.split('"')[1].strip() if text[:-1].split(":")[-1].endswith("_scope"): diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index b0fb706..48996be 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -22,21 +22,24 @@ class TermType(Enum): DEFINITION = 3 NOTATION = 4 INDUCTIVE = 5 - RECORD = 6 - CLASS = 7 - INSTANCE = 8 - FIXPOINT = 9 - TACTIC = 10 - SCHEME = 11 - VARIANT = 12 - FACT = 13 - REMARK = 14 - COROLLARY = 15 - PROPOSITION = 16 - PROPERTY = 17 - OBLIGATION = 18 - RELATION = 19 - SETOID = 20 + COINDUCTIVE = 6 + RECORD = 7 + CLASS = 8 + INSTANCE = 9 + FIXPOINT = 10 + COFIXPOINT = 11 + SCHEME = 12 + VARIANT = 13 + FACT = 14 + REMARK = 15 + COROLLARY = 16 + PROPOSITION = 17 + PROPERTY = 18 + OBLIGATION = 19 + TACTIC = 20 + RELATION = 21 + SETOID = 22 + FUNCTION = 23 OTHER = 100 From 63bcad2524a13769e2decd1e152eebbb494b648c Mon Sep 17 00:00:00 2001 From: pcarrott Date: Wed, 13 Sep 2023 22:36:18 +0100 Subject: [PATCH 14/27] Add keyword Derive --- coqlspclient/coq_file.py | 22 ++++++++++++++++------ coqlspclient/coq_structs.py | 1 + coqlspclient/proof_state.py | 6 +++--- tests/resources/test_derive.v | 14 ++++++++++++++ tests/test_coq_file.py | 16 ++++++++++++++++ 5 files changed, 50 insertions(+), 9 deletions(-) create mode 100644 tests/resources/test_derive.v diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 00d95ee..0ff6d22 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -122,15 +122,17 @@ def get_id(id: List) -> str: return None @staticmethod - def get_identref(el: List) -> Optional[str]: + def get_ident(el: List) -> Optional[str]: if ( len(el) == 3 and el[0] == "GenArg" and el[1][0] == "Rawwit" and el[1][1][0] == "ExtraArg" - and el[1][1][1] == "identref" ): - return el[2][0][1][1] + if el[1][1][1] == "identref": + return el[2][0][1][1] + elif el[1][1][1] == "ident": + return el[2][1] return None @staticmethod @@ -237,6 +239,8 @@ def __get_term_type(expr: List) -> TermType: return TermType.FIXPOINT elif expr[0] == "VernacScheme": return TermType.SCHEME + elif expr[0] == "VernacExtend" and expr[1][0].startswith("Derive"): + return TermType.DERIVE elif expr[0] == "VernacExtend" and expr[1][0].startswith("AddSetoid"): return TermType.SETOID elif expr[0] == "VernacExtend" and expr[1][0].startswith( @@ -334,9 +338,9 @@ def traverse_expr(expr): if len(el) > 0 and el[0] == "CLocalAssum": continue - identref = CoqFile.get_identref(el) - if identref is not None and extend: - return [identref] + ident = CoqFile.get_ident(el) + if ident is not None and extend: + return [ident] for v in reversed(el): if isinstance(v, (dict, list)): @@ -416,6 +420,12 @@ def traverse_expr(expr): name = f"Unnamed_thm{self.__anonymous_id}" self.__anonymous_id += 1 self.__add_term(name, self.ast[self.steps_taken], text, term_type) + elif term_type == TermType.DERIVE: + name = CoqFile.get_ident(expr[2][0]) + self.__add_term(name, self.ast[self.steps_taken], text, term_type) + if expr[1][0] == "Derive": + prop = CoqFile.get_ident(expr[2][2]) + self.__add_term(prop, self.ast[self.steps_taken], text, term_type) else: names = traverse_expr(expr) for name in names: diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 48996be..63a60e0 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -40,6 +40,7 @@ class TermType(Enum): RELATION = 21 SETOID = 22 FUNCTION = 23 + DERIVE = 24 OTHER = 100 diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index 85f69d8..b32a7b4 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -289,9 +289,9 @@ def traverse_expr(expr): while len(stack) > 0: el = stack.pop() if isinstance(el, list): - identref = CoqFile.get_identref(el) - if identref is not None: - return identref + ident = CoqFile.get_ident(el) + if ident is not None: + return ident for v in reversed(el): if isinstance(v, list): diff --git a/tests/resources/test_derive.v b/tests/resources/test_derive.v new file mode 100644 index 0000000..30e4a48 --- /dev/null +++ b/tests/resources/test_derive.v @@ -0,0 +1,14 @@ +Require Import Coq.derive.Derive. + +Derive incr +SuchThat (forall n, incr n = plus 1 n) +As incr_correct. +Proof. intros n. simpl. subst incr. reflexivity. Qed. + +Inductive Le : nat -> nat -> Set := +| LeO : forall n:nat, Le 0 n +| LeS : forall n m:nat, Le n m -> Le (S n) (S m). +Derive Inversion leminv1 with (forall n m:nat, Le (S n) m) Sort Prop. +Derive Inversion_clear leminv2 with (forall n m:nat, Le (S n) m) Sort Prop. +Derive Dependent Inversion leminv3 with (forall n m:nat, Le (S n) m) Sort Prop. +Derive Dependent Inversion_clear leminv4 with (forall n m:nat, Le (S n) m) Sort Prop. \ No newline at end of file diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index 0b3baec..9c21171 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -75,3 +75,19 @@ def test_module_type(setup, teardown): # We ignore terms inside a Module Type since they can't be used outside # and should be overriden. assert len(coq_file.context.terms) == 1 + + +@pytest.mark.parametrize("setup", ["test_derive.v"], indirect=True) +def test_derive(setup, teardown): + coq_file.run() + for key in ["incr", "incr_correct"]: + assert key in coq_file.context.terms + assert ( + coq_file.context.terms[key].text + == "Derive incr SuchThat (forall n, incr n = plus 1 n) As incr_correct." + ) + keywords = ["Inversion", "Inversion_clear", "Dependent Inversion", "Dependent Inversion_clear"] + for i in range(4): + key = f"leminv{i + 1}" + assert key in coq_file.context.terms + assert coq_file.context.terms[key].text == f"Derive {keywords[i]} {key} with (forall n m:nat, Le (S n) m) Sort Prop." From 08c8212966e61914f511dc2bee8bdf9b5f7c4c79 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 14 Sep 2023 00:29:51 +0100 Subject: [PATCH 15/27] Increase default timeout, refactor exceptions (new error codes for ServerTimeout and NotationNotFound) --- coqlspclient/coq_file.py | 7 +++++-- coqlspclient/coq_lsp_client.py | 4 ++-- coqlspclient/coq_lsp_structs.py | 10 ---------- coqlspclient/coq_structs.py | 28 +++++++++++++++++++++------- coqlspclient/proof_state.py | 23 +++++++++++------------ pylspclient/lsp_structs.py | 1 + tests/test_coq_file.py | 12 ++++++++++-- tests/test_proof_state.py | 5 +++-- 8 files changed, 53 insertions(+), 37 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 0ff6d22..7eaf930 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -31,7 +31,7 @@ def __init__( self, file_path: str, library: Optional[str] = None, - timeout: int = 2, + timeout: int = 30, workspace: Optional[str] = None, ): """Creates a CoqFile. @@ -95,7 +95,10 @@ def __init_path(self, file_path, library): self.__path = new_path def __handle_exception(self, e): - if not (isinstance(e, ResponseError) and e.code == ErrorCodes.ServerQuit.value): + if not isinstance(e, ResponseError) or e.code not in [ + ErrorCodes.ServerQuit.value, + ErrorCodes.ServerTimeout.value, + ]: self.coq_lsp_client.shutdown() self.coq_lsp_client.exit() if self.__from_lib: diff --git a/coqlspclient/coq_lsp_client.py b/coqlspclient/coq_lsp_client.py index cd8ebef..306b399 100644 --- a/coqlspclient/coq_lsp_client.py +++ b/coqlspclient/coq_lsp_client.py @@ -25,7 +25,7 @@ class CoqLspClient(LspClient): def __init__( self, root_uri: str, - timeout: int = 2, + timeout: int = 30, memory_limit: int = 2097152, init_options: Dict = __DEFAULT_INIT_OPTIONS, ): @@ -107,7 +107,7 @@ def __wait_for_operation(self): if timeout <= 0: self.shutdown() self.exit() - raise ResponseError(ErrorCodes.ServerQuit, "Server quit") + raise ResponseError(ErrorCodes.ServerTimeout, "Server timeout") def didOpen(self, textDocument: TextDocumentItem): """Open a text document in the server. diff --git a/coqlspclient/coq_lsp_structs.py b/coqlspclient/coq_lsp_structs.py index f457cd1..c81103a 100644 --- a/coqlspclient/coq_lsp_structs.py +++ b/coqlspclient/coq_lsp_structs.py @@ -214,13 +214,3 @@ def parse(coqFileProgressParams: Dict) -> Optional["CoqFileProgressParams"]: ) ) return CoqFileProgressParams(textDocument, processing) - - -class CoqErrorCodes(Enum): - InvalidFile = 0 - - -class CoqError(Exception): - def __init__(self, code: CoqErrorCodes, message: str): - self.code = code - self.message = message diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 63a60e0..eec3d71 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -4,12 +4,6 @@ from coqlspclient.coq_lsp_structs import RangedSpan, GoalAnswer -class Step(object): - def __init__(self, text: str, ast: RangedSpan): - self.text = text - self.ast = ast - - class SegmentType(Enum): MODULE = 1 MODULE_TYPE = 2 @@ -44,6 +38,23 @@ class TermType(Enum): OTHER = 100 +class CoqErrorCodes(Enum): + InvalidFile = 0 + NotationNotFound = 1 + + +class CoqError(Exception): + def __init__(self, code: CoqErrorCodes, message: str): + self.code = code + self.message = message + + +class Step(object): + def __init__(self, text: str, ast: RangedSpan): + self.text = text + self.ast = ast + + class Term: def __init__( self, @@ -145,7 +156,10 @@ def get_notation(self, notation: str, scope: str) -> Term: if key in self.terms: return self.terms[key] - raise RuntimeError(f"Notation not found in context: {notation_id}") + raise CoqError( + CoqErrorCodes.NotationNotFound, + f"Notation not found in context: {notation_id}", + ) @staticmethod def get_notation_key(notation: str, scope: str): diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index b32a7b4..1736e93 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -10,27 +10,23 @@ ErrorCodes, ) from coqlspclient.coq_structs import ( - ProofStep, - FileContext, + TermType, + CoqErrorCodes, + CoqError, Step, Term, + ProofStep, ProofTerm, - TermType, -) -from coqlspclient.coq_lsp_structs import ( - CoqError, - CoqErrorCodes, - Result, - Query, - GoalAnswer, + FileContext, ) +from coqlspclient.coq_lsp_structs import Result, Query, GoalAnswer from coqlspclient.coq_file import CoqFile from coqlspclient.coq_lsp_client import CoqLspClient from typing import List, Dict, Optional, Tuple class _AuxFile(object): - def __init__(self, file_path: Optional[str] = None, timeout: int = 2): + def __init__(self, file_path: Optional[str] = None, timeout: int = 30): self.__init_path(file_path) uri = f"file://{self.path}" self.coq_lsp_client = CoqLspClient(uri, timeout=timeout) @@ -69,7 +65,10 @@ def append(self, text): f.write(text) def __handle_exception(self, e): - if not (isinstance(e, ResponseError) and e.code == ErrorCodes.ServerQuit.value): + if not isinstance(e, ResponseError) or e.code not in [ + ErrorCodes.ServerQuit.value, + ErrorCodes.ServerTimeout.value, + ]: self.coq_lsp_client.shutdown() self.coq_lsp_client.exit() os.remove(self.path) diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index d6ae4ea..52d0f86 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -609,6 +609,7 @@ class ErrorCodes(enum.Enum): InternalError = -32603 serverErrorStart = -32099 serverErrorEnd = -32000 + ServerTimeout = -32004 ServerQuit = -32003 ServerNotInitialized = -32002 UnknownErrorCode = -32001 diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index 9c21171..8ed8d69 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -86,8 +86,16 @@ def test_derive(setup, teardown): coq_file.context.terms[key].text == "Derive incr SuchThat (forall n, incr n = plus 1 n) As incr_correct." ) - keywords = ["Inversion", "Inversion_clear", "Dependent Inversion", "Dependent Inversion_clear"] + keywords = [ + "Inversion", + "Inversion_clear", + "Dependent Inversion", + "Dependent Inversion_clear", + ] for i in range(4): key = f"leminv{i + 1}" assert key in coq_file.context.terms - assert coq_file.context.terms[key].text == f"Derive {keywords[i]} {key} with (forall n m:nat, Le (S n) m) Sort Prop." + assert ( + coq_file.context.terms[key].text + == f"Derive {keywords[i]} {key} with (forall n m:nat, Le (S n) m) Sort Prop." + ) diff --git a/tests/test_proof_state.py b/tests/test_proof_state.py index edc8c42..c964a38 100644 --- a/tests/test_proof_state.py +++ b/tests/test_proof_state.py @@ -3,7 +3,7 @@ import pytest from typing import List, Tuple from coqlspclient.coq_lsp_structs import * -from coqlspclient.coq_structs import TermType, Term +from coqlspclient.coq_structs import TermType, Term, CoqError, CoqErrorCodes from coqlspclient.proof_state import ProofState, CoqFile versionId: VersionedTextDocumentIdentifier = None @@ -433,8 +433,9 @@ def test_unknown_notation(setup, teardown): """Checks if it is able to handle the notation { _ } that is unknown for the Locate command because it is a default notation. """ - with pytest.raises(RuntimeError): + with pytest.raises(CoqError) as e_info: assert state.context.get_notation("{ _ }", "") + assert e_info.value.code == CoqErrorCodes.NotationNotFound @pytest.mark.parametrize("setup", [("test_nested_proofs.v", None)], indirect=True) From 2d84bf3d267761d5dac692036a7a7d71bddc8a7d Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 14 Sep 2023 01:24:12 +0100 Subject: [PATCH 16/27] Minor fix --- coqlspclient/coq_file.py | 2 +- coqlspclient/coq_structs.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 7eaf930..7ce3b2b 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -333,7 +333,7 @@ def traverse_expr(expr): return [v[1][1]] res.append(v[1][1]) - if isinstance(el, dict): + elif isinstance(el, dict): for v in reversed(el.values()): if isinstance(v, (dict, list)): stack.append(v) diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index eec3d71..5cc8ec4 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -140,14 +140,15 @@ def get_notation(self, notation: str, scope: str) -> Term: regex = "^" + "\\ ".join(regex) + "$" # Search notations + unscoped = None for term in self.terms.keys(): if re.match(regex, term): return self.terms[term] - # We search again in case the stored id contains the scope but no scope is provided - for term in self.terms.keys(): - unscoped = term.split(":")[0].strip() - if re.match(regex, unscoped): - return self.terms[term] + if re.match(regex, term.split(":")[0].strip()): + unscoped = term + # In case the stored id contains the scope but no scope is provided + if unscoped is not None: + return self.terms[unscoped] # Search Infix if re.match("^_ ([^ ]*) _$", notation): From 0acc1042f74743c3a875dfb1b4ad94cf48c2ca47 Mon Sep 17 00:00:00 2001 From: Nfsaavedra Date: Fri, 15 Sep 2023 16:18:06 +0100 Subject: [PATCH 17/27] improve running files with errors but our program became an error --- coqlspclient/coq_file.py | 29 +++++++++++++++++++++++++-- coqlspclient/proof_state.py | 13 +++++-------- pylspclient/lsp_endpoint.py | 3 ++- pylspclient/lsp_structs.py | 39 +++++++++++++------------------------ tests/test_coq_file.py | 1 + 5 files changed, 48 insertions(+), 37 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 7ce3b2b..6d2ea05 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -3,7 +3,7 @@ import uuid import tempfile from pylspclient.lsp_structs import TextDocumentItem, TextDocumentIdentifier -from pylspclient.lsp_structs import ResponseError, ErrorCodes +from pylspclient.lsp_structs import ResponseError, ErrorCodes, Diagnostic from coqlspclient.coq_lsp_structs import Position, GoalAnswer, RangedSpan, Range from coqlspclient.coq_structs import Step, FileContext, Term, TermType, SegmentType from coqlspclient.coq_lsp_client import CoqLspClient @@ -62,6 +62,7 @@ def __init__( self.__handle_exception(e) raise e + self.__first_error: Optional[Diagnostic] = None self.__validate() self.steps_taken: int = 0 self.curr_module: List[str] = [] @@ -110,9 +111,13 @@ def __validate(self): self.is_valid = True return + self.coq_lsp_client.lsp_endpoint.diagnostics[uri].sort( + key=lambda d: d.range.start.line + ) for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: if diagnostic.severity == 1: self.is_valid = False + self.__first_error = diagnostic return self.is_valid = True @@ -489,6 +494,16 @@ def terms(self) -> List[Term]: ) ) + @property + def diagnostics(self) -> List[Diagnostic]: + """ + Returns: + List[Diagnostic]: The diagnostics of the file. + Includes all messages given by Coq. + """ + uri = f"file://{self.__path}" + return self.coq_lsp_client.lsp_endpoint.diagnostics[uri] + @staticmethod def get_term_type(ast: RangedSpan) -> TermType: expr = CoqFile.expr(ast) @@ -512,7 +527,17 @@ def exec(self, nsteps=1) -> List[Step]: len(self.ast) - self.steps_taken if sign > 0 else self.steps_taken, ) for _ in range(nsteps): - steps.append(Step(self.__step_text(), self.ast[self.steps_taken])) + ast = self.ast[self.steps_taken] + if not self.is_valid and ( + (ast.range.end.line > self.__first_error.range.start.line) + or ( + ast.range.end.line == self.__first_error.range.start.line + and ast.range.end.character + >= self.__first_error.range.start.character + ) + ): + raise RuntimeError(self.__first_error.message) + steps.append(Step(self.__step_text(), ast)) self.__process_step(sign) return steps diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index 1736e93..a265463 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -102,18 +102,15 @@ def __get_queries(self, keyword): lines = self.read().split("\n") for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: command = lines[ - diagnostic.range["start"]["line"] : diagnostic.range["end"]["line"] + 1 + diagnostic.range.start.line : diagnostic.range.end.line + 1 ] if len(command) == 1: command[0] = command[0][ - diagnostic.range["start"]["character"] : diagnostic.range["end"][ - "character" - ] - + 1 + diagnostic.range.start.character : diagnostic.range.end.character + 1 ] else: - command[0] = command[0][diagnostic.range["start"]["character"] :] - command[-1] = command[-1][: diagnostic.range["end"]["character"] + 1] + command[0] = command[0][diagnostic.range.start.character :] + command[-1] = command[-1][: diagnostic.range.end.character + 1] command = "".join(command).strip() if command.startswith(keyword): @@ -132,7 +129,7 @@ def get_diagnostics(self, keyword, search, line): for query in self.__get_queries(keyword): if query.query == f"{search}": for result in query.results: - if result.range["start"]["line"] == line: + if result.range.start.line == line: return result.message break return None diff --git a/pylspclient/lsp_endpoint.py b/pylspclient/lsp_endpoint.py index 120018c..19dab99 100644 --- a/pylspclient/lsp_endpoint.py +++ b/pylspclient/lsp_endpoint.py @@ -2,6 +2,7 @@ import threading import logging from pylspclient import lsp_structs +from typing import Dict, List class LspEndpoint(threading.Thread): @@ -17,7 +18,7 @@ def __init__( self.next_id = 0 self.timeout = timeout self.shutdown_flag = False - self.diagnostics = {} + self.diagnostics: Dict[str, List[lsp_structs.Diagnostic]] = {} def handle_result(self, rpc_id, result, error): self.response_dict[rpc_id] = (result, error) diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index 52d0f86..7e39ed4 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -82,7 +82,18 @@ def __init__( class Diagnostic(object): - def __init__(self, range, severity, code, source, message, relatedInformation): + def __init__( + self, + range, + message, + severity=None, + code=None, + codeDescription=None, + source=None, + tags=None, + relatedInformation=None, + data=None, + ): """ Constructs a new Diagnostic instance. :param Range range: The range at which the message applies.Resource file. @@ -95,7 +106,7 @@ def __init__(self, range, severity, code, source, message, relatedInformation): :param list relatedInformation: An array of related diagnostic information, e.g. when symbol-names within a scope collide all definitions can be marked via this property. """ - self.range = range + self.range: Range = Range(**range) self.severity = severity self.code = code self.source = source @@ -532,30 +543,6 @@ def __init__( self.score = score -class Diagnostic(object): - def __init__( - self, - range, - message, - severity=None, - code=None, - codeDescription=None, - source=None, - tags=None, - relatedInformation=None, - data=None, - ): - self.range = range - self.message = message - self.severity = severity - self.code = code - self.codeDescription = codeDescription - self.source = source - self.tags = tags - self.relatedInformation = relatedInformation - self.data = data - - class CompletionItemKind(enum.Enum): Text = 1 Method = 2 diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index 8ed8d69..ee41432 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -67,6 +67,7 @@ def test_is_invalid_1(setup, teardown): @pytest.mark.parametrize("setup", ["test_invalid_2.v"], indirect=True) def test_is_invalid_2(setup, teardown): assert not coq_file.is_valid + coq_file.run() @pytest.mark.parametrize("setup", ["test_module_type.v"], indirect=True) From afd67bd4b186e1731ac278d9433362d28723d9f4 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 15 Sep 2023 17:35:47 +0100 Subject: [PATCH 18/27] Only copy files from Coq.Init and not all files from Coq --- coqlspclient/coq_file.py | 2 +- tests/test_coq_file.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 6d2ea05..a54e716 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -80,7 +80,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __init_path(self, file_path, library): self.file_module = [] if library is None else library.split(".") - self.__from_lib = self.file_module[:1] == ["Coq"] + self.__from_lib = self.file_module[:2] == ["Coq", "Init"] self.path = file_path if not self.__from_lib: self.__path = file_path diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index ee41432..aa6340a 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -62,12 +62,15 @@ def test_get_notation(setup, teardown): @pytest.mark.parametrize("setup", ["test_invalid_1.v"], indirect=True) def test_is_invalid_1(setup, teardown): assert not coq_file.is_valid + with pytest.raises(RuntimeError): + coq_file.run() @pytest.mark.parametrize("setup", ["test_invalid_2.v"], indirect=True) def test_is_invalid_2(setup, teardown): assert not coq_file.is_valid - coq_file.run() + with pytest.raises(RuntimeError): + coq_file.run() @pytest.mark.parametrize("setup", ["test_module_type.v"], indirect=True) From 6ccf3196606ebea23938b3568a75eee46558d679 Mon Sep 17 00:00:00 2001 From: Nfsaavedra Date: Thu, 21 Sep 2023 16:53:34 +0100 Subject: [PATCH 19/27] CoqFile refactor. Add diagnostics to steps --- coqlspclient/coq_file.py | 51 ++++++++++++++++++++----------------- coqlspclient/coq_structs.py | 2 ++ coqlspclient/proof_state.py | 7 +++-- pylspclient/lsp_structs.py | 8 ++++++ tests/test_coq_file.py | 13 ++++++---- 5 files changed, 48 insertions(+), 33 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index a54e716..3c39caf 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -58,13 +58,18 @@ def __init__( text, text_id = "\n".join(self.__lines), TextDocumentIdentifier(uri) self.coq_lsp_client.didOpen(TextDocumentItem(uri, "coq", 1, text)) self.ast = self.coq_lsp_client.get_document(text_id).spans + self.steps_taken: int = 0 + self.steps: List[Step] = [] + for curr_step in self.ast: + text = self.__get_text(curr_step.range) + self.steps.append(Step(text, curr_step)) + self.steps_taken += 1 + self.steps_taken = 0 except Exception as e: self.__handle_exception(e) raise e - self.__first_error: Optional[Diagnostic] = None self.__validate() - self.steps_taken: int = 0 self.curr_module: List[str] = [] self.curr_module_type: List[str] = [] self.curr_section: List[str] = [] @@ -107,8 +112,8 @@ def __handle_exception(self, e): def __validate(self): uri = f"file://{self.__path}" + self.is_valid = True if uri not in self.coq_lsp_client.lsp_endpoint.diagnostics: - self.is_valid = True return self.coq_lsp_client.lsp_endpoint.diagnostics[uri].sort( @@ -117,9 +122,22 @@ def __validate(self): for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: if diagnostic.severity == 1: self.is_valid = False - self.__first_error = diagnostic - return - self.is_valid = True + + for step in self.steps: + if (diagnostic.range.start.line < step.ast.range.start.line) or ( + diagnostic.range.start.line == step.ast.range.start.line + and diagnostic.range.start.character < step.ast.range.start.character + ): + early_range, late_range = diagnostic.range, step.ast.range + else: + early_range, late_range = step.ast.range, diagnostic.range + + if (late_range.start.line < early_range.end.line) or ( + late_range.start.line == early_range.end.line + and late_range.start.character < early_range.end.character + ): + step.diagnostics.append(diagnostic) + break @staticmethod def get_id(id: List) -> str: @@ -201,10 +219,6 @@ def slice_line( text = "\n".join(lines) return " ".join(text.split()) if trim else text - def __step_text(self, trim=False): - curr_step = self.ast[self.steps_taken] - return self.__get_text(curr_step.range, trim=trim) - def __add_term(self, name: str, ast: RangedSpan, text: str, term_type: TermType): term = Term(text, ast, term_type, self.path, self.curr_module[:]) if term.type == TermType.NOTATION: @@ -359,7 +373,7 @@ def traverse_expr(expr): # TODO: A negative sign should handle things differently. For example: # - names should be removed from the context # - curr_module should change as you leave or re-enter modules - text = self.__step_text(trim=True) + text = self.__get_text(self.steps[self.steps_taken].ast.range, trim=True) # FIXME Let (and maybe Variable) should be handled. However, # I think we can't handle them as normal Locals since they are # specific to a section. @@ -520,26 +534,15 @@ def exec(self, nsteps=1) -> List[Step]: Returns: List[Step]: List of steps executed. """ - steps: List[Step] = [] sign = 1 if nsteps > 0 else -1 + initial_steps_taken = self.steps_taken nsteps = min( nsteps * sign, len(self.ast) - self.steps_taken if sign > 0 else self.steps_taken, ) for _ in range(nsteps): - ast = self.ast[self.steps_taken] - if not self.is_valid and ( - (ast.range.end.line > self.__first_error.range.start.line) - or ( - ast.range.end.line == self.__first_error.range.start.line - and ast.range.end.character - >= self.__first_error.range.start.character - ) - ): - raise RuntimeError(self.__first_error.message) - steps.append(Step(self.__step_text(), ast)) self.__process_step(sign) - return steps + return self.steps[initial_steps_taken:self.steps_taken] def run(self) -> List[Step]: """Executes all the steps in the file. diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 5cc8ec4..92a491f 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Dict, List from coqlspclient.coq_lsp_structs import RangedSpan, GoalAnswer +from pylspclient.lsp_structs import Diagnostic class SegmentType(Enum): @@ -53,6 +54,7 @@ class Step(object): def __init__(self, text: str, ast: RangedSpan): self.text = text self.ast = ast + self.diagnostics: List[Diagnostic] = [] class Term: diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index a265463..e6198d7 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -101,12 +101,11 @@ def __get_queries(self, keyword): searches = {} lines = self.read().split("\n") for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: - command = lines[ - diagnostic.range.start.line : diagnostic.range.end.line + 1 - ] + command = lines[diagnostic.range.start.line : diagnostic.range.end.line + 1] if len(command) == 1: command[0] = command[0][ - diagnostic.range.start.character : diagnostic.range.end.character + 1 + diagnostic.range.start.character : diagnostic.range.end.character + + 1 ] else: command[0] = command[0][diagnostic.range.start.character :] diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index 7e39ed4..cded2ad 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -26,6 +26,11 @@ def __init__(self, line, character, offset=0): self.character = character self.offset = offset + def __repr__(self) -> str: + return str( + {"line": self.line, "character": self.character, "offset": self.offset} + ) + class Range(object): def __init__(self, start, end): @@ -38,6 +43,9 @@ def __init__(self, start, end): self.start = to_type(start, Position) self.end = to_type(end, Position) + def __repr__(self) -> str: + return str({"start": repr(self.start), "end": repr(self.end)}) + class Location(object): """ diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index aa6340a..508cc99 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -62,16 +62,19 @@ def test_get_notation(setup, teardown): @pytest.mark.parametrize("setup", ["test_invalid_1.v"], indirect=True) def test_is_invalid_1(setup, teardown): assert not coq_file.is_valid - with pytest.raises(RuntimeError): - coq_file.run() + steps = coq_file.run() + assert len(steps[11].diagnostics) == 1 + assert steps[11].diagnostics[0].message == "Found no subterm matching \"0 + ?M152\" in the current goal." + assert steps[11].diagnostics[0].severity == 1 @pytest.mark.parametrize("setup", ["test_invalid_2.v"], indirect=True) def test_is_invalid_2(setup, teardown): assert not coq_file.is_valid - with pytest.raises(RuntimeError): - coq_file.run() - + steps = coq_file.run() + assert len(steps[15].diagnostics) == 1 + assert steps[15].diagnostics[0].message == "Syntax error: '.' expected after [command] (in [vernac_aux])." + assert steps[15].diagnostics[0].severity == 1 @pytest.mark.parametrize("setup", ["test_module_type.v"], indirect=True) def test_module_type(setup, teardown): From 69f9bdfc9604c350ca41f72fb36d859d026d6521 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 21 Sep 2023 18:04:54 +0100 Subject: [PATCH 20/27] Fix diagnostic range for steps --- coqlspclient/coq_file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 3c39caf..3b20693 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -126,7 +126,7 @@ def __validate(self): for step in self.steps: if (diagnostic.range.start.line < step.ast.range.start.line) or ( diagnostic.range.start.line == step.ast.range.start.line - and diagnostic.range.start.character < step.ast.range.start.character + and diagnostic.range.start.character <= step.ast.range.start.character ): early_range, late_range = diagnostic.range, step.ast.range else: @@ -134,7 +134,7 @@ def __validate(self): if (late_range.start.line < early_range.end.line) or ( late_range.start.line == early_range.end.line - and late_range.start.character < early_range.end.character + and late_range.start.character <= early_range.end.character ): step.diagnostics.append(diagnostic) break From 22835d27021c4aeb779d51077d24f709b33324d3 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Thu, 21 Sep 2023 19:19:30 +0100 Subject: [PATCH 21/27] Comparison operators for Position --- coqlspclient/coq_file.py | 19 ++++--------------- coqlspclient/proof_state.py | 5 +---- pylspclient/lsp_structs.py | 36 ++++++++++++++++++++++++++++++++++++ tests/test_coq_file.py | 11 +++++++++-- 4 files changed, 50 insertions(+), 21 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 3b20693..fff2b65 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -116,25 +116,14 @@ def __validate(self): if uri not in self.coq_lsp_client.lsp_endpoint.diagnostics: return - self.coq_lsp_client.lsp_endpoint.diagnostics[uri].sort( - key=lambda d: d.range.start.line - ) for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: if diagnostic.severity == 1: self.is_valid = False for step in self.steps: - if (diagnostic.range.start.line < step.ast.range.start.line) or ( - diagnostic.range.start.line == step.ast.range.start.line - and diagnostic.range.start.character <= step.ast.range.start.character - ): - early_range, late_range = diagnostic.range, step.ast.range - else: - early_range, late_range = step.ast.range, diagnostic.range - - if (late_range.start.line < early_range.end.line) or ( - late_range.start.line == early_range.end.line - and late_range.start.character <= early_range.end.character + if ( + step.ast.range.start <= diagnostic.range.start + and step.ast.range.end >= diagnostic.range.end ): step.diagnostics.append(diagnostic) break @@ -542,7 +531,7 @@ def exec(self, nsteps=1) -> List[Step]: ) for _ in range(nsteps): self.__process_step(sign) - return self.steps[initial_steps_taken:self.steps_taken] + return self.steps[initial_steps_taken : self.steps_taken] def run(self) -> List[Step]: """Executes all the steps in the file. diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index e6198d7..7fbaeee 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -271,10 +271,7 @@ def __get_last_term(self): return None last_term = terms[0] for term in terms[1:]: - if (term.ast.range.end.line > last_term.ast.range.end.line) or ( - term.ast.range.end.line == last_term.ast.range.end.line - and term.ast.range.end.character > last_term.ast.range.end.character - ): + if last_term.ast.range.end < term.ast.range.end: last_term = term return last_term diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index cded2ad..1d4741f 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -31,6 +31,42 @@ def __repr__(self) -> str: {"line": self.line, "character": self.character, "offset": self.offset} ) + def __eq__(self, __value: object) -> bool: + return isinstance(__value, Position) and (self.line, self.character) == ( + __value.line, + __value.character, + ) + + def __ne__(self, __value: object) -> bool: + return not isinstance(__value, Position) or (self.line, self.character) != ( + __value.line, + __value.character, + ) + + def __gt__(self, __value: object) -> bool: + return isinstance(__value, Position) and (self.line, self.character) > ( + __value.line, + __value.character, + ) + + def __ge__(self, __value: object) -> bool: + return isinstance(__value, Position) and (self.line, self.character) >= ( + __value.line, + __value.character, + ) + + def __lt__(self, __value: object) -> bool: + return isinstance(__value, Position) and (self.line, self.character) < ( + __value.line, + __value.character, + ) + + def __le__(self, __value: object) -> bool: + return isinstance(__value, Position) and (self.line, self.character) <= ( + __value.line, + __value.character, + ) + class Range(object): def __init__(self, start, end): diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index 508cc99..a3eb435 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -64,7 +64,10 @@ def test_is_invalid_1(setup, teardown): assert not coq_file.is_valid steps = coq_file.run() assert len(steps[11].diagnostics) == 1 - assert steps[11].diagnostics[0].message == "Found no subterm matching \"0 + ?M152\" in the current goal." + assert ( + steps[11].diagnostics[0].message + == 'Found no subterm matching "0 + ?M152" in the current goal.' + ) assert steps[11].diagnostics[0].severity == 1 @@ -73,9 +76,13 @@ def test_is_invalid_2(setup, teardown): assert not coq_file.is_valid steps = coq_file.run() assert len(steps[15].diagnostics) == 1 - assert steps[15].diagnostics[0].message == "Syntax error: '.' expected after [command] (in [vernac_aux])." + assert ( + steps[15].diagnostics[0].message + == "Syntax error: '.' expected after [command] (in [vernac_aux])." + ) assert steps[15].diagnostics[0].severity == 1 + @pytest.mark.parametrize("setup", ["test_module_type.v"], indirect=True) def test_module_type(setup, teardown): coq_file.run() From 245c09d116e11224fe63e0ba5b4d2bad9cf0fe17 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 22 Sep 2023 00:01:56 +0100 Subject: [PATCH 22/27] Remove lines and AST from CoqFile, keep list of Steps --- README.md | 2 - coqlspclient/coq_file.py | 128 ++++++++++++++++++++------------------- 2 files changed, 67 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 0027b9a..6f98121 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,6 @@ from coqlspclient.coq_structs import TermType # Open Coq file with CoqFile(os.path.join(os.getcwd(), "examples/example.v")) as coq_file: - # Print AST - print(coq_file.ast) coq_file.exec(nsteps=2) # Get all terms defined until now print("Number of terms:", len(coq_file.context.terms)) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index fff2b65..0d37101 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -52,23 +52,16 @@ def __init__( self.coq_lsp_client = CoqLspClient(uri, timeout=timeout) uri = f"file://{self.__path}" with open(self.__path, "r") as f: - self.__lines = f.read().split("\n") + text = f.read() try: - text, text_id = "\n".join(self.__lines), TextDocumentIdentifier(uri) self.coq_lsp_client.didOpen(TextDocumentItem(uri, "coq", 1, text)) - self.ast = self.coq_lsp_client.get_document(text_id).spans - self.steps_taken: int = 0 - self.steps: List[Step] = [] - for curr_step in self.ast: - text = self.__get_text(curr_step.range) - self.steps.append(Step(text, curr_step)) - self.steps_taken += 1 - self.steps_taken = 0 + ast = self.coq_lsp_client.get_document(TextDocumentIdentifier(uri)).spans except Exception as e: self.__handle_exception(e) raise e + self.__init_steps(text, ast) self.__validate() self.curr_module: List[str] = [] self.curr_module_type: List[str] = [] @@ -110,6 +103,22 @@ def __handle_exception(self, e): if self.__from_lib: os.remove(self.__path) + def __init_steps(self, text: str, ast: List[RangedSpan]): + lines = text.split("\n") + self.steps: List[Step] = [] + for i, curr_ast in enumerate(ast): + start_line = 0 if i == 0 else ast[i - 1].range.end.line + start_char = 0 if i == 0 else ast[i - 1].range.end.character + end_line = curr_ast.range.end.line + end_char = curr_ast.range.end.character + + curr_lines = lines[start_line : end_line + 1] + curr_lines[-1] = curr_lines[-1][:end_char] + curr_lines[0] = curr_lines[0][start_char:] + step_text = "\n".join(curr_lines) + self.steps.append(Step(step_text, curr_ast)) + self.steps_taken: int = 0 + def __validate(self): uri = f"file://{self.__path}" self.is_valid = True @@ -128,6 +137,14 @@ def __validate(self): step.diagnostics.append(diagnostic) break + @property + def __curr_step(self): + return self.steps[self.steps_taken] + + @property + def __prev_step(self): + return self.steps[self.steps_taken - 1] + @staticmethod def get_id(id: List) -> str: if id[0] == "Ser_Qualid": @@ -171,42 +188,39 @@ def expr(step: RangedSpan) -> Optional[List]: return [None] - def __step_expr(self): - curr_step = self.ast[self.steps_taken] - return CoqFile.expr(curr_step) - - def __get_text(self, range: Range, trim: bool = False, encode: bool = False): + def __short_text(self, range: Optional[Range] = None): def slice_line( line: str, start: Optional[int] = None, stop: Optional[int] = None ): - # The encode flag indicates if range.character is measured in bytes, - # rather than characters. If true, the string must be encoded before - # indexing. This special treatment is necessary for handling Unicode - # characters which take up more than 1 byte. - return ( - line.encode("utf-8")[start:stop].decode() - if encode - else line[start:stop] - ) - - end_line = range.end.line - end_character = range.end.character + if range is None: + return line[start:stop] - if trim: - start_line = range.start.line - start_character = range.start.character + # A range will be provided when range.character is measured in bytes, + # rather than characters. If so, the string must be encoded before + # indexing. This special treatment is necessary for handling Unicode + # characters which take up more than 1 byte. Currently necessary to + # handle where-notations. + line = line.encode("utf-8") + + # For where-notations, the character count does not include the closing + # parenthesis when present. + if stop is not None and chr(line[stop]) == ")": + stop += 1 + return line[start:stop].decode() + + curr_range = self.__curr_step.ast.range if range is None else range + nlines = curr_range.end.line - curr_range.start.line + 1 + lines = self.__curr_step.text.split("\n")[-nlines:] + + prev_range = None if self.steps_taken == 0 else self.__prev_step.ast.range + if prev_range is None or prev_range.end.line < curr_range.start.line: + start = curr_range.start.character else: - prev_step = ( - None if self.steps_taken == 0 else self.ast[self.steps_taken - 1] - ) - start_line = 0 if prev_step is None else prev_step.range.end.line - start_character = 0 if prev_step is None else prev_step.range.end.character + start = curr_range.start.character - prev_range.end.character - lines = self.__lines[start_line : end_line + 1] - lines[-1] = slice_line(lines[-1], stop=end_character) - lines[0] = slice_line(lines[0], start=start_character) - text = "\n".join(lines) - return " ".join(text.split()) if trim else text + lines[-1] = slice_line(lines[-1], stop=curr_range.end.character) + lines[0] = slice_line(lines[0], start=start) + return " ".join(" ".join(lines).split()) def __add_term(self, name: str, ast: RangedSpan, text: str, term_type: TermType): term = Term(text, ast, term_type, self.path, self.curr_module[:]) @@ -310,10 +324,8 @@ def __handle_where_notations(self, expr: List, term_type: TermType): span["decl_ntn_interp"]["loc"]["ep"] - span["decl_ntn_interp"]["loc"]["bol_pos"], ) - if chr(self.__lines[end.line].encode("utf-8")[end.character]) == ")": - end.character += 1 range = Range(start, end) - text = self.__get_text(range, trim=True, encode=True) + text = self.__short_text(range=range) name = FileContext.get_notation_key( span["decl_ntn_string"]["v"], span["decl_ntn_scope"] ) @@ -362,7 +374,7 @@ def traverse_expr(expr): # TODO: A negative sign should handle things differently. For example: # - names should be removed from the context # - curr_module should change as you leave or re-enter modules - text = self.__get_text(self.steps[self.steps_taken].ast.range, trim=True) + text = self.__short_text() # FIXME Let (and maybe Variable) should be handled. However, # I think we can't handle them as normal Locals since they are # specific to a section. @@ -375,7 +387,7 @@ def traverse_expr(expr): ]: if text.startswith(keyword): return - expr = self.__step_expr() + expr = CoqFile.expr(self.__curr_step.ast) if expr == [None]: return if expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve": @@ -413,16 +425,12 @@ def traverse_expr(expr): name = text.split('"')[1].strip() if text[:-1].split(":")[-1].endswith("_scope"): name += " : " + text[:-1].split(":")[-1].strip() - self.__add_term( - name, self.ast[self.steps_taken], text, TermType.NOTATION - ) + self.__add_term(name, self.__curr_step.ast, text, TermType.NOTATION) elif expr[0] == "VernacSyntacticDefinition": name = text.split(" ")[1] if text[:-1].split(":")[-1].endswith("_scope"): name += " : " + text[:-1].split(":")[-1].strip() - self.__add_term( - name, self.ast[self.steps_taken], text, TermType.NOTATION - ) + self.__add_term(name, self.__curr_step.ast, text, TermType.NOTATION) elif expr[0] == "VernacDefinition" and expr[2][0]["v"][0] == "Anonymous": # We associate the anonymous term to the same name used internally by Coq if self.__anonymous_id is None: @@ -430,17 +438,17 @@ def traverse_expr(expr): else: name = f"Unnamed_thm{self.__anonymous_id}" self.__anonymous_id += 1 - self.__add_term(name, self.ast[self.steps_taken], text, term_type) + self.__add_term(name, self.__curr_step.ast, text, term_type) elif term_type == TermType.DERIVE: name = CoqFile.get_ident(expr[2][0]) - self.__add_term(name, self.ast[self.steps_taken], text, term_type) + self.__add_term(name, self.__curr_step.ast, text, term_type) if expr[1][0] == "Derive": prop = CoqFile.get_ident(expr[2][2]) - self.__add_term(prop, self.ast[self.steps_taken], text, term_type) + self.__add_term(prop, self.__curr_step.ast, text, term_type) else: names = traverse_expr(expr) for name in names: - self.__add_term(name, self.ast[self.steps_taken], text, term_type) + self.__add_term(name, self.__curr_step.ast, text, term_type) self.__handle_where_notations(expr, term_type) finally: @@ -461,7 +469,7 @@ def checked(self) -> bool: Returns: bool: True if the whole file was already executed """ - return self.steps_taken == len(self.ast) + return self.steps_taken == len(self.steps) @property def in_proof(self) -> bool: @@ -527,7 +535,7 @@ def exec(self, nsteps=1) -> List[Step]: initial_steps_taken = self.steps_taken nsteps = min( nsteps * sign, - len(self.ast) - self.steps_taken if sign > 0 else self.steps_taken, + len(self.steps) - self.steps_taken if sign > 0 else self.steps_taken, ) for _ in range(nsteps): self.__process_step(sign) @@ -539,7 +547,7 @@ def run(self) -> List[Step]: Returns: List[Step]: List of all the steps in the file. """ - return self.exec(len(self.ast)) + return self.exec(len(self.steps)) def current_goals(self) -> Optional[GoalAnswer]: """Get goals in current position. @@ -549,9 +557,7 @@ def current_goals(self) -> Optional[GoalAnswer]: """ uri = f"file://{self.__path}" end_pos = ( - Position(0, 0) - if self.steps_taken == 0 - else self.ast[self.steps_taken - 1].range.end + Position(0, 0) if self.steps_taken == 0 else self.__prev_step.ast.range.end ) try: return self.coq_lsp_client.proof_goals(TextDocumentIdentifier(uri), end_pos) From 36e048c23c44b7fcb03241545cc4df940dcdfd4c Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 22 Sep 2023 16:46:47 +0100 Subject: [PATCH 23/27] Temporary fix to anonymous instances --- coqlspclient/coq_file.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 0d37101..6abefcc 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -431,6 +431,9 @@ def traverse_expr(expr): if text[:-1].split(":")[-1].endswith("_scope"): name += " : " + text[:-1].split(":")[-1].strip() self.__add_term(name, self.__curr_step.ast, text, TermType.NOTATION) + elif expr[0] == "VernacInstance" and expr[1][0]["v"][0] == "Anonymous": + # FIXME: The name should be "_instance_N" + self.__add_term("_anonymous", self.__curr_step.ast, text, term_type) elif expr[0] == "VernacDefinition" and expr[2][0]["v"][0] == "Anonymous": # We associate the anonymous term to the same name used internally by Coq if self.__anonymous_id is None: From 4cb9601143170a177569e071bc619ecd517f8b3e Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 22 Sep 2023 17:44:46 +0100 Subject: [PATCH 24/27] Simplify comparison operators --- pylspclient/lsp_structs.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index 1d4741f..35711fe 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -37,35 +37,26 @@ def __eq__(self, __value: object) -> bool: __value.character, ) - def __ne__(self, __value: object) -> bool: - return not isinstance(__value, Position) or (self.line, self.character) != ( - __value.line, - __value.character, - ) - def __gt__(self, __value: object) -> bool: return isinstance(__value, Position) and (self.line, self.character) > ( __value.line, __value.character, ) - def __ge__(self, __value: object) -> bool: - return isinstance(__value, Position) and (self.line, self.character) >= ( - __value.line, - __value.character, - ) - def __lt__(self, __value: object) -> bool: return isinstance(__value, Position) and (self.line, self.character) < ( __value.line, __value.character, ) + def __ne__(self, __value: object) -> bool: + return not self.__eq__(__value) + + def __ge__(self, __value: object) -> bool: + return not self.__lt__(__value) + def __le__(self, __value: object) -> bool: - return isinstance(__value, Position) and (self.line, self.character) <= ( - __value.line, - __value.character, - ) + return not self.__gt__(__value) class Range(object): From 5ef87c2ab8932fd08230158be853e621c7de2de5 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 22 Sep 2023 17:48:13 +0100 Subject: [PATCH 25/27] minor --- pylspclient/lsp_structs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index 35711fe..174ae2e 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -44,10 +44,7 @@ def __gt__(self, __value: object) -> bool: ) def __lt__(self, __value: object) -> bool: - return isinstance(__value, Position) and (self.line, self.character) < ( - __value.line, - __value.character, - ) + return not self.__eq__(__value) and not self.__gt__(__value) def __ne__(self, __value: object) -> bool: return not self.__eq__(__value) From c1188cae5bda5f9c33063af4f6b14f292fe6612f Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 22 Sep 2023 19:05:04 +0100 Subject: [PATCH 26/27] Raise exception in comparison between different types --- pylspclient/lsp_structs.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index 174ae2e..c608235 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -38,10 +38,9 @@ def __eq__(self, __value: object) -> bool: ) def __gt__(self, __value: object) -> bool: - return isinstance(__value, Position) and (self.line, self.character) > ( - __value.line, - __value.character, - ) + if not isinstance(__value, Position): + raise TypeError(f"Invalid type for comparison: {type(__value).__name__}") + return (self.line, self.character) > (__value.line, __value.character) def __lt__(self, __value: object) -> bool: return not self.__eq__(__value) and not self.__gt__(__value) From e67c1ec46a3e5594f1cbabc76a9dd10b70631079 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Fri, 22 Sep 2023 22:19:28 +0100 Subject: [PATCH 27/27] Update to Coq 8.18 --- coqlspclient/coq_file.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 6abefcc..40138c8 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -184,7 +184,8 @@ def expr(step: RangedSpan) -> Optional[List]: and isinstance(step.span["v"], dict) and "expr" in step.span["v"] ): - return step.span["v"]["expr"] + # We ignore the tags [VernacSynterp] and [VernacSynPure] + return step.span["v"]["expr"][1] return [None] @@ -315,22 +316,22 @@ def __handle_where_notations(self, expr: List, term_type: TermType): # handles when multiple notations are defined for span in spans: start = Position( - span["decl_ntn_string"]["loc"]["line_nb"] - 1, - span["decl_ntn_string"]["loc"]["bp"] - - span["decl_ntn_string"]["loc"]["bol_pos"], + span["ntn_decl_string"]["loc"]["line_nb"] - 1, + span["ntn_decl_string"]["loc"]["bp"] + - span["ntn_decl_string"]["loc"]["bol_pos"], ) end = Position( - span["decl_ntn_interp"]["loc"]["line_nb_last"] - 1, - span["decl_ntn_interp"]["loc"]["ep"] - - span["decl_ntn_interp"]["loc"]["bol_pos"], + span["ntn_decl_interp"]["loc"]["line_nb_last"] - 1, + span["ntn_decl_interp"]["loc"]["ep"] + - span["ntn_decl_interp"]["loc"]["bol_pos"], ) range = Range(start, end) text = self.__short_text(range=range) name = FileContext.get_notation_key( - span["decl_ntn_string"]["v"], span["decl_ntn_scope"] + span["ntn_decl_string"]["v"], span["ntn_decl_scope"] ) - if span["decl_ntn_scope"] is not None: - text += " : " + span["decl_ntn_scope"] + if span["ntn_decl_scope"] is not None: + text += " : " + span["ntn_decl_scope"] text = "Notation " + text self.__add_term(name, RangedSpan(range, span), text, TermType.NOTATION)