diff --git a/README.md b/README.md index 0027b9a..6f98121 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,6 @@ from coqlspclient.coq_structs import TermType # Open Coq file with CoqFile(os.path.join(os.getcwd(), "examples/example.v")) as coq_file: - # Print AST - print(coq_file.ast) coq_file.exec(nsteps=2) # Get all terms defined until now print("Number of terms:", len(coq_file.context.terms)) diff --git a/coqlspclient/coq_file.py b/coqlspclient/coq_file.py index 039d46c..40138c8 100644 --- a/coqlspclient/coq_file.py +++ b/coqlspclient/coq_file.py @@ -3,7 +3,7 @@ import uuid import tempfile from pylspclient.lsp_structs import TextDocumentItem, TextDocumentIdentifier -from pylspclient.lsp_structs import ResponseError, ErrorCodes +from pylspclient.lsp_structs import ResponseError, ErrorCodes, Diagnostic from coqlspclient.coq_lsp_structs import Position, GoalAnswer, RangedSpan, Range from coqlspclient.coq_structs import Step, FileContext, Term, TermType, SegmentType from coqlspclient.coq_lsp_client import CoqLspClient @@ -31,7 +31,7 @@ def __init__( self, file_path: str, library: Optional[str] = None, - timeout: int = 2, + timeout: int = 30, workspace: Optional[str] = None, ): """Creates a CoqFile. @@ -52,26 +52,23 @@ def __init__( self.coq_lsp_client = CoqLspClient(uri, timeout=timeout) uri = f"file://{self.__path}" with open(self.__path, "r") as f: - self.__lines = f.read().split("\n") + text = f.read() try: - self.coq_lsp_client.didOpen( - TextDocumentItem(uri, "coq", 1, "\n".join(self.__lines)) - ) - self.ast = self.coq_lsp_client.get_document( - TextDocumentIdentifier(uri) - ).spans + self.coq_lsp_client.didOpen(TextDocumentItem(uri, "coq", 1, text)) + ast = self.coq_lsp_client.get_document(TextDocumentIdentifier(uri)).spans except Exception as e: self.__handle_exception(e) raise e + self.__init_steps(text, ast) self.__validate() - self.steps_taken: int = 0 self.curr_module: List[str] = [] self.curr_module_type: List[str] = [] self.curr_section: List[str] = [] self.__segment_stack: List[SegmentType] = [] self.context = FileContext() + self.__anonymous_id: Optional[int] = None def __enter__(self): return self @@ -81,7 +78,7 @@ def __exit__(self, exc_type, exc_value, traceback): def __init_path(self, file_path, library): self.file_module = [] if library is None else library.split(".") - self.__from_lib = self.file_module[:1] == ["Coq"] + self.__from_lib = self.file_module[:2] == ["Coq", "Init"] self.path = file_path if not self.__from_lib: self.__path = file_path @@ -97,31 +94,86 @@ def __init_path(self, file_path, library): self.__path = new_path def __handle_exception(self, e): - if not (isinstance(e, ResponseError) and e.code == ErrorCodes.ServerQuit.value): + if not isinstance(e, ResponseError) or e.code not in [ + ErrorCodes.ServerQuit.value, + ErrorCodes.ServerTimeout.value, + ]: self.coq_lsp_client.shutdown() self.coq_lsp_client.exit() if self.__from_lib: os.remove(self.__path) + def __init_steps(self, text: str, ast: List[RangedSpan]): + lines = text.split("\n") + self.steps: List[Step] = [] + for i, curr_ast in enumerate(ast): + start_line = 0 if i == 0 else ast[i - 1].range.end.line + start_char = 0 if i == 0 else ast[i - 1].range.end.character + end_line = curr_ast.range.end.line + end_char = curr_ast.range.end.character + + curr_lines = lines[start_line : end_line + 1] + curr_lines[-1] = curr_lines[-1][:end_char] + curr_lines[0] = curr_lines[0][start_char:] + step_text = "\n".join(curr_lines) + self.steps.append(Step(step_text, curr_ast)) + self.steps_taken: int = 0 + def __validate(self): uri = f"file://{self.__path}" + self.is_valid = True if uri not in self.coq_lsp_client.lsp_endpoint.diagnostics: - self.is_valid = True return for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: if diagnostic.severity == 1: self.is_valid = False - return - self.is_valid = True + + for step in self.steps: + if ( + step.ast.range.start <= diagnostic.range.start + and step.ast.range.end >= diagnostic.range.end + ): + step.diagnostics.append(diagnostic) + break + + @property + def __curr_step(self): + return self.steps[self.steps_taken] + + @property + def __prev_step(self): + return self.steps[self.steps_taken - 1] @staticmethod def get_id(id: List) -> str: if id[0] == "Ser_Qualid": - return ".".join([l[1] for l in id[1][1][::-1]] + [id[2][1]]) + return ".".join([l[1] for l in reversed(id[1][1])] + [id[2][1]]) elif id[0] == "Id": return id[1] - return "" + return None + + @staticmethod + def get_ident(el: List) -> Optional[str]: + if ( + len(el) == 3 + and el[0] == "GenArg" + and el[1][0] == "Rawwit" + and el[1][1][0] == "ExtraArg" + ): + if el[1][1][1] == "identref": + return el[2][0][1][1] + elif el[1][1][1] == "ident": + return el[2][1] + return None + + @staticmethod + def get_v(el): + if isinstance(el, dict) and "v" in el: + return el["v"] + elif isinstance(el, list) and len(el) == 2 and el[0] == "v": + return el[1] + return None @staticmethod def expr(step: RangedSpan) -> Optional[List]: @@ -132,37 +184,44 @@ def expr(step: RangedSpan) -> Optional[List]: and isinstance(step.span["v"], dict) and "expr" in step.span["v"] ): - return step.span["v"]["expr"] + # We ignore the tags [VernacSynterp] and [VernacSynPure] + return step.span["v"]["expr"][1] return [None] - def __step_expr(self): - curr_step = self.ast[self.steps_taken] - return CoqFile.expr(curr_step) - - def __get_text(self, range: Range, trim: bool = False): - end_line = range.end.line - end_character = range.end.character - - if trim: - start_line = range.start.line - start_character = range.start.character + def __short_text(self, range: Optional[Range] = None): + def slice_line( + line: str, start: Optional[int] = None, stop: Optional[int] = None + ): + if range is None: + return line[start:stop] + + # A range will be provided when range.character is measured in bytes, + # rather than characters. If so, the string must be encoded before + # indexing. This special treatment is necessary for handling Unicode + # characters which take up more than 1 byte. Currently necessary to + # handle where-notations. + line = line.encode("utf-8") + + # For where-notations, the character count does not include the closing + # parenthesis when present. + if stop is not None and chr(line[stop]) == ")": + stop += 1 + return line[start:stop].decode() + + curr_range = self.__curr_step.ast.range if range is None else range + nlines = curr_range.end.line - curr_range.start.line + 1 + lines = self.__curr_step.text.split("\n")[-nlines:] + + prev_range = None if self.steps_taken == 0 else self.__prev_step.ast.range + if prev_range is None or prev_range.end.line < curr_range.start.line: + start = curr_range.start.character else: - prev_step = ( - None if self.steps_taken == 0 else self.ast[self.steps_taken - 1] - ) - start_line = 0 if prev_step is None else prev_step.range.end.line - start_character = 0 if prev_step is None else prev_step.range.end.character - - lines = self.__lines[start_line : end_line + 1] - lines[-1] = lines[-1][: end_character + 1] - lines[0] = lines[0][start_character:] - text = "\n".join(lines) - return " ".join(text.split()) if trim else text + start = curr_range.start.character - prev_range.end.character - def __step_text(self, trim=False): - curr_step = self.ast[self.steps_taken] - return self.__get_text(curr_step.range, trim=trim) + lines[-1] = slice_line(lines[-1], stop=curr_range.end.character) + lines[0] = slice_line(lines[0], start=start) + return " ".join(" ".join(lines).split()) def __add_term(self, name: str, ast: RangedSpan, text: str, term_type: TermType): term = Term(text, ast, term_type, self.path, self.curr_module[:]) @@ -174,7 +233,7 @@ def __add_term(self, name: str, ast: RangedSpan, text: str, term_type: TermType) self.context.update(terms={full_name(name): term}) curr_file_module = "" - for module in self.file_module[::-1]: + for module in reversed(self.file_module): curr_file_module = module + "." + curr_file_module self.context.update(terms={curr_file_module + name: term}) @@ -194,20 +253,32 @@ def __get_term_type(expr: List) -> TermType: return TermType.RECORD elif expr[0] == "VernacInductive" and expr[1][0] == "Variant": return TermType.VARIANT + elif expr[0] == "VernacInductive" and expr[1][0] == "CoInductive": + return TermType.COINDUCTIVE elif expr[0] == "VernacInductive": return TermType.INDUCTIVE elif expr[0] == "VernacInstance": return TermType.INSTANCE + elif expr[0] == "VernacCoFixpoint": + return TermType.COFIXPOINT elif expr[0] == "VernacFixpoint": return TermType.FIXPOINT elif expr[0] == "VernacScheme": return TermType.SCHEME + elif expr[0] == "VernacExtend" and expr[1][0].startswith("Derive"): + return TermType.DERIVE + elif expr[0] == "VernacExtend" and expr[1][0].startswith("AddSetoid"): + return TermType.SETOID + elif expr[0] == "VernacExtend" and expr[1][0].startswith( + ("AddRelation", "AddParametricRelation") + ): + return TermType.RELATION elif ( - len(expr) > 1 - and isinstance(expr[1], list) - and expr[1][0] == "VernacDeclareTacticDefinition" + expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition" ): return TermType.TACTIC + elif expr[0] == "VernacExtend" and expr[1][0] == "Function": + return TermType.FUNCTION else: return TermType.OTHER @@ -244,66 +315,67 @@ def __handle_where_notations(self, expr: List, term_type: TermType): # handles when multiple notations are defined for span in spans: - range = Range( - Position( - span["decl_ntn_string"]["loc"]["line_nb"] - 1, - span["decl_ntn_string"]["loc"]["bp"] - - span["decl_ntn_string"]["loc"]["bol_pos"], - ), - Position( - span["decl_ntn_interp"]["loc"]["line_nb_last"] - 1, - span["decl_ntn_interp"]["loc"]["ep"] - - span["decl_ntn_interp"]["loc"]["bol_pos"], - ), + start = Position( + span["ntn_decl_string"]["loc"]["line_nb"] - 1, + span["ntn_decl_string"]["loc"]["bp"] + - span["ntn_decl_string"]["loc"]["bol_pos"], + ) + end = Position( + span["ntn_decl_interp"]["loc"]["line_nb_last"] - 1, + span["ntn_decl_interp"]["loc"]["ep"] + - span["ntn_decl_interp"]["loc"]["bol_pos"], ) - text = self.__get_text(range, trim=True) + range = Range(start, end) + text = self.__short_text(range=range) name = FileContext.get_notation_key( - span["decl_ntn_string"]["v"], span["decl_ntn_scope"] + span["ntn_decl_string"]["v"], span["ntn_decl_scope"] ) - if span["decl_ntn_scope"] is not None: - text += " : " + span["decl_ntn_scope"] + if span["ntn_decl_scope"] is not None: + text += " : " + span["ntn_decl_scope"] text = "Notation " + text self.__add_term(name, RangedSpan(range, span), text, TermType.NOTATION) - def __get_tactic_name(self, expr): - if len(expr[2][0][2][0][1][0]) == 2 and expr[2][0][2][0][1][0][0] == "v": - name = CoqFile.get_id(expr[2][0][2][0][1][0][1]) - if name != "": - return name - - return None - def __process_step(self, sign): - def traverse_ast(el, inductive=False): - if isinstance(el, dict): - if "v" in el and isinstance(el["v"], list) and len(el["v"]) == 2: - if el["v"][0] == "Id": - return [el["v"][1]] - if el["v"][0] == "Name": - return [el["v"][1][1]] - - return [x for v in el.values() for x in traverse_ast(v, inductive)] - elif isinstance(el, list): - if len(el) > 0: - if el[0] == "CLocalAssum": - return [] - if el[0] == "VernacInductive": - inductive = True - - res = [] - for v in el: - res.extend(traverse_ast(v, inductive)) - if not inductive and len(res) > 0: - return [res[0]] - return res - - return [] + def traverse_expr(expr): + inductive = expr[0] == "VernacInductive" + extend = expr[0] == "VernacExtend" + stack, res = expr[:0:-1], [] + while len(stack) > 0: + el = stack.pop() + v = CoqFile.get_v(el) + if v is not None and isinstance(v, list) and len(v) == 2: + id = CoqFile.get_id(v) + if id is not None: + if not inductive: + return [id] + res.append(id) + elif v[0] == "Name": + if not inductive: + return [v[1][1]] + res.append(v[1][1]) + + elif isinstance(el, dict): + for v in reversed(el.values()): + if isinstance(v, (dict, list)): + stack.append(v) + elif isinstance(el, list): + if len(el) > 0 and el[0] == "CLocalAssum": + continue + + ident = CoqFile.get_ident(el) + if ident is not None and extend: + return [ident] + + for v in reversed(el): + if isinstance(v, (dict, list)): + stack.append(v) + return res try: # TODO: A negative sign should handle things differently. For example: # - names should be removed from the context # - curr_module should change as you leave or re-enter modules - text = self.__step_text(trim=True) + text = self.__short_text() # FIXME Let (and maybe Variable) should be handled. However, # I think we can't handle them as normal Locals since they are # specific to a section. @@ -316,9 +388,11 @@ def traverse_ast(el, inductive=False): ]: if text.startswith(keyword): return - expr = self.__step_expr() + expr = CoqFile.expr(self.__curr_step.ast) if expr == [None]: return + if expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve": + return term_type = CoqFile.__get_term_type(expr) if expr[0] == "VernacEndSegment": @@ -345,32 +419,40 @@ def traverse_ast(el, inductive=False): # and should be overriden. elif len(self.curr_module_type) > 0: return - elif ( - len(expr) >= 2 - and isinstance(expr[1], list) - and len(expr[1]) == 2 - and expr[1][0] == "VernacDeclareTacticDefinition" - ): - name = self.__get_tactic_name(expr) - self.__add_term(name, self.ast[self.steps_taken], text, TermType.TACTIC) + elif expr[0] == "VernacExtend" and expr[1][0] == "VernacTacticNotation": + # FIXME: Handle this case + return elif expr[0] == "VernacNotation": - name = text.split('"')[1] + name = text.split('"')[1].strip() if text[:-1].split(":")[-1].endswith("_scope"): name += " : " + text[:-1].split(":")[-1].strip() - self.__add_term( - name, self.ast[self.steps_taken], text, TermType.NOTATION - ) + self.__add_term(name, self.__curr_step.ast, text, TermType.NOTATION) elif expr[0] == "VernacSyntacticDefinition": name = text.split(" ")[1] if text[:-1].split(":")[-1].endswith("_scope"): name += " : " + text[:-1].split(":")[-1].strip() - self.__add_term( - name, self.ast[self.steps_taken], text, TermType.NOTATION - ) + self.__add_term(name, self.__curr_step.ast, text, TermType.NOTATION) + elif expr[0] == "VernacInstance" and expr[1][0]["v"][0] == "Anonymous": + # FIXME: The name should be "_instance_N" + self.__add_term("_anonymous", self.__curr_step.ast, text, term_type) + elif expr[0] == "VernacDefinition" and expr[2][0]["v"][0] == "Anonymous": + # We associate the anonymous term to the same name used internally by Coq + if self.__anonymous_id is None: + name, self.__anonymous_id = "Unnamed_thm", 0 + else: + name = f"Unnamed_thm{self.__anonymous_id}" + self.__anonymous_id += 1 + self.__add_term(name, self.__curr_step.ast, text, term_type) + elif term_type == TermType.DERIVE: + name = CoqFile.get_ident(expr[2][0]) + self.__add_term(name, self.__curr_step.ast, text, term_type) + if expr[1][0] == "Derive": + prop = CoqFile.get_ident(expr[2][2]) + self.__add_term(prop, self.__curr_step.ast, text, term_type) else: - names = traverse_ast(expr) + names = traverse_expr(expr) for name in names: - self.__add_term(name, self.ast[self.steps_taken], text, term_type) + self.__add_term(name, self.__curr_step.ast, text, term_type) self.__handle_where_notations(expr, term_type) finally: @@ -391,7 +473,7 @@ def checked(self) -> bool: Returns: bool: True if the whole file was already executed """ - return self.steps_taken == len(self.ast) + return self.steps_taken == len(self.steps) @property def in_proof(self) -> bool: @@ -427,6 +509,16 @@ def terms(self) -> List[Term]: ) ) + @property + def diagnostics(self) -> List[Diagnostic]: + """ + Returns: + List[Diagnostic]: The diagnostics of the file. + Includes all messages given by Coq. + """ + uri = f"file://{self.__path}" + return self.coq_lsp_client.lsp_endpoint.diagnostics[uri] + @staticmethod def get_term_type(ast: RangedSpan) -> TermType: expr = CoqFile.expr(ast) @@ -443,16 +535,15 @@ def exec(self, nsteps=1) -> List[Step]: Returns: List[Step]: List of steps executed. """ - steps: List[Step] = [] sign = 1 if nsteps > 0 else -1 + initial_steps_taken = self.steps_taken nsteps = min( nsteps * sign, - len(self.ast) - self.steps_taken if sign > 0 else self.steps_taken, + len(self.steps) - self.steps_taken if sign > 0 else self.steps_taken, ) for _ in range(nsteps): - steps.append(Step(self.__step_text(), self.ast[self.steps_taken])) self.__process_step(sign) - return steps + return self.steps[initial_steps_taken : self.steps_taken] def run(self) -> List[Step]: """Executes all the steps in the file. @@ -460,7 +551,7 @@ def run(self) -> List[Step]: Returns: List[Step]: List of all the steps in the file. """ - return self.exec(len(self.ast)) + return self.exec(len(self.steps)) def current_goals(self) -> Optional[GoalAnswer]: """Get goals in current position. @@ -470,9 +561,7 @@ def current_goals(self) -> Optional[GoalAnswer]: """ uri = f"file://{self.__path}" end_pos = ( - Position(0, 0) - if self.steps_taken == 0 - else self.ast[self.steps_taken - 1].range.end + Position(0, 0) if self.steps_taken == 0 else self.__prev_step.ast.range.end ) try: return self.coq_lsp_client.proof_goals(TextDocumentIdentifier(uri), end_pos) diff --git a/coqlspclient/coq_lsp_client.py b/coqlspclient/coq_lsp_client.py index cd8ebef..306b399 100644 --- a/coqlspclient/coq_lsp_client.py +++ b/coqlspclient/coq_lsp_client.py @@ -25,7 +25,7 @@ class CoqLspClient(LspClient): def __init__( self, root_uri: str, - timeout: int = 2, + timeout: int = 30, memory_limit: int = 2097152, init_options: Dict = __DEFAULT_INIT_OPTIONS, ): @@ -107,7 +107,7 @@ def __wait_for_operation(self): if timeout <= 0: self.shutdown() self.exit() - raise ResponseError(ErrorCodes.ServerQuit, "Server quit") + raise ResponseError(ErrorCodes.ServerTimeout, "Server timeout") def didOpen(self, textDocument: TextDocumentItem): """Open a text document in the server. diff --git a/coqlspclient/coq_lsp_structs.py b/coqlspclient/coq_lsp_structs.py index f457cd1..c81103a 100644 --- a/coqlspclient/coq_lsp_structs.py +++ b/coqlspclient/coq_lsp_structs.py @@ -214,13 +214,3 @@ def parse(coqFileProgressParams: Dict) -> Optional["CoqFileProgressParams"]: ) ) return CoqFileProgressParams(textDocument, processing) - - -class CoqErrorCodes(Enum): - InvalidFile = 0 - - -class CoqError(Exception): - def __init__(self, code: CoqErrorCodes, message: str): - self.code = code - self.message = message diff --git a/coqlspclient/coq_structs.py b/coqlspclient/coq_structs.py index 069c8e0..92a491f 100644 --- a/coqlspclient/coq_structs.py +++ b/coqlspclient/coq_structs.py @@ -2,12 +2,7 @@ from enum import Enum from typing import Dict, List from coqlspclient.coq_lsp_structs import RangedSpan, GoalAnswer - - -class Step(object): - def __init__(self, text: str, ast: RangedSpan): - self.text = text - self.ast = ast +from pylspclient.lsp_structs import Diagnostic class SegmentType(Enum): @@ -22,22 +17,46 @@ class TermType(Enum): DEFINITION = 3 NOTATION = 4 INDUCTIVE = 5 - RECORD = 6 - CLASS = 7 - INSTANCE = 8 - FIXPOINT = 9 - TACTIC = 10 - SCHEME = 11 - VARIANT = 12 - FACT = 13 - REMARK = 14 - COROLLARY = 15 - PROPOSITION = 16 - PROPERTY = 17 - OBLIGATION = 18 + COINDUCTIVE = 6 + RECORD = 7 + CLASS = 8 + INSTANCE = 9 + FIXPOINT = 10 + COFIXPOINT = 11 + SCHEME = 12 + VARIANT = 13 + FACT = 14 + REMARK = 15 + COROLLARY = 16 + PROPOSITION = 17 + PROPERTY = 18 + OBLIGATION = 19 + TACTIC = 20 + RELATION = 21 + SETOID = 22 + FUNCTION = 23 + DERIVE = 24 OTHER = 100 +class CoqErrorCodes(Enum): + InvalidFile = 0 + NotationNotFound = 1 + + +class CoqError(Exception): + def __init__(self, code: CoqErrorCodes, message: str): + self.code = code + self.message = message + + +class Step(object): + def __init__(self, text: str, ast: RangedSpan): + self.text = text + self.ast = ast + self.diagnostics: List[Diagnostic] = [] + + class Term: def __init__( self, @@ -112,16 +131,26 @@ def get_notation(self, notation: str, scope: str) -> Term: regex = f"{re.escape(notation_id)}".split("\\ ") for i, sub in enumerate(regex): if sub == "_": - regex[i] = "(.+)" + # We match the wildcard with the description from here: + # https://coq.inria.fr/distrib/current/refman/language/core/basic.html#grammar-token-ident + # Coq accepts more characters, but no one should need more than these... + chars = "A-Za-zÀ-ÖØ-öø-ˁˆ-ˑˠ-ˤˬˮͰ-ʹͶͷͺ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-Ֆՙա-և" + regex[i] = f"([{chars}][{chars}0-9_']*|_[{chars}0-9_']+)" else: # Handle '_' regex[i] = f"({sub}|('{sub}'))" regex = "^" + "\\ ".join(regex) + "$" # Search notations + unscoped = None for term in self.terms.keys(): if re.match(regex, term): return self.terms[term] + if re.match(regex, term.split(":")[0].strip()): + unscoped = term + # In case the stored id contains the scope but no scope is provided + if unscoped is not None: + return self.terms[unscoped] # Search Infix if re.match("^_ ([^ ]*) _$", notation): @@ -130,7 +159,10 @@ def get_notation(self, notation: str, scope: str) -> Term: if key in self.terms: return self.terms[key] - raise RuntimeError(f"Notation not found in context: {notation_id}") + raise CoqError( + CoqErrorCodes.NotationNotFound, + f"Notation not found in context: {notation_id}", + ) @staticmethod def get_notation_key(notation: str, scope: str): diff --git a/coqlspclient/proof_state.py b/coqlspclient/proof_state.py index 08930ba..7fbaeee 100644 --- a/coqlspclient/proof_state.py +++ b/coqlspclient/proof_state.py @@ -10,27 +10,23 @@ ErrorCodes, ) from coqlspclient.coq_structs import ( - ProofStep, - FileContext, + TermType, + CoqErrorCodes, + CoqError, Step, Term, + ProofStep, ProofTerm, - TermType, -) -from coqlspclient.coq_lsp_structs import ( - CoqError, - CoqErrorCodes, - Result, - Query, - GoalAnswer, + FileContext, ) +from coqlspclient.coq_lsp_structs import Result, Query, GoalAnswer from coqlspclient.coq_file import CoqFile from coqlspclient.coq_lsp_client import CoqLspClient from typing import List, Dict, Optional, Tuple class _AuxFile(object): - def __init__(self, file_path: Optional[str] = None, timeout: int = 2): + def __init__(self, file_path: Optional[str] = None, timeout: int = 30): self.__init_path(file_path) uri = f"file://{self.path}" self.coq_lsp_client = CoqLspClient(uri, timeout=timeout) @@ -69,7 +65,10 @@ def append(self, text): f.write(text) def __handle_exception(self, e): - if not (isinstance(e, ResponseError) and e.code == ErrorCodes.ServerQuit.value): + if not isinstance(e, ResponseError) or e.code not in [ + ErrorCodes.ServerQuit.value, + ErrorCodes.ServerTimeout.value, + ]: self.coq_lsp_client.shutdown() self.coq_lsp_client.exit() os.remove(self.path) @@ -102,19 +101,15 @@ def __get_queries(self, keyword): searches = {} lines = self.read().split("\n") for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]: - command = lines[ - diagnostic.range["start"]["line"] : diagnostic.range["end"]["line"] + 1 - ] + command = lines[diagnostic.range.start.line : diagnostic.range.end.line + 1] if len(command) == 1: command[0] = command[0][ - diagnostic.range["start"]["character"] : diagnostic.range["end"][ - "character" - ] + diagnostic.range.start.character : diagnostic.range.end.character + 1 ] else: - command[0] = command[0][diagnostic.range["start"]["character"] :] - command[-1] = command[-1][: diagnostic.range["end"]["character"] + 1] + command[0] = command[0][diagnostic.range.start.character :] + command[-1] = command[-1][: diagnostic.range.end.character + 1] command = "".join(command).strip() if command.startswith(keyword): @@ -133,7 +128,7 @@ def get_diagnostics(self, keyword, search, line): for query in self.__get_queries(keyword): if query.query == f"{search}": for result in query.results: - if result.range["start"]["line"] == line: + if result.range.start.line == line: return result.message break return None @@ -162,7 +157,7 @@ def get_context(file_path: str, timeout: int): for i, library in enumerate(libraries): v_file = aux_file.get_diagnostics( "Locate Library", library, last_line + i + 1 - ).split("\n")[-1][:-1] + ).split()[-1][:-1] coq_file = CoqFile(v_file, library=library, timeout=timeout) coq_file.run() @@ -231,39 +226,44 @@ def __locate(self, search, line): fun = lambda x: x.endswith("(default interpretation)") return nots[0][:-25] if fun(nots[0]) else nots[0] - def __step_context(self, step=None): - def traverse_ast(el): - if isinstance(el, dict): - return [x for v in el.values() for x in traverse_ast(v)] - elif isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": - id = ".".join([l[1] for l in el[1][1][::-1]] + [el[2][1]]) - term = self.__get_term(id) - return [] if term is None else [(lambda x: x, term)] - elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation": - line = len(self.__aux_file.read().split("\n")) - self.__aux_file.append(f'\nLocate "{el[2][1]}".') - - def __search_notation(call): - notation_name = call[0] - scope = "" - notation = call[1](*call[2:]) - if notation == "Unknown notation": - return None - if notation.split(":")[-1].endswith("_scope"): - scope = notation.split(":")[-1].strip() - return self.context.get_notation(notation_name, scope) - - return [ - (__search_notation, (el[2][1], self.__locate, el[2][1], line)) - ] + traverse_ast(el[1:]) - elif isinstance(el, list): - return [x for v in el for x in traverse_ast(v)] - - return [] - - if step is None: - step = self.__current_step.ast - return traverse_ast(step.span) + def __step_context(self): + def traverse_expr(expr): + stack, res = expr[:0:-1], [] + while len(stack) > 0: + el = stack.pop() + if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": + term = self.__get_term(CoqFile.get_id(el)) + if term is not None: + res.append((lambda x: x, term)) + elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation": + line = len(self.__aux_file.read().split("\n")) + self.__aux_file.append(f'\nLocate "{el[2][1]}".') + + def __search_notation(call): + notation_name = call[0] + scope = "" + notation = call[1](*call[2:]) + if notation == "Unknown notation": + return None + if notation.split(":")[-1].endswith("_scope"): + scope = notation.split(":")[-1].strip() + return self.context.get_notation(notation_name, scope) + + res.append( + (__search_notation, (el[2][1], self.__locate, el[2][1], line)) + ) + stack.append(el[1:]) + elif isinstance(el, list): + for v in reversed(el): + if isinstance(v, (dict, list)): + stack.append(v) + elif isinstance(el, dict): + for v in reversed(el.values()): + if isinstance(v, (dict, list)): + stack.append(v) + return res + + return traverse_expr(CoqFile.expr(self.__current_step.ast)) def __get_last_term(self): terms = self.coq_file.terms @@ -271,32 +271,23 @@ def __get_last_term(self): return None last_term = terms[0] for term in terms[1:]: - if (term.ast.range.end.line > last_term.ast.range.end.line) or ( - term.ast.range.end.line == last_term.ast.range.end.line - and term.ast.range.end.character > last_term.ast.range.end.character - ): + if last_term.ast.range.end < term.ast.range.end: last_term = term return last_term def __get_program_context(self): - def traverse_ast(el, keep_id=False): - if ( - isinstance(el, list) - and len(el) == 2 - and ( - (el[0] == "Id" and keep_id) - or (el[0] == "ExtraArg" and el[1] == "identref") - ) - ): - return el[1] - elif isinstance(el, list): - for x in el: - id = traverse_ast(x, keep_id=keep_id) - if id == "identref": - keep_id = True - elif id is not None: - return id - return "identref" if keep_id else None + def traverse_expr(expr): + stack = expr[:0:-1] + while len(stack) > 0: + el = stack.pop() + if isinstance(el, list): + ident = CoqFile.get_ident(el) + if ident is not None: + return ident + + for v in reversed(el): + if isinstance(v, list): + stack.append(v) return None # Tags: @@ -308,7 +299,7 @@ def traverse_ast(el, keep_id=False): # 5 - Next Obligation tag = CoqFile.expr(self.__current_step.ast)[1][1] if tag in [0, 1, 4]: - id = traverse_ast(CoqFile.expr(self.__current_step.ast)) + id = traverse_expr(CoqFile.expr(self.__current_step.ast)) # This works because the obligation must be in the # same module as the program id = ".".join(self.coq_file.curr_module + [id]) diff --git a/pylspclient/lsp_endpoint.py b/pylspclient/lsp_endpoint.py index 120018c..19dab99 100644 --- a/pylspclient/lsp_endpoint.py +++ b/pylspclient/lsp_endpoint.py @@ -2,6 +2,7 @@ import threading import logging from pylspclient import lsp_structs +from typing import Dict, List class LspEndpoint(threading.Thread): @@ -17,7 +18,7 @@ def __init__( self.next_id = 0 self.timeout = timeout self.shutdown_flag = False - self.diagnostics = {} + self.diagnostics: Dict[str, List[lsp_structs.Diagnostic]] = {} def handle_result(self, rpc_id, result, error): self.response_dict[rpc_id] = (result, error) diff --git a/pylspclient/lsp_structs.py b/pylspclient/lsp_structs.py index d6ae4ea..c608235 100644 --- a/pylspclient/lsp_structs.py +++ b/pylspclient/lsp_structs.py @@ -26,6 +26,34 @@ def __init__(self, line, character, offset=0): self.character = character self.offset = offset + def __repr__(self) -> str: + return str( + {"line": self.line, "character": self.character, "offset": self.offset} + ) + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, Position) and (self.line, self.character) == ( + __value.line, + __value.character, + ) + + def __gt__(self, __value: object) -> bool: + if not isinstance(__value, Position): + raise TypeError(f"Invalid type for comparison: {type(__value).__name__}") + return (self.line, self.character) > (__value.line, __value.character) + + def __lt__(self, __value: object) -> bool: + return not self.__eq__(__value) and not self.__gt__(__value) + + def __ne__(self, __value: object) -> bool: + return not self.__eq__(__value) + + def __ge__(self, __value: object) -> bool: + return not self.__lt__(__value) + + def __le__(self, __value: object) -> bool: + return not self.__gt__(__value) + class Range(object): def __init__(self, start, end): @@ -38,6 +66,9 @@ def __init__(self, start, end): self.start = to_type(start, Position) self.end = to_type(end, Position) + def __repr__(self) -> str: + return str({"start": repr(self.start), "end": repr(self.end)}) + class Location(object): """ @@ -82,7 +113,18 @@ def __init__( class Diagnostic(object): - def __init__(self, range, severity, code, source, message, relatedInformation): + def __init__( + self, + range, + message, + severity=None, + code=None, + codeDescription=None, + source=None, + tags=None, + relatedInformation=None, + data=None, + ): """ Constructs a new Diagnostic instance. :param Range range: The range at which the message applies.Resource file. @@ -95,7 +137,7 @@ def __init__(self, range, severity, code, source, message, relatedInformation): :param list relatedInformation: An array of related diagnostic information, e.g. when symbol-names within a scope collide all definitions can be marked via this property. """ - self.range = range + self.range: Range = Range(**range) self.severity = severity self.code = code self.source = source @@ -532,30 +574,6 @@ def __init__( self.score = score -class Diagnostic(object): - def __init__( - self, - range, - message, - severity=None, - code=None, - codeDescription=None, - source=None, - tags=None, - relatedInformation=None, - data=None, - ): - self.range = range - self.message = message - self.severity = severity - self.code = code - self.codeDescription = codeDescription - self.source = source - self.tags = tags - self.relatedInformation = relatedInformation - self.data = data - - class CompletionItemKind(enum.Enum): Text = 1 Method = 2 @@ -609,6 +627,7 @@ class ErrorCodes(enum.Enum): InternalError = -32603 serverErrorStart = -32099 serverErrorEnd = -32000 + ServerTimeout = -32004 ServerQuit = -32003 ServerNotInitialized = -32002 UnknownErrorCode = -32001 diff --git a/tests/resources/test_bullets.v b/tests/resources/test_bullets.v index 1681031..929d850 100644 --- a/tests/resources/test_bullets.v +++ b/tests/resources/test_bullets.v @@ -1,6 +1,6 @@ Theorem bullets: forall x y: nat, x = x /\ y = y. Proof. intros x y. split. - - reflexivity. + -reflexivity. - reflexivity. Qed. diff --git a/tests/resources/test_derive.v b/tests/resources/test_derive.v new file mode 100644 index 0000000..30e4a48 --- /dev/null +++ b/tests/resources/test_derive.v @@ -0,0 +1,14 @@ +Require Import Coq.derive.Derive. + +Derive incr +SuchThat (forall n, incr n = plus 1 n) +As incr_correct. +Proof. intros n. simpl. subst incr. reflexivity. Qed. + +Inductive Le : nat -> nat -> Set := +| LeO : forall n:nat, Le 0 n +| LeS : forall n m:nat, Le n m -> Le (S n) (S m). +Derive Inversion leminv1 with (forall n m:nat, Le (S n) m) Sort Prop. +Derive Inversion_clear leminv2 with (forall n m:nat, Le (S n) m) Sort Prop. +Derive Dependent Inversion leminv3 with (forall n m:nat, Le (S n) m) Sort Prop. +Derive Dependent Inversion_clear leminv4 with (forall n m:nat, Le (S n) m) Sort Prop. \ No newline at end of file diff --git a/tests/resources/test_goal.v b/tests/resources/test_goal.v new file mode 100644 index 0000000..9f3ade8 --- /dev/null +++ b/tests/resources/test_goal.v @@ -0,0 +1,21 @@ +(* http://d.hatena.ne.jp/hzkr/20100902 *) + +Definition ignored : forall P Q: Prop, (P -> Q) -> P -> Q. +Proof. + intros. (* 全部剥がす *) + apply H. + exact H0. (* apply H0 でも同じ *) +Save opaque. + + +Goal forall P Q: Prop, (P -> Q) -> P -> Q. +Proof. + intros p q f. (* 名前の数だけ剥がす *) + assumption. (* exact f でも同じ *) +Qed. + + +Goal forall P Q: Prop, (P -> Q) -> P -> Q. +Proof. + exact (fun p q f x => f x). +Defined transparent. \ No newline at end of file diff --git a/tests/resources/test_list_notation.v b/tests/resources/test_list_notation.v new file mode 100644 index 0000000..5730ba4 --- /dev/null +++ b/tests/resources/test_list_notation.v @@ -0,0 +1,5 @@ +Require Import List. +Import List.ListNotations. + +Goal [1] ++ [2] = [1; 2]. +Proof. reflexivity. Qed. \ No newline at end of file diff --git a/tests/resources/test_where_notation.v b/tests/resources/test_where_notation.v index 141ea9a..6238dcd 100644 --- a/tests/resources/test_where_notation.v +++ b/tests/resources/test_where_notation.v @@ -1,4 +1,6 @@ Reserved Notation "A & B" (at level 80). +Reserved Notation "'ONE'" (at level 80). +Reserved Notation "x 🀄 y" (at level 80). Fixpoint plus_test (n m : nat) {struct n} : nat := match n with @@ -8,4 +10,10 @@ end where "n + m" := (plus n m) : test_scope and "n - m" := (minus n m). Inductive and' (A B : Prop) : Prop := conj' : A -> B -> A & B -where "A & B" := (and' A B). \ No newline at end of file +where "A & B" := (and' A B). + +Fixpoint incr (n : nat) : nat := n + ONE +where "'ONE'" := 1. + +Fixpoint unicode x y := x 🀄 y +where "x 🀄 y" := (plus_test x y). \ No newline at end of file diff --git a/tests/test_coq_file.py b/tests/test_coq_file.py index bd9b943..a3eb435 100644 --- a/tests/test_coq_file.py +++ b/tests/test_coq_file.py @@ -36,6 +36,10 @@ def test_where_notation(setup, teardown): assert coq_file.context.terms["n - m"].text == 'Notation "n - m" := (minus n m)' assert "A & B" in coq_file.context.terms assert coq_file.context.terms["A & B"].text == 'Notation "A & B" := (and\' A B)' + assert "'ONE'" in coq_file.context.terms + assert coq_file.context.terms["'ONE'"].text == "Notation \"'ONE'\" := 1" + assert "x 🀄 y" in coq_file.context.terms + assert coq_file.context.terms["x 🀄 y"].text == 'Notation "x 🀄 y" := (plus_test x y)' @pytest.mark.parametrize("setup", ["test_get_notation.v"], indirect=True) @@ -58,11 +62,25 @@ def test_get_notation(setup, teardown): @pytest.mark.parametrize("setup", ["test_invalid_1.v"], indirect=True) def test_is_invalid_1(setup, teardown): assert not coq_file.is_valid + steps = coq_file.run() + assert len(steps[11].diagnostics) == 1 + assert ( + steps[11].diagnostics[0].message + == 'Found no subterm matching "0 + ?M152" in the current goal.' + ) + assert steps[11].diagnostics[0].severity == 1 @pytest.mark.parametrize("setup", ["test_invalid_2.v"], indirect=True) def test_is_invalid_2(setup, teardown): assert not coq_file.is_valid + steps = coq_file.run() + assert len(steps[15].diagnostics) == 1 + assert ( + steps[15].diagnostics[0].message + == "Syntax error: '.' expected after [command] (in [vernac_aux])." + ) + assert steps[15].diagnostics[0].severity == 1 @pytest.mark.parametrize("setup", ["test_module_type.v"], indirect=True) @@ -71,3 +89,27 @@ def test_module_type(setup, teardown): # We ignore terms inside a Module Type since they can't be used outside # and should be overriden. assert len(coq_file.context.terms) == 1 + + +@pytest.mark.parametrize("setup", ["test_derive.v"], indirect=True) +def test_derive(setup, teardown): + coq_file.run() + for key in ["incr", "incr_correct"]: + assert key in coq_file.context.terms + assert ( + coq_file.context.terms[key].text + == "Derive incr SuchThat (forall n, incr n = plus 1 n) As incr_correct." + ) + keywords = [ + "Inversion", + "Inversion_clear", + "Dependent Inversion", + "Dependent Inversion_clear", + ] + for i in range(4): + key = f"leminv{i + 1}" + assert key in coq_file.context.terms + assert ( + coq_file.context.terms[key].text + == f"Derive {keywords[i]} {key} with (forall n m:nat, Le (S n) m) Sort Prop." + ) diff --git a/tests/test_proof_state.py b/tests/test_proof_state.py index a7104ca..c964a38 100644 --- a/tests/test_proof_state.py +++ b/tests/test_proof_state.py @@ -3,7 +3,7 @@ import pytest from typing import List, Tuple from coqlspclient.coq_lsp_structs import * -from coqlspclient.coq_structs import TermType, Term +from coqlspclient.coq_structs import TermType, Term, CoqError, CoqErrorCodes from coqlspclient.proof_state import ProofState, CoqFile versionId: VersionedTextDocumentIdentifier = None @@ -94,7 +94,7 @@ def test_get_proofs(setup, teardown): ] statement_context = [ ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), ] @@ -182,7 +182,7 @@ def test_get_proofs(setup, teardown): TermType.NOTATION, [], ), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), ('Notation "n * m" := (mul n m) : nat_scope', TermType.NOTATION, []), ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), @@ -254,7 +254,7 @@ def test_get_proofs(setup, teardown): ] statement_context = [ ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), ] @@ -345,7 +345,7 @@ def test_get_proofs(setup, teardown): TermType.NOTATION, [], ), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ('Notation "n * m" := (mul n m) : nat_scope', TermType.NOTATION, []), ("Inductive nat : Set := | O : nat | S : nat -> nat.", TermType.INDUCTIVE, []), ('Notation "n + m" := (add n m) : nat_scope', TermType.NOTATION, []), @@ -404,13 +404,38 @@ def test_exists_notation(setup, teardown): ) +@pytest.mark.parametrize("setup", [("test_list_notation.v", None)], indirect=True) +def test_list_notation(setup, teardown): + assert len(state.proofs) == 1 + context = [ + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), + ( + 'Infix "++" := app (right associativity, at level 60) : list_scope.', + TermType.NOTATION, + [], + ), + ( + 'Notation "[ x ]" := (cons x nil) : list_scope.', + TermType.NOTATION, + ["ListNotations"], + ), + ( + "Notation \"[ x ; y ; .. ; z ]\" := (cons x (cons y .. (cons z nil) ..)) (format \"[ '[' x ; '/' y ; '/' .. ; '/' z ']' ]\") : list_scope.", + TermType.NOTATION, + ["ListNotations"], + ), + ] + compare_context(context, state.proofs[0].context) + + @pytest.mark.parametrize("setup", [("test_unknown_notation.v", None)], indirect=True) def test_unknown_notation(setup, teardown): """Checks if it is able to handle the notation { _ } that is unknown for the Locate command because it is a default notation. """ - with pytest.raises(RuntimeError): + with pytest.raises(CoqError) as e_info: assert state.context.get_notation("{ _ }", "") + assert e_info.value.code == CoqErrorCodes.NotationNotFound @pytest.mark.parametrize("setup", [("test_nested_proofs.v", None)], indirect=True) @@ -467,11 +492,11 @@ def test_bullets(setup, teardown): proofs = state.proofs assert len(proofs) == 1 steps = [ - "\n intros x y. ", + "\n intros x y.", " split.", - "\n - ", - " reflexivity.", - "\n - ", + "\n -", + "reflexivity.", + "\n -", " reflexivity.", ] assert len(proofs[0].steps) == 6 @@ -494,7 +519,7 @@ def test_obligation(setup, teardown): TermType.NOTATION, [], ), - ('Notation "x = y :> A" := (@eq A x y) : type_scope', TermType.NOTATION, []), + ('Notation "x = y" := (eq x y) : type_scope.', TermType.NOTATION, []), ] programs = [ ("id1", "S (pred n)"), @@ -574,3 +599,25 @@ def test_type_class(setup, teardown): ("Inductive unit : Set := tt : unit.", TermType.INDUCTIVE, []), ] compare_context(context, state.proofs[1].context) + + +@pytest.mark.parametrize("setup", [("test_goal.v", None)], indirect=True) +def test_goal(setup, teardown): + assert len(state.proofs) == 3 + goals = [ + "Definition ignored : forall P Q: Prop, (P -> Q) -> P -> Q.", + "Goal forall P Q: Prop, (P -> Q) -> P -> Q.", + "Goal forall P Q: Prop, (P -> Q) -> P -> Q.", + ] + for i, proof in enumerate(state.proofs): + assert proof.text == goals[i] + compare_context( + [ + ( + 'Notation "A -> B" := (forall (_ : A), B) : type_scope.', + TermType.NOTATION, + [], + ) + ], + proof.context, + )