Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom rewrites with parameteric phases #185

Merged
merged 21 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions zxlive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from PySide6.QtCore import QCommandLineParser
from PySide6.QtGui import QIcon
import sys
sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx

from .mainwindow import MainWindow
from .common import get_data

#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx

# The following hack is needed on windows in order to show the icon in the taskbar
# See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105
import os
Expand Down
2 changes: 1 addition & 1 deletion zxlive/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from PySide6.QtWidgets import QListView
from pyzx import basicrules
from pyzx.graph import GraphDiff
from pyzx.symbolic import Poly
from pyzx.utils import EdgeType, VertexType, get_w_partner, vertex_is_w, get_w_io, get_z_box_label, set_z_box_label

from .common import ET, VT, W_INPUT_OFFSET, GraphT
from .graphview import GraphView
from .poly import Poly
from .proof import ProofModel, Rewrite


Expand Down
125 changes: 111 additions & 14 deletions zxlive/custom_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pyzx.utils import EdgeType, VertexType
from shapely import Polygon

from pyzx.symbolic import Poly

from .common import ET, VT, GraphT

if TYPE_CHECKING:
Expand All @@ -31,8 +33,11 @@ def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description:
def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutputType[ET,VT]:
subgraph_nx, boundary_mapping = create_subgraph(graph, vertices)
graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx,
node_match=categorical_node_match(['type', 'phase'], default=[1, 0]))
matching = list(graph_matcher.match())[0]
node_match=categorical_node_match('type', 1))
matchings = graph_matcher.match()
matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx)
matching = matchings[0]
symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx)

vertices_to_remove = []
for v in matching:
Expand All @@ -53,10 +58,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu
vertex_map = boundary_vertex_map
for v in self.rhs_graph_nx.nodes():
if self.rhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY:
phase = self.rhs_graph_nx.nodes()[v]['phase']
if isinstance(phase, Poly):
phase = phase.substitute(symbolic_params_map)
if phase.free_vars() == set():
phase = phase.terms[0][0]
vertex_map[v] = graph.add_vertex(ty = self.rhs_graph_nx.nodes()[v]['type'],
row = vertex_positions[v][0],
qubit = vertex_positions[v][1],
phase = self.rhs_graph_nx.nodes()[v]['phase'],)
phase = phase,)

# create etab to add edges
etab = {}
Expand All @@ -75,8 +85,10 @@ def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT]
vertices = [v for v in graph.vertices() if in_selection(v)]
subgraph_nx, _ = create_subgraph(graph, vertices)
graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx,
node_match=categorical_node_match(['type', 'phase'], default=[1, 0]))
if graph_matcher.is_isomorphic():
node_match=categorical_node_match('type', 1))
matchings = list(graph_matcher.match())
matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx)
if len(matchings) > 0:
return vertices
return []

Expand All @@ -102,6 +114,75 @@ def to_proof_action(self) -> "ProofAction":
return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description)


def get_linear(v):
if not isinstance(v, Poly):
raise ValueError("Not a symbolic parameter")
if len(v.terms) > 2 or len(v.free_vars()) > 1:
raise ValueError("Only linear symbolic parameters are supported")
if len(v.terms) == 0:
return 1, None, 0
elif len(v.terms) == 1:
if len(v.terms[0][1].vars) > 0:
var_term = v.terms[0]
const = 0
else:
const = v.terms[0][0]
return 1, None, const
else:
if len(v.terms[0][1].vars) > 0:
var_term = v.terms[0]
const = v.terms[1][0]
else:
var_term = v.terms[1]
const = v.terms[0][0]
coeff = var_term[0]
var, power = var_term[1].vars[0]
if power != 1:
raise ValueError("Only linear symbolic parameters are supported")
return coeff, var, const


def match_symbolic_parameters(match, left, right):
params = {}
left_phase = left.nodes.data('phase', default=0)
right_phase = right.nodes.data('phase', default=0)

def check_phase_equality(v):
if left_phase[v] != right_phase[match[v]]:
raise ValueError("Parameters do not match")

def update_params(v, var, coeff, const):
var_value = (right_phase[match[v]] - const) / coeff
if var in params and params[var] != var_value:
raise ValueError("Symbolic parameters do not match")
params[var] = var_value

for v in left.nodes():
if isinstance(left_phase[v], Poly):
coeff, var, const = get_linear(left_phase[v])
if var is None:
check_phase_equality(v)
continue
update_params(v, var, coeff, const)
else:
check_phase_equality(v)

return params


def filter_matchings_if_symbolic_compatible(matchings, left, right):
new_matchings = []
for matching in matchings:
if len(matching) != len(left):
continue
try:
match_symbolic_parameters(matching, left, right)
new_matchings.append(matching)
except ValueError:
pass
return new_matchings


def to_networkx(graph: GraphT) -> nx.Graph:
G = nx.Graph()
v_data = {v: {"type": graph.type(v),
Expand Down Expand Up @@ -154,13 +235,29 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool:
from .dialogs import show_error_msg
show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different numbers of inputs or outputs.")
return False
left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix()
if not np.allclose(left_matrix, right_matrix):
if show_error:
from .dialogs import show_error_msg
if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)):
show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.")
else:
show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.")
return False
if not rule.lhs_graph.variable_types and not rule.rhs_graph.variable_types:
left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix()
if not np.allclose(left_matrix, right_matrix):
if show_error:
from .dialogs import show_error_msg
if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)):
show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.")
else:
show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.")
return False
else:
if not (rule.rhs_graph.variable_types.items() <= rule.lhs_graph.variable_types.items()):
if show_error:
from .dialogs import show_error_msg
show_error_msg("Warning!", "The right-hand side has more free variables than the left-hand side.")
return False
for vertex in rule.lhs_graph.vertices():
if isinstance(rule.lhs_graph.phase(vertex), Poly):
try:
get_linear(rule.lhs_graph.phase(vertex))
except ValueError as e:
if show_error:
from .dialogs import show_error_msg
show_error_msg("Warning!", str(e))
return False
return True
4 changes: 2 additions & 2 deletions zxlive/edit_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from PySide6.QtCore import Signal
from PySide6.QtGui import QAction
from PySide6.QtWidgets import (QToolButton)
from pyzx import EdgeType, VertexType, Circuit
from pyzx import EdgeType, VertexType
from pyzx.circuit.qasmparser import QASMParser
from pyzx.symbolic import Poly

from .base_panel import ToolbarSection
from .commands import UpdateGraph
Expand All @@ -16,7 +17,6 @@
from .editor_base_panel import EditorBasePanel
from .graphscene import EditGraphScene
from .graphview import GraphView
from .poly import Poly


class GraphEditPanel(EditorBasePanel):
Expand Down
55 changes: 10 additions & 45 deletions zxlive/editor_base_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import copy
import re
from enum import Enum
from fractions import Fraction
from typing import Callable, Iterator, TypedDict, Union
from typing import Callable, Iterator, TypedDict

from PySide6.QtCore import QPoint, QSize, Qt, Signal
from PySide6.QtGui import (QAction, QColor, QIcon, QPainter, QPalette, QPen,
Expand All @@ -15,7 +14,7 @@
QSpacerItem, QSplitter, QToolButton, QWidget)
from pyzx import EdgeType, VertexType
from pyzx.utils import get_w_partner, vertex_is_w

from pyzx.graph.jsonparser import string_to_phase

from .base_panel import BasePanel, ToolbarSection
from .commands import (AddEdge, AddNode, AddWNode, ChangeEdgeColor,
Expand All @@ -25,8 +24,6 @@
from .dialogs import show_error_msg
from .eitem import HAD_EDGE_BLUE
from .graphscene import EditGraphScene
from .parse_poly import parse
from .poly import Poly, new_var
from .vitem import BLACK


Expand Down Expand Up @@ -98,16 +95,8 @@ def update_colors(self) -> None:
super().update_colors()
self.update_side_bar()

def update_variable_viewer(self) -> None:
self.update_side_bar()

def _populate_variables(self) -> None:
self.variable_types = {}
for vert in self.graph.vertices():
phase = self.graph.phase(vert)
if isinstance(phase, Poly):
for var in phase.free_vars():
self.variable_types[var.name] = var.is_bool
self.variable_types = self.graph.variable_types.copy()

def _tool_clicked(self, tool: ToolType) -> None:
self.graph_scene.curr_tool = tool
Expand Down Expand Up @@ -189,18 +178,18 @@ def vert_double_clicked(self, v: VT) -> None:
if not ok:
return None
try:
new_phase = string_to_complex(input_) if phase_is_complex else string_to_fraction(input_, self._new_var)
new_phase = string_to_complex(input_) if phase_is_complex else string_to_phase(input_, graph)
except ValueError:
show_error_msg("Invalid Input", error_msg)
return None
cmd = ChangePhase(self.graph_view, v, new_phase)
self.undo_stack.push(cmd)

def _new_var(self, name: str) -> Poly:
if name not in self.variable_types:
self.variable_types[name] = False
self.variable_viewer.add_item(name)
return new_var(name, self.variable_types)
# For some reason it is important we first push to the stack before we do the following.
if len(graph.variable_types) != len(self.variable_types):
new_vars = graph.variable_types.keys() - self.variable_types.keys()
self.variable_types.update(graph.variable_types)
for v in new_vars:
self.variable_viewer.add_item(v)


class VariableViewer(QScrollArea):
Expand Down Expand Up @@ -378,29 +367,5 @@ def create_icon(shape: ShapeType, color: QColor) -> QIcon:
return icon


def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]:
if not string:
return Fraction(0)
try:
s = string.lower().replace(' ', '')
s = re.sub('\\*?(pi|\u04c0)\\*?', '', s)
if '.' in s or 'e' in s:
return Fraction(float(s))
elif '/' in s:
a, b = s.split("/", 2)
if not a:
return Fraction(1, int(b))
if a == '-':
a = '-1'
return Fraction(int(a), int(b))
else:
return Fraction(int(s))
except ValueError:
try:
return parse(string, new_var_)
except Exception as e:
raise ValueError(e)


def string_to_complex(string: str) -> complex:
return complex(string) if string else complex(0)
58 changes: 0 additions & 58 deletions zxlive/parse_poly.py

This file was deleted.

Loading
Loading