From d3bafb7a5ca046c3b2bcfc4ec4107be8d4b68c9e Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Fri, 5 Jul 2024 00:28:54 +0100 Subject: [PATCH] adding a lot of type annotations in custom_rules.py --- zxlive/custom_rule.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index aa67632c..79fd2fbc 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -1,7 +1,7 @@ import json from fractions import Fraction -from typing import TYPE_CHECKING, Callable, Sequence, Dict, Union +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Dict, Union import networkx as nx import numpy as np @@ -94,7 +94,7 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu return etab, vertices_to_remove, [], True def unfuse_subgraph_for_rewrite(self, graph: GraphT, vertices: list[VT]) -> None: - def get_adjacent_boundary_vertices(g, v) -> Sequence[VT]: + def get_adjacent_boundary_vertices(g: nx.Graph, v: VT) -> Sequence[VT]: return [n for n in g.neighbors(v) if g.nodes()[n]['type'] == VertexType.BOUNDARY] subgraph_nx_without_boundaries = nx.Graph(to_networkx(graph).subgraph(vertices)) @@ -121,26 +121,26 @@ def get_adjacent_boundary_vertices(g, v) -> Sequence[VT]: elif vtype == VertexType.W_OUTPUT or vtype == VertexType.W_INPUT: self.unfuse_w_vertex(graph, subgraph_nx, matching[v], vtype) - def unfuse_update_edges(self, graph, subgraph_nx, old_v, new_v) -> None: + def unfuse_update_edges(self, graph: GraphT, subgraph_nx: nx.Graph, old_v: VT, new_v: VT) -> None: neighbors = list(graph.neighbors(old_v)) for b in neighbors: if b not in subgraph_nx.nodes: graph.add_edge((new_v, b), graph.edge_type((old_v, b))) graph.remove_edge(graph.edge(old_v, b)) - def unfuse_zx_vertex(self, graph, subgraph_nx, v, vtype) -> None: + def unfuse_zx_vertex(self, graph: GraphT, subgraph_nx: nx.Graph, v: VT, vtype: VertexType) -> None: new_v = graph.add_vertex(vtype, qubit=graph.qubit(v), row=graph.row(v)) self.unfuse_update_edges(graph, subgraph_nx, v, new_v) graph.add_edge(graph.edge(new_v, v)) - def unfuse_h_box_vertex(self, graph, subgraph_nx, v) -> None: + def unfuse_h_box_vertex(self, graph: GraphT, subgraph_nx: nx.Graph, v: VT) -> None: new_h = graph.add_vertex(VertexType.H_BOX, qubit=graph.qubit(v)+0.3, row=graph.row(v)+0.3) new_mid_h = graph.add_vertex(VertexType.H_BOX, qubit=graph.qubit(v), row=graph.row(v)) self.unfuse_update_edges(graph, subgraph_nx, v, new_h) graph.add_edge((new_mid_h, v)) graph.add_edge((new_h, new_mid_h)) - def unfuse_w_vertex(self, graph, subgraph_nx, v, vtype) -> None: + def unfuse_w_vertex(self, graph: GraphT, subgraph_nx: nx.Graph, v: VT, vtype: VertexType) -> None: w_in, w_out = get_w_io(graph, v) new_w_in = graph.add_vertex(VertexType.W_INPUT, qubit=graph.qubit(w_in), row=graph.row(w_in)) new_w_out = graph.add_vertex(VertexType.W_OUTPUT, qubit=graph.qubit(w_out), row=graph.row(w_out)) @@ -203,7 +203,7 @@ def is_rewrite_unfusable(lhs_graph: GraphT) -> bool: return False return True -def get_linear(v): +def get_linear(v: Poly) -> tuple[Union[int, float, complex, Fraction], Optional[Var], Union[int, float, complex, Fraction]]: if not isinstance(v, Poly): raise ValueError("Not a symbolic parameter") if len(v.terms) > 2 or len(v.free_vars()) > 1: @@ -213,7 +213,7 @@ def get_linear(v): elif len(v.terms) == 1: if len(v.terms[0][1].vars) > 0: var_term = v.terms[0] - const = 0 + const: Union[int, float, complex, Fraction] = 0 else: const = v.terms[0][0] return 1, None, const @@ -231,16 +231,16 @@ def get_linear(v): return coeff, var, const -def match_symbolic_parameters(match, left: nx.Graph, right: nx.Graph) -> Dict[Var, Union[float, complex, Fraction]]: - params: Dict[Var, Union[float, complex, Fraction]] = {} - left_phase = left.nodes.data('phase', default=0) - right_phase = right.nodes.data('phase', default=0) +def match_symbolic_parameters(match: Dict[VT, VT], left: nx.Graph, right: nx.Graph) -> Dict[Var, Union[int, float, complex, Fraction]]: + params: Dict[Var, Union[int, float, complex, Fraction]] = {} + left_phase = left.nodes.data('phase', default=0) # type: ignore + right_phase = right.nodes.data('phase', default=0) # type: ignore - def check_phase_equality(v) -> None: + def check_phase_equality(v: VT) -> None: if left_phase[v] != right_phase[match[v]]: raise ValueError("Parameters do not match") - def update_params(v, var, coeff, const) -> None: + def update_params(v: VT, var: Var, coeff: Union[int, float, complex, Fraction], const: Union[int, float, complex, Fraction]) -> None: var_value = (right_phase[match[v]] - const) / coeff if var in params and params[var] != var_value: raise ValueError("Symbolic parameters do not match") @@ -259,7 +259,7 @@ def update_params(v, var, coeff, const) -> None: return params -def filter_matchings_if_symbolic_compatible(matchings, left, right): +def filter_matchings_if_symbolic_compatible(matchings: list[Dict[VT, VT]], left: nx.Graph, right: nx.Graph) -> list[Dict[VT, VT]]: new_matchings = [] for matching in matchings: if len(matching) != len(left):