Skip to content

Commit

Permalink
Refactor where notation. Add short_text to Steps. Fix __find_prev
Browse files Browse the repository at this point in the history
  • Loading branch information
Nfsaavedra committed Oct 27, 2023
1 parent 1995442 commit 1c543eb
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 121 deletions.
110 changes: 34 additions & 76 deletions coqlspclient/coq_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def __init__(
uri = f"file://{self.__path}"
self.coq_lsp_client = CoqLspClient(uri, timeout=timeout)
uri = f"file://{self.__path}"
with open(self.__path, "r") as f:
text = f.read()
text = self.__read()

try:
self.coq_lsp_client.didOpen(TextDocumentItem(uri, "coq", 1, text))
Expand Down Expand Up @@ -130,7 +129,13 @@ def __init_step(
curr_lines[-1] = curr_lines[-1][:end_char]
curr_lines[0] = curr_lines[0][start_char:]
step_text = "\n".join(curr_lines)
return Step(step_text, step_ast)

if index == 0:
short_text = self.__short_text(step_text, step_ast)
else:
short_text = self.__short_text(step_text, step_ast, prev_step_ast)

return Step(step_text, short_text, step_ast)

def __init_steps(self, text: str, ast: List[RangedSpan]):
lines = text.split("\n")
Expand Down Expand Up @@ -208,48 +213,24 @@ def expr(step: RangedSpan) -> Optional[List]:

return [None]

def __slice_line(
self,
line: str,
start: Optional[int] = None,
stop: Optional[int] = None,
range: Optional[Range] = None,
):
if range is None:
return line[start:stop]

# A range will be provided when range.character is measured in bytes,
# rather than characters. If so, the string must be encoded before
# indexing. This special treatment is necessary for handling Unicode
# characters which take up more than 1 byte. Currently necessary to
# handle where-notations.
line = line.encode("utf-8")

# For where-notations, the character count does not include the closing
# parenthesis when present.
if stop is not None and len(line) > stop and chr(line[stop]) == ")":
stop += 1
return line[start:stop].decode()

def __short_text(self, range: Optional[Range] = None):
curr_range = self.curr_step.ast.range if range is None else range
def __short_text(self, text: str, curr_step: RangedSpan, prev_step: Optional[RangedSpan] = None):
curr_range = curr_step.range
nlines = curr_range.end.line - curr_range.start.line + 1
lines = self.curr_step.text.split("\n")[-nlines:]
lines = text.split("\n")[-nlines:]

prev_range = None if self.steps_taken == 0 else self.prev_step.ast.range
prev_range = None if prev_step is None else prev_step.range
if prev_range is None or prev_range.end.line < curr_range.start.line:
start = curr_range.start.character
else:
start = curr_range.start.character - prev_range.end.character

lines[-1] = self.__slice_line(
lines[-1], stop=curr_range.end.character, range=range
)
lines[0] = self.__slice_line(lines[0], start=start, range=range)
lines[-1] = lines[-1][:curr_range.end.character]
lines[0] = lines[0][start:]

return " ".join(" ".join(lines).split())

def __add_term(self, name: str, ast: RangedSpan, text: str, term_type: TermType):
term = Term(text, ast, term_type, self.path, self.curr_module[:])
def __add_term(self, name: str, step: Step, term_type: TermType):
term = Term(step, term_type, self.path, self.curr_module[:])
if term.type == TermType.NOTATION:
self.context.update(terms={name: term})
return
Expand Down Expand Up @@ -340,25 +321,10 @@ def __handle_where_notations(self, expr: List, term_type: TermType):

# handles when multiple notations are defined
for span in spans:
start = Position(
span["ntn_decl_string"]["loc"]["line_nb"] - 1,
span["ntn_decl_string"]["loc"]["bp"]
- span["ntn_decl_string"]["loc"]["bol_pos"],
)
end = Position(
span["ntn_decl_interp"]["loc"]["line_nb_last"] - 1,
span["ntn_decl_interp"]["loc"]["ep"]
- span["ntn_decl_interp"]["loc"]["bol_pos"],
)
range = Range(start, end)
text = self.__short_text(range=range)
name = FileContext.get_notation_key(
span["ntn_decl_string"]["v"], span["ntn_decl_scope"]
)
if span["ntn_decl_scope"] is not None:
text += " : " + span["ntn_decl_scope"]
text = "Notation " + text
self.__add_term(name, RangedSpan(range, span), text, TermType.NOTATION)
self.__add_term(name, self.curr_step, TermType.NOTATION)

def __process_step(self, sign):
def traverse_expr(expr):
Expand Down Expand Up @@ -400,7 +366,9 @@ def traverse_expr(expr):
# TODO: A negative sign should handle things differently. For example:
# - names should be removed from the context
# - curr_module should change as you leave or re-enter modules
text = self.__short_text()

text = self.curr_step.short_text

# FIXME Let (and maybe Variable) should be handled. However,
# I think we can't handle them as normal Locals since they are
# specific to a section.
Expand Down Expand Up @@ -451,33 +419,33 @@ def traverse_expr(expr):
name = text.split('"')[1].strip()
if text[:-1].split(":")[-1].endswith("_scope"):
name += " : " + text[:-1].split(":")[-1].strip()
self.__add_term(name, self.curr_step.ast, text, TermType.NOTATION)
self.__add_term(name, self.curr_step, TermType.NOTATION)
elif expr[0] == "VernacSyntacticDefinition":
name = text.split(" ")[1]
if text[:-1].split(":")[-1].endswith("_scope"):
name += " : " + text[:-1].split(":")[-1].strip()
self.__add_term(name, self.curr_step.ast, text, TermType.NOTATION)
self.__add_term(name, self.curr_step, TermType.NOTATION)
elif expr[0] == "VernacInstance" and expr[1][0]["v"][0] == "Anonymous":
# FIXME: The name should be "<Class>_instance_N"
self.__add_term("_anonymous", self.curr_step.ast, text, term_type)
self.__add_term("_anonymous", self.curr_step, term_type)
elif expr[0] == "VernacDefinition" and expr[2][0]["v"][0] == "Anonymous":
# We associate the anonymous term to the same name used internally by Coq
if self.__anonymous_id is None:
name, self.__anonymous_id = "Unnamed_thm", 0
else:
name = f"Unnamed_thm{self.__anonymous_id}"
self.__anonymous_id += 1
self.__add_term(name, self.curr_step.ast, text, term_type)
self.__add_term(name, self.curr_step, term_type)
elif term_type == TermType.DERIVE:
name = CoqFile.get_ident(expr[2][0])
self.__add_term(name, self.curr_step.ast, text, term_type)
self.__add_term(name, self.curr_step, term_type)
if expr[1][0] == "Derive":
prop = CoqFile.get_ident(expr[2][2])
self.__add_term(prop, self.curr_step.ast, text, term_type)
self.__add_term(prop, self.curr_step, term_type)
else:
names = traverse_expr(expr)
for name in names:
self.__add_term(name, self.curr_step.ast, text, term_type)
self.__add_term(name, self.curr_step, term_type)

self.__handle_where_notations(expr, term_type)
finally:
Expand Down Expand Up @@ -531,8 +499,7 @@ def _make_change(self, change_function, *args):
previous_steps = self.steps
old_steps_taken = self.steps_taken
old_diagnostics = self.coq_lsp_client.lsp_endpoint.diagnostics
with open(self.__path, "r") as f:
lines = f.read().split("\n")
lines = self.__read().split("\n")
old_text = "\n".join(lines)

try:
Expand Down Expand Up @@ -569,14 +536,9 @@ def _delete_step(
start_line = lines[prev_step.ast.range.end.line]
end_line = lines[step.ast.range.end.line]

start_line = self.__slice_line(
start_line,
stop=prev_step.ast.range.end.character,
range=prev_step.ast.range,
)
end_line = self.__slice_line(
end_line, start=step.ast.range.end.character, range=step.ast.range
)
start_line = start_line[: prev_step.ast.range.end.character]
end_line = end_line[step.ast.range.end.character:]

if prev_step.ast.range.end.line == step.ast.range.end.line:
lines[prev_step.ast.range.end.line] = start_line + end_line
else:
Expand Down Expand Up @@ -628,13 +590,9 @@ def _add_step(
end_line = previous_step.ast.range.end.line
end_char = previous_step.ast.range.end.character
lines[end_line] = (
self.__slice_line(
lines[end_line], stop=end_char, range=previous_step.ast.range
)
lines[end_line][:end_char]
+ step_text
+ self.__slice_line(
lines[end_line], start=end_char + 1, range=previous_step.ast.range
)
+ lines[end_line][end_char + 1:]
)
text = "\n".join(lines)
f.write(text)
Expand Down
18 changes: 12 additions & 6 deletions coqlspclient/coq_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ def __init__(self, code: CoqErrorCodes, message: str):


class Step(object):
def __init__(self, text: str, ast: RangedSpan):
def __init__(self, text: str, short_text: str, ast: RangedSpan):
self.text = text
self.short_text = short_text
self.ast = ast
self.diagnostics: List[Diagnostic] = []


class Term:
def __init__(
self,
text: str,
ast: RangedSpan,
step: Step,
term_type: TermType,
file_path: str,
module: List[str],
Expand All @@ -75,8 +75,7 @@ def __init__(
file_path (str): The file where the term is.
module (str): The module where the term is.
"""
self.text = text
self.ast = ast
self.step = step
self.type = term_type
self.file_path = file_path
self.module = module
Expand All @@ -89,6 +88,13 @@ def __eq__(self, __value: object) -> bool:
def __hash__(self) -> int:
return hash(self.text)

@property
def text(self) -> str:
return self.step.short_text

@property
def ast(self) -> RangedSpan:
return self.step.ast

class ProofStep:
def __init__(
Expand Down Expand Up @@ -126,7 +132,7 @@ def diagnostics(self) -> List[Diagnostic]:

class ProofTerm(Term):
def __init__(self, term: Term, context: List[Term], steps: List[ProofStep]):
super().__init__(term.text, term.ast, term.type, term.file_path, term.module)
super().__init__(term.step, term.type, term.file_path, term.module)
self.steps = steps
self.context = context

Expand Down
29 changes: 5 additions & 24 deletions coqlspclient/proof_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,20 +429,9 @@ def __find_step(self, range: Range) -> Optional[Tuple[ProofTerm, int]]:
for proof in self.proofs:
for i, proof_step in enumerate(proof.steps):
if proof_step.ast.range == range:
break
else:
continue
break
return (proof, i)
else:
return None
return (proof, i)

def __find_proof(self, range: Range) -> Optional[ProofTerm]:
for proof in self.proofs:
for step in proof.steps:
if step.ast.range >= range:
return proof
return None

def __find_prev(self, range: Range) -> Tuple[ProofTerm, Optional[int]]:
optional = self.__find_step(range)
Expand All @@ -452,16 +441,9 @@ def __find_prev(self, range: Range) -> Tuple[ProofTerm, Optional[int]]:
if proof.ast.range == range:
return (proof, -1)

# When the step is the first step of the proof
proof = self.__find_proof(range)
# For a proof that did not end on the end of the file.
if proof is None and len(self.proofs[-1].steps) == 0:
proof = self.proofs[-1]
elif proof is None:
raise NotImplementedError(
"Adding steps outside of a proof is not implemented yet"
)
return (proof, None)
raise NotImplementedError(
"Adding steps outside of a proof is not implemented yet"
)
else:
return optional

Expand Down Expand Up @@ -491,8 +473,7 @@ def proofs(self) -> List[ProofTerm]:

def add_step(self, step_text: str, previous_step_index: int):
proof, prev = self.__find_prev(self.steps[previous_step_index].ast.range)
if prev is None:
prev = -1
if prev == -1:
self._make_change(self._add_step, step_text, previous_step_index)
else:
self._make_change(self._add_step, step_text, previous_step_index, True)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_coq_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,16 @@ def test_where_notation(setup, teardown):
assert "n + m : test_scope" in coq_file.context.terms
assert (
coq_file.context.terms["n + m : test_scope"].text
== 'Notation "n + m" := (plus n m) : test_scope'
== 'Fixpoint plus_test (n m : nat) {struct n} : nat := match n with | O => m | S p => S (p + m) end where "n + m" := (plus n m) : test_scope and "n - m" := (minus n m).'
)
assert "n - m" in coq_file.context.terms
assert coq_file.context.terms["n - m"].text == 'Notation "n - m" := (minus n m)'
assert coq_file.context.terms["n - m"].text == 'Fixpoint plus_test (n m : nat) {struct n} : nat := match n with | O => m | S p => S (p + m) end where "n + m" := (plus n m) : test_scope and "n - m" := (minus n m).'
assert "A & B" in coq_file.context.terms
assert coq_file.context.terms["A & B"].text == 'Notation "A & B" := (and\' A B)'
assert coq_file.context.terms["A & B"].text == "Inductive and' (A B : Prop) : Prop := conj' : A -> B -> A & B where \"A & B\" := (and' A B)."
assert "'ONE'" in coq_file.context.terms
assert coq_file.context.terms["'ONE'"].text == "Notation \"'ONE'\" := 1"
assert coq_file.context.terms["'ONE'"].text == "Fixpoint incr (n : nat) : nat := n + ONE where \"'ONE'\" := 1."
assert "x πŸ€„ y" in coq_file.context.terms
assert coq_file.context.terms["x πŸ€„ y"].text == 'Notation "x πŸ€„ y" := (plus_test x y)'
assert coq_file.context.terms["x πŸ€„ y"].text == 'Fixpoint unicode x y := x πŸ€„ y where "x πŸ€„ y" := (plus_test x y).'


@pytest.mark.parametrize("setup", ["test_get_notation.v"], indirect=True)
Expand Down
Loading

0 comments on commit 1c543eb

Please sign in to comment.