Skip to content

Commit

Permalink
Downgrade python features to support python >= 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Oct 12, 2023
1 parent b724f67 commit ab9a090
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 97 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ authors = [
{name = "Miroslav Olšák", email = "mirek@olsak.net"},
]
readme = "pytact/README.md"
requires-python = ">=3.10"
requires-python = ">=3.8"
license = {text = "MIT License"}
classifiers = [
"Development Status :: 1 - Planning",
Expand Down
18 changes: 13 additions & 5 deletions pytact/data_reader.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ as can be found in `fake_coq_client.py`.
from __future__ import annotations
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from typing import Any, Callable, TypeVar, TypeAlias, Union, cast, BinaryIO
from typing import Any, Callable, TypeVar, Union, cast, BinaryIO
from collections.abc import Iterable, Sequence, Generator
from pathlib import Path
from immutables import Map
Expand Down Expand Up @@ -583,8 +583,8 @@ cdef class Argument_List:
def __len__(self):
return self.reader.size()

Unresolvable: TypeAlias = None
Unknown: TypeAlias = None
Unresolvable = None # TypeAlias
Unknown = None # TypeAlias
cdef class Outcome:
"""An outcome is the result of running a tactic on a proof state. A tactic may run on multiple proof states."""

Expand Down Expand Up @@ -1584,6 +1584,13 @@ def capnp_message_generator_from_file(message_file: BinaryIO,
pg = prediction_generator(lgenerator, defs)
return GlobalContextMessage(defs, [], None, pg)


@contextmanager
def _new_context() -> Generator[GlobalContextSets, None, None]:
"""Crate a new caching context where global-context-set's can be retrieved and cached."""
yield GlobalContextSets(Map(), None, lambda _: False)


class GlobalContextSets:
"""Lazily retrieve a the global context of a definition as a set, with memoization.

Expand Down Expand Up @@ -1651,11 +1658,12 @@ class GlobalContextSets:
the result is propagated to the parent's cache."""
yield GlobalContextSets(self.cache, self, propagate)

@contextmanager
# TODO: Hack: Python <= 3.9 cannot deal with a simultaneous @contextmanager and @staticmethod
# Therefore, we have a helper function _new_context(). This can be merged once python 3.9 is deprecated.
@staticmethod
def new_context() -> Generator[GlobalContextSets, None, None]:
"""Crate a new caching context where global-context-set's can be retrieved and cached."""
yield GlobalContextSets(Map(), None, lambda _: False)
return _new_context()

cdef struct GlobalNode:
GraphId graphid
Expand Down
21 changes: 10 additions & 11 deletions pytact/graph_sanity_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,16 @@ def process1(args, fname: Path):
file_errors.append(f"{fname}: Node {d.node} should be a definition but is not")

# Check correctness of Definition struct
match d.status:
case Original():
file_original_definitions += 1
case Discharged(original):
if not original.node.definition:
file_errors.append(f"{fname}: Discharged node of definition {d.name} "
f"is not a definition")
case Substituted(original):
if not original.node.definition:
file_errors.append(f"{fname}: Substituted node of definition {d.name} "
f"is not a definition")
if isinstance(d.status, Original):
file_original_definitions += 1
elif isinstance(d.status, Discharged):
if not d.status.original.node.definition:
file_errors.append(f"{fname}: Discharged node of definition {d.name} "
f"is not a definition")
elif isinstance(d.status, Substituted):
if not d.status.original.node.definition:
file_errors.append(f"{fname}: Substituted node of definition {d.name} "
f"is not a definition")

def check_in_global_context(s):
global_context = sub_global_contexts.global_context_set(d)
Expand Down
158 changes: 79 additions & 79 deletions pytact/graph_visualize_browse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os
import graphviz
from collections import defaultdict
from collections.abc import Sequence
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, List, Dict, Set, Tuple, Sequence
from pytact.data_reader import Dataset, Definition, Node, ProofState, Original, Discharged, Substituted
import html
import inflection
Expand Down Expand Up @@ -36,8 +36,8 @@ def root_folder(self) -> str:
@dataclass
class Settings:
no_defaults: bool = False # Should we use default settings?
ignore_edges: list[int] = field(default_factory=lambda: [])
unshare_nodes: list[int] = field(default_factory=lambda: [])
ignore_edges: List[int] = field(default_factory=lambda: [])
unshare_nodes: List[int] = field(default_factory=lambda: [])
show_trivial_evar_substs: bool = False
hide_proof_terms: bool = False
show_edge_labels: bool = False
Expand All @@ -56,9 +56,9 @@ def __post_init__(self):

@dataclass
class GraphVisualizationData:
data: dict[Path, Dataset]
trans_deps: dict[Path, set[Path]] = field(init=False)
graphid2path: list[Path] = field(init=False)
data: Dict[Path, Dataset]
trans_deps: Dict[Path, Set[Path]] = field(init=False)
graphid2path: List[Path] = field(init=False)

def __post_init__(self):
self.trans_deps = transitive_closure({d.filename: list(d.dependencies)
Expand All @@ -68,12 +68,12 @@ def __post_init__(self):
@dataclass
class GraphVisualizationOutput:
svg: str
location: list[tuple[str, str]] # tuple[Name, URL]
location: List[Tuple[str, str]] # tuple[Name, URL]
active_location: int
text: list[str] = field(default_factory=lambda: [])
popups: list[tuple[str, str]] = field(default_factory=lambda: []) # DOM id, text
text: List[str] = field(default_factory=lambda: [])
popups: List[Tuple[str, str]] = field(default_factory=lambda: []) # DOM id, text

def node_label_map(node: Node) -> tuple[str, str, str]:
def node_label_map(node: Node) -> Tuple[str, str, str]:
enum = apic.Graph_Node_Label_Which
label = node.label
if d := node.definition:
Expand All @@ -82,30 +82,30 @@ def node_label_map(node: Node) -> tuple[str, str, str]:
'box', name.split('.')[-1],
f"{inflection.camelize(node.label.definition.which.name.lower())} {d.name}"
)
match label.which:
case enum.SORT_PROP:
return 'ellipse', 'Prop', 'SortProp'
case enum.SORT_S_PROP:
return 'ellipse', 'SProp', 'SortSProp'
case enum.SORT_SET:
return 'ellipse', 'Set', 'SortSet'
case enum.SORT_TYPE:
return 'ellipse', 'Type', 'SortType'
case enum.REL:
return 'circle', '↑', 'rel'
case enum.PROD:
return 'circle', '∀', 'prod'
case enum.LAMBDA:
return 'circle', 'λ', 'lambda'
case enum.LET_IN:
return 'ellipse', 'let', 'LetIn'
case enum.APP:
return 'circle', '@', 'app'
case enum.CASE_BRANCH:
return 'ellipse', 'branch', 'CaseBranch'
case _:
name = inflection.camelize(label.which.name.lower())
return 'ellipse', name, name
which = label.which
if which == enum.SORT_PROP:
return 'ellipse', 'Prop', 'SortProp'
elif which == enum.SORT_S_PROP:
return 'ellipse', 'SProp', 'SortSProp'
elif which == enum.SORT_SET:
return 'ellipse', 'Set', 'SortSet'
elif which == enum.SORT_TYPE:
return 'ellipse', 'Type', 'SortType'
elif which == enum.REL:
return 'circle', '↑', 'rel'
elif which == enum.PROD:
return 'circle', '∀', 'prod'
elif which == enum.LAMBDA:
return 'circle', 'λ', 'lambda'
elif which == enum.LET_IN:
return 'ellipse', 'let', 'LetIn'
elif which == enum.APP:
return 'circle', '@', 'app'
elif which == enum.CASE_BRANCH:
return 'ellipse', 'branch', 'CaseBranch'
else:
name = inflection.camelize(label.which.name.lower())
return 'ellipse', name, name

def truncate_string(data, maximum):
return data[:(maximum-2)] + '..' if len(data) > maximum else data
Expand Down Expand Up @@ -165,8 +165,8 @@ def dot_apply_style(self, dot):
dot.attr('graph', concentrate='true')


def render_node(self, dot, node: Node, shape: str, label: str, id: str | None = None,
tooltip: str | None = None):
def render_node(self, dot, node: Node, shape: str, label: str, id: Union[str, None] = None,
tooltip: Union[str, None] = None):
if not id:
id = str(node)
if not tooltip:
Expand Down Expand Up @@ -197,27 +197,28 @@ def render_def(dot2, d: Definition):
if representative and representative.node == d.node:
label = "Representative: " + label
tooltip = make_tooltip(d)
match d.status:
case Original():
if isinstance(d.status, Original):
id = self.render_node(dot2, d.node, 'box', label, tooltip=tooltip)
elif isinstance(d.status, Discharged):
id = self.render_node(dot2, d.node, 'box', label, tooltip=tooltip)
target = d.status.original
dot.edge(id, repr(target.node),
arrowtail="inv", dir="both", constraint="false", style="dashed")
elif isinstance(d.status, Substituted):
target = d.status.original
if d.node.graph == target.node.graph:
id = self.render_node(dot2, d.node, 'box', label, tooltip=tooltip)
case Discharged(target):
id = self.render_node(dot2, d.node, 'box', label, tooltip=tooltip)
dot.edge(id, repr(target.node),
arrowtail="inv", dir="both", constraint="false", style="dashed")
case Substituted(target):
if d.node.graph == target.node.graph:
id = self.render_node(dot2, d.node, 'box', label, tooltip=tooltip)
dot.edge(id, str(target.node),
arrowtail="odot", dir="both", constraint="false", style="dashed")
else:
with dot2.subgraph() as dot3:
dot3.attr(rank='same')
id = self.render_node(dot3, d.node, 'box', label, tooltip=tooltip)
id2 = self.render_node(dot3, target.node, 'box',
make_label(module_name, target.name),
tooltip=make_tooltip(target))
dot.edge(id, id2,
arrowtail="odot", dir="both", constraint="false", style="dashed")
dot.edge(id, str(target.node),
arrowtail="odot", dir="both", constraint="false", style="dashed")
else:
with dot2.subgraph() as dot3:
dot3.attr(rank='same')
id = self.render_node(dot3, d.node, 'box', label, tooltip=tooltip)
id2 = self.render_node(dot3, target.node, 'box',
make_label(module_name, target.name),
tooltip=make_tooltip(target))
dot.edge(id, id2,
arrowtail="odot", dir="both", constraint="false", style="dashed")

for cluster in dataset.clustered_definitions():

Expand Down Expand Up @@ -249,10 +250,10 @@ def render_def(dot2, d: Definition):
location = self.path2location(fname)
return GraphVisualizationOutput(dot.source, location, len(location) - 1)

def visualize_term(self, dot, start: Node, depth, depth_ignore: set[Node] = set(),
max_nodes=100, seen: dict[str, str]|None=None,
def visualize_term(self, dot, start: Node, depth, depth_ignore: Set[Node] = set(),
max_nodes=100, seen: Union[Dict[str, str], None]=None,
node_label_map=node_label_map,
prefix='', before_prefix='', proof_state_prefix: dict[int, str] = {}
prefix='', before_prefix='', proof_state_prefix: Dict[int, str] = {}
) -> str:
if seen == None:
seen = {}
Expand All @@ -262,17 +263,17 @@ def recurse(node: Node, depth, context_prefix):
nonlocal nodes_left

enum = graph_api_capnp.Graph.Node.Label
match node.label.which:
case enum.proofState:
node_prefix = context_prefix
case enum.contextAssum:
node_prefix = context_prefix
case enum.contextDef:
node_prefix = context_prefix
case enum.evarSubst:
node_prefix = prefix + context_prefix
case _:
node_prefix = prefix
which = node.label.which
if which == enum.proofState:
node_prefix = context_prefix
elif which == enum.contextAssum:
node_prefix = context_prefix
elif which == enum.contextDef:
node_prefix = context_prefix
elif which == enum.evarSubst:
node_prefix = prefix + context_prefix
else:
node_prefix = prefix
id = node_prefix + str(node)
if id in seen:
return seen[id]
Expand All @@ -294,7 +295,6 @@ def recurse(node: Node, depth, context_prefix):
# Find the evar-id
evarid = [c.label.proof_state.value for _, c in node.children
if c.label.which == graph_api_capnp.Graph.Node.Label.proofState][0]
print(evarid)
context_prefix = proof_state_prefix.get(evarid, context_prefix)

for edge, child in node.children:
Expand Down Expand Up @@ -431,14 +431,14 @@ def node_label_map_with_ctx_names(context: Sequence[Node],
mapping = {n: s for n, s in zip(context, context_text)}
def nlm(node: Node):
enum = graph_api_capnp.Graph.Node.Label
match node.label.which:
case enum.contextAssum:
name = graphviz_escape(mapping[node])
return 'ellipse', truncate_string(name, 20), f"ContextAssum {name}"
case enum.contextDef:
name = graphviz_escape(mapping[node])
return 'ellipse', truncate_string(name, 20), f"ContextDef {name}"
case _:
which = node.label.which
if which == enum.contextAssum:
name = graphviz_escape(mapping[node])
return 'ellipse', truncate_string(name, 20), f"ContextAssum {name}"
elif which == enum.contextDef:
name = graphviz_escape(mapping[node])
return 'ellipse', truncate_string(name, 20), f"ContextDef {name}"
else:
return node_label_map(node)
return nlm

Expand Down
3 changes: 2 additions & 1 deletion pytact/oracle_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import socketserver
import argparse
import contextlib
from typing import Union, Tuple
from pytact.data_reader import (data_reader, Original, capnp_message_generator, ProofState,
TacticPredictionGraph, TacticPredictionsGraph,
TacticPredictionText, TacticPredictionsText,
Expand All @@ -20,7 +21,7 @@ class LocalArgument:
@dataclass(eq=True, frozen=True)
class OracleTactic:
tactic_id: int
arguments: tuple[GlobalArgument | LocalArgument, ...]
arguments: Tuple[Union[GlobalArgument, LocalArgument], ...]
clean: bool

def text_prediction_loop(text_oracle_data, context: GlobalContextMessage):
Expand Down

0 comments on commit ab9a090

Please sign in to comment.