diff --git a/coqpyt/coq/context.py b/coqpyt/coq/context.py index 2cbc0dc..4dd8aff 100644 --- a/coqpyt/coq/context.py +++ b/coqpyt/coq/context.py @@ -24,15 +24,24 @@ def __init__( def __init_coq_version(self, coqtop): output = subprocess.check_output(f"{coqtop} -v", shell=True) coq_version = output.decode("utf-8").split("\n")[0].split()[-1] - outdated = version.parse(coq_version) < version.parse("8.18") - # For version 8.18, we ignore the tags [VernacSynterp] and [VernacSynPure] + # For versions 8.18+, we ignore the tags [VernacSynterp] and [VernacSynPure] # and use the "ntn_decl" prefix when handling where notations - # For older versions, we only tested 8.17, so we provide no claims about - # versions prior to that. + post17 = version.parse(coq_version) > version.parse("8.17") + self.__expr = lambda e: e[1] if post17 else e + self.__where_notation_key = "ntn_decl" if post17 else "decl_ntn" + + # For versions 8.19+, VernacExtend has a dictionary instead of a list in the + # AST, so we use "ext_plugin","ext_entry" and "ext_index" instead of indices + post18 = version.parse(coq_version) > version.parse("8.18") + self.__ext_plugin = lambda e: e["ext_plugin"] if post18 else None + self.__ext_entry = lambda e: e["ext_entry"] if post18 else e[0] + # FIXME: This should be made private once [__get_program_context] is extracted + # from ProofFile to here. + self.ext_index = lambda e: e["ext_index"] if post18 else e[1] - self.__expr = lambda e: e if outdated else e[1] - self.__where_notation_key = "decl_ntn" if outdated else "ntn_decl" + # We only tested versions 8.17/8.18/8.19, so we provide no claims about + # versions prior to that. def __init_context(self, terms: Optional[Dict[str, Term]] = None): # NOTE: We use a stack for each term because of the following case: @@ -50,7 +59,7 @@ def __repr__(self) -> str: return res def __add_terms(self, step: Step, expr: List): - term_type = FileContext.__term_type(expr) + term_type = self.__term_type(expr) text = step.short_text # FIXME: Section-local terms are ignored. We do this to avoid @@ -73,7 +82,7 @@ def __add_terms(self, step: Step, expr: List): ) return - if expr[0] == "VernacExtend" and expr[1][0] == "VernacTacticNotation": + if self.__is_extend(expr, "VernacTacticNotation"): # FIXME: Handle this case return elif expr[0] == "VernacNotation": @@ -100,7 +109,7 @@ def __add_terms(self, step: Step, expr: List): elif term_type == TermType.DERIVE: name = FileContext.get_ident(expr[2][0]) self.__add_term(name, step, term_type) - if expr[1][0] == "Derive": + if self.__ext_entry(expr[1]) == "Derive": prop = FileContext.get_ident(expr[2][2]) self.__add_term(prop, step, term_type) elif term_type == TermType.OBLIGATION: @@ -198,6 +207,56 @@ def __handle_where_notations(self, step: Step, expr: List, term_type: TermType): ) self.__add_term(name, step, TermType.NOTATION) + def __is_extend( + self, expr: List, entry: str | Tuple[str], exact: bool = True + ) -> bool: + if expr[0] != "VernacExtend": + return False + if exact: + return self.__ext_entry(expr[1]) == entry + return self.__ext_entry(expr[1]).startswith(entry) + + def __term_type(self, expr: List) -> TermType: + if expr[0] == "VernacStartTheoremProof": + return getattr(TermType, expr[1][0].upper()) + if expr[0] == "VernacDefinition": + return TermType.DEFINITION + if expr[0] in ["VernacNotation", "VernacSyntacticDefinition"]: + return TermType.NOTATION + if expr[0] == "VernacInductive" and expr[1][0] == "Class": + return TermType.CLASS + if expr[0] == "VernacInductive" and expr[1][0] in ["Record", "Structure"]: + return TermType.RECORD + if expr[0] == "VernacInductive" and expr[1][0] == "Variant": + return TermType.VARIANT + if expr[0] == "VernacInductive" and expr[1][0] == "CoInductive": + return TermType.COINDUCTIVE + if expr[0] == "VernacInductive": + return TermType.INDUCTIVE + if expr[0] == "VernacInstance": + return TermType.INSTANCE + if expr[0] == "VernacCoFixpoint": + return TermType.COFIXPOINT + if expr[0] == "VernacFixpoint": + return TermType.FIXPOINT + if expr[0] == "VernacScheme": + return TermType.SCHEME + if self.__is_extend(expr, "Obligations"): + return TermType.OBLIGATION + if self.__is_extend(expr, "VernacDeclareTacticDefinition"): + return TermType.TACTIC + if self.__is_extend(expr, "Function"): + return TermType.FUNCTION + if self.__is_extend(expr, "Derive", exact=False): + return TermType.DERIVE + if self.__is_extend(expr, "AddSetoid", exact=False): + return TermType.SETOID + if self.__is_extend( + expr, ("AddRelation", "AddParametricRelation"), exact=False + ): + return TermType.RELATION + return TermType.OTHER + @staticmethod def __get_names(expr: List) -> List[str]: inductive = expr[0] == "VernacInductive" @@ -268,48 +327,6 @@ def __get_v(el: List) -> Optional[str]: return el[1] return None - @staticmethod - def __term_type(expr: List) -> TermType: - if expr[0] == "VernacStartTheoremProof": - return getattr(TermType, expr[1][0].upper()) - if expr[0] == "VernacDefinition": - return TermType.DEFINITION - if expr[0] in ["VernacNotation", "VernacSyntacticDefinition"]: - return TermType.NOTATION - if expr[0] == "VernacExtend" and expr[1][0] == "Obligations": - return TermType.OBLIGATION - if expr[0] == "VernacInductive" and expr[1][0] == "Class": - return TermType.CLASS - if expr[0] == "VernacInductive" and expr[1][0] in ["Record", "Structure"]: - return TermType.RECORD - if expr[0] == "VernacInductive" and expr[1][0] == "Variant": - return TermType.VARIANT - if expr[0] == "VernacInductive" and expr[1][0] == "CoInductive": - return TermType.COINDUCTIVE - if expr[0] == "VernacInductive": - return TermType.INDUCTIVE - if expr[0] == "VernacInstance": - return TermType.INSTANCE - if expr[0] == "VernacCoFixpoint": - return TermType.COFIXPOINT - if expr[0] == "VernacFixpoint": - return TermType.FIXPOINT - if expr[0] == "VernacScheme": - return TermType.SCHEME - if expr[0] == "VernacExtend" and expr[1][0].startswith("Derive"): - return TermType.DERIVE - if expr[0] == "VernacExtend" and expr[1][0].startswith("AddSetoid"): - return TermType.SETOID - if expr[0] == "VernacExtend" and expr[1][0].startswith( - ("AddRelation", "AddParametricRelation") - ): - return TermType.RELATION - if expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition": - return TermType.TACTIC - if expr[0] == "VernacExtend" and expr[1][0] == "Function": - return TermType.FUNCTION - return TermType.OTHER - @staticmethod def __get_notation_key(notation: str, scope: str) -> str: if scope != "" and scope is not None: @@ -376,7 +393,7 @@ def process_step(self, step: Step): if ( expr == [None] or expr[0] == "VernacProof" - or (expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve") + or self.__is_extend(expr, "VernacSolve") ): return @@ -460,7 +477,7 @@ def term_type(self, step: Step) -> TermType: Returns: List: The term type of the step. """ - return FileContext.__term_type(self.expr(step)) + return self.__term_type(self.expr(step)) def update(self, context: Union["FileContext", Dict[str, Term]] = {}): """Updates the context with new terms. diff --git a/coqpyt/coq/proof_file.py b/coqpyt/coq/proof_file.py index 2851abf..fd92517 100644 --- a/coqpyt/coq/proof_file.py +++ b/coqpyt/coq/proof_file.py @@ -354,7 +354,7 @@ def __get_program_context(self) -> Tuple[Term, List[Term]]: # 3 - Obligation N # 4 - Next Obligation of id # 5 - Next Obligation - tag = expr[1][1] + tag = self.context.ext_index(expr[1]) if tag in [0, 1, 4]: stack = expr[:0:-1] while len(stack) > 0: diff --git a/coqpyt/tests/proof_file/expected/imports.yml b/coqpyt/tests/proof_file/expected/imports.yml index 5348415..f33be24 100644 --- a/coqpyt/tests/proof_file/expected/imports.yml +++ b/coqpyt/tests/proof_file/expected/imports.yml @@ -167,8 +167,12 @@ proofs: line: 21 character: 0 context: - - text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." - type: NOTATION + - "8.19.0": + text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 10, x binder, y binder, P at level 200, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." + type: NOTATION + default: + text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." + type: NOTATION - text: 'Notation "x = y" := (eq x y) : type_scope.' type: NOTATION - text: 'Fixpoint add n m := match n with | 0 => m | S p => S (p + m) end where "n + m" := (add n m) : nat_scope.' diff --git a/coqpyt/tests/proof_file/expected/valid_file.yml b/coqpyt/tests/proof_file/expected/valid_file.yml index 450d300..fee90ee 100644 --- a/coqpyt/tests/proof_file/expected/valid_file.yml +++ b/coqpyt/tests/proof_file/expected/valid_file.yml @@ -198,8 +198,12 @@ proofs: line: 26 character: 10 context: - - text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." - type: NOTATION + - "8.19.0": + text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 10, x binder, y binder, P at level 200, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." + type: NOTATION + default: + text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." + type: NOTATION - text: 'Notation "x = y" := (eq x y) : type_scope.' type: NOTATION - text: 'Fixpoint add n m := match n with | 0 => m | S p => S (p + m) end where "n + m" := (add n m) : nat_scope.' @@ -360,8 +364,12 @@ proofs: line: 51 character: 4 context: - - text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." - type: NOTATION + - "8.19.0": + text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 10, x binder, y binder, P at level 200, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." + type: NOTATION + default: + text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope." + type: NOTATION - text: 'Notation "x = y" := (eq x y) : type_scope.' type: NOTATION - text: 'Fixpoint mul n m := match n with | 0 => 0 | S p => m + p * m end where "n * m" := (mul n m) : nat_scope.' diff --git a/coqpyt/tests/proof_file/test_proof_file.py b/coqpyt/tests/proof_file/test_proof_file.py index 04ca957..e625a5b 100644 --- a/coqpyt/tests/proof_file/test_proof_file.py +++ b/coqpyt/tests/proof_file/test_proof_file.py @@ -13,7 +13,11 @@ def setup_method(self, method): def test_valid_file(self): proofs = self.proof_file.proofs - check_proofs("tests/proof_file/expected/valid_file.yml", proofs) + check_proofs( + "tests/proof_file/expected/valid_file.yml", + proofs, + coq_version=self.coq_version, + ) def test_exec(self): # Rollback whole file @@ -40,7 +44,11 @@ def setup_method(self, method): self.setup("test_imports/test_import.v", workspace="test_imports/") def test_imports(self): - check_proofs("tests/proof_file/expected/imports.yml", self.proof_file.proofs) + check_proofs( + "tests/proof_file/expected/imports.yml", + self.proof_file.proofs, + coq_version=self.coq_version, + ) def test_exec(self): # Rollback whole file diff --git a/coqpyt/tests/proof_file/utility.py b/coqpyt/tests/proof_file/utility.py index 0160c22..fd44db9 100644 --- a/coqpyt/tests/proof_file/utility.py +++ b/coqpyt/tests/proof_file/utility.py @@ -41,6 +41,9 @@ def setup(self, file_path, workspace=None): self.proof_file.run() self.versionId = VersionedTextDocumentIdentifier(uri, 1) + output = subprocess.check_output(f"coqtop -v", shell=True) + self.coq_version = output.decode("utf-8").split("\n")[0].split()[-1] + @abstractmethod def setup_method(self, method): pass @@ -144,8 +147,10 @@ def check_proof(test_proof: Dict, proof: ProofTerm): check_step(step, proof.steps[j]) -def check_proofs(yaml_file: str, proofs: List[ProofTerm]): - test_proofs = get_test_proofs(yaml_file) +def check_proofs( + yaml_file: str, proofs: List[ProofTerm], coq_version: Optional[str] = None +): + test_proofs = get_test_proofs(yaml_file, coq_version) assert len(proofs) == len(test_proofs["proofs"]) for i, test_proof in enumerate(test_proofs["proofs"]): check_proof(test_proof, proofs[i]) @@ -170,12 +175,23 @@ def add_step_defaults(step): step["context"] = [] -def get_test_proofs(yaml_file: str): +def get_test_proofs(yaml_file: str, coq_version: Optional[str] = None): with open(yaml_file, "r") as f: test_proofs = yaml.safe_load(f) for test_proof in test_proofs["proofs"]: if "context" not in test_proof: test_proof["context"] = [] + if coq_version is not None: + test_proof["context"] = list( + map( + lambda x: x[coq_version] + if coq_version in x + else x["default"] + if "default" in x + else x, + test_proof["context"], + ) + ) for step in test_proof["steps"]: add_step_defaults(step) return test_proofs