diff --git a/coqpyt/coq/context.py b/coqpyt/coq/context.py index 1c59bcf..7fda3e9 100644 --- a/coqpyt/coq/context.py +++ b/coqpyt/coq/context.py @@ -107,11 +107,10 @@ def __add_terms(self, step: Step, expr: List): self.__anonymous_id += 1 self.__add_term(name, step, term_type) elif term_type == TermType.DERIVE: - name = FileContext.get_ident(expr[2][0]) - self.__add_term(name, step, term_type) - if self.__ext_entry(expr[1]) == "Derive": - prop = FileContext.get_ident(expr[2][2]) - self.__add_term(prop, step, term_type) + for arg in expr[2]: + name = FileContext.get_ident(arg) + if name is not None: + self.__add_term(name, step, term_type) elif term_type == TermType.OBLIGATION: self.__last_terms[-1].append( ("", Term(step, term_type, self.__path, self.__segments.modules[:])) @@ -241,12 +240,15 @@ def __term_type(self, expr: List) -> TermType: return TermType.FIXPOINT if expr[0] == "VernacScheme": return TermType.SCHEME + # FIXME: These are plugins and should probably be handled differently if self.__is_extend(expr, "Obligations"): return TermType.OBLIGATION if self.__is_extend(expr, "VernacDeclareTacticDefinition"): return TermType.TACTIC if self.__is_extend(expr, "Function"): return TermType.FUNCTION + if self.__is_extend(expr, "Define_equations", exact=False): + return TermType.EQUATION if self.__is_extend(expr, "Derive", exact=False): return TermType.DERIVE if self.__is_extend(expr, "AddSetoid", exact=False): @@ -293,6 +295,16 @@ def __get_names(expr: List) -> List[str]: stack.append(v) return res + @staticmethod + def is_id(el) -> bool: + return isinstance(el, list) and (len(el) == 3 and el[0] == "Ser_Qualid") + + @staticmethod + def is_notation(el) -> bool: + return isinstance(el, list) and ( + len(el) == 4 and el[0] == "CNotation" and el[2][1] != "" + ) + @staticmethod def get_id(id: List) -> Optional[str]: # FIXME: This should be made private once [__step_context] is extracted @@ -305,18 +317,23 @@ def get_id(id: List) -> Optional[str]: @staticmethod def get_ident(el: List) -> Optional[str]: - # FIXME: This should be made private once [__get_program_context] is extracted - # from ProofFile to here. - if ( - len(el) == 3 - and el[0] == "GenArg" - and el[1][0] == "Rawwit" - and el[1][1][0] == "ExtraArg" - ): - if el[1][1][1] == "identref": - return el[2][0][1][1] - elif el[1][1][1] == "ident": - return el[2][1] + # FIXME: This method should be made private once [__get_program_context] + # is extracted from ProofFile to here. + def handle_arg_type(args, ids): + # FIXME: Other options for arg[0] are "OptArg" and "PairArg". + if args[0] == "ExtraArg": + if args[1] == "identref": + return ids[0][1][1] + elif args[1] == "ident": + return ids[1] + elif args[0] == "ListArg": + # FIXME: This recursive case works when the list is of length 1, + # but it should be generalized to handle any length. + return handle_arg_type(args[1], ids[0]) + return None + + if len(el) == 3 and el[0] == "GenArg" and el[1][0] == "Rawwit": + return handle_arg_type(el[1][1], el[2]) return None @staticmethod @@ -546,6 +563,19 @@ def get_term(self, name: str) -> Optional[Term]: return self.__terms[curr_name][-1] return None + @staticmethod + def get_notation_scope(notation: str) -> str: + """Get the scope of a notation. + Args: + notation (str): Possibly scoped notation pattern. E.g. "_ + _ : nat_scope". + + Returns: + str: The scope of the notation. E.g. "nat_scope". + """ + if notation.split(":")[-1].endswith("_scope"): + return notation.split(":")[-1].strip() + return "" + def get_notation(self, notation: str, scope: str) -> Term: """Get a notation from the context. diff --git a/coqpyt/coq/proof_file.py b/coqpyt/coq/proof_file.py index fc323a5..0597f4f 100644 --- a/coqpyt/coq/proof_file.py +++ b/coqpyt/coq/proof_file.py @@ -313,7 +313,7 @@ def __init__( Args: file_path (str): Path of the Coq file. library (Optional[str], optional): The library of the file. Defaults to None. - timeout (int, optional): Timeout used in coq-lsp. Defaults to 2. + timeout (int, optional): Timeout used in coq-lsp. Defaults to 30. workspace (Optional[str], optional): Absolute path for the workspace. If the workspace is not defined, the workspace is equal to the path of the file. @@ -372,42 +372,44 @@ def _handle_exception(self, e): raise e def __locate(self, search, line): - nots = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line).split( - "\n" - ) - fun = lambda x: x.endswith("(default interpretation)") - return nots[0][:-25] if fun(nots[0]) else nots[0] + located = self.__aux_file.get_diagnostics("Locate", f'"{search}"', line) + trim = lambda x: x[:-25] if x.endswith("(default interpretation)") else x + return list(map(trim, located.split("\n"))) def __step_context(self, step: Step) -> List[Term]: stack, res = self.context.expr(step)[:0:-1], [] while len(stack) > 0: el = stack.pop() - if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid": + if FileContext.is_id(el): term = self.context.get_term(FileContext.get_id(el)) if term is not None and term not in res: res.append(term) - elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation": + elif FileContext.is_notation(el): + stack.append(el[1:]) + + notation_name = el[2][1] line = len(self.__aux_file.read().split("\n")) - self.__aux_file.append(f'\nLocate "{el[2][1]}".') + self.__aux_file.append(f'\nLocate "{notation_name}".') self.__aux_file.didChange() + notations = self.__locate(notation_name, line) + if len(notations) == 1 and notations[0] == "Unknown notation": + continue - notation_name, scope = el[2][1], "" - notation = self.__locate(notation_name, line) - if notation.split(":")[-1].endswith("_scope"): - scope = notation.split(":")[-1].strip() - - if notation != "Unknown notation": + for notation in notations: + scope = FileContext.get_notation_scope(notation) try: term = self.context.get_notation(notation_name, scope) if term not in res: res.append(term) - except NotationNotFoundException as e: - if self.__error_mode == "strict": - raise e - else: - logging.warning(str(e)) - - stack.append(el[1:]) + break + except NotationNotFoundException: + continue + else: + e = NotationNotFoundException(notation_name) + if self.__error_mode == "strict": + raise e + else: + logging.warning(str(e)) elif isinstance(el, list): for v in reversed(el): if isinstance(v, (dict, list)): diff --git a/coqpyt/coq/structs.py b/coqpyt/coq/structs.py index c8c704d..c8ae587 100644 --- a/coqpyt/coq/structs.py +++ b/coqpyt/coq/structs.py @@ -36,6 +36,7 @@ class TermType(Enum): SETOID = 22 FUNCTION = 23 DERIVE = 24 + EQUATION = 25 OTHER = 100