diff --git a/zxlive/app.py b/zxlive/app.py index 6fc133ce..ca2e0f22 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -24,7 +24,7 @@ from .mainwindow import MainWindow from .common import get_data, GraphT -from typing import Optional +from typing import Optional, cast # The following hack is needed on windows in order to show the icon in the taskbar # See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105 @@ -73,8 +73,9 @@ def edit_graph(self, g: GraphT, name: str) -> None: self.main_window.show() self.main_window.open_graph_from_notebook(g, name) - def get_copy_of_graph(self, name: str) -> GraphT: + def get_copy_of_graph(self, name: str) -> Optional[GraphT]: """Returns a copy of the graph which has the given name.""" + assert self.main_window return self.main_window.get_copy_of_graph(name) @@ -82,7 +83,7 @@ def get_embedded_app() -> ZXLive: """Main entry point for ZXLive as an embedded app inside a jupyter notebook.""" app = QApplication.instance() or ZXLive() app.__class__ = ZXLive - return app + return cast(ZXLive, app) def main() -> None: diff --git a/zxlive/commands.py b/zxlive/commands.py index 86728568..82e96125 100644 --- a/zxlive/commands.py +++ b/zxlive/commands.py @@ -4,7 +4,7 @@ from collections import namedtuple from dataclasses import dataclass, field from fractions import Fraction -from typing import Iterable, Optional, Set, Union +from typing import Iterable, Optional, Set, Union, Callable from PySide6.QtCore import QModelIndex from PySide6.QtGui import QUndoCommand diff --git a/zxlive/common.py b/zxlive/common.py index 76e5644f..1efb1f1d 100644 --- a/zxlive/common.py +++ b/zxlive/common.py @@ -94,14 +94,15 @@ class ToolType(IntEnum): for key, value in defaults.items(): if not settings.contains(key): settings.setValue(key, value) - + class Settings(object): SNAP_DIVISION = 4 # Should be an integer dividing SCALE - def __init__(self): + + def __init__(self) -> None: self.update() - def update(self): + def update(self) -> None: settings = QSettings("zxlive", "zxlive") self.SNAP_DIVISION = int(settings.value("snap-granularity")) self.SNAP = SCALE / self.SNAP_DIVISION diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 38a1b71f..95bc12d0 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -1,6 +1,7 @@ import json -from typing import TYPE_CHECKING, Callable +from fractions import Fraction +from typing import TYPE_CHECKING, Callable, Sequence, Dict, Union import networkx as nx import numpy as np @@ -11,13 +12,14 @@ from pyzx.utils import EdgeType, VertexType, get_w_io from shapely import Polygon -from pyzx.symbolic import Poly +from pyzx.symbolic import Poly, Var from .common import ET, VT, GraphT if TYPE_CHECKING: from .rewrite_data import RewriteData + class CustomRule: def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: str) -> None: lhs_graph.auto_detect_io() @@ -91,15 +93,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu return etab, vertices_to_remove, [], True - def unfuse_subgraph_for_rewrite(self, graph, vertices): - def get_adjacent_boundary_vertices(graph, v): - return [n for n in graph.neighbors(v) if graph.nodes()[n]['type'] == VertexType.BOUNDARY] + def unfuse_subgraph_for_rewrite(self, graph, vertices) -> None: + def get_adjacent_boundary_vertices(g, v) -> Sequence[VT]: + return [n for n in g.neighbors(v) if g.nodes()[n]['type'] == VertexType.BOUNDARY] subgraph_nx_without_boundaries = nx.Graph(to_networkx(graph).subgraph(vertices)) lhs_vertices = [v for v in self.lhs_graph.vertices() if self.lhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY] lhs_graph_nx = nx.Graph(self.lhs_graph_nx.subgraph(lhs_vertices)) graph_matcher = GraphMatcher(lhs_graph_nx, subgraph_nx_without_boundaries, - node_match=categorical_node_match('type', 1)) + node_match=categorical_node_match('type', 1)) matching = list(graph_matcher.match())[0] subgraph_nx, _ = create_subgraph(graph, vertices) @@ -119,26 +121,26 @@ def get_adjacent_boundary_vertices(graph, v): elif vtype == VertexType.W_OUTPUT or vtype == VertexType.W_INPUT: self.unfuse_w_vertex(graph, subgraph_nx, matching[v], vtype) - def unfuse_update_edges(self, graph, subgraph_nx, old_v, new_v): + def unfuse_update_edges(self, graph, subgraph_nx, old_v, new_v) -> None: neighbors = list(graph.neighbors(old_v)) for b in neighbors: if b not in subgraph_nx.nodes: graph.add_edge((new_v, b), graph.edge_type((old_v, b))) graph.remove_edge(graph.edge(old_v, b)) - def unfuse_zx_vertex(self, graph, subgraph_nx, v, vtype): + def unfuse_zx_vertex(self, graph, subgraph_nx, v, vtype) -> None: new_v = graph.add_vertex(vtype, qubit=graph.qubit(v), row=graph.row(v)) self.unfuse_update_edges(graph, subgraph_nx, v, new_v) graph.add_edge(graph.edge(new_v, v)) - def unfuse_h_box_vertex(self, graph, subgraph_nx, v): + def unfuse_h_box_vertex(self, graph, subgraph_nx, v) -> None: new_h = graph.add_vertex(VertexType.H_BOX, qubit=graph.qubit(v)+0.3, row=graph.row(v)+0.3) new_mid_h = graph.add_vertex(VertexType.H_BOX, qubit=graph.qubit(v), row=graph.row(v)) self.unfuse_update_edges(graph, subgraph_nx, v, new_h) graph.add_edge((new_mid_h, v)) graph.add_edge((new_h, new_mid_h)) - def unfuse_w_vertex(self, graph, subgraph_nx, v, vtype): + def unfuse_w_vertex(self, graph, subgraph_nx, v, vtype) -> None: w_in, w_out = get_w_io(graph, v) new_w_in = graph.add_vertex(VertexType.W_INPUT, qubit=graph.qubit(w_in), row=graph.row(w_in)) new_w_out = graph.add_vertex(VertexType.W_OUTPUT, qubit=graph.qubit(w_out), row=graph.row(w_out)) @@ -229,16 +231,16 @@ def get_linear(v): return coeff, var, const -def match_symbolic_parameters(match, left, right): - params = {} +def match_symbolic_parameters(match, left: nx.Graph, right: nx.Graph) -> Dict[Var, Union[float, complex, Fraction]]: + params: Dict[Var, Union[float, complex, Fraction]] = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) - def check_phase_equality(v): + def check_phase_equality(v) -> None: if left_phase[v] != right_phase[match[v]]: raise ValueError("Parameters do not match") - def update_params(v, var, coeff, const): + def update_params(v, var, coeff, const) -> None: var_value = (right_phase[match[v]] - const) / coeff if var in params and params[var] != var_value: raise ValueError("Symbolic parameters do not match") diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 3063b9e6..9724a418 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -190,15 +190,15 @@ def vert_double_clicked(self, v: VT) -> None: if len(graph.variable_types) != len(old_variables): new_vars = graph.variable_types.keys() - old_variables.keys() #self.graph.variable_types.update(graph.variable_types) - for v in new_vars: - self.variable_viewer.add_item(v) + for nv in new_vars: + self.variable_viewer.add_item(nv) class VariableViewer(QScrollArea): def __init__(self, parent: EditorBasePanel) -> None: super().__init__() - self.parent = parent + self.parent_panel = parent self._widget = QWidget() lpal = QApplication.palette("QListWidget") # type: ignore palette = QPalette() @@ -238,7 +238,7 @@ def __init__(self, parent: EditorBasePanel) -> None: self._layout.addItem(QSpacerItem(0, 0, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding), 2, 2) - for name in self.parent.graph.variable_types.keys(): + for name in self.parent_panel.graph.variable_types.keys(): self.add_item(name) self.setWidget(self._widget) @@ -259,7 +259,7 @@ def sizeHint(self) -> QSize: def add_item(self, name: str) -> None: combobox = QComboBox() combobox.insertItems(0, ["Parametric", "Boolean"]) - if self.parent.graph.variable_types[name]: + if self.parent_panel.graph.variable_types[name]: combobox.setCurrentIndex(1) else: combobox.setCurrentIndex(0) @@ -280,9 +280,9 @@ def add_item(self, name: str) -> None: def _text_changed(self, name: str, text: str) -> None: if text == "Parametric": - self.parent.graph.variable_types[name] = False + self.parent_panel.graph.variable_types[name] = False elif text == "Boolean": - self.parent.graph.variable_types[name] = True + self.parent_panel.graph.variable_types[name] = True def toolbar_select_node_edge(parent: EditorBasePanel) -> ToolbarSection: @@ -315,7 +315,7 @@ def toolbar_select_node_edge(parent: EditorBasePanel) -> ToolbarSection: def create_list_widget(parent: EditorBasePanel, data: dict[VertexType.Type, DrawPanelNodeType] | dict[EdgeType.Type, DrawPanelNodeType], onclick: Callable[[VertexType.Type], None] | Callable[[EdgeType.Type], None], - ondoubleclick: Callable[[VertexType.Type, None] | Callable[[EdgeType.Type], None]]) \ + ondoubleclick: Callable[[VertexType.Type], None] | Callable[[EdgeType.Type], None]) \ -> QListWidget: list_widget = QListWidget(parent) list_widget.setResizeMode(QListView.ResizeMode.Adjust) @@ -333,7 +333,7 @@ def create_list_widget(parent: EditorBasePanel, def populate_list_widget(list_widget: QListWidget, data: dict[VertexType.Type, DrawPanelNodeType] | dict[EdgeType.Type, DrawPanelNodeType], onclick: Callable[[VertexType.Type], None] | Callable[[EdgeType.Type], None], - ondoubleclick: Callable[[VertexType.Type, None] | Callable[[EdgeType.Type], None]]) \ + ondoubleclick: Callable[[VertexType.Type], None] | Callable[[EdgeType.Type], None]) \ -> None: row = list_widget.currentRow() list_widget.clear() diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index 1176b1e7..9361de6d 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -16,7 +16,7 @@ from __future__ import annotations import copy -from typing import Callable, Optional +from typing import Callable, Optional, cast from PySide6.QtCore import (QByteArray, QEvent, QFile, QFileInfo, QIODevice, QSettings, QTextStream, Qt) @@ -443,7 +443,7 @@ def new_graph(self, graph: Optional[GraphT] = None, name: Optional[str] = None) if name is None: name = "New Graph" self._new_panel(panel, name) - def open_graph_from_notebook(self, graph: GraphT, name: str = None) -> None: + def open_graph_from_notebook(self, graph: GraphT, name: str) -> None: """Opens a ZXLive window from within a notebook to edit a graph. Replaces the graph in an existing tab if it has the same name.""" @@ -451,15 +451,17 @@ def open_graph_from_notebook(self, graph: GraphT, name: str = None) -> None: for i in range(self.tab_widget.count()): if self.tab_widget.tabText(i) == name or self.tab_widget.tabText(i) == name + "*": self.tab_widget.setCurrentIndex(i) + assert self.active_panel self.active_panel.replace_graph(graph) return self.new_graph(copy.deepcopy(graph), name) - def get_copy_of_graph(self, name: str): + def get_copy_of_graph(self, name: str) -> Optional[GraphT]: # TODO: handle multiple tabs with the same name somehow for i in range(self.tab_widget.count()): if self.tab_widget.tabText(i) == name or self.tab_widget.tabText(i) == name + "*": - return copy.deepcopy(self.tab_widget.widget(i).graph_scene.g) + panel = cast(BasePanel, self.tab_widget.widget(i)) + return cast(GraphT, copy.deepcopy(panel.graph_scene.g)) return None def new_rule_editor(self, rule: Optional[CustomRule] = None, name: Optional[str] = None) -> None: diff --git a/zxlive/proof.py b/zxlive/proof.py index dca41c37..af2efa9f 100644 --- a/zxlive/proof.py +++ b/zxlive/proof.py @@ -122,7 +122,7 @@ def get_graph(self, index: int) -> GraphT: assert isinstance(copy, GraphT) # type: ignore return copy - def rename_step(self, index: int, name: str): + def rename_step(self, index: int, name: str) -> None: """Change the display name""" old_step = self.steps[index] diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index b253b81d..9427ce52 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -65,11 +65,11 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: self.step_view.setCurrentIndex(self.proof_model.index(0, 0)) self.step_view.selectionModel().selectionChanged.connect(self._proof_step_selected) self.step_view.viewport().setAttribute(Qt.WidgetAttribute.WA_Hover) - self.step_view.doubleClicked.connect(self.__doubleClickHandler) + self.step_view.doubleClicked.connect(self._double_click_handler) self.splitter.addWidget(self.step_view) - def __doubleClickHandler(self, index: QModelIndex | QPersistentModelIndex): + def _double_click_handler(self, index: QModelIndex | QPersistentModelIndex) -> None: # The first row in the item list is the START step, which is not interactive if index.row() == 0: return @@ -79,7 +79,7 @@ def __doubleClickHandler(self, index: QModelIndex | QPersistentModelIndex): if ok: # Subtract 1 from index since the START step isn't part of the model old_name = self.proof_model.steps[index.row()-1].display_name - cmd = UndoableChange(self, + cmd = UndoableChange(self.graph_view, lambda: self.proof_model.rename_step(index.row()-1, old_name), lambda: self.proof_model.rename_step(index.row()-1, new_name) ) @@ -300,7 +300,7 @@ def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> Non ).normalized() perp_dir = QVector2D(mouse_dir - QPointF(self.graph.row(v)/SCALE, self.graph.qubit(v)/SCALE)).normalized() - perp_dir -= QVector2D.dotProduct(perp_dir, par_dir) * par_dir + perp_dir -= par_dir * QVector2D.dotProduct(perp_dir, par_dir) perp_dir.normalize() out_offset_x = par_dir.x() * 0.5 + perp_dir.x() * 0.5 @@ -328,7 +328,8 @@ def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> Non cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "unfuse") self.undo_stack.push(cmd, anim_after=anim) - def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: FractionLike) -> None: + def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: Union[FractionLike, complex]) -> \ + None: def snap_vector(v: QVector2D) -> None: if abs(v.x()) > abs(v.y()): v.setY(0.0) diff --git a/zxlive/rewrite_action.py b/zxlive/rewrite_action.py index caf18906..4c0f945f 100644 --- a/zxlive/rewrite_action.py +++ b/zxlive/rewrite_action.py @@ -2,10 +2,10 @@ import copy from dataclasses import dataclass, field -from typing import Callable, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING, Iterable, Any, Optional, cast, Union import pyzx -from PySide6.QtCore import Qt, QAbstractItemModel, QModelIndex +from PySide6.QtCore import Qt, QAbstractItemModel, QModelIndex, QPersistentModelIndex from .animations import make_animation from .commands import AddRewriteStep @@ -58,15 +58,16 @@ def do_rewrite(self, panel: ProofPanel) -> None: try: g, rem_verts = self.apply_rewrite(g, matches) - except Exception as e: - show_error_msg('Error while applying rewrite rule', str(e)) + except Exception as ex: + show_error_msg('Error while applying rewrite rule', str(ex)) return cmd = AddRewriteStep(panel.graph_view, g, panel.step_view, self.name) anim_before, anim_after = make_animation(self, panel, g, matches, rem_verts) panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after) - def apply_rewrite(self, g: GraphT, matches: list): + # TODO: Narrow down the type of the first return value. + def apply_rewrite(self, g: GraphT, matches: list) -> tuple[Any, Optional[Iterable[VT]]]: if self.returns_new_graph: return self.rule(g, matches), None @@ -128,14 +129,14 @@ def enabled(self) -> bool: def from_dict(cls, d: dict, header: str = "", parent: RewriteActionTree | None = None) -> RewriteActionTree: if is_rewrite_data(d): return RewriteActionTree( - header, RewriteAction.from_rewrite_data(d), [], parent + header, RewriteAction.from_rewrite_data(cast(RewriteData, d)), [], parent ) ret = RewriteActionTree(header, None, [], parent) for group, actions in d.items(): ret.append_child(cls.from_dict(actions, group, ret)) return ret - def update_on_selection(self, g, selection, edges): + def update_on_selection(self, g, selection, edges) -> None: for child in self.child_items: child.update_on_selection(g, selection, edges) if self.rewrite is not None: @@ -157,13 +158,14 @@ def from_dict(cls, d: dict, proof_panel: ProofPanel): proof_panel ) - def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex: + def index(self, row: int, column: int, parent: Union[QModelIndex, QPersistentModelIndex] = QModelIndex()) -> \ + QModelIndex: if not self.hasIndex(row, column, parent): return QModelIndex() - parentItem = parent.internalPointer() if parent.isValid() else self.root_item + parent_item = parent.internalPointer() if parent.isValid() else self.root_item - if childItem := parentItem.child(row): + if childItem := parent_item.child(row): return self.createIndex(row, column, childItem) return QModelIndex() @@ -171,18 +173,18 @@ def parent(self, index: QModelIndex = None) -> QModelIndex: if not index.isValid(): return QModelIndex() - parentItem = index.internalPointer().parent + parent_item = index.internalPointer().parent - if parentItem == self.root_item: + if parent_item == self.root_item: return QModelIndex() - return self.createIndex(parentItem.row(), 0, parentItem) + return self.createIndex(parent_item.row(), 0, parent_item) def rowCount(self, parent: QModelIndex = None) -> int: if parent.column() > 0: return 0 - parentItem = parent.internalPointer() if parent.isValid() else self.root_item - return parentItem.child_count() + parent_item = parent.internalPointer() if parent.isValid() else self.root_item + return parent_item.child_count() def columnCount(self, parent: QModelIndex = None) -> int: return 1 diff --git a/zxlive/rewrite_data.py b/zxlive/rewrite_data.py index 76de351e..07d0e32d 100644 --- a/zxlive/rewrite_data.py +++ b/zxlive/rewrite_data.py @@ -2,7 +2,8 @@ import copy import os -from typing import Callable, Literal, TypedDict +from typing import Callable, Literal, cast, Optional +from typing_extensions import TypedDict, NotRequired import pyzx from pyzx import simplify, extract_circuit @@ -25,8 +26,8 @@ class RewriteData(TypedDict): rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]] type: MatchType tooltip: str - copy_first: bool | None - returns_new_graph: bool | None + copy_first: NotRequired[bool] + returns_new_graph: NotRequired[bool] def is_rewrite_data(d: dict) -> bool: @@ -79,7 +80,7 @@ def read_custom_rules() -> list[RewriteData]: const_true = lambda graph, matches: matches -def apply_simplification(simplification: Callable[[GraphT], GraphT]) -> Callable[ +def apply_simplification(simplification: Callable[[GraphT], Optional[int]]) -> Callable[ [GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]]: def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET, VT]: simplification(g) @@ -91,7 +92,7 @@ def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET, VT]: def _extract_circuit(graph: GraphT, matches: list) -> GraphT: graph.auto_detect_io() simplify.full_reduce(graph) - return extract_circuit(graph).to_graph() + return cast(GraphT, extract_circuit(graph).to_graph()) simplifications: dict[str, RewriteData] = { diff --git a/zxlive/settings_dialog.py b/zxlive/settings_dialog.py index 089d772b..92f8d0bb 100644 --- a/zxlive/settings_dialog.py +++ b/zxlive/settings_dialog.py @@ -15,7 +15,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Any, Optional +from typing import TYPE_CHECKING, Dict, Any, Optional, Union from PySide6.QtCore import QSettings from PySide6.QtWidgets import (QDialog, QFileDialog, @@ -186,7 +186,7 @@ def __init__(self, main_window: MainWindow) -> None: def add_setting(self,form:QFormLayout, name:str, label:str, ty:str, data: Optional[dict[str, str]] = None) -> None: val = self.settings.value(name) - widget: QWidget + widget: Union[QWidget, QLineEdit, QSpinBox, QDoubleSpinBox, QComboBox] if val is None: val = defaults[name] if ty == 'str': widget = QLineEdit() @@ -203,8 +203,7 @@ def add_setting(self,form:QFormLayout, name:str, label:str, ty:str, data: Option hlayout = QHBoxLayout() widget.setLayout(hlayout) widget_line = QLineEdit() - val = str(val) - widget_line.setText(val) + widget_line.setText(str(val)) def browse() -> None: directory = QFileDialog.getExistingDirectory(self,"Pick folder",options=QFileDialog.Option.ShowDirsOnly) if directory: @@ -218,7 +217,7 @@ def browse() -> None: widget = QComboBox() assert data is not None widget.addItems(list(data.values())) - widget.setCurrentText(data[val]) + widget.setCurrentText(data[str(val)]) setattr(widget, "data", data)