From 9d87e77a419b9fb2d27d201d01d67f51cb43bd60 Mon Sep 17 00:00:00 2001 From: Tuomas Laakkonen Date: Mon, 13 Nov 2023 21:35:11 +0000 Subject: [PATCH] allow unfusing W nodes --- zxlive/proof_panel.py | 54 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index 5c3975be..b8b3d97c 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -15,7 +15,7 @@ QStyleOptionViewItem, QToolButton, QWidget, QVBoxLayout, QTabWidget, QInputDialog) from pyzx import VertexType, basicrules -from pyzx.utils import get_z_box_label, set_z_box_label +from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType from . import animations as anims from . import proof_actions @@ -29,7 +29,7 @@ from .graphscene import GraphScene from .graphview import GraphTool, GraphView, WandTrace from .proof import ProofModel -from .vitem import DragState, VItem +from .vitem import DragState, VItem, get_w_partner_vitem, W_INPUT_OFFSET, SCALE from .editor_base_panel import string_to_complex, string_to_fraction from .poly import Poly @@ -220,14 +220,14 @@ def cross(a: QPointF, b: QPointF) -> float: return False item = filtered[0] vertex = item.v - if self.graph.type(vertex) not in (VertexType.Z, VertexType.X, VertexType.Z_BOX): + if self.graph.type(vertex) not in (VertexType.Z, VertexType.X, VertexType.Z_BOX, VertexType.W_OUTPUT): return False if not trace.shift and basicrules.check_remove_id(self.graph, vertex): self._remove_id(vertex) return True - if trace.shift: + if trace.shift and self.graph.type(vertex) != VertexType.W_OUTPUT: phase_is_complex = (self.graph.type(vertex) == VertexType.Z_BOX) if phase_is_complex: prompt = "Enter desired phase value (complex value):" @@ -245,7 +245,7 @@ def new_var(_): except ValueError: show_error_msg("Invalid Input", error_msg) return False - else: + elif self.graph.type(vertex) != VertexType.W_OUTPUT: if self.graph.type(vertex) == VertexType.Z_BOX: phase = get_z_box_label(self.graph, vertex) else: @@ -268,7 +268,11 @@ def new_var(_): else: right.append(neighbor) mouse_dir = ((start + end) * (1/2)) - pos - self._unfuse(vertex, left, mouse_dir, phase) + + if self.graph.type(vertex) == VertexType.W_OUTPUT: + self._unfuse_w(vertex, left, mouse_dir) + else: + self._unfuse(vertex, left, mouse_dir, phase) return True def _remove_id(self, v: VT) -> None: @@ -278,6 +282,44 @@ def _remove_id(self, v: VT) -> None: cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "id") self.undo_stack.push(cmd, anim_before=anim) + def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> None: + new_g = copy.deepcopy(self.graph) + + vi = get_w_partner(self.graph, v) + par_dir = QVector2D( + self.graph.row(v) - self.graph.row(vi), + self.graph.qubit(v) - self.graph.qubit(vi) + ).normalized() + + perp_dir = QVector2D(mouse_dir - QPointF(self.graph.row(v)/SCALE, self.graph.qubit(v)/SCALE)).normalized() + perp_dir -= QVector2D.dotProduct(perp_dir, par_dir) * par_dir + perp_dir.normalize() + + out_offset_x = par_dir.x() * 0.5 + perp_dir.x() * 0.5 + out_offset_y = par_dir.y() * 0.5 + perp_dir.y() * 0.5 + + in_offset_x = out_offset_x - par_dir.x()*W_INPUT_OFFSET + in_offset_y = out_offset_y - par_dir.y()*W_INPUT_OFFSET + + left_vert = new_g.add_vertex(VertexType.W_OUTPUT, + qubit=self.graph.qubit(v) + out_offset_y, + row=self.graph.row(v) + out_offset_x) + left_vert_i = new_g.add_vertex(VertexType.W_INPUT, + qubit=self.graph.qubit(v) + in_offset_y, + row=self.graph.row(v) + in_offset_x) + new_g.add_edge((left_vert_i, left_vert), EdgeType.W_IO) + new_g.add_edge((v, left_vert_i)) + new_g.set_row(v, self.graph.row(v)) + new_g.set_qubit(v, self.graph.qubit(v)) + for neighbor in left_neighbours: + new_g.add_edge((neighbor, left_vert), + self.graph.edge_type((v, neighbor))) + new_g.remove_edge((v, neighbor)) + + anim = anims.unfuse(self.graph, new_g, v, self.graph_scene) + cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "unfuse") + self.undo_stack.push(cmd, anim_after=anim) + def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: Poly | complex | Fraction) -> None: def snap_vector(v: QVector2D) -> None: if abs(v.x()) > abs(v.y()):