Skip to content

Commit

Permalink
handle custom entries and (partially) gen args for plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
pcarrott committed Oct 3, 2024
1 parent 226c0ff commit 441b1e6
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 39 deletions.
64 changes: 47 additions & 17 deletions coqpyt/coq/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ def __add_terms(self, step: Step, expr: List):
self.__anonymous_id += 1
self.__add_term(name, step, term_type)
elif term_type == TermType.DERIVE:
name = FileContext.get_ident(expr[2][0])
self.__add_term(name, step, term_type)
if self.__ext_entry(expr[1]) == "Derive":
prop = FileContext.get_ident(expr[2][2])
self.__add_term(prop, step, term_type)
for arg in expr[2]:
name = FileContext.get_ident(arg)
if name is not None:
self.__add_term(name, step, term_type)
elif term_type == TermType.OBLIGATION:
self.__last_terms[-1].append(
("", Term(step, term_type, self.__path, self.__segments.modules[:]))
Expand Down Expand Up @@ -241,12 +240,15 @@ def __term_type(self, expr: List) -> TermType:
return TermType.FIXPOINT
if expr[0] == "VernacScheme":
return TermType.SCHEME
# FIXME: These are plugins and should probably be handled differently
if self.__is_extend(expr, "Obligations"):
return TermType.OBLIGATION
if self.__is_extend(expr, "VernacDeclareTacticDefinition"):
return TermType.TACTIC
if self.__is_extend(expr, "Function"):
return TermType.FUNCTION
if self.__is_extend(expr, "Define_equations", exact=False):
return TermType.EQUATION
if self.__is_extend(expr, "Derive", exact=False):
return TermType.DERIVE
if self.__is_extend(expr, "AddSetoid", exact=False):
Expand Down Expand Up @@ -293,6 +295,16 @@ def __get_names(expr: List) -> List[str]:
stack.append(v)
return res

@staticmethod
def is_id(el) -> bool:
return isinstance(el, list) and (len(el) == 3 and el[0] == "Ser_Qualid")

@staticmethod
def is_notation(el) -> bool:
return isinstance(el, list) and (
len(el) == 4 and el[0] == "CNotation" and el[2][1] != ""
)

@staticmethod
def get_id(id: List) -> Optional[str]:
# FIXME: This should be made private once [__step_context] is extracted
Expand All @@ -305,18 +317,23 @@ def get_id(id: List) -> Optional[str]:

@staticmethod
def get_ident(el: List) -> Optional[str]:
# FIXME: This should be made private once [__get_program_context] is extracted
# from ProofFile to here.
if (
len(el) == 3
and el[0] == "GenArg"
and el[1][0] == "Rawwit"
and el[1][1][0] == "ExtraArg"
):
if el[1][1][1] == "identref":
return el[2][0][1][1]
elif el[1][1][1] == "ident":
return el[2][1]
# FIXME: This method should be made private once [__get_program_context]
# is extracted from ProofFile to here.
def handle_arg_type(args, ids):
# FIXME: Other options for arg[0] are "OptArg" and "PairArg".
if args[0] == "ExtraArg":
if args[1] == "identref":
return ids[0][1][1]
elif args[1] == "ident":
return ids[1]
elif args[0] == "ListArg":
# FIXME: This recursive case works when the list is of length 1,
# but it should be generalized to handle any length.
return handle_arg_type(args[1], ids[0])
return None

if len(el) == 3 and el[0] == "GenArg" and el[1][0] == "Rawwit":
return handle_arg_type(el[1][1], el[2])
return None

@staticmethod
Expand Down Expand Up @@ -546,6 +563,19 @@ def get_term(self, name: str) -> Optional[Term]:
return self.__terms[curr_name][-1]
return None

@staticmethod
def get_notation_scope(notation: str) -> str:
"""Get the scope of a notation.
Args:
notation (str): Possibly scoped notation pattern. E.g. "_ + _ : nat_scope".
Returns:
str: The scope of the notation. E.g. "nat_scope".
"""
if notation.split(":")[-1].endswith("_scope"):
return notation.split(":")[-1].strip()
return ""

def get_notation(self, notation: str, scope: str) -> Term:
"""Get a notation from the context.
Expand Down
46 changes: 24 additions & 22 deletions coqpyt/coq/proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
Args:
file_path (str): Path of the Coq file.
library (Optional[str], optional): The library of the file. Defaults to None.
timeout (int, optional): Timeout used in coq-lsp. Defaults to 2.
timeout (int, optional): Timeout used in coq-lsp. Defaults to 30.
workspace (Optional[str], optional): Absolute path for the workspace.
If the workspace is not defined, the workspace is equal to the
path of the file.
Expand Down Expand Up @@ -372,42 +372,44 @@ def _handle_exception(self, e):
raise e

def __locate(self, search, line):
nots = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line).split(
"\n"
)
fun = lambda x: x.endswith("(default interpretation)")
return nots[0][:-25] if fun(nots[0]) else nots[0]
located = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line)
trim = lambda x: x[:-25] if x.endswith("(default interpretation)") else x
return list(map(trim, located.split("\n")))

def __step_context(self, step: Step) -> List[Term]:
stack, res = self.context.expr(step)[:0:-1], []
while len(stack) > 0:
el = stack.pop()
if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid":
if FileContext.is_id(el):
term = self.context.get_term(FileContext.get_id(el))
if term is not None and term not in res:
res.append(term)
elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation":
elif FileContext.is_notation(el):
stack.append(el[1:])

notation_name = el[2][1]
line = len(self.__aux_file.read().split("\n"))
self.__aux_file.append(f'\nLocate "{el[2][1]}".')
self.__aux_file.append(f'\nLocate "{notation_name}".')
self.__aux_file.didChange()
notations = self.__locate(notation_name, line)
if len(notations) == 1 and notations[0] == "Unknown notation":
continue

notation_name, scope = el[2][1], ""
notation = self.__locate(notation_name, line)
if notation.split(":")[-1].endswith("_scope"):
scope = notation.split(":")[-1].strip()

if notation != "Unknown notation":
for notation in notations:
scope = FileContext.get_notation_scope(notation)
try:
term = self.context.get_notation(notation_name, scope)
if term not in res:
res.append(term)
except NotationNotFoundException as e:
if self.__error_mode == "strict":
raise e
else:
logging.warning(str(e))

stack.append(el[1:])
break
except NotationNotFoundException:
continue
else:
e = NotationNotFoundException(notation_name)
if self.__error_mode == "strict":
raise e
else:
logging.warning(str(e))
elif isinstance(el, list):
for v in reversed(el):
if isinstance(v, (dict, list)):
Expand Down
1 change: 1 addition & 0 deletions coqpyt/coq/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TermType(Enum):
SETOID = 22
FUNCTION = 23
DERIVE = 24
EQUATION = 25
OTHER = 100


Expand Down

0 comments on commit 441b1e6

Please sign in to comment.