From 8f1a4d08817b8f41ffbd0bf8c142c115ec8de180 Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 20:50:16 +0000 Subject: [PATCH 01/20] Use Poly class in PyZX --- zxlive/app.py | 4 +- zxlive/editor_base_panel.py | 2 +- zxlive/parse_poly.py | 58 ------------ zxlive/poly.py | 180 ------------------------------------ 4 files changed, 3 insertions(+), 241 deletions(-) delete mode 100644 zxlive/parse_poly.py delete mode 100644 zxlive/poly.py diff --git a/zxlive/app.py b/zxlive/app.py index 2b9fba83..590b8ab4 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,11 +19,11 @@ from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon import sys +sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx + from .mainwindow import MainWindow from .common import get_data -#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx - # The following hack is needed on windows in order to show the icon in the taskbar # See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105 import os diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 70d20882..6c78a50c 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,6 +15,7 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.symbolic import Poly from .base_panel import BasePanel, ToolbarSection @@ -26,7 +27,6 @@ from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene from .parse_poly import parse -from .poly import Poly, new_var from .vitem import BLACK diff --git a/zxlive/parse_poly.py b/zxlive/parse_poly.py deleted file mode 100644 index f70a33ab..00000000 --- a/zxlive/parse_poly.py +++ /dev/null @@ -1,58 +0,0 @@ -from .poly import Poly, new_const - -from typing import Any, Callable -from lark import Lark, Transformer -from functools import reduce -from operator import add, mul -from fractions import Fraction - -poly_grammar = Lark(""" - start : "(" start ")" | term ("+" term)* - term : (intf | frac)? factor ("*" factor)* - ?factor : intf | frac | pi | pifrac | var - var : CNAME - intf : INT - pi : "\\pi" | "pi" - frac : INT "/" INT - pifrac : [INT] pi "/" INT - - %import common.INT - %import common.CNAME - %import common.WS - %ignore WS - """, - parser='lalr', - maybe_placeholders=True) - -class PolyTransformer(Transformer): - def __init__(self, new_var: Callable[[str], Poly]): - super().__init__() - - self._new_var = new_var - - def start(self, items: list[Poly]) -> Poly: - return reduce(add, items) - - def term(self, items: list[Poly]) -> Poly: - return reduce(mul, items) - - def var(self, items: list[Any]) -> Poly: - v = str(items[0]) - return self._new_var(v) - - def pi(self, _: list[Any]) -> Poly: - return new_const(1) - - def intf(self, items: list[Any]) -> Poly: - return new_const(int(items[0])) - - def frac(self, items: list[Any]) -> Poly: - return new_const(Fraction(int(items[0]), int(items[1]))) - - def pifrac(self, items: list[Any]) -> Poly: - numerator = int(items[0]) if items[0] else 1 - return new_const(Fraction(numerator, int(items[2]))) - -def parse(expr: str, new_var: Callable[[str], Poly]) -> Poly: - tree = poly_grammar.parse(expr) - return PolyTransformer(new_var).transform(tree) diff --git a/zxlive/poly.py b/zxlive/poly.py deleted file mode 100644 index 0dfac718..00000000 --- a/zxlive/poly.py +++ /dev/null @@ -1,180 +0,0 @@ -from fractions import Fraction -from typing import Union, Optional - - -class Var: - name: str - _is_bool: bool - _types_dict: Optional[Union[bool, dict[str, bool]]] - - def __init__(self, name: str, data: Union[bool, dict[str, bool]]): - self.name = name - if isinstance(data, dict): - self._types_dict = data - self._frozen = False - self._is_bool = False - else: - self._types_dict = None - self._frozen = True - self._is_bool = data - - @property - def is_bool(self) -> bool: - if self._frozen: - return self._is_bool - else: - assert isinstance(self._types_dict, dict) - return self._types_dict[self.name] - - def __repr__(self) -> str: - return self.name - - def __lt__(self, other: 'Var') -> bool: - if int(self.is_bool) == int(other.is_bool): - return self.name < other.name - return int(self.is_bool) < int(other.is_bool) - - def __hash__(self) -> int: - # Variables with the same name map to the same type - # within the same graph, so no need to include is_bool - # in the hash. - return int(hash(self.name)) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - def freeze(self) -> None: - if not self._frozen: - assert isinstance(self._types_dict, dict) - self._is_bool = self._types_dict[self.name] - self._frozen = True - self._types_dict = None - - def __copy__(self) -> 'Var': - if self._frozen: - return Var(self.name, self.is_bool) - else: - assert isinstance(self._types_dict, dict) - return Var(self.name, self._types_dict) - - def __deepcopy__(self, _memo: object) -> 'Var': - return self.__copy__() - -class Term: - vars: list[tuple[Var, int]] - - def __init__(self, vars: list[tuple[Var,int]]) -> None: - self.vars = vars - - def freeze(self) -> None: - for var, _ in self.vars: - var.freeze() - - def free_vars(self) -> set[Var]: - return set(var for var, _ in self.vars) - - def __repr__(self) -> str: - vs = [] - for v, c in self.vars: - if c == 1: - vs.append(f'{v}') - else: - vs.append(f'{v}^{c}') - return '*'.join(vs) - - def __mul__(self, other: 'Term') -> 'Term': - vs = dict() - for v, c in self.vars + other.vars: - if v not in vs: vs[v] = c - else: vs[v] += c - # TODO deal with fractional / symbolic powers - if v.is_bool and c > 1: - vs[v] = 1 - return Term([(v, c) for v, c in vs.items()]) - - def __hash__(self) -> int: - return hash(tuple(sorted(self.vars))) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - -class Poly: - terms: list[tuple[Union[int, float, Fraction], Term]] - - def __init__(self, terms: list[tuple[Union[int, float, Fraction], Term]]) -> None: - self.terms = terms - - def freeze(self) -> None: - for _, term in self.terms: - term.freeze() - - def free_vars(self) -> set[Var]: - output = set() - for _, term in self.terms: - output.update(term.free_vars()) - return output - - def __add__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float, Fraction)): - other = Poly([(other, Term([]))]) - counter = dict() - for c, t in self.terms + other.terms: - if t not in counter: counter[t] = c - else: counter[t] += c - if all(tt[0].is_bool for tt in t.vars): - counter[t] = counter[t] % 2 - - # remove terms with coefficient 0 - for t in list(counter.keys()): - if counter[t] == 0: - del counter[t] - return Poly([(c, t) for t, c in counter.items()]) - - __radd__ = __add__ - - def __mul__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float)): - other = Poly([(other, Term([]))]) - p = Poly([]) - for c1, t1 in self.terms: - for c2, t2 in other.terms: - p += Poly([(c1 * c2, t1 * t2)]) - return p - - __rmul__ = __mul__ - - def __repr__(self) -> str: - ts = [] - for c, t in self.terms: - if t == Term([]): - ts.append(f'{c}') - elif c == 1: - ts.append(f'{t}') - else: - ts.append(f'{c}{t}') - return ' + '.join(ts) - - def __eq__(self, other: object) -> bool: - if isinstance(other, (int, float, Fraction)): - if other == 0: - other = Poly([]) - else: - other = Poly([(other, Term([]))]) - assert isinstance(other, Poly) - return set(self.terms) == set(other.terms) - - @property - def is_pauli(self) -> bool: - for c, t in self.terms: - if not all(v.is_bool for v, _ in t.vars): - return False - if c % 1 != 0: - return False - return True - -def new_var(name: str, types_dict: Union[bool, dict[str, bool]]) -> Poly: - return Poly([(1, Term([(Var(name, types_dict), 1)]))]) - -def new_const(coeff: Union[int, Fraction]) -> Poly: - return Poly([(coeff, Term([]))]) From a9ff770179baf61d019a08ad133649759141c3ef Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 21:31:50 +0000 Subject: [PATCH 02/20] Fixed bugs with saving parameters --- zxlive/edit_panel.py | 2 +- zxlive/editor_base_panel.py | 71 ++++++++++++++++--------------------- zxlive/rule_panel.py | 2 +- 3 files changed, 33 insertions(+), 42 deletions(-) diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index 4cb46930..78e7d32a 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -7,6 +7,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import (QToolButton) from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection from .common import GraphT @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import GraphView -from .poly import Poly class GraphEditPanel(EditorBasePanel): diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 6c78a50c..57386bf7 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,9 +15,9 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.graph.jsonparser import string_to_phase from pyzx.symbolic import Poly - from .base_panel import BasePanel, ToolbarSection from .commands import (AddEdge, AddNode, AddWNode, ChangeEdgeColor, ChangeNodeType, ChangePhase, MoveNode, SetGraph, @@ -26,7 +26,6 @@ from .dialogs import show_error_msg from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene -from .parse_poly import parse from .vitem import BLACK @@ -98,16 +97,8 @@ def update_colors(self) -> None: super().update_colors() self.update_side_bar() - def update_variable_viewer(self) -> None: - self.update_side_bar() - def _populate_variables(self) -> None: - self.variable_types = {} - for vert in self.graph.vertices(): - phase = self.graph.phase(vert) - if isinstance(phase, Poly): - for var in phase.free_vars(): - self.variable_types[var.name] = var.is_bool + self.variable_types = self.graph.variable_types.copy() def _tool_clicked(self, tool: ToolType) -> None: self.graph_scene.curr_tool = tool @@ -189,18 +180,18 @@ def vert_double_clicked(self, v: VT) -> None: if not ok: return None try: - new_phase = string_to_complex(input_) if phase_is_complex else string_to_fraction(input_, self._new_var) + new_phase = string_to_complex(input_) if phase_is_complex else string_to_phase(input_, graph) except ValueError: show_error_msg("Invalid Input", error_msg) return None cmd = ChangePhase(self.graph_view, v, new_phase) self.undo_stack.push(cmd) - - def _new_var(self, name: str) -> Poly: - if name not in self.variable_types: - self.variable_types[name] = False - self.variable_viewer.add_item(name) - return new_var(name, self.variable_types) + # For some reason it is important we first push to the stack before we do the following. + if len(graph.variable_types) != len(self.variable_types): + new_vars = graph.variable_types.keys() - self.variable_types.keys() + self.variable_types.update(graph.variable_types) + for v in new_vars: + self.variable_viewer.add_item(v) class VariableViewer(QScrollArea): @@ -378,28 +369,28 @@ def create_icon(shape: ShapeType, color: str) -> QIcon: return icon -def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: - if not string: - return Fraction(0) - try: - s = string.lower().replace(' ', '') - s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) - if '.' in s or 'e' in s: - return Fraction(float(s)) - elif '/' in s: - a, b = s.split("/", 2) - if not a: - return Fraction(1, int(b)) - if a == '-': - a = '-1' - return Fraction(int(a), int(b)) - else: - return Fraction(int(s)) - except ValueError: - try: - return parse(string, new_var_) - except Exception as e: - raise ValueError(e) +#def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: +# if not string: +# return Fraction(0) +# try: +# s = string.lower().replace(' ', '') +# s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) +# if '.' in s or 'e' in s: +# return Fraction(float(s)) +# elif '/' in s: +# a, b = s.split("/", 2) +# if not a: +# return Fraction(1, int(b)) +# if a == '-': +# a = '-1' +# return Fraction(int(a), int(b)) +# else: +# return Fraction(int(s)) +# except ValueError: +# try: +# return parse(string, new_var_) +# except Exception as e: +# raise ValueError(e) def string_to_complex(string: str) -> complex: diff --git a/zxlive/rule_panel.py b/zxlive/rule_panel.py index 61f4c567..179fca7e 100644 --- a/zxlive/rule_panel.py +++ b/zxlive/rule_panel.py @@ -6,6 +6,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import QLineEdit from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import RuleEditGraphView -from .poly import Poly class RulePanel(EditorBasePanel): From f62718d6594d85f31fb3efa7d42a76cde385ab5c Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:25:36 +0100 Subject: [PATCH 03/20] only check matrix for non-symbolic rules --- zxlive/custom_rule.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 9d3e5a23..0f76eaaf 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -154,13 +154,14 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: from .dialogs import show_error_msg show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different numbers of inputs or outputs.") return False - left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() - if not np.allclose(left_matrix, right_matrix): - if show_error: - from .dialogs import show_error_msg - if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") - else: - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") - return False + if not rule.lhs_graph.variable_types and not rule.rhs_graph.variable_types: + left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() + if not np.allclose(left_matrix, right_matrix): + if show_error: + from .dialogs import show_error_msg + if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") + else: + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") + return False return True From 057179997b593ef36740eaaaf265748b694742df Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:26:39 +0100 Subject: [PATCH 04/20] rewrite rule matching with symbolic parameters --- zxlive/custom_rule.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 0f76eaaf..144a23c9 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -11,6 +11,8 @@ from pyzx.utils import EdgeType, VertexType from shapely import Polygon +from pyzx.symbolic import Poly + from .common import ET, VT, GraphT if TYPE_CHECKING: @@ -75,8 +77,10 @@ def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT] vertices = [v for v in graph.vertices() if in_selection(v)] subgraph_nx, _ = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - if graph_matcher.is_isomorphic(): + node_match=categorical_node_match('type', 1)) + matchings = list(graph_matcher.match()) + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + if len(matchings) > 0: return vertices return [] @@ -102,6 +106,32 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def match_symbolic_parameters(match, left, right): + params = {} + left_phase = left.nodes.data('phase', default=0) + right_phase = right.nodes.data('phase', default=0) + for v in left.nodes(): + if isinstance(left_phase[v], Poly): + if str(left_phase[v]) in params: + if params[str(left_phase)] != right_phase[match[v]]: + raise ValueError("Symbolic parameters do not match") + else: + params[str(left_phase[v])] = right_phase[match[v]] + elif left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + return params + +def filter_matchings_if_symbolic_compatible(matchings, left, right): + new_matchings = [] + for matching in matchings: + try: + match_symbolic_parameters(matching, left, right) + new_matchings.append(matching) + except ValueError: + pass + return new_matchings + + def to_networkx(graph: GraphT) -> nx.Graph: G = nx.Graph() v_data = {v: {"type": graph.type(v), From 745d55e21491cde50f0718e8d8ba67be3352976d Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:28:14 +0100 Subject: [PATCH 05/20] applying custom rule with symbolic parameters --- zxlive/custom_rule.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 144a23c9..8e33da6c 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -33,8 +33,11 @@ def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutputType[ET,VT]: subgraph_nx, boundary_mapping = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - matching = list(graph_matcher.match())[0] + node_match=categorical_node_match('type', 1)) + matchings = graph_matcher.match() + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + matching = matchings[0] + symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx) vertices_to_remove = [] for v in matching: @@ -55,10 +58,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu vertex_map = boundary_vertex_map for v in self.rhs_graph_nx.nodes(): if self.rhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY: + phase = self.rhs_graph_nx.nodes()[v]['phase'] + if isinstance(phase, Poly): + phase = phase.substitute(symbolic_params_map) + if phase.free_vars() == set(): + phase = phase.terms[0][0] vertex_map[v] = graph.add_vertex(ty = self.rhs_graph_nx.nodes()[v]['type'], row = vertex_positions[v][0], qubit = vertex_positions[v][1], - phase = self.rhs_graph_nx.nodes()[v]['phase'],) + phase = phase,) # create etab to add edges etab = {} From 6517efcd438f747d17f0793c3bba4df7bdce801e Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:39:35 +0100 Subject: [PATCH 06/20] typo in match_symbolic_parameters --- zxlive/custom_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 8e33da6c..1f810132 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -121,7 +121,7 @@ def match_symbolic_parameters(match, left, right): for v in left.nodes(): if isinstance(left_phase[v], Poly): if str(left_phase[v]) in params: - if params[str(left_phase)] != right_phase[match[v]]: + if params[str(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: params[str(left_phase[v])] = right_phase[match[v]] From 58ca1cc95e239712cbae3ef699025ca79664cec8 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:33:39 +0100 Subject: [PATCH 07/20] small bug fix --- zxlive/custom_rule.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 1f810132..82ed0e80 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -132,6 +132,8 @@ def match_symbolic_parameters(match, left, right): def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: + if len(matching) != len(left): + continue try: match_symbolic_parameters(matching, left, right) new_matchings.append(matching) From 726ad180078d69477a0fdcfded477ce0a485ab7f Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:34:02 +0100 Subject: [PATCH 08/20] get var method for symbolic parameters --- zxlive/custom_rule.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 82ed0e80..c838a40b 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,17 +114,26 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def get_var(v): + if not isinstance(v, Poly): + raise ValueError("Not a symbolic parameter") + if len(v.terms) != 1: + raise ValueError("Only single-term symbolic parameters are supported") + if len(v.terms[0][1].vars) != 1: + raise ValueError("Only single-variable symbolic parameters are supported") + return v.terms[0][1].vars[0][0] + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) for v in left.nodes(): if isinstance(left_phase[v], Poly): - if str(left_phase[v]) in params: - if params[str(left_phase[v])] != right_phase[match[v]]: + if get_var(left_phase[v]) in params: + if params[get_var(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: - params[str(left_phase[v])] = right_phase[match[v]] + params[get_var(left_phase[v])] = right_phase[match[v]] elif left_phase[v] != right_phase[match[v]]: raise ValueError("Parameters do not match") return params From 3cf64733dbc31437e04940db47eb247c7146d034 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 17:35:43 +0100 Subject: [PATCH 09/20] add warnings for custom rules with symbolic parameters --- zxlive/custom_rule.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index c838a40b..a86fd0c9 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -213,4 +213,17 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: else: show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") return False + else: + if not (rule.rhs_graph.variable_types.items() <= rule.lhs_graph.variable_types.items()): + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "The right-hand side has more free variables than the left-hand side.") + return False + for vertex in rule.lhs_graph.vertices(): + if isinstance(rule.lhs_graph.phase(vertex), Poly): + if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + return False return True From 04169436da874d8ce71d9211f70047d29e29a4e9 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 18:47:17 +0100 Subject: [PATCH 10/20] symbolic rewrites support linear terms Co-authored-by: Tuomas Laakkonen --- zxlive/custom_rule.py | 64 +++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index a86fd0c9..c3c81abf 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,30 +114,62 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) -def get_var(v): +def get_linear(v): if not isinstance(v, Poly): raise ValueError("Not a symbolic parameter") - if len(v.terms) != 1: - raise ValueError("Only single-term symbolic parameters are supported") - if len(v.terms[0][1].vars) != 1: - raise ValueError("Only single-variable symbolic parameters are supported") - return v.terms[0][1].vars[0][0] + if len(v.terms) > 2 or len(v.free_vars()) > 1: + raise ValueError("Only linear symbolic parameters are supported") + if len(v.terms) == 0: + return 1, None, 0 + elif len(v.terms) == 1: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = 0 + else: + const = v.terms[0][0] + return 1, None, const + else: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = v.terms[1][0] + else: + var_term = v.terms[1] + const = v.terms[0][0] + coeff = var_term[0] + var, power = var_term[1].vars[0] + if power != 1: + raise ValueError("Only linear symbolic parameters are supported") + return coeff, var, const + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) + + def check_phase_equality(v): + if left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + + def update_params(v, var, coeff, const): + var_value = (right_phase[match[v]] - const) / coeff + if var in params and params[var] != var_value: + raise ValueError("Symbolic parameters do not match") + params[var] = var_value + for v in left.nodes(): if isinstance(left_phase[v], Poly): - if get_var(left_phase[v]) in params: - if params[get_var(left_phase[v])] != right_phase[match[v]]: - raise ValueError("Symbolic parameters do not match") - else: - params[get_var(left_phase[v])] = right_phase[match[v]] - elif left_phase[v] != right_phase[match[v]]: - raise ValueError("Parameters do not match") + coeff, var, const = get_linear(left_phase[v]) + if var is None: + check_phase_equality(v) + continue + update_params(v, var, coeff, const) + else: + check_phase_equality(v) + return params + def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: @@ -221,9 +253,11 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: return False for vertex in rule.lhs_graph.vertices(): if isinstance(rule.lhs_graph.phase(vertex), Poly): - if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + try: + get_linear(rule.lhs_graph.phase(vertex)) + except ValueError as e: if show_error: from .dialogs import show_error_msg - show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + show_error_msg("Warning!", str(e)) return False return True From 8a1ab003f5fd781fd93a36fbfbab845cc06fa6ff Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 20:50:16 +0000 Subject: [PATCH 11/20] Use Poly class in PyZX --- zxlive/app.py | 4 +- zxlive/editor_base_panel.py | 2 +- zxlive/parse_poly.py | 58 ------------ zxlive/poly.py | 180 ------------------------------------ 4 files changed, 3 insertions(+), 241 deletions(-) delete mode 100644 zxlive/parse_poly.py delete mode 100644 zxlive/poly.py diff --git a/zxlive/app.py b/zxlive/app.py index e525f4fc..56a2991c 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,11 +19,11 @@ from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon import sys +sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx + from .mainwindow import MainWindow from .common import get_data -#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx - # The following hack is needed on windows in order to show the icon in the taskbar # See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105 import os diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 9230491d..79cfeccd 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,6 +15,7 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.symbolic import Poly from .base_panel import BasePanel, ToolbarSection @@ -26,7 +27,6 @@ from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene from .parse_poly import parse -from .poly import Poly, new_var from .vitem import BLACK diff --git a/zxlive/parse_poly.py b/zxlive/parse_poly.py deleted file mode 100644 index 42e9da97..00000000 --- a/zxlive/parse_poly.py +++ /dev/null @@ -1,58 +0,0 @@ -from .poly import Poly, new_const - -from typing import Any, Callable -from lark import Lark, Transformer -from functools import reduce -from operator import add, mul -from fractions import Fraction - -poly_grammar = Lark(""" - start : "(" start ")" | term ("+" term)* - term : (intf | frac)? factor ("*" factor)* - ?factor : intf | frac | pi | pifrac | var - var : CNAME - intf : INT - pi : "\\pi" | "pi" - frac : INT "/" INT - pifrac : [INT] pi "/" INT - - %import common.INT - %import common.CNAME - %import common.WS - %ignore WS - """, - parser='lalr', - maybe_placeholders=True) - -class PolyTransformer(Transformer[Poly]): - def __init__(self, new_var: Callable[[str], Poly]): - super().__init__() - - self._new_var = new_var - - def start(self, items: list[Poly]) -> Poly: - return reduce(add, items) - - def term(self, items: list[Poly]) -> Poly: - return reduce(mul, items) - - def var(self, items: list[Any]) -> Poly: - v = str(items[0]) - return self._new_var(v) - - def pi(self, _: list[Any]) -> Poly: - return new_const(1) - - def intf(self, items: list[Any]) -> Poly: - return new_const(int(items[0])) - - def frac(self, items: list[Any]) -> Poly: - return new_const(Fraction(int(items[0]), int(items[1]))) - - def pifrac(self, items: list[Any]) -> Poly: - numerator = int(items[0]) if items[0] else 1 - return new_const(Fraction(numerator, int(items[2]))) - -def parse(expr: str, new_var: Callable[[str], Poly]) -> Poly: - tree = poly_grammar.parse(expr) - return PolyTransformer(new_var).transform(tree) diff --git a/zxlive/poly.py b/zxlive/poly.py deleted file mode 100644 index 0dfac718..00000000 --- a/zxlive/poly.py +++ /dev/null @@ -1,180 +0,0 @@ -from fractions import Fraction -from typing import Union, Optional - - -class Var: - name: str - _is_bool: bool - _types_dict: Optional[Union[bool, dict[str, bool]]] - - def __init__(self, name: str, data: Union[bool, dict[str, bool]]): - self.name = name - if isinstance(data, dict): - self._types_dict = data - self._frozen = False - self._is_bool = False - else: - self._types_dict = None - self._frozen = True - self._is_bool = data - - @property - def is_bool(self) -> bool: - if self._frozen: - return self._is_bool - else: - assert isinstance(self._types_dict, dict) - return self._types_dict[self.name] - - def __repr__(self) -> str: - return self.name - - def __lt__(self, other: 'Var') -> bool: - if int(self.is_bool) == int(other.is_bool): - return self.name < other.name - return int(self.is_bool) < int(other.is_bool) - - def __hash__(self) -> int: - # Variables with the same name map to the same type - # within the same graph, so no need to include is_bool - # in the hash. - return int(hash(self.name)) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - def freeze(self) -> None: - if not self._frozen: - assert isinstance(self._types_dict, dict) - self._is_bool = self._types_dict[self.name] - self._frozen = True - self._types_dict = None - - def __copy__(self) -> 'Var': - if self._frozen: - return Var(self.name, self.is_bool) - else: - assert isinstance(self._types_dict, dict) - return Var(self.name, self._types_dict) - - def __deepcopy__(self, _memo: object) -> 'Var': - return self.__copy__() - -class Term: - vars: list[tuple[Var, int]] - - def __init__(self, vars: list[tuple[Var,int]]) -> None: - self.vars = vars - - def freeze(self) -> None: - for var, _ in self.vars: - var.freeze() - - def free_vars(self) -> set[Var]: - return set(var for var, _ in self.vars) - - def __repr__(self) -> str: - vs = [] - for v, c in self.vars: - if c == 1: - vs.append(f'{v}') - else: - vs.append(f'{v}^{c}') - return '*'.join(vs) - - def __mul__(self, other: 'Term') -> 'Term': - vs = dict() - for v, c in self.vars + other.vars: - if v not in vs: vs[v] = c - else: vs[v] += c - # TODO deal with fractional / symbolic powers - if v.is_bool and c > 1: - vs[v] = 1 - return Term([(v, c) for v, c in vs.items()]) - - def __hash__(self) -> int: - return hash(tuple(sorted(self.vars))) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - -class Poly: - terms: list[tuple[Union[int, float, Fraction], Term]] - - def __init__(self, terms: list[tuple[Union[int, float, Fraction], Term]]) -> None: - self.terms = terms - - def freeze(self) -> None: - for _, term in self.terms: - term.freeze() - - def free_vars(self) -> set[Var]: - output = set() - for _, term in self.terms: - output.update(term.free_vars()) - return output - - def __add__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float, Fraction)): - other = Poly([(other, Term([]))]) - counter = dict() - for c, t in self.terms + other.terms: - if t not in counter: counter[t] = c - else: counter[t] += c - if all(tt[0].is_bool for tt in t.vars): - counter[t] = counter[t] % 2 - - # remove terms with coefficient 0 - for t in list(counter.keys()): - if counter[t] == 0: - del counter[t] - return Poly([(c, t) for t, c in counter.items()]) - - __radd__ = __add__ - - def __mul__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float)): - other = Poly([(other, Term([]))]) - p = Poly([]) - for c1, t1 in self.terms: - for c2, t2 in other.terms: - p += Poly([(c1 * c2, t1 * t2)]) - return p - - __rmul__ = __mul__ - - def __repr__(self) -> str: - ts = [] - for c, t in self.terms: - if t == Term([]): - ts.append(f'{c}') - elif c == 1: - ts.append(f'{t}') - else: - ts.append(f'{c}{t}') - return ' + '.join(ts) - - def __eq__(self, other: object) -> bool: - if isinstance(other, (int, float, Fraction)): - if other == 0: - other = Poly([]) - else: - other = Poly([(other, Term([]))]) - assert isinstance(other, Poly) - return set(self.terms) == set(other.terms) - - @property - def is_pauli(self) -> bool: - for c, t in self.terms: - if not all(v.is_bool for v, _ in t.vars): - return False - if c % 1 != 0: - return False - return True - -def new_var(name: str, types_dict: Union[bool, dict[str, bool]]) -> Poly: - return Poly([(1, Term([(Var(name, types_dict), 1)]))]) - -def new_const(coeff: Union[int, Fraction]) -> Poly: - return Poly([(coeff, Term([]))]) From dfaa9f62eb833354e6e882be92dc22bd67e8b116 Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 21:31:50 +0000 Subject: [PATCH 12/20] Fixed bugs with saving parameters --- zxlive/edit_panel.py | 4 +-- zxlive/editor_base_panel.py | 71 ++++++++++++++++--------------------- zxlive/rule_panel.py | 2 +- 3 files changed, 34 insertions(+), 43 deletions(-) diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index 70fbf5b7..1bf64b54 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -6,8 +6,9 @@ from PySide6.QtCore import Signal from PySide6.QtGui import QAction from PySide6.QtWidgets import (QToolButton) -from pyzx import EdgeType, VertexType, Circuit +from pyzx import EdgeType, VertexType from pyzx.circuit.qasmparser import QASMParser +from pyzx.symbolic import Poly from .base_panel import ToolbarSection from .commands import UpdateGraph @@ -16,7 +17,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import GraphView -from .poly import Poly class GraphEditPanel(EditorBasePanel): diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 79cfeccd..d0881352 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,9 +15,9 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.graph.jsonparser import string_to_phase from pyzx.symbolic import Poly - from .base_panel import BasePanel, ToolbarSection from .commands import (AddEdge, AddNode, AddWNode, ChangeEdgeColor, ChangeNodeType, ChangePhase, MoveNode, SetGraph, @@ -26,7 +26,6 @@ from .dialogs import show_error_msg from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene -from .parse_poly import parse from .vitem import BLACK @@ -98,16 +97,8 @@ def update_colors(self) -> None: super().update_colors() self.update_side_bar() - def update_variable_viewer(self) -> None: - self.update_side_bar() - def _populate_variables(self) -> None: - self.variable_types = {} - for vert in self.graph.vertices(): - phase = self.graph.phase(vert) - if isinstance(phase, Poly): - for var in phase.free_vars(): - self.variable_types[var.name] = var.is_bool + self.variable_types = self.graph.variable_types.copy() def _tool_clicked(self, tool: ToolType) -> None: self.graph_scene.curr_tool = tool @@ -189,18 +180,18 @@ def vert_double_clicked(self, v: VT) -> None: if not ok: return None try: - new_phase = string_to_complex(input_) if phase_is_complex else string_to_fraction(input_, self._new_var) + new_phase = string_to_complex(input_) if phase_is_complex else string_to_phase(input_, graph) except ValueError: show_error_msg("Invalid Input", error_msg) return None cmd = ChangePhase(self.graph_view, v, new_phase) self.undo_stack.push(cmd) - - def _new_var(self, name: str) -> Poly: - if name not in self.variable_types: - self.variable_types[name] = False - self.variable_viewer.add_item(name) - return new_var(name, self.variable_types) + # For some reason it is important we first push to the stack before we do the following. + if len(graph.variable_types) != len(self.variable_types): + new_vars = graph.variable_types.keys() - self.variable_types.keys() + self.variable_types.update(graph.variable_types) + for v in new_vars: + self.variable_viewer.add_item(v) class VariableViewer(QScrollArea): @@ -378,28 +369,28 @@ def create_icon(shape: ShapeType, color: QColor) -> QIcon: return icon -def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: - if not string: - return Fraction(0) - try: - s = string.lower().replace(' ', '') - s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) - if '.' in s or 'e' in s: - return Fraction(float(s)) - elif '/' in s: - a, b = s.split("/", 2) - if not a: - return Fraction(1, int(b)) - if a == '-': - a = '-1' - return Fraction(int(a), int(b)) - else: - return Fraction(int(s)) - except ValueError: - try: - return parse(string, new_var_) - except Exception as e: - raise ValueError(e) +#def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: +# if not string: +# return Fraction(0) +# try: +# s = string.lower().replace(' ', '') +# s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) +# if '.' in s or 'e' in s: +# return Fraction(float(s)) +# elif '/' in s: +# a, b = s.split("/", 2) +# if not a: +# return Fraction(1, int(b)) +# if a == '-': +# a = '-1' +# return Fraction(int(a), int(b)) +# else: +# return Fraction(int(s)) +# except ValueError: +# try: +# return parse(string, new_var_) +# except Exception as e: +# raise ValueError(e) def string_to_complex(string: str) -> complex: diff --git a/zxlive/rule_panel.py b/zxlive/rule_panel.py index 61f4c567..179fca7e 100644 --- a/zxlive/rule_panel.py +++ b/zxlive/rule_panel.py @@ -6,6 +6,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import QLineEdit from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import RuleEditGraphView -from .poly import Poly class RulePanel(EditorBasePanel): From 7cc448f7449b9c76808279640a9df71020a09eca Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:25:36 +0100 Subject: [PATCH 13/20] only check matrix for non-symbolic rules --- zxlive/custom_rule.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index e3dee9f0..515e5288 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -154,13 +154,14 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: from .dialogs import show_error_msg show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different numbers of inputs or outputs.") return False - left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() - if not np.allclose(left_matrix, right_matrix): - if show_error: - from .dialogs import show_error_msg - if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") - else: - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") - return False + if not rule.lhs_graph.variable_types and not rule.rhs_graph.variable_types: + left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() + if not np.allclose(left_matrix, right_matrix): + if show_error: + from .dialogs import show_error_msg + if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") + else: + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") + return False return True From ea0e3ec6c0036ece65acda22f5f51a06a708a5b6 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:26:39 +0100 Subject: [PATCH 14/20] rewrite rule matching with symbolic parameters --- zxlive/custom_rule.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 515e5288..ea0e232e 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -11,6 +11,8 @@ from pyzx.utils import EdgeType, VertexType from shapely import Polygon +from pyzx.symbolic import Poly + from .common import ET, VT, GraphT if TYPE_CHECKING: @@ -75,8 +77,10 @@ def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT] vertices = [v for v in graph.vertices() if in_selection(v)] subgraph_nx, _ = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - if graph_matcher.is_isomorphic(): + node_match=categorical_node_match('type', 1)) + matchings = list(graph_matcher.match()) + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + if len(matchings) > 0: return vertices return [] @@ -102,6 +106,32 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def match_symbolic_parameters(match, left, right): + params = {} + left_phase = left.nodes.data('phase', default=0) + right_phase = right.nodes.data('phase', default=0) + for v in left.nodes(): + if isinstance(left_phase[v], Poly): + if str(left_phase[v]) in params: + if params[str(left_phase)] != right_phase[match[v]]: + raise ValueError("Symbolic parameters do not match") + else: + params[str(left_phase[v])] = right_phase[match[v]] + elif left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + return params + +def filter_matchings_if_symbolic_compatible(matchings, left, right): + new_matchings = [] + for matching in matchings: + try: + match_symbolic_parameters(matching, left, right) + new_matchings.append(matching) + except ValueError: + pass + return new_matchings + + def to_networkx(graph: GraphT) -> nx.Graph: G = nx.Graph() v_data = {v: {"type": graph.type(v), From 4eaa8ca5182cb5d70d9acbfd4985f0a1185af15e Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:28:14 +0100 Subject: [PATCH 15/20] applying custom rule with symbolic parameters --- zxlive/custom_rule.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index ea0e232e..ca549d66 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -33,8 +33,11 @@ def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutputType[ET,VT]: subgraph_nx, boundary_mapping = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - matching = list(graph_matcher.match())[0] + node_match=categorical_node_match('type', 1)) + matchings = graph_matcher.match() + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + matching = matchings[0] + symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx) vertices_to_remove = [] for v in matching: @@ -55,10 +58,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu vertex_map = boundary_vertex_map for v in self.rhs_graph_nx.nodes(): if self.rhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY: + phase = self.rhs_graph_nx.nodes()[v]['phase'] + if isinstance(phase, Poly): + phase = phase.substitute(symbolic_params_map) + if phase.free_vars() == set(): + phase = phase.terms[0][0] vertex_map[v] = graph.add_vertex(ty = self.rhs_graph_nx.nodes()[v]['type'], row = vertex_positions[v][0], qubit = vertex_positions[v][1], - phase = self.rhs_graph_nx.nodes()[v]['phase'],) + phase = phase,) # create etab to add edges etab = {} From 56b1ee5ad652153686b0a7722fe788f9cabb70ca Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:39:35 +0100 Subject: [PATCH 16/20] typo in match_symbolic_parameters --- zxlive/custom_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index ca549d66..65a131eb 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -121,7 +121,7 @@ def match_symbolic_parameters(match, left, right): for v in left.nodes(): if isinstance(left_phase[v], Poly): if str(left_phase[v]) in params: - if params[str(left_phase)] != right_phase[match[v]]: + if params[str(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: params[str(left_phase[v])] = right_phase[match[v]] From f4264a14326852d86de84570d989734ae12d906c Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:33:39 +0100 Subject: [PATCH 17/20] small bug fix --- zxlive/custom_rule.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 65a131eb..44e26af8 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -132,6 +132,8 @@ def match_symbolic_parameters(match, left, right): def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: + if len(matching) != len(left): + continue try: match_symbolic_parameters(matching, left, right) new_matchings.append(matching) From 26bad0ce919ea280ff55755757c9889928003e46 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:34:02 +0100 Subject: [PATCH 18/20] get var method for symbolic parameters --- zxlive/custom_rule.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 44e26af8..c2533c85 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,17 +114,26 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def get_var(v): + if not isinstance(v, Poly): + raise ValueError("Not a symbolic parameter") + if len(v.terms) != 1: + raise ValueError("Only single-term symbolic parameters are supported") + if len(v.terms[0][1].vars) != 1: + raise ValueError("Only single-variable symbolic parameters are supported") + return v.terms[0][1].vars[0][0] + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) for v in left.nodes(): if isinstance(left_phase[v], Poly): - if str(left_phase[v]) in params: - if params[str(left_phase[v])] != right_phase[match[v]]: + if get_var(left_phase[v]) in params: + if params[get_var(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: - params[str(left_phase[v])] = right_phase[match[v]] + params[get_var(left_phase[v])] = right_phase[match[v]] elif left_phase[v] != right_phase[match[v]]: raise ValueError("Parameters do not match") return params From 1053ac1eece6bfdcf178ff8f65bf107db5845099 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 17:35:43 +0100 Subject: [PATCH 19/20] add warnings for custom rules with symbolic parameters --- zxlive/custom_rule.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index c2533c85..accfe13d 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -213,4 +213,17 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: else: show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") return False + else: + if not (rule.rhs_graph.variable_types.items() <= rule.lhs_graph.variable_types.items()): + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "The right-hand side has more free variables than the left-hand side.") + return False + for vertex in rule.lhs_graph.vertices(): + if isinstance(rule.lhs_graph.phase(vertex), Poly): + if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + return False return True From 40f1fa4a46cea4b9bd0988905e6efa47b4d960e1 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 18:47:17 +0100 Subject: [PATCH 20/20] symbolic rewrites support linear terms Co-authored-by: Tuomas Laakkonen --- zxlive/custom_rule.py | 64 +++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index accfe13d..349a9463 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,30 +114,62 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) -def get_var(v): +def get_linear(v): if not isinstance(v, Poly): raise ValueError("Not a symbolic parameter") - if len(v.terms) != 1: - raise ValueError("Only single-term symbolic parameters are supported") - if len(v.terms[0][1].vars) != 1: - raise ValueError("Only single-variable symbolic parameters are supported") - return v.terms[0][1].vars[0][0] + if len(v.terms) > 2 or len(v.free_vars()) > 1: + raise ValueError("Only linear symbolic parameters are supported") + if len(v.terms) == 0: + return 1, None, 0 + elif len(v.terms) == 1: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = 0 + else: + const = v.terms[0][0] + return 1, None, const + else: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = v.terms[1][0] + else: + var_term = v.terms[1] + const = v.terms[0][0] + coeff = var_term[0] + var, power = var_term[1].vars[0] + if power != 1: + raise ValueError("Only linear symbolic parameters are supported") + return coeff, var, const + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) + + def check_phase_equality(v): + if left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + + def update_params(v, var, coeff, const): + var_value = (right_phase[match[v]] - const) / coeff + if var in params and params[var] != var_value: + raise ValueError("Symbolic parameters do not match") + params[var] = var_value + for v in left.nodes(): if isinstance(left_phase[v], Poly): - if get_var(left_phase[v]) in params: - if params[get_var(left_phase[v])] != right_phase[match[v]]: - raise ValueError("Symbolic parameters do not match") - else: - params[get_var(left_phase[v])] = right_phase[match[v]] - elif left_phase[v] != right_phase[match[v]]: - raise ValueError("Parameters do not match") + coeff, var, const = get_linear(left_phase[v]) + if var is None: + check_phase_equality(v) + continue + update_params(v, var, coeff, const) + else: + check_phase_equality(v) + return params + def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: @@ -221,9 +253,11 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: return False for vertex in rule.lhs_graph.vertices(): if isinstance(rule.lhs_graph.phase(vertex), Poly): - if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + try: + get_linear(rule.lhs_graph.phase(vertex)) + except ValueError as e: if show_error: from .dialogs import show_error_msg - show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + show_error_msg("Warning!", str(e)) return False return True