diff --git a/pyproject.toml b/pyproject.toml index 4bfbc770..a02cd432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ dependencies = [ "PySide6", "pyzx @ git+https://github.com/Quantomatic/pyzx.git", "sympy>=1.12", + "networkx", + "numpy", + "shapely", ] [project.optional-dependencies] diff --git a/zxlive/dialogs.py b/zxlive/dialogs.py index 0895068b..e3b5bd9d 100644 --- a/zxlive/dialogs.py +++ b/zxlive/dialogs.py @@ -5,9 +5,11 @@ from dataclasses import dataclass from PySide6.QtCore import QFile, QIODevice, QTextStream -from PySide6.QtWidgets import QWidget, QFileDialog, QMessageBox +from PySide6.QtWidgets import QWidget, QFileDialog, QMessageBox, QDialog, QFormLayout, QLineEdit, QTextEdit, QPushButton, QDialogButtonBox +import numpy as np from pyzx import Circuit, extract_circuit from pyzx.graph.base import BaseGraph +from zxlive import proof_actions from zxlive.proof import ProofModel @@ -186,7 +188,6 @@ def export_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[Tuple[str, return file_path, selected_format - def export_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[Tuple[str, FileFormat]]: file_path_and_format = get_file_path_and_format(parent, FileFormat.ZXProof.filter) if file_path_and_format is None or not file_path_and_format[0]: @@ -196,3 +197,52 @@ def export_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[Tu if not write_to_file(file_path, data): return None return file_path, selected_format + +def create_new_rewrite(parent) -> None: + dialog = QDialog() + parent.rewrite_form = QFormLayout(dialog) + name = QLineEdit() + parent.rewrite_form.addRow("Name", name) + description = QTextEdit() + parent.rewrite_form.addRow("Description", description) + left_button = QPushButton("Left-hand side of the rule") + right_button = QPushButton("Right-hand side of the rule") + parent.left_graph = None + parent.right_graph = None + def get_file(self, button, side) -> None: + out = import_diagram_dialog(self) + if out is not None: + button.setText(out.file_path) + if side == "left": + self.left_graph = out.g + else: + self.right_graph = out.g + left_button.clicked.connect(lambda: get_file(parent, left_button, "left")) + right_button.clicked.connect(lambda: get_file(parent, right_button, "right")) + parent.rewrite_form.addRow(left_button) + parent.rewrite_form.addRow(right_button) + button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + parent.rewrite_form.addRow(button_box) + def add_rewrite() -> None: + if parent.left_graph is None or parent.right_graph is None: + return + parent.left_graph.auto_detect_io() + parent.right_graph.auto_detect_io() + left_matrix, right_matrix = parent.left_graph.to_matrix(), parent.right_graph.to_matrix() + if not np.allclose(left_matrix, right_matrix): + if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") + else: + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") + rewrite = proof_actions.ProofAction.from_dict({ + "text":name.text(), + "tooltip":description.toPlainText(), + "matcher": proof_actions.create_custom_matcher(parent.left_graph), + "rule": proof_actions.create_custom_rule(parent.left_graph, parent.right_graph), + "type": proof_actions.MATCHES_VERTICES, + }) + proof_actions.rewrites.append(rewrite) + dialog.accept() + button_box.accepted.connect(add_rewrite) + button_box.rejected.connect(dialog.reject) + if not dialog.exec(): return diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index adfee8e3..96acb82e 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -166,6 +166,14 @@ def _vert_moved(self, vs: list[tuple[VT, float, float]]) -> None: def _vert_double_clicked(self, v: VT) -> None: if self.graph.type(v) == VertexType.BOUNDARY: + input_, ok = QInputDialog.getText( + self, "Input Dialog", "Enter Qubit Index:" + ) + try: + input_ = int(input_.strip()) + self.graph.set_qubit(v, input_) + except ValueError: + show_error_msg("Wrong Input Type", "Please enter a valid input (e.g. 1, 2)") return input_, ok = QInputDialog.getText( diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index c5ce4fd8..bf3790bc 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -19,17 +19,16 @@ import copy from PySide6.QtCore import QFile, QFileInfo, QTextStream, QIODevice, QSettings, QByteArray, QEvent -from PySide6.QtGui import QAction, QShortcut, QKeySequence, QCloseEvent -from PySide6.QtWidgets import QMessageBox, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, QFileDialog, QSizePolicy +from PySide6.QtGui import QAction, QKeySequence, QCloseEvent +from PySide6.QtWidgets import QMessageBox, QMainWindow, QWidget, QVBoxLayout, QTabWidget from pyzx.graph.base import BaseGraph from .commands import AddRewriteStep - from .base_panel import BasePanel from .edit_panel import GraphEditPanel from .proof_panel import ProofPanel from .construct import * -from .dialogs import ImportGraphOutput, export_proof_dialog, import_diagram_dialog, export_diagram_dialog, show_error_msg, FileFormat +from .dialogs import ImportGraphOutput, create_new_rewrite, export_proof_dialog, import_diagram_dialog, export_diagram_dialog, show_error_msg, FileFormat from .common import GraphT from pyzx import Graph, simplify, Circuit @@ -83,13 +82,13 @@ def __init__(self) -> None: close_action = self._new_action("Close", self.close_action, QKeySequence.StandardKey.Close, "Closes the window") close_action.setShortcuts([QKeySequence(QKeySequence.StandardKey.Close), QKeySequence("Ctrl+W")]) - # TODO: We should remember if we have saved the diagram before, + # TODO: We should remember if we have saved the diagram before, # and give an open to overwrite this file with a Save action save_file = self._new_action("&Save", self.save_file, QKeySequence.StandardKey.Save, "Save the diagram by overwriting the previous loaded file.") save_as = self._new_action("Save &as...", self.save_as, QKeySequence.StandardKey.SaveAs, "Opens a file-picker dialog to save the diagram in a chosen file format") - + file_menu = menu.addMenu("&File") file_menu.addAction(new_graph) file_menu.addAction(open_file) @@ -143,6 +142,10 @@ def __init__(self) -> None: view_menu.addAction(zoom_out) view_menu.addAction(fit_view) + new_rewrite = self._new_action("Create new rewrite", lambda: create_new_rewrite(self), None, "Create a new rewrite") + rewrite_menu = menu.addMenu("&Rewrite") + rewrite_menu.addAction(new_rewrite) + simplify_actions = [] for simp in simplifications.values(): simplify_actions.append(self._new_action(simp["text"], self.apply_pyzx_reduction(simp), None, simp["tool_tip"])) @@ -174,7 +177,7 @@ def active_panel(self) -> Optional[BasePanel]: def closeEvent(self, e: QCloseEvent) -> None: while self.active_panel is not None: # We close all the tabs and ask the user if they want to save progress success = self.close_action() - if not success: + if not success: e.ignore() # Abort the closing return @@ -229,7 +232,7 @@ def close_action(self) -> bool: self.close() if not self.active_panel.undo_stack.isClean(): name = self.tab_widget.tabText(i).replace("*","") - answer = QMessageBox.question(self, "Save Changes", + answer = QMessageBox.question(self, "Save Changes", f"Do you wish to save your changes to {name} before closing?", QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No | QMessageBox.StandardButton.Cancel) if answer == QMessageBox.StandardButton.Cancel: return False @@ -350,7 +353,7 @@ def reduce() -> None: cmd = AddRewriteStep(self.active_panel.graph_view, new_graph, self.active_panel.step_view, reduction["text"]) self.active_panel.undo_stack.push(cmd) return reduce - + class SimpEntry(TypedDict): text: str diff --git a/zxlive/proof_actions.py b/zxlive/proof_actions.py index 829b47b5..ecde7334 100644 --- a/zxlive/proof_actions.py +++ b/zxlive/proof_actions.py @@ -1,14 +1,19 @@ import copy from dataclasses import dataclass, field, replace -from typing import Callable, Literal, List, Optional, Final, TYPE_CHECKING - -from PySide6.QtWidgets import QPushButton, QButtonGroup +from typing import Callable, Literal, List, Optional, TYPE_CHECKING +import networkx as nx +from networkx.algorithms.isomorphism import GraphMatcher, categorical_node_match +import numpy as np import pyzx +from pyzx.utils import VertexType, EdgeType +from shapely import Polygon + +from PySide6.QtWidgets import QPushButton, QButtonGroup -from .commands import AddRewriteStep -from .common import VT,ET, GraphT from . import animations as anims +from .commands import AddRewriteStep +from .common import ET, Graph, GraphT, VT if TYPE_CHECKING: from .proof_panel import ProofPanel @@ -25,17 +30,17 @@ @dataclass class ProofAction(object): name: str - matcher: Callable[[GraphT,Callable],List] - rule: Callable[[GraphT,List],pyzx.rules.RewriteOutputType[ET,VT]] + matcher: Callable[[GraphT, Callable], List] + rule: Callable[[GraphT, List], pyzx.rules.RewriteOutputType[ET,VT]] match_type: MatchType tooltip: str button: Optional[QPushButton] = field(default=None, init=False) @classmethod def from_dict(cls, d: dict) -> "ProofAction": - return cls(d['text'],d['matcher'],d['rule'],d['type'],d['tooltip']) + return cls(d['text'], d['matcher'], d['rule'], d['type'], d['tooltip']) - def do_rewrite(self,panel: "ProofPanel") -> None: + def do_rewrite(self, panel: "ProofPanel") -> None: verts, edges = panel.parse_selection() g = copy.deepcopy(panel.graph_scene.g) @@ -112,17 +117,116 @@ def rewriter() -> None: return rewriter for action in self.actions: if action.button is not None: continue - btn = QPushButton(action.name,parent) + btn = QPushButton(action.name, parent) btn.setMaximumWidth(150) btn.setStatusTip(action.tooltip) btn.setEnabled(False) - btn.clicked.connect(create_rewrite(action,parent)) + btn.clicked.connect(create_rewrite(action, parent)) self.btn_group.addButton(btn) action.button = btn def update_active(self, g: GraphT, verts: List[VT], edges: List[ET]) -> None: for action in self.actions: - action.update_active(g,verts,edges) + action.update_active(g, verts, edges) + + +def to_networkx(graph: Graph) -> nx.Graph: + G = nx.Graph() + v_data = {v: {"type": graph.type(v), + "phase": graph.phase(v),} + for v in graph.vertices()} + for i, input_vertex in enumerate(graph.inputs()): + v_data[input_vertex]["boundary_index"] = f'input_{i}' + for i, output_vertex in enumerate(graph.outputs()): + v_data[output_vertex]["boundary_index"] = f'output_{i}' + G.add_nodes_from([(v, v_data[v]) for v in graph.vertices()]) + G.add_edges_from([(*v, {"type": graph.edge_type(v)}) for v in graph.edges()]) + return G + +def create_subgraph(graph: Graph, verts: List[VT]) -> nx.Graph: + graph_nx = to_networkx(graph) + subgraph_nx = nx.Graph(graph_nx.subgraph(verts)) + boundary_mapping = {} + i = 0 + for v in verts: + for vn in graph.neighbors(v): + if vn not in verts: + boundary_node = 'b' + str(i) + boundary_mapping[boundary_node] = vn + subgraph_nx.add_node(boundary_node, type=VertexType.BOUNDARY) + subgraph_nx.add_edge(v, boundary_node, type=EdgeType.SIMPLE) + i += 1 + return subgraph_nx, boundary_mapping + +def custom_matcher(graph: Graph, in_selection: Callable[[VT], bool], lhs_graph: nx.Graph) -> List[VT]: + verts = [v for v in graph.vertices() if in_selection(v)] + subgraph_nx, _ = create_subgraph(graph, verts) + graph_matcher = GraphMatcher(lhs_graph, subgraph_nx,\ + node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) + if graph_matcher.is_isomorphic(): + return verts + return [] + +def custom_rule(graph: Graph, vertices: List[VT], lhs_graph: nx.Graph, rhs_graph: nx.Graph) -> pyzx.rules.RewriteOutputType[ET,VT]: + subgraph_nx, boundary_mapping = create_subgraph(graph, vertices) + graph_matcher = GraphMatcher(lhs_graph, subgraph_nx,\ + node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) + matching = list(graph_matcher.match())[0] + + vertices_to_remove = [] + for v in matching: + if subgraph_nx.nodes()[matching[v]]['type'] != VertexType.BOUNDARY: + vertices_to_remove.append(matching[v]) + + boundary_vertex_map = {} + for v in rhs_graph.nodes(): + if rhs_graph.nodes()[v]['type'] == VertexType.BOUNDARY: + for x, data in lhs_graph.nodes(data=True): + if data['type'] == VertexType.BOUNDARY and \ + data['boundary_index'] == rhs_graph.nodes()[v]['boundary_index']: + boundary_vertex_map[v] = boundary_mapping[matching[x]] + break + + vertex_positions = get_vertex_positions(graph, rhs_graph, boundary_vertex_map) + vertex_map = boundary_vertex_map + for v in rhs_graph.nodes(): + if rhs_graph.nodes()[v]['type'] != VertexType.BOUNDARY: + vertex_map[v] = graph.add_vertex(ty = rhs_graph.nodes()[v]['type'], + row = vertex_positions[v][0], + qubit = vertex_positions[v][1], + phase = rhs_graph.nodes()[v]['phase'],) + + # create etab to add edges + etab = {} + for v1, v2, data in rhs_graph.edges(data=True): + v1 = vertex_map[v1] + v2 = vertex_map[v2] + if (v1, v2) not in etab: etab[(v1, v2)] = [0, 0] + etab[(v1, v2)][data['type']-1] += 1 + + return etab, vertices_to_remove, [], True + +def get_vertex_positions(graph, rhs_graph, boundary_vertex_map): + pos_dict = {v: (graph.row(m), graph.qubit(m)) for v, m in boundary_vertex_map.items()} + coords = np.array(list(pos_dict.values())) + center = np.mean(coords, axis=0) + angles = np.arctan2(coords[:,1]-center[1], coords[:,0]-center[0]) + coords = coords[np.argsort(-angles)] + try: + area = Polygon(coords).area + except: + area = 1 + k = (area ** 0.5) / len(rhs_graph) + return nx.spring_layout(rhs_graph, k=k, pos=pos_dict, fixed=boundary_vertex_map.keys()) + +def create_custom_matcher(lhs_graph: Graph) -> Callable[[Graph, Callable[[VT], bool]], List[VT]]: + lhs_graph.auto_detect_io() + return lambda g, selection: custom_matcher(g, selection, to_networkx(lhs_graph)) + +def create_custom_rule(lhs_graph: Graph, rhs_graph: Graph) -> Callable[[Graph, List[VT]], pyzx.rules.RewriteOutputType[ET,VT]]: + lhs_graph.auto_detect_io() + rhs_graph.auto_detect_io() + return lambda g, verts: custom_rule(g, verts, to_networkx(lhs_graph), to_networkx(rhs_graph)) spider_fuse = ProofAction.from_dict(operations['spider']) @@ -133,5 +237,4 @@ def update_active(self, g: GraphT, verts: List[VT], edges: List[ET]) -> None: pauli = ProofAction.from_dict(operations['pauli']) bialgebra = ProofAction.from_dict(operations['bialgebra']) -actions_basic = ProofActionGroup(spider_fuse,to_z,to_x,rem_id,copy_action,pauli,bialgebra) - +rewrites = [spider_fuse, to_z, to_x, rem_id, copy_action, pauli, bialgebra] diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index e1b8a756..480b50a2 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -79,7 +79,7 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]: yield ToolbarSection(*self.identity_choice, exclusive=True) def init_action_groups(self) -> None: - self.action_groups = [proof_actions.actions_basic.copy()] + self.action_groups = [proof_actions.ProofActionGroup(*proof_actions.rewrites).copy()] for group in reversed(self.action_groups): hlayout = QHBoxLayout() group.init_buttons(self) @@ -90,7 +90,7 @@ def init_action_groups(self) -> None: widget = QWidget() widget.setLayout(hlayout) - self.layout().insertWidget(1,widget) + self.layout().insertWidget(1, widget) def parse_selection(self) -> tuple[list[VT], list[ET]]: selection = list(self.graph_scene.selected_vertices) diff --git a/zxlive/vitem.py b/zxlive/vitem.py index d8c842c1..7c3ba15e 100644 --- a/zxlive/vitem.py +++ b/zxlive/vitem.py @@ -64,7 +64,7 @@ class VItem(QGraphicsPathItem): phase_item: PhaseItem adj_items: Set[EItem] # Connected edges graph_scene: GraphScene - + halftone = "1000100010001000" #QPixmap("images/halftone.png") # Set of animations that are currently running on this vertex @@ -108,7 +108,7 @@ def __init__(self, graph_scene: GraphScene, v: VT) -> None: path = QPainterPath() if self.g.type(self.v) == VertexType.H_BOX: path.addRect(-0.2 * SCALE, -0.2 * SCALE, 0.4 * SCALE, 0.4 * SCALE) - else: + else: path.addEllipse(-0.2 * SCALE, -0.2 * SCALE, 0.4 * SCALE, 0.4 * SCALE) self.setPath(path) self.refresh() @@ -180,7 +180,7 @@ def paint(self, painter: QPainter, option: QStyleOptionGraphicsItem, widget: Opt # we intercept the selected option here. option.state &= ~QStyle.StateFlag.State_Selected super().paint(painter, option, widget) - + def itemChange(self, change: QGraphicsItem.GraphicsItemChange, value: Any) -> Any: # Snap items to grid on movement by intercepting the position-change # event and returning a new position @@ -360,5 +360,7 @@ def refresh(self) -> None: phase = self.v_item.g.phase(self.v_item.v) # phase = self.v_item.v self.setPlainText(phase_to_s(phase, self.v_item.g.type(self.v_item.v))) + if self.v_item.g.type(self.v_item.v) == VertexType.BOUNDARY: + self.setPlainText(str(int(self.v_item.g.qubit(self.v_item.v)))) p = self.v_item.pos() self.setPos(p.x(), p.y() - 0.6 * SCALE)