From 55628953851f077f60dd5a7780376751f20fa309 Mon Sep 17 00:00:00 2001 From: pcarrott Date: Mon, 30 Oct 2023 00:10:10 +0000 Subject: [PATCH] Support versions 8.17 and 8.18 in CoqFile --- coq/coq_file.py | 32 +++++++++++++++++++++----------- coq/proof_file.py | 14 +++++++------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/coq/coq_file.py b/coq/coq_file.py index 02a870c..9c51cdf 100644 --- a/coq/coq_file.py +++ b/coq/coq_file.py @@ -2,6 +2,7 @@ import shutil import uuid import tempfile +import subprocess from lsp.lsp_structs import ( TextDocumentItem, TextDocumentIdentifier, @@ -9,12 +10,13 @@ TextDocumentContentChangeEvent, ) from lsp.lsp_structs import ResponseError, ErrorCodes, Diagnostic -from coq.lsp.coq_lsp_structs import Position, GoalAnswer, RangedSpan, Range +from coq.lsp.coq_lsp_structs import Position, GoalAnswer, RangedSpan from coq.coq_structs import Step, FileContext, Term, TermType, SegmentType from coq.lsp.coq_lsp_client import CoqLspClient from coq.coq_exceptions import * from coq.coq_changes import * -from typing import List, Optional, Callable +from typing import List, Optional +from packaging import version class CoqFile(object): @@ -70,6 +72,7 @@ def __init__( self.steps_taken: int = 0 self.__init_steps(text, ast) self.__validate() + self.__init_coq_version() self.curr_module: List[str] = [] self.curr_module_type: List[str] = [] self.curr_section: List[str] = [] @@ -161,6 +164,15 @@ def __validate(self): step.diagnostics.append(diagnostic) break + def __init_coq_version(self): + output = subprocess.check_output("coqtop -v", shell=True) + coq_version = output.decode("utf-8").split("\n")[0].split()[-1] + outdated = version.parse(coq_version) < version.parse("8.18") + + # We ignore the tags [VernacSynterp] and [VernacSynPure] + self.__expr = lambda e: e if outdated else e[1] + self.__where_notation_key = "decl_ntn" if outdated else "ntn_decl" + @property def curr_step(self): return self.steps[self.steps_taken] @@ -199,8 +211,7 @@ def get_v(el): return el[1] return None - @staticmethod - def expr(step: RangedSpan) -> Optional[List]: + def expr(self, step: RangedSpan) -> Optional[List]: if ( step.span is not None and isinstance(step.span, dict) @@ -208,8 +219,7 @@ def expr(step: RangedSpan) -> Optional[List]: and isinstance(step.span["v"], dict) and "expr" in step.span["v"] ): - # We ignore the tags [VernacSynterp] and [VernacSynPure] - return step.span["v"]["expr"][1] + return self.__expr(step.span["v"]["expr"]) return [None] @@ -324,7 +334,8 @@ def __handle_where_notations(self, expr: List, term_type: TermType): # handles when multiple notations are defined for span in spans: name = FileContext.get_notation_key( - span["ntn_decl_string"]["v"], span["ntn_decl_scope"] + span[f"{self.__where_notation_key}_string"]["v"], + span[f"{self.__where_notation_key}_scope"], ) self.__add_term(name, self.curr_step, TermType.NOTATION) @@ -383,7 +394,7 @@ def traverse_expr(expr): ]: if text.startswith(keyword): return - expr = CoqFile.expr(self.curr_step.ast) + expr = self.expr(self.curr_step.ast) if expr == [None]: return if expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve": @@ -688,9 +699,8 @@ def diagnostics(self) -> List[Diagnostic]: uri = f"file://{self.__path}" return self.coq_lsp_client.lsp_endpoint.diagnostics[uri] - @staticmethod - def get_term_type(ast: RangedSpan) -> TermType: - expr = CoqFile.expr(ast) + def get_term_type(self, ast: RangedSpan) -> TermType: + expr = self.expr(ast) if expr is not None: return CoqFile.__get_term_type(expr) return TermType.OTHER diff --git a/coq/proof_file.py b/coq/proof_file.py index 90bd09f..a6cc929 100644 --- a/coq/proof_file.py +++ b/coq/proof_file.py @@ -279,7 +279,7 @@ def __search_notation(call): stack.append(v) return res - return traverse_expr(CoqFile.expr(step.ast)) + return traverse_expr(self.expr(step.ast)) def __get_last_term(self): terms = self.terms @@ -313,9 +313,9 @@ def traverse_expr(expr): # 3 - Obligation N # 4 - Next Obligation of id # 5 - Next Obligation - tag = CoqFile.expr(self.prev_step.ast)[1][1] + tag = self.expr(self.prev_step.ast)[1][1] if tag in [0, 1, 4]: - id = traverse_expr(CoqFile.expr(self.prev_step.ast)) + id = traverse_expr(self.expr(self.prev_step.ast)) # This works because the obligation must be in the # same module as the program id = ".".join(self.curr_module + [id]) @@ -359,7 +359,7 @@ def __get_steps( self.__step() # Nested proofs - if CoqFile.get_term_type(self.prev_step.ast) != TermType.OTHER: + if self.get_term_type(self.prev_step.ast) != TermType.OTHER: self.__get_proof(proofs) # Pass Qed if it exists while not self.in_proof and not self.checked: @@ -377,7 +377,7 @@ def __get_steps( raise e if ( self.steps_taken < len(self.steps) - and CoqFile.expr(self.curr_step.ast)[0] == "VernacEndProof" + and self.expr(self.curr_step.ast)[0] == "VernacEndProof" ): steps.append((self.curr_step, goals, [])) @@ -385,9 +385,9 @@ def __get_steps( def __get_proof(self, proofs): term, statement_context = None, None - if CoqFile.get_term_type(self.prev_step.ast) == TermType.OBLIGATION: + if self.get_term_type(self.prev_step.ast) == TermType.OBLIGATION: term, statement_context = self.__get_program_context() - elif CoqFile.get_term_type(self.prev_step.ast) != TermType.OTHER: + elif self.get_term_type(self.prev_step.ast) != TermType.OTHER: term = self.__get_last_term() statement_context = self.__step_context(self.prev_step) # HACK: We ignore proofs inside a Module Type since they can't be used outside