Skip to content

Commit

Permalink
implement w node unfusion for custom rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
RazinShaikh committed Nov 19, 2023
1 parent c0b7709 commit b074998
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions zxlive/custom_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from networkx.algorithms.isomorphism import (GraphMatcher,
categorical_node_match)
from networkx.classes.reportviews import NodeView
from pyzx.utils import EdgeType, VertexType
from pyzx.utils import EdgeType, VertexType, get_w_io
from shapely import Polygon

from pyzx.symbolic import Poly
Expand Down Expand Up @@ -43,6 +43,8 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu
node_match=categorical_node_match('type', 1))
matchings = graph_matcher.match()
matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx)
if len(matchings) == 0:
raise ValueError("No matchings found")
matching = matchings[0]
symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx)

Expand Down Expand Up @@ -103,15 +105,18 @@ def get_adjacent_boundary_vertices(graph, v):
for v in matching:
if len(get_adjacent_boundary_vertices(self.lhs_graph_nx, v)) != 1:
continue
vtype = self.lhs_graph_nx.nodes()[v]['type']
outside_verts = get_adjacent_boundary_vertices(subgraph_nx, matching[v])
if len(outside_verts) == 1 and \
subgraph_nx.edges()[(matching[v], outside_verts[0])]['type'] == EdgeType.SIMPLE:
subgraph_nx.edges()[(matching[v], outside_verts[0])]['type'] == EdgeType.SIMPLE and \
vtype != VertexType.W_INPUT:
continue
vtype = self.lhs_graph_nx.nodes()[v]['type']
if vtype == VertexType.Z or vtype == VertexType.X or vtype == VertexType.Z_BOX:
self.unfuse_zx_vertex(graph, subgraph_nx, matching[v], vtype)
elif vtype == VertexType.H_BOX:
self.unfuse_h_box_vertex(graph, subgraph_nx, matching[v])
elif vtype == VertexType.W_OUTPUT or vtype == VertexType.W_INPUT:
self.unfuse_w_vertex(graph, subgraph_nx, matching[v], vtype)

def unfuse_update_edges(self, graph, subgraph_nx, old_v, new_v):
neighbors = list(graph.neighbors(old_v))
Expand All @@ -132,6 +137,17 @@ def unfuse_h_box_vertex(self, graph, subgraph_nx, v):
graph.add_edge((new_mid_h, v))
graph.add_edge((new_h, new_mid_h))

def unfuse_w_vertex(self, graph, subgraph_nx, v, vtype):
w_in, w_out = get_w_io(graph, v)
new_w_in = graph.add_vertex(VertexType.W_INPUT, qubit=graph.qubit(w_in), row=graph.row(w_in))
new_w_out = graph.add_vertex(VertexType.W_OUTPUT, qubit=graph.qubit(w_out), row=graph.row(w_out))
self.unfuse_update_edges(graph, subgraph_nx, w_in, new_w_in)
self.unfuse_update_edges(graph, subgraph_nx, w_out, new_w_out)
if vtype == VertexType.W_OUTPUT:
graph.add_edge((new_w_in, w_out))
else:
graph.add_edge((w_in, new_w_out))
graph.add_edge((new_w_in, new_w_out), EdgeType.W_IO)

def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT]:
vertices = [v for v in graph.vertices() if in_selection(v)]
Expand Down

0 comments on commit b074998

Please sign in to comment.