Skip to content

Commit

Permalink
Support versions 8.17 and 8.18 in CoqFile
Browse files Browse the repository at this point in the history
  • Loading branch information
pcarrott committed Oct 30, 2023
1 parent 4d8775c commit 5562895
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
32 changes: 21 additions & 11 deletions coq/coq_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import shutil
import uuid
import tempfile
import subprocess
from lsp.lsp_structs import (
TextDocumentItem,
TextDocumentIdentifier,
VersionedTextDocumentIdentifier,
TextDocumentContentChangeEvent,
)
from lsp.lsp_structs import ResponseError, ErrorCodes, Diagnostic
from coq.lsp.coq_lsp_structs import Position, GoalAnswer, RangedSpan, Range
from coq.lsp.coq_lsp_structs import Position, GoalAnswer, RangedSpan
from coq.coq_structs import Step, FileContext, Term, TermType, SegmentType
from coq.lsp.coq_lsp_client import CoqLspClient
from coq.coq_exceptions import *
from coq.coq_changes import *
from typing import List, Optional, Callable
from typing import List, Optional
from packaging import version


class CoqFile(object):
Expand Down Expand Up @@ -70,6 +72,7 @@ def __init__(
self.steps_taken: int = 0
self.__init_steps(text, ast)
self.__validate()
self.__init_coq_version()
self.curr_module: List[str] = []
self.curr_module_type: List[str] = []
self.curr_section: List[str] = []
Expand Down Expand Up @@ -161,6 +164,15 @@ def __validate(self):
step.diagnostics.append(diagnostic)
break

def __init_coq_version(self):
output = subprocess.check_output("coqtop -v", shell=True)
coq_version = output.decode("utf-8").split("\n")[0].split()[-1]
outdated = version.parse(coq_version) < version.parse("8.18")

# We ignore the tags [VernacSynterp] and [VernacSynPure]
self.__expr = lambda e: e if outdated else e[1]
self.__where_notation_key = "decl_ntn" if outdated else "ntn_decl"

@property
def curr_step(self):
return self.steps[self.steps_taken]
Expand Down Expand Up @@ -199,17 +211,15 @@ def get_v(el):
return el[1]
return None

@staticmethod
def expr(step: RangedSpan) -> Optional[List]:
def expr(self, step: RangedSpan) -> Optional[List]:
if (
step.span is not None
and isinstance(step.span, dict)
and "v" in step.span
and isinstance(step.span["v"], dict)
and "expr" in step.span["v"]
):
# We ignore the tags [VernacSynterp] and [VernacSynPure]
return step.span["v"]["expr"][1]
return self.__expr(step.span["v"]["expr"])

return [None]

Expand Down Expand Up @@ -324,7 +334,8 @@ def __handle_where_notations(self, expr: List, term_type: TermType):
# handles when multiple notations are defined
for span in spans:
name = FileContext.get_notation_key(
span["ntn_decl_string"]["v"], span["ntn_decl_scope"]
span[f"{self.__where_notation_key}_string"]["v"],
span[f"{self.__where_notation_key}_scope"],
)
self.__add_term(name, self.curr_step, TermType.NOTATION)

Expand Down Expand Up @@ -383,7 +394,7 @@ def traverse_expr(expr):
]:
if text.startswith(keyword):
return
expr = CoqFile.expr(self.curr_step.ast)
expr = self.expr(self.curr_step.ast)
if expr == [None]:
return
if expr[0] == "VernacExtend" and expr[1][0] == "VernacSolve":
Expand Down Expand Up @@ -688,9 +699,8 @@ def diagnostics(self) -> List[Diagnostic]:
uri = f"file://{self.__path}"
return self.coq_lsp_client.lsp_endpoint.diagnostics[uri]

@staticmethod
def get_term_type(ast: RangedSpan) -> TermType:
expr = CoqFile.expr(ast)
def get_term_type(self, ast: RangedSpan) -> TermType:
expr = self.expr(ast)
if expr is not None:
return CoqFile.__get_term_type(expr)
return TermType.OTHER
Expand Down
14 changes: 7 additions & 7 deletions coq/proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __search_notation(call):
stack.append(v)
return res

return traverse_expr(CoqFile.expr(step.ast))
return traverse_expr(self.expr(step.ast))

def __get_last_term(self):
terms = self.terms
Expand Down Expand Up @@ -313,9 +313,9 @@ def traverse_expr(expr):
# 3 - Obligation N
# 4 - Next Obligation of id
# 5 - Next Obligation
tag = CoqFile.expr(self.prev_step.ast)[1][1]
tag = self.expr(self.prev_step.ast)[1][1]
if tag in [0, 1, 4]:
id = traverse_expr(CoqFile.expr(self.prev_step.ast))
id = traverse_expr(self.expr(self.prev_step.ast))
# This works because the obligation must be in the
# same module as the program
id = ".".join(self.curr_module + [id])
Expand Down Expand Up @@ -359,7 +359,7 @@ def __get_steps(

self.__step()
# Nested proofs
if CoqFile.get_term_type(self.prev_step.ast) != TermType.OTHER:
if self.get_term_type(self.prev_step.ast) != TermType.OTHER:
self.__get_proof(proofs)
# Pass Qed if it exists
while not self.in_proof and not self.checked:
Expand All @@ -377,17 +377,17 @@ def __get_steps(
raise e
if (
self.steps_taken < len(self.steps)
and CoqFile.expr(self.curr_step.ast)[0] == "VernacEndProof"
and self.expr(self.curr_step.ast)[0] == "VernacEndProof"
):
steps.append((self.curr_step, goals, []))

return steps

def __get_proof(self, proofs):
term, statement_context = None, None
if CoqFile.get_term_type(self.prev_step.ast) == TermType.OBLIGATION:
if self.get_term_type(self.prev_step.ast) == TermType.OBLIGATION:
term, statement_context = self.__get_program_context()
elif CoqFile.get_term_type(self.prev_step.ast) != TermType.OTHER:
elif self.get_term_type(self.prev_step.ast) != TermType.OTHER:
term = self.__get_last_term()
statement_context = self.__step_context(self.prev_step)
# HACK: We ignore proofs inside a Module Type since they can't be used outside
Expand Down

0 comments on commit 5562895

Please sign in to comment.