diff --git a/zxlive/app.py b/zxlive/app.py index e525f4fc..56a2991c 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,11 +19,11 @@ from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon import sys +sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx + from .mainwindow import MainWindow from .common import get_data -#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx - # The following hack is needed on windows in order to show the icon in the taskbar # See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105 import os diff --git a/zxlive/commands.py b/zxlive/commands.py index 880142f3..29f15b84 100644 --- a/zxlive/commands.py +++ b/zxlive/commands.py @@ -11,11 +11,11 @@ from PySide6.QtWidgets import QListView from pyzx import basicrules from pyzx.graph import GraphDiff +from pyzx.symbolic import Poly from pyzx.utils import EdgeType, VertexType, get_w_partner, vertex_is_w, get_w_io, get_z_box_label, set_z_box_label from .common import ET, VT, W_INPUT_OFFSET, GraphT from .graphview import GraphView -from .poly import Poly from .proof import ProofModel, Rewrite diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index e3dee9f0..349a9463 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -11,6 +11,8 @@ from pyzx.utils import EdgeType, VertexType from shapely import Polygon +from pyzx.symbolic import Poly + from .common import ET, VT, GraphT if TYPE_CHECKING: @@ -31,8 +33,11 @@ def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutputType[ET,VT]: subgraph_nx, boundary_mapping = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - matching = list(graph_matcher.match())[0] + node_match=categorical_node_match('type', 1)) + matchings = graph_matcher.match() + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + matching = matchings[0] + symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx) vertices_to_remove = [] for v in matching: @@ -53,10 +58,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu vertex_map = boundary_vertex_map for v in self.rhs_graph_nx.nodes(): if self.rhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY: + phase = self.rhs_graph_nx.nodes()[v]['phase'] + if isinstance(phase, Poly): + phase = phase.substitute(symbolic_params_map) + if phase.free_vars() == set(): + phase = phase.terms[0][0] vertex_map[v] = graph.add_vertex(ty = self.rhs_graph_nx.nodes()[v]['type'], row = vertex_positions[v][0], qubit = vertex_positions[v][1], - phase = self.rhs_graph_nx.nodes()[v]['phase'],) + phase = phase,) # create etab to add edges etab = {} @@ -75,8 +85,10 @@ def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT] vertices = [v for v in graph.vertices() if in_selection(v)] subgraph_nx, _ = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - if graph_matcher.is_isomorphic(): + node_match=categorical_node_match('type', 1)) + matchings = list(graph_matcher.match()) + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + if len(matchings) > 0: return vertices return [] @@ -102,6 +114,75 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def get_linear(v): + if not isinstance(v, Poly): + raise ValueError("Not a symbolic parameter") + if len(v.terms) > 2 or len(v.free_vars()) > 1: + raise ValueError("Only linear symbolic parameters are supported") + if len(v.terms) == 0: + return 1, None, 0 + elif len(v.terms) == 1: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = 0 + else: + const = v.terms[0][0] + return 1, None, const + else: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = v.terms[1][0] + else: + var_term = v.terms[1] + const = v.terms[0][0] + coeff = var_term[0] + var, power = var_term[1].vars[0] + if power != 1: + raise ValueError("Only linear symbolic parameters are supported") + return coeff, var, const + + +def match_symbolic_parameters(match, left, right): + params = {} + left_phase = left.nodes.data('phase', default=0) + right_phase = right.nodes.data('phase', default=0) + + def check_phase_equality(v): + if left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + + def update_params(v, var, coeff, const): + var_value = (right_phase[match[v]] - const) / coeff + if var in params and params[var] != var_value: + raise ValueError("Symbolic parameters do not match") + params[var] = var_value + + for v in left.nodes(): + if isinstance(left_phase[v], Poly): + coeff, var, const = get_linear(left_phase[v]) + if var is None: + check_phase_equality(v) + continue + update_params(v, var, coeff, const) + else: + check_phase_equality(v) + + return params + + +def filter_matchings_if_symbolic_compatible(matchings, left, right): + new_matchings = [] + for matching in matchings: + if len(matching) != len(left): + continue + try: + match_symbolic_parameters(matching, left, right) + new_matchings.append(matching) + except ValueError: + pass + return new_matchings + + def to_networkx(graph: GraphT) -> nx.Graph: G = nx.Graph() v_data = {v: {"type": graph.type(v), @@ -154,13 +235,29 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: from .dialogs import show_error_msg show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different numbers of inputs or outputs.") return False - left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() - if not np.allclose(left_matrix, right_matrix): - if show_error: - from .dialogs import show_error_msg - if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") - else: - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") - return False + if not rule.lhs_graph.variable_types and not rule.rhs_graph.variable_types: + left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() + if not np.allclose(left_matrix, right_matrix): + if show_error: + from .dialogs import show_error_msg + if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") + else: + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") + return False + else: + if not (rule.rhs_graph.variable_types.items() <= rule.lhs_graph.variable_types.items()): + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "The right-hand side has more free variables than the left-hand side.") + return False + for vertex in rule.lhs_graph.vertices(): + if isinstance(rule.lhs_graph.phase(vertex), Poly): + try: + get_linear(rule.lhs_graph.phase(vertex)) + except ValueError as e: + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", str(e)) + return False return True diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index 70fbf5b7..1bf64b54 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -6,8 +6,9 @@ from PySide6.QtCore import Signal from PySide6.QtGui import QAction from PySide6.QtWidgets import (QToolButton) -from pyzx import EdgeType, VertexType, Circuit +from pyzx import EdgeType, VertexType from pyzx.circuit.qasmparser import QASMParser +from pyzx.symbolic import Poly from .base_panel import ToolbarSection from .commands import UpdateGraph @@ -16,7 +17,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import GraphView -from .poly import Poly class GraphEditPanel(EditorBasePanel): diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 9230491d..74f52022 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -3,8 +3,7 @@ import copy import re from enum import Enum -from fractions import Fraction -from typing import Callable, Iterator, TypedDict, Union +from typing import Callable, Iterator, TypedDict from PySide6.QtCore import QPoint, QSize, Qt, Signal from PySide6.QtGui import (QAction, QColor, QIcon, QPainter, QPalette, QPen, @@ -15,7 +14,7 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w - +from pyzx.graph.jsonparser import string_to_phase from .base_panel import BasePanel, ToolbarSection from .commands import (AddEdge, AddNode, AddWNode, ChangeEdgeColor, @@ -25,8 +24,6 @@ from .dialogs import show_error_msg from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene -from .parse_poly import parse -from .poly import Poly, new_var from .vitem import BLACK @@ -98,16 +95,8 @@ def update_colors(self) -> None: super().update_colors() self.update_side_bar() - def update_variable_viewer(self) -> None: - self.update_side_bar() - def _populate_variables(self) -> None: - self.variable_types = {} - for vert in self.graph.vertices(): - phase = self.graph.phase(vert) - if isinstance(phase, Poly): - for var in phase.free_vars(): - self.variable_types[var.name] = var.is_bool + self.variable_types = self.graph.variable_types.copy() def _tool_clicked(self, tool: ToolType) -> None: self.graph_scene.curr_tool = tool @@ -189,18 +178,18 @@ def vert_double_clicked(self, v: VT) -> None: if not ok: return None try: - new_phase = string_to_complex(input_) if phase_is_complex else string_to_fraction(input_, self._new_var) + new_phase = string_to_complex(input_) if phase_is_complex else string_to_phase(input_, graph) except ValueError: show_error_msg("Invalid Input", error_msg) return None cmd = ChangePhase(self.graph_view, v, new_phase) self.undo_stack.push(cmd) - - def _new_var(self, name: str) -> Poly: - if name not in self.variable_types: - self.variable_types[name] = False - self.variable_viewer.add_item(name) - return new_var(name, self.variable_types) + # For some reason it is important we first push to the stack before we do the following. + if len(graph.variable_types) != len(self.variable_types): + new_vars = graph.variable_types.keys() - self.variable_types.keys() + self.variable_types.update(graph.variable_types) + for v in new_vars: + self.variable_viewer.add_item(v) class VariableViewer(QScrollArea): @@ -378,29 +367,5 @@ def create_icon(shape: ShapeType, color: QColor) -> QIcon: return icon -def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: - if not string: - return Fraction(0) - try: - s = string.lower().replace(' ', '') - s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) - if '.' in s or 'e' in s: - return Fraction(float(s)) - elif '/' in s: - a, b = s.split("/", 2) - if not a: - return Fraction(1, int(b)) - if a == '-': - a = '-1' - return Fraction(int(a), int(b)) - else: - return Fraction(int(s)) - except ValueError: - try: - return parse(string, new_var_) - except Exception as e: - raise ValueError(e) - - def string_to_complex(string: str) -> complex: return complex(string) if string else complex(0) diff --git a/zxlive/parse_poly.py b/zxlive/parse_poly.py deleted file mode 100644 index 42e9da97..00000000 --- a/zxlive/parse_poly.py +++ /dev/null @@ -1,58 +0,0 @@ -from .poly import Poly, new_const - -from typing import Any, Callable -from lark import Lark, Transformer -from functools import reduce -from operator import add, mul -from fractions import Fraction - -poly_grammar = Lark(""" - start : "(" start ")" | term ("+" term)* - term : (intf | frac)? factor ("*" factor)* - ?factor : intf | frac | pi | pifrac | var - var : CNAME - intf : INT - pi : "\\pi" | "pi" - frac : INT "/" INT - pifrac : [INT] pi "/" INT - - %import common.INT - %import common.CNAME - %import common.WS - %ignore WS - """, - parser='lalr', - maybe_placeholders=True) - -class PolyTransformer(Transformer[Poly]): - def __init__(self, new_var: Callable[[str], Poly]): - super().__init__() - - self._new_var = new_var - - def start(self, items: list[Poly]) -> Poly: - return reduce(add, items) - - def term(self, items: list[Poly]) -> Poly: - return reduce(mul, items) - - def var(self, items: list[Any]) -> Poly: - v = str(items[0]) - return self._new_var(v) - - def pi(self, _: list[Any]) -> Poly: - return new_const(1) - - def intf(self, items: list[Any]) -> Poly: - return new_const(int(items[0])) - - def frac(self, items: list[Any]) -> Poly: - return new_const(Fraction(int(items[0]), int(items[1]))) - - def pifrac(self, items: list[Any]) -> Poly: - numerator = int(items[0]) if items[0] else 1 - return new_const(Fraction(numerator, int(items[2]))) - -def parse(expr: str, new_var: Callable[[str], Poly]) -> Poly: - tree = poly_grammar.parse(expr) - return PolyTransformer(new_var).transform(tree) diff --git a/zxlive/poly.py b/zxlive/poly.py deleted file mode 100644 index 0dfac718..00000000 --- a/zxlive/poly.py +++ /dev/null @@ -1,180 +0,0 @@ -from fractions import Fraction -from typing import Union, Optional - - -class Var: - name: str - _is_bool: bool - _types_dict: Optional[Union[bool, dict[str, bool]]] - - def __init__(self, name: str, data: Union[bool, dict[str, bool]]): - self.name = name - if isinstance(data, dict): - self._types_dict = data - self._frozen = False - self._is_bool = False - else: - self._types_dict = None - self._frozen = True - self._is_bool = data - - @property - def is_bool(self) -> bool: - if self._frozen: - return self._is_bool - else: - assert isinstance(self._types_dict, dict) - return self._types_dict[self.name] - - def __repr__(self) -> str: - return self.name - - def __lt__(self, other: 'Var') -> bool: - if int(self.is_bool) == int(other.is_bool): - return self.name < other.name - return int(self.is_bool) < int(other.is_bool) - - def __hash__(self) -> int: - # Variables with the same name map to the same type - # within the same graph, so no need to include is_bool - # in the hash. - return int(hash(self.name)) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - def freeze(self) -> None: - if not self._frozen: - assert isinstance(self._types_dict, dict) - self._is_bool = self._types_dict[self.name] - self._frozen = True - self._types_dict = None - - def __copy__(self) -> 'Var': - if self._frozen: - return Var(self.name, self.is_bool) - else: - assert isinstance(self._types_dict, dict) - return Var(self.name, self._types_dict) - - def __deepcopy__(self, _memo: object) -> 'Var': - return self.__copy__() - -class Term: - vars: list[tuple[Var, int]] - - def __init__(self, vars: list[tuple[Var,int]]) -> None: - self.vars = vars - - def freeze(self) -> None: - for var, _ in self.vars: - var.freeze() - - def free_vars(self) -> set[Var]: - return set(var for var, _ in self.vars) - - def __repr__(self) -> str: - vs = [] - for v, c in self.vars: - if c == 1: - vs.append(f'{v}') - else: - vs.append(f'{v}^{c}') - return '*'.join(vs) - - def __mul__(self, other: 'Term') -> 'Term': - vs = dict() - for v, c in self.vars + other.vars: - if v not in vs: vs[v] = c - else: vs[v] += c - # TODO deal with fractional / symbolic powers - if v.is_bool and c > 1: - vs[v] = 1 - return Term([(v, c) for v, c in vs.items()]) - - def __hash__(self) -> int: - return hash(tuple(sorted(self.vars))) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - -class Poly: - terms: list[tuple[Union[int, float, Fraction], Term]] - - def __init__(self, terms: list[tuple[Union[int, float, Fraction], Term]]) -> None: - self.terms = terms - - def freeze(self) -> None: - for _, term in self.terms: - term.freeze() - - def free_vars(self) -> set[Var]: - output = set() - for _, term in self.terms: - output.update(term.free_vars()) - return output - - def __add__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float, Fraction)): - other = Poly([(other, Term([]))]) - counter = dict() - for c, t in self.terms + other.terms: - if t not in counter: counter[t] = c - else: counter[t] += c - if all(tt[0].is_bool for tt in t.vars): - counter[t] = counter[t] % 2 - - # remove terms with coefficient 0 - for t in list(counter.keys()): - if counter[t] == 0: - del counter[t] - return Poly([(c, t) for t, c in counter.items()]) - - __radd__ = __add__ - - def __mul__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float)): - other = Poly([(other, Term([]))]) - p = Poly([]) - for c1, t1 in self.terms: - for c2, t2 in other.terms: - p += Poly([(c1 * c2, t1 * t2)]) - return p - - __rmul__ = __mul__ - - def __repr__(self) -> str: - ts = [] - for c, t in self.terms: - if t == Term([]): - ts.append(f'{c}') - elif c == 1: - ts.append(f'{t}') - else: - ts.append(f'{c}{t}') - return ' + '.join(ts) - - def __eq__(self, other: object) -> bool: - if isinstance(other, (int, float, Fraction)): - if other == 0: - other = Poly([]) - else: - other = Poly([(other, Term([]))]) - assert isinstance(other, Poly) - return set(self.terms) == set(other.terms) - - @property - def is_pauli(self) -> bool: - for c, t in self.terms: - if not all(v.is_bool for v, _ in t.vars): - return False - if c % 1 != 0: - return False - return True - -def new_var(name: str, types_dict: Union[bool, dict[str, bool]]) -> Poly: - return Poly([(1, Term([(Var(name, types_dict), 1)]))]) - -def new_const(coeff: Union[int, Fraction]) -> Poly: - return Poly([(coeff, Term([]))]) diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index affd9414..1697257d 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -15,7 +15,8 @@ QStyleOptionViewItem, QToolButton, QWidget, QVBoxLayout, QTabWidget, QInputDialog) from pyzx import VertexType, basicrules -from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType +from pyzx.graph.jsonparser import string_to_phase +from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType, FractionLike from . import animations as anims from . import proof_actions @@ -29,9 +30,8 @@ from .graphscene import GraphScene from .graphview import GraphTool, GraphView, WandTrace from .proof import ProofModel -from .vitem import DragState, VItem, get_w_partner_vitem, W_INPUT_OFFSET, SCALE -from .editor_base_panel import string_to_complex, string_to_fraction -from .poly import Poly +from .vitem import DragState, VItem, W_INPUT_OFFSET, SCALE +from .editor_base_panel import string_to_complex class ProofPanel(BasePanel): @@ -245,9 +245,7 @@ def cross(a: QPointF, b: QPointF) -> float: if not ok: return False try: - def new_var(_: str) -> Poly: - raise ValueError() - phase = string_to_complex(text) if phase_is_complex else string_to_fraction(text, new_var) + phase = string_to_complex(text) if phase_is_complex else string_to_phase(input_, graph) except ValueError: show_error_msg("Invalid Input", error_msg) return False @@ -326,7 +324,7 @@ def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> Non cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "unfuse") self.undo_stack.push(cmd, anim_after=anim) - def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: Poly | complex | Fraction) -> None: + def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: FractionLike) -> None: def snap_vector(v: QVector2D) -> None: if abs(v.x()) > abs(v.y()): v.setY(0.0) diff --git a/zxlive/rule_panel.py b/zxlive/rule_panel.py index 61f4c567..179fca7e 100644 --- a/zxlive/rule_panel.py +++ b/zxlive/rule_panel.py @@ -6,6 +6,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import QLineEdit from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import RuleEditGraphView -from .poly import Poly class RulePanel(EditorBasePanel):