Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): automatically add state order edges for inter-graph edges #1165

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._hugr import Hugr, Node, Wire, OutPort

from ._ops import Op, Command, Input, Output, DFG
from ._exceptions import NoSiblingAncestor
from hugr.serialization.tys import FunctionType, Type


Expand Down Expand Up @@ -67,12 +68,25 @@

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
# breaks if further edges are added
self.hugr.add_link(
src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst))
)
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)

Check warning on line 78 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L78

Added line #L78 was not covered by tests
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(i))


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
src_parent = h[src].parent

while (tgt_parent := h[tgt].parent) is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could get away with this:

Suggested change
while (tgt_parent := h[tgt].parent) is not None:
while tgt_parent := h[tgt].parent:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i prefer explicit checks, python truthiness is gross

if tgt_parent == src_parent:
return tgt
tgt = tgt_parent

return None

Check warning on line 92 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L92

Added line #L92 was not covered by tests
11 changes: 11 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass


@dataclass
class NoSiblingAncestor(Exception):
src: int
tgt: int

@property
def msg(self):
return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up."

Check warning on line 11 in hugr-py/src/hugr/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_exceptions.py#L11

Added line #L11 was not covered by tests
36 changes: 29 additions & 7 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def linked_ports(self, port: OutPort | InPort):

# TODO: single linked port

def outgoing_order_links(self, node: Node) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.out(-1)))

def incoming_order_links(self, node: Node) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.inp(-1)))

def _node_links(
self, node: Node, links: dict[_SubPort[P], _SubPort[K]]
) -> Iterable[tuple[P, list[K]]]:
Expand Down Expand Up @@ -320,19 +326,35 @@ def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) ->

def to_serial(self) -> SerialHugr:
node_it = (node for node in self._nodes if node is not None)

def _serialise_link(
link: tuple[_SO, _SI],
) -> tuple[tuple[int, int], tuple[int, int]]:
src, dst = link
s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port)
return (src.port.node.idx, s), (dst.port.node.idx, d)

return SerialHugr(
version="v1",
# non contiguous indices will be erased
nodes=[node.to_serial(Node(idx), self) for idx, node in enumerate(node_it)],
edges=[
(
(src.port.node.idx, src.port.offset),
(dst.port.node.idx, dst.port.offset),
)
for src, dst in self._links.items()
],
edges=[_serialise_link(link) for link in self._links.items()],
)

def _constrain_offset(self, p: P) -> int:
# negative offsets are used to refer to the last port
if p.offset < 0:
match p.direction:
case Direction.INCOMING:
current = self.num_incoming(p.node)
case Direction.OUTGOING:
current = self.num_outgoing(p.node)
offset = current + p.offset + 1
else:
offset = p.offset

return offset

@classmethod
def from_serial(cls, serial: SerialHugr) -> Hugr:
assert serial.nodes, "Empty Hugr is invalid"
Expand Down
30 changes: 22 additions & 8 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import subprocess
import os
import pathlib
from hugr._hugr import Hugr, Node, Wire
from hugr._dfg import Dfg
from hugr._hugr import Hugr, Node, Wire, _SubPort
from hugr._dfg import Dfg, _ancestral_sibling
from hugr._ops import Custom, Command
import hugr._ops as ops
from hugr.serialization import SerialHugr
Expand Down Expand Up @@ -226,15 +226,29 @@ def _nested_nop(dfg: Dfg):


def test_build_inter_graph():
h = Dfg.endo([BOOL_T, BOOL_T])
(a, b) = h.inputs()
nested = h.add_nested([], [BOOL_T])

nt = nested.add(Not(a))
nested.set_outputs(nt)

h.set_outputs(nested.root, b)

_validate(h.hugr, True)

assert _SubPort(h.input_node.out(-1)) in h.hugr._links
assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order
assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1
assert len(list(h.hugr.incoming_order_links(nested.root))) == 1
assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0


def test_ancestral_sibling():
h = Dfg.endo([BOOL_T])
(a,) = h.inputs()
nested = h.add_nested([], [BOOL_T])

nt = nested.add(Not(a))
nested.set_outputs(nt)
# TODO a context manager could add this state order edge on
# exit by tracking parents of source nodes
h.add_state_order(h.input_node, nested.root)
h.set_outputs(nested.root)

_validate(h.hugr)
assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root