Skip to content

Commit

Permalink
Merge pull request #11 from sr-lab/general_fixes
Browse files Browse the repository at this point in the history
General fixes
  • Loading branch information
pcarrott authored Sep 22, 2023
2 parents e107d7d + e67c1ec commit 695726b
Show file tree
Hide file tree
Showing 15 changed files with 532 additions and 275 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ from coqlspclient.coq_structs import TermType

# Open Coq file
with CoqFile(os.path.join(os.getcwd(), "examples/example.v")) as coq_file:
# Print AST
print(coq_file.ast)
coq_file.exec(nsteps=2)
# Get all terms defined until now
print("Number of terms:", len(coq_file.context.terms))
Expand Down
333 changes: 211 additions & 122 deletions coqlspclient/coq_file.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions coqlspclient/coq_lsp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CoqLspClient(LspClient):
def __init__(
self,
root_uri: str,
timeout: int = 2,
timeout: int = 30,
memory_limit: int = 2097152,
init_options: Dict = __DEFAULT_INIT_OPTIONS,
):
Expand Down Expand Up @@ -107,7 +107,7 @@ def __wait_for_operation(self):
if timeout <= 0:
self.shutdown()
self.exit()
raise ResponseError(ErrorCodes.ServerQuit, "Server quit")
raise ResponseError(ErrorCodes.ServerTimeout, "Server timeout")

def didOpen(self, textDocument: TextDocumentItem):
"""Open a text document in the server.
Expand Down
10 changes: 0 additions & 10 deletions coqlspclient/coq_lsp_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,3 @@ def parse(coqFileProgressParams: Dict) -> Optional["CoqFileProgressParams"]:
)
)
return CoqFileProgressParams(textDocument, processing)


class CoqErrorCodes(Enum):
InvalidFile = 0


class CoqError(Exception):
def __init__(self, code: CoqErrorCodes, message: str):
self.code = code
self.message = message
74 changes: 53 additions & 21 deletions coqlspclient/coq_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
from enum import Enum
from typing import Dict, List
from coqlspclient.coq_lsp_structs import RangedSpan, GoalAnswer


class Step(object):
def __init__(self, text: str, ast: RangedSpan):
self.text = text
self.ast = ast
from pylspclient.lsp_structs import Diagnostic


class SegmentType(Enum):
Expand All @@ -22,22 +17,46 @@ class TermType(Enum):
DEFINITION = 3
NOTATION = 4
INDUCTIVE = 5
RECORD = 6
CLASS = 7
INSTANCE = 8
FIXPOINT = 9
TACTIC = 10
SCHEME = 11
VARIANT = 12
FACT = 13
REMARK = 14
COROLLARY = 15
PROPOSITION = 16
PROPERTY = 17
OBLIGATION = 18
COINDUCTIVE = 6
RECORD = 7
CLASS = 8
INSTANCE = 9
FIXPOINT = 10
COFIXPOINT = 11
SCHEME = 12
VARIANT = 13
FACT = 14
REMARK = 15
COROLLARY = 16
PROPOSITION = 17
PROPERTY = 18
OBLIGATION = 19
TACTIC = 20
RELATION = 21
SETOID = 22
FUNCTION = 23
DERIVE = 24
OTHER = 100


class CoqErrorCodes(Enum):
InvalidFile = 0
NotationNotFound = 1


class CoqError(Exception):
def __init__(self, code: CoqErrorCodes, message: str):
self.code = code
self.message = message


class Step(object):
def __init__(self, text: str, ast: RangedSpan):
self.text = text
self.ast = ast
self.diagnostics: List[Diagnostic] = []


class Term:
def __init__(
self,
Expand Down Expand Up @@ -112,16 +131,26 @@ def get_notation(self, notation: str, scope: str) -> Term:
regex = f"{re.escape(notation_id)}".split("\\ ")
for i, sub in enumerate(regex):
if sub == "_":
regex[i] = "(.+)"
# We match the wildcard with the description from here:
# https://coq.inria.fr/distrib/current/refman/language/core/basic.html#grammar-token-ident
# Coq accepts more characters, but no one should need more than these...
chars = "A-Za-zÀ-ÖØ-öø-ˁˆ-ˑˠ-ˤˬˮͰ-ʹͶͷͺ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-Ֆՙա-և"
regex[i] = f"([{chars}][{chars}0-9_']*|_[{chars}0-9_']+)"
else:
# Handle '_'
regex[i] = f"({sub}|('{sub}'))"
regex = "^" + "\\ ".join(regex) + "$"

# Search notations
unscoped = None
for term in self.terms.keys():
if re.match(regex, term):
return self.terms[term]
if re.match(regex, term.split(":")[0].strip()):
unscoped = term
# In case the stored id contains the scope but no scope is provided
if unscoped is not None:
return self.terms[unscoped]

# Search Infix
if re.match("^_ ([^ ]*) _$", notation):
Expand All @@ -130,7 +159,10 @@ def get_notation(self, notation: str, scope: str) -> Term:
if key in self.terms:
return self.terms[key]

raise RuntimeError(f"Notation not found in context: {notation_id}")
raise CoqError(
CoqErrorCodes.NotationNotFound,
f"Notation not found in context: {notation_id}",
)

@staticmethod
def get_notation_key(notation: str, scope: str):
Expand Down
147 changes: 69 additions & 78 deletions coqlspclient/proof_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,23 @@
ErrorCodes,
)
from coqlspclient.coq_structs import (
ProofStep,
FileContext,
TermType,
CoqErrorCodes,
CoqError,
Step,
Term,
ProofStep,
ProofTerm,
TermType,
)
from coqlspclient.coq_lsp_structs import (
CoqError,
CoqErrorCodes,
Result,
Query,
GoalAnswer,
FileContext,
)
from coqlspclient.coq_lsp_structs import Result, Query, GoalAnswer
from coqlspclient.coq_file import CoqFile
from coqlspclient.coq_lsp_client import CoqLspClient
from typing import List, Dict, Optional, Tuple


class _AuxFile(object):
def __init__(self, file_path: Optional[str] = None, timeout: int = 2):
def __init__(self, file_path: Optional[str] = None, timeout: int = 30):
self.__init_path(file_path)
uri = f"file://{self.path}"
self.coq_lsp_client = CoqLspClient(uri, timeout=timeout)
Expand Down Expand Up @@ -69,7 +65,10 @@ def append(self, text):
f.write(text)

def __handle_exception(self, e):
if not (isinstance(e, ResponseError) and e.code == ErrorCodes.ServerQuit.value):
if not isinstance(e, ResponseError) or e.code not in [
ErrorCodes.ServerQuit.value,
ErrorCodes.ServerTimeout.value,
]:
self.coq_lsp_client.shutdown()
self.coq_lsp_client.exit()
os.remove(self.path)
Expand Down Expand Up @@ -102,19 +101,15 @@ def __get_queries(self, keyword):
searches = {}
lines = self.read().split("\n")
for diagnostic in self.coq_lsp_client.lsp_endpoint.diagnostics[uri]:
command = lines[
diagnostic.range["start"]["line"] : diagnostic.range["end"]["line"] + 1
]
command = lines[diagnostic.range.start.line : diagnostic.range.end.line + 1]
if len(command) == 1:
command[0] = command[0][
diagnostic.range["start"]["character"] : diagnostic.range["end"][
"character"
]
diagnostic.range.start.character : diagnostic.range.end.character
+ 1
]
else:
command[0] = command[0][diagnostic.range["start"]["character"] :]
command[-1] = command[-1][: diagnostic.range["end"]["character"] + 1]
command[0] = command[0][diagnostic.range.start.character :]
command[-1] = command[-1][: diagnostic.range.end.character + 1]
command = "".join(command).strip()

if command.startswith(keyword):
Expand All @@ -133,7 +128,7 @@ def get_diagnostics(self, keyword, search, line):
for query in self.__get_queries(keyword):
if query.query == f"{search}":
for result in query.results:
if result.range["start"]["line"] == line:
if result.range.start.line == line:
return result.message
break
return None
Expand Down Expand Up @@ -162,7 +157,7 @@ def get_context(file_path: str, timeout: int):
for i, library in enumerate(libraries):
v_file = aux_file.get_diagnostics(
"Locate Library", library, last_line + i + 1
).split("\n")[-1][:-1]
).split()[-1][:-1]
coq_file = CoqFile(v_file, library=library, timeout=timeout)
coq_file.run()

Expand Down Expand Up @@ -231,72 +226,68 @@ def __locate(self, search, line):
fun = lambda x: x.endswith("(default interpretation)")
return nots[0][:-25] if fun(nots[0]) else nots[0]

def __step_context(self, step=None):
def traverse_ast(el):
if isinstance(el, dict):
return [x for v in el.values() for x in traverse_ast(v)]
elif isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid":
id = ".".join([l[1] for l in el[1][1][::-1]] + [el[2][1]])
term = self.__get_term(id)
return [] if term is None else [(lambda x: x, term)]
elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation":
line = len(self.__aux_file.read().split("\n"))
self.__aux_file.append(f'\nLocate "{el[2][1]}".')

def __search_notation(call):
notation_name = call[0]
scope = ""
notation = call[1](*call[2:])
if notation == "Unknown notation":
return None
if notation.split(":")[-1].endswith("_scope"):
scope = notation.split(":")[-1].strip()
return self.context.get_notation(notation_name, scope)

return [
(__search_notation, (el[2][1], self.__locate, el[2][1], line))
] + traverse_ast(el[1:])
elif isinstance(el, list):
return [x for v in el for x in traverse_ast(v)]

return []

if step is None:
step = self.__current_step.ast
return traverse_ast(step.span)
def __step_context(self):
def traverse_expr(expr):
stack, res = expr[:0:-1], []
while len(stack) > 0:
el = stack.pop()
if isinstance(el, list) and len(el) == 3 and el[0] == "Ser_Qualid":
term = self.__get_term(CoqFile.get_id(el))
if term is not None:
res.append((lambda x: x, term))
elif isinstance(el, list) and len(el) == 4 and el[0] == "CNotation":
line = len(self.__aux_file.read().split("\n"))
self.__aux_file.append(f'\nLocate "{el[2][1]}".')

def __search_notation(call):
notation_name = call[0]
scope = ""
notation = call[1](*call[2:])
if notation == "Unknown notation":
return None
if notation.split(":")[-1].endswith("_scope"):
scope = notation.split(":")[-1].strip()
return self.context.get_notation(notation_name, scope)

res.append(
(__search_notation, (el[2][1], self.__locate, el[2][1], line))
)
stack.append(el[1:])
elif isinstance(el, list):
for v in reversed(el):
if isinstance(v, (dict, list)):
stack.append(v)
elif isinstance(el, dict):
for v in reversed(el.values()):
if isinstance(v, (dict, list)):
stack.append(v)
return res

return traverse_expr(CoqFile.expr(self.__current_step.ast))

def __get_last_term(self):
terms = self.coq_file.terms
if len(terms) == 0:
return None
last_term = terms[0]
for term in terms[1:]:
if (term.ast.range.end.line > last_term.ast.range.end.line) or (
term.ast.range.end.line == last_term.ast.range.end.line
and term.ast.range.end.character > last_term.ast.range.end.character
):
if last_term.ast.range.end < term.ast.range.end:
last_term = term
return last_term

def __get_program_context(self):
def traverse_ast(el, keep_id=False):
if (
isinstance(el, list)
and len(el) == 2
and (
(el[0] == "Id" and keep_id)
or (el[0] == "ExtraArg" and el[1] == "identref")
)
):
return el[1]
elif isinstance(el, list):
for x in el:
id = traverse_ast(x, keep_id=keep_id)
if id == "identref":
keep_id = True
elif id is not None:
return id
return "identref" if keep_id else None
def traverse_expr(expr):
stack = expr[:0:-1]
while len(stack) > 0:
el = stack.pop()
if isinstance(el, list):
ident = CoqFile.get_ident(el)
if ident is not None:
return ident

for v in reversed(el):
if isinstance(v, list):
stack.append(v)
return None

# Tags:
Expand All @@ -308,7 +299,7 @@ def traverse_ast(el, keep_id=False):
# 5 - Next Obligation
tag = CoqFile.expr(self.__current_step.ast)[1][1]
if tag in [0, 1, 4]:
id = traverse_ast(CoqFile.expr(self.__current_step.ast))
id = traverse_expr(CoqFile.expr(self.__current_step.ast))
# This works because the obligation must be in the
# same module as the program
id = ".".join(self.coq_file.curr_module + [id])
Expand Down
3 changes: 2 additions & 1 deletion pylspclient/lsp_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import logging
from pylspclient import lsp_structs
from typing import Dict, List


class LspEndpoint(threading.Thread):
Expand All @@ -17,7 +18,7 @@ def __init__(
self.next_id = 0
self.timeout = timeout
self.shutdown_flag = False
self.diagnostics = {}
self.diagnostics: Dict[str, List[lsp_structs.Diagnostic]] = {}

def handle_result(self, rpc_id, result, error):
self.response_dict[rpc_id] = (result, error)
Expand Down
Loading

0 comments on commit 695726b

Please sign in to comment.