Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Coq 8.19 #40

Merged
merged 6 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 70 additions & 53 deletions coqpyt/coq/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,24 @@ def __init__(
def __init_coq_version(self, coqtop):
output = subprocess.check_output(f"{coqtop} -v", shell=True)
coq_version = output.decode("utf-8").split("\n")[0].split()[-1]
outdated = version.parse(coq_version) < version.parse("8.18")

# For version 8.18, we ignore the tags [VernacSynterp] and [VernacSynPure]
# For versions 8.18+, we ignore the tags [VernacSynterp] and [VernacSynPure]
# and use the "ntn_decl" prefix when handling where notations
# For older versions, we only tested 8.17, so we provide no claims about
# versions prior to that.
post17 = version.parse(coq_version) > version.parse("8.17")
self.__expr = lambda e: e[1] if post17 else e
self.__where_notation_key = "ntn_decl" if post17 else "decl_ntn"

# For versions 8.19+, VernacExtend has a dictionary instead of a list in the
# AST, so we use "ext_plugin","ext_entry" and "ext_index" instead of indices
post18 = version.parse(coq_version) > version.parse("8.18")
self.__ext_plugin = lambda e: e["ext_plugin"] if post18 else None
self.__ext_entry = lambda e: e["ext_entry"] if post18 else e[0]
# FIXME: This should be made private once [__get_program_context] is extracted
# from ProofFile to here.
self.ext_index = lambda e: e["ext_index"] if post18 else e[1]

self.__expr = lambda e: e if outdated else e[1]
self.__where_notation_key = "decl_ntn" if outdated else "ntn_decl"
# We only tested versions 8.17/8.18/8.19, so we provide no claims about
# versions prior to that.

def __init_context(self, terms: Optional[Dict[str, Term]] = None):
# NOTE: We use a stack for each term because of the following case:
Expand All @@ -50,7 +59,7 @@ def __repr__(self) -> str:
return res

def __add_terms(self, step: Step, expr: List):
term_type = FileContext.__term_type(expr)
term_type = self.__term_type(expr)
text = step.short_text

# FIXME: Section-local terms are ignored. We do this to avoid
Expand All @@ -73,7 +82,7 @@ def __add_terms(self, step: Step, expr: List):
)
return

if expr[0] == "VernacExtend" and expr[1][0] == "VernacTacticNotation":
if self.__is_extend(expr, "VernacTacticNotation"):
# FIXME: Handle this case
return
elif expr[0] == "VernacNotation":
Expand All @@ -100,7 +109,7 @@ def __add_terms(self, step: Step, expr: List):
elif term_type == TermType.DERIVE:
name = FileContext.get_ident(expr[2][0])
self.__add_term(name, step, term_type)
if expr[1][0] == "Derive":
if self.__ext_entry(expr[1]) == "Derive":
prop = FileContext.get_ident(expr[2][2])
self.__add_term(prop, step, term_type)
elif term_type == TermType.OBLIGATION:
Expand Down Expand Up @@ -198,6 +207,56 @@ def __handle_where_notations(self, step: Step, expr: List, term_type: TermType):
)
self.__add_term(name, step, TermType.NOTATION)

def __is_extend(
self, expr: List, entry: str | Tuple[str], exact: bool = True
) -> bool:
if expr[0] != "VernacExtend":
return False
if exact:
return self.__ext_entry(expr[1]) == entry
return self.__ext_entry(expr[1]).startswith(entry)

def __term_type(self, expr: List) -> TermType:
if expr[0] == "VernacStartTheoremProof":
return getattr(TermType, expr[1][0].upper())
if expr[0] == "VernacDefinition":
return TermType.DEFINITION
if expr[0] in ["VernacNotation", "VernacSyntacticDefinition"]:
return TermType.NOTATION
if expr[0] == "VernacInductive" and expr[1][0] == "Class":
return TermType.CLASS
if expr[0] == "VernacInductive" and expr[1][0] in ["Record", "Structure"]:
return TermType.RECORD
if expr[0] == "VernacInductive" and expr[1][0] == "Variant":
return TermType.VARIANT
if expr[0] == "VernacInductive" and expr[1][0] == "CoInductive":
return TermType.COINDUCTIVE
if expr[0] == "VernacInductive":
return TermType.INDUCTIVE
if expr[0] == "VernacInstance":
return TermType.INSTANCE
if expr[0] == "VernacCoFixpoint":
return TermType.COFIXPOINT
if expr[0] == "VernacFixpoint":
return TermType.FIXPOINT
if expr[0] == "VernacScheme":
return TermType.SCHEME
if self.__is_extend(expr, "Obligations"):
return TermType.OBLIGATION
if self.__is_extend(expr, "VernacDeclareTacticDefinition"):
return TermType.TACTIC
if self.__is_extend(expr, "Function"):
return TermType.FUNCTION
if self.__is_extend(expr, "Derive", exact=False):
return TermType.DERIVE
if self.__is_extend(expr, "AddSetoid", exact=False):
return TermType.SETOID
if self.__is_extend(
expr, ("AddRelation", "AddParametricRelation"), exact=False
):
return TermType.RELATION
return TermType.OTHER

@staticmethod
def __get_names(expr: List) -> List[str]:
inductive = expr[0] == "VernacInductive"
Expand Down Expand Up @@ -268,48 +327,6 @@ def __get_v(el: List) -> Optional[str]:
return el[1]
return None

@staticmethod
def __term_type(expr: List) -> TermType:
if expr[0] == "VernacStartTheoremProof":
return getattr(TermType, expr[1][0].upper())
if expr[0] == "VernacDefinition":
return TermType.DEFINITION
if expr[0] in ["VernacNotation", "VernacSyntacticDefinition"]:
return TermType.NOTATION
if expr[0] == "VernacExtend" and expr[1][0] == "Obligations":
return TermType.OBLIGATION
if expr[0] == "VernacInductive" and expr[1][0] == "Class":
return TermType.CLASS
if expr[0] == "VernacInductive" and expr[1][0] in ["Record", "Structure"]:
return TermType.RECORD
if expr[0] == "VernacInductive" and expr[1][0] == "Variant":
return TermType.VARIANT
if expr[0] == "VernacInductive" and expr[1][0] == "CoInductive":
return TermType.COINDUCTIVE
if expr[0] == "VernacInductive":
return TermType.INDUCTIVE
if expr[0] == "VernacInstance":
return TermType.INSTANCE
if expr[0] == "VernacCoFixpoint":
return TermType.COFIXPOINT
if expr[0] == "VernacFixpoint":
return TermType.FIXPOINT
if expr[0] == "VernacScheme":
return TermType.SCHEME
if expr[0] == "VernacExtend" and expr[1][0].startswith("Derive"):
return TermType.DERIVE
if expr[0] == "VernacExtend" and expr[1][0].startswith("AddSetoid"):
return TermType.SETOID
if expr[0] == "VernacExtend" and expr[1][0].startswith(
("AddRelation", "AddParametricRelation")
):
return TermType.RELATION
if expr[0] == "VernacExtend" and expr[1][0] == "VernacDeclareTacticDefinition":
return TermType.TACTIC
if expr[0] == "VernacExtend" and expr[1][0] == "Function":
return TermType.FUNCTION
return TermType.OTHER

@staticmethod
def __get_notation_key(notation: str, scope: str) -> str:
if scope != "" and scope is not None:
Expand Down Expand Up @@ -376,7 +393,7 @@ def process_step(self, step: Step):
if (
expr == [None]
or expr[0] == "VernacProof"
or (expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve")
or self.__is_extend(expr, "VernacSolve")
):
return

Expand Down Expand Up @@ -460,7 +477,7 @@ def term_type(self, step: Step) -> TermType:
Returns:
List: The term type of the step.
"""
return FileContext.__term_type(self.expr(step))
return self.__term_type(self.expr(step))

def update(self, context: Union["FileContext", Dict[str, Term]] = {}):
"""Updates the context with new terms.
Expand Down
2 changes: 1 addition & 1 deletion coqpyt/coq/proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def __get_program_context(self) -> Tuple[Term, List[Term]]:
# 3 - Obligation N
# 4 - Next Obligation of id
# 5 - Next Obligation
tag = expr[1][1]
tag = self.context.ext_index(expr[1])
if tag in [0, 1, 4]:
stack = expr[:0:-1]
while len(stack) > 0:
Expand Down
8 changes: 6 additions & 2 deletions coqpyt/tests/proof_file/expected/imports.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,12 @@ proofs:
line: 21
character: 0
context:
- text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
- "8.19":
text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 10, x binder, y binder, P at level 200, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
default:
text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
- text: 'Notation "x = y" := (eq x y) : type_scope.'
type: NOTATION
- text: 'Fixpoint add n m := match n with | 0 => m | S p => S (p + m) end where "n + m" := (add n m) : nat_scope.'
Expand Down
16 changes: 12 additions & 4 deletions coqpyt/tests/proof_file/expected/valid_file.yml
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,12 @@ proofs:
line: 26
character: 10
context:
- text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
- "8.19":
text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 10, x binder, y binder, P at level 200, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
default:
text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
- text: 'Notation "x = y" := (eq x y) : type_scope.'
type: NOTATION
- text: 'Fixpoint add n m := match n with | 0 => m | S p => S (p + m) end where "n + m" := (add n m) : nat_scope.'
Expand Down Expand Up @@ -360,8 +364,12 @@ proofs:
line: 51
character: 4
context:
- text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
- "8.19":
text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 10, x binder, y binder, P at level 200, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
default:
text: "Notation \"∀ x .. y , P\" := (forall x, .. (forall y, P) ..) (at level 200, x binder, y binder, right associativity, format \"'[ ' '[ ' ∀ x .. y ']' , '/' P ']'\") : type_scope."
type: NOTATION
- text: 'Notation "x = y" := (eq x y) : type_scope.'
type: NOTATION
- text: 'Fixpoint mul n m := match n with | 0 => 0 | S p => m + p * m end where "n * m" := (mul n m) : nat_scope.'
Expand Down
12 changes: 10 additions & 2 deletions coqpyt/tests/proof_file/test_proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def setup_method(self, method):

def test_valid_file(self):
proofs = self.proof_file.proofs
check_proofs("tests/proof_file/expected/valid_file.yml", proofs)
check_proofs(
"tests/proof_file/expected/valid_file.yml",
proofs,
coq_version=self.coq_version,
)

def test_exec(self):
# Rollback whole file
Expand All @@ -40,7 +44,11 @@ def setup_method(self, method):
self.setup("test_imports/test_import.v", workspace="test_imports/")

def test_imports(self):
check_proofs("tests/proof_file/expected/imports.yml", self.proof_file.proofs)
check_proofs(
"tests/proof_file/expected/imports.yml",
self.proof_file.proofs,
coq_version=self.coq_version,
)

def test_exec(self):
# Rollback whole file
Expand Down
21 changes: 18 additions & 3 deletions coqpyt/tests/proof_file/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from abc import ABC, abstractmethod
from typing import Tuple, List, Dict, Union, Any
from packaging import version

from coqpyt.coq.proof_file import ProofFile, ProofStep, ProofTerm
from coqpyt.coq.structs import TermType, Term
Expand Down Expand Up @@ -41,6 +42,11 @@ def setup(self, file_path, workspace=None):
self.proof_file.run()
self.versionId = VersionedTextDocumentIdentifier(uri, 1)

output = subprocess.check_output(f"coqtop -v", shell=True)
coq_version = output.decode("utf-8").split("\n")[0].split()[-1]
default = version.parse(coq_version) < version.parse("8.19")
self.coq_version = "default" if default else "8.19"
Nfsaavedra marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def setup_method(self, method):
pass
Expand Down Expand Up @@ -144,8 +150,10 @@ def check_proof(test_proof: Dict, proof: ProofTerm):
check_step(step, proof.steps[j])


def check_proofs(yaml_file: str, proofs: List[ProofTerm]):
test_proofs = get_test_proofs(yaml_file)
def check_proofs(
yaml_file: str, proofs: List[ProofTerm], coq_version: Optional[str] = None
):
test_proofs = get_test_proofs(yaml_file, coq_version)
assert len(proofs) == len(test_proofs["proofs"])
for i, test_proof in enumerate(test_proofs["proofs"]):
check_proof(test_proof, proofs[i])
Expand All @@ -170,12 +178,19 @@ def add_step_defaults(step):
step["context"] = []


def get_test_proofs(yaml_file: str):
def get_test_proofs(yaml_file: str, coq_version: Optional[str] = None):
Nfsaavedra marked this conversation as resolved.
Show resolved Hide resolved
with open(yaml_file, "r") as f:
test_proofs = yaml.safe_load(f)
for test_proof in test_proofs["proofs"]:
if "context" not in test_proof:
test_proof["context"] = []
if coq_version is not None:
test_proof["context"] = list(
map(
lambda x: x if coq_version not in x else x[coq_version],
test_proof["context"],
)
)
for step in test_proof["steps"]:
add_step_defaults(step)
return test_proofs
Loading