Skip to content

Commit

Permalink
adding a lot of type annotations in custom_rules.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RazinShaikh committed Jul 4, 2024
1 parent 640313a commit d3bafb7
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions zxlive/custom_rule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import json
from fractions import Fraction
from typing import TYPE_CHECKING, Callable, Sequence, Dict, Union
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Dict, Union

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -94,7 +94,7 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu
return etab, vertices_to_remove, [], True

def unfuse_subgraph_for_rewrite(self, graph: GraphT, vertices: list[VT]) -> None:
def get_adjacent_boundary_vertices(g, v) -> Sequence[VT]:
def get_adjacent_boundary_vertices(g: nx.Graph, v: VT) -> Sequence[VT]:
return [n for n in g.neighbors(v) if g.nodes()[n]['type'] == VertexType.BOUNDARY]

subgraph_nx_without_boundaries = nx.Graph(to_networkx(graph).subgraph(vertices))
Expand All @@ -121,26 +121,26 @@ def get_adjacent_boundary_vertices(g, v) -> Sequence[VT]:
elif vtype == VertexType.W_OUTPUT or vtype == VertexType.W_INPUT:
self.unfuse_w_vertex(graph, subgraph_nx, matching[v], vtype)

def unfuse_update_edges(self, graph, subgraph_nx, old_v, new_v) -> None:
def unfuse_update_edges(self, graph: GraphT, subgraph_nx: nx.Graph, old_v: VT, new_v: VT) -> None:
neighbors = list(graph.neighbors(old_v))
for b in neighbors:
if b not in subgraph_nx.nodes:
graph.add_edge((new_v, b), graph.edge_type((old_v, b)))
graph.remove_edge(graph.edge(old_v, b))

def unfuse_zx_vertex(self, graph, subgraph_nx, v, vtype) -> None:
def unfuse_zx_vertex(self, graph: GraphT, subgraph_nx: nx.Graph, v: VT, vtype: VertexType) -> None:
new_v = graph.add_vertex(vtype, qubit=graph.qubit(v), row=graph.row(v))
self.unfuse_update_edges(graph, subgraph_nx, v, new_v)
graph.add_edge(graph.edge(new_v, v))

def unfuse_h_box_vertex(self, graph, subgraph_nx, v) -> None:
def unfuse_h_box_vertex(self, graph: GraphT, subgraph_nx: nx.Graph, v: VT) -> None:
new_h = graph.add_vertex(VertexType.H_BOX, qubit=graph.qubit(v)+0.3, row=graph.row(v)+0.3)
new_mid_h = graph.add_vertex(VertexType.H_BOX, qubit=graph.qubit(v), row=graph.row(v))
self.unfuse_update_edges(graph, subgraph_nx, v, new_h)
graph.add_edge((new_mid_h, v))
graph.add_edge((new_h, new_mid_h))

def unfuse_w_vertex(self, graph, subgraph_nx, v, vtype) -> None:
def unfuse_w_vertex(self, graph: GraphT, subgraph_nx: nx.Graph, v: VT, vtype: VertexType) -> None:
w_in, w_out = get_w_io(graph, v)
new_w_in = graph.add_vertex(VertexType.W_INPUT, qubit=graph.qubit(w_in), row=graph.row(w_in))
new_w_out = graph.add_vertex(VertexType.W_OUTPUT, qubit=graph.qubit(w_out), row=graph.row(w_out))
Expand Down Expand Up @@ -203,7 +203,7 @@ def is_rewrite_unfusable(lhs_graph: GraphT) -> bool:
return False
return True

def get_linear(v):
def get_linear(v: Poly) -> tuple[Union[int, float, complex, Fraction], Optional[Var], Union[int, float, complex, Fraction]]:
if not isinstance(v, Poly):
raise ValueError("Not a symbolic parameter")
if len(v.terms) > 2 or len(v.free_vars()) > 1:
Expand All @@ -213,7 +213,7 @@ def get_linear(v):
elif len(v.terms) == 1:
if len(v.terms[0][1].vars) > 0:
var_term = v.terms[0]
const = 0
const: Union[int, float, complex, Fraction] = 0
else:
const = v.terms[0][0]
return 1, None, const
Expand All @@ -231,16 +231,16 @@ def get_linear(v):
return coeff, var, const


def match_symbolic_parameters(match, left: nx.Graph, right: nx.Graph) -> Dict[Var, Union[float, complex, Fraction]]:
params: Dict[Var, Union[float, complex, Fraction]] = {}
left_phase = left.nodes.data('phase', default=0)
right_phase = right.nodes.data('phase', default=0)
def match_symbolic_parameters(match: Dict[VT, VT], left: nx.Graph, right: nx.Graph) -> Dict[Var, Union[int, float, complex, Fraction]]:
params: Dict[Var, Union[int, float, complex, Fraction]] = {}
left_phase = left.nodes.data('phase', default=0) # type: ignore
right_phase = right.nodes.data('phase', default=0) # type: ignore

def check_phase_equality(v) -> None:
def check_phase_equality(v: VT) -> None:
if left_phase[v] != right_phase[match[v]]:
raise ValueError("Parameters do not match")

def update_params(v, var, coeff, const) -> None:
def update_params(v: VT, var: Var, coeff: Union[int, float, complex, Fraction], const: Union[int, float, complex, Fraction]) -> None:
var_value = (right_phase[match[v]] - const) / coeff
if var in params and params[var] != var_value:
raise ValueError("Symbolic parameters do not match")
Expand All @@ -259,7 +259,7 @@ def update_params(v, var, coeff, const) -> None:
return params


def filter_matchings_if_symbolic_compatible(matchings, left, right):
def filter_matchings_if_symbolic_compatible(matchings: list[Dict[VT, VT]], left: nx.Graph, right: nx.Graph) -> list[Dict[VT, VT]]:
new_matchings = []
for matching in matchings:
if len(matching) != len(left):
Expand Down

0 comments on commit d3bafb7

Please sign in to comment.