From a4b4487a511a509a6ca3c1988e06547eee189a42 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 5 Jun 2024 16:49:25 +0100 Subject: [PATCH 1/4] feat: ancestral sibling function --- hugr-py/src/hugr/_dfg.py | 11 +++++++++++ hugr-py/tests/test_hugr_build.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index f4eda9826..803427162 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -76,3 +76,14 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): src = p.out_port() self.hugr.add_link(src, node.inp(i)) + + +def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: + src_parent = h[src].parent + + while (tgt_parent := h[tgt].parent) is not None: + if tgt_parent == src_parent: + return tgt + tgt = tgt_parent + + return None diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 8691e8e08..8ffab4866 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -4,7 +4,7 @@ import os import pathlib from hugr._hugr import Hugr, Node, Wire -from hugr._dfg import Dfg +from hugr._dfg import Dfg, _ancestral_sibling from hugr._ops import Custom, Command import hugr._ops as ops from hugr.serialization import SerialHugr @@ -238,3 +238,13 @@ def test_build_inter_graph(): h.set_outputs(nested.root) _validate(h.hugr) + + +def test_ancestral_sibling(): + h = Dfg.endo([BOOL_T]) + (a,) = h.inputs() + nested = h.add_nested([], [BOOL_T]) + + nt = nested.add(Not(a)) + + assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root From db5a504f4dd1b10eceef4904563d74ade015b011 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 5 Jun 2024 17:24:00 +0100 Subject: [PATCH 2/4] fix: use -1 indexing to add state order edges --- hugr-py/src/hugr/_dfg.py | 5 +---- hugr-py/src/hugr/_hugr.py | 30 +++++++++++++++++++++++------- hugr-py/tests/test_hugr_build.py | 8 ++++---- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 803427162..bf19c3856 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -67,10 +67,7 @@ def set_outputs(self, *args: Wire) -> None: def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges - # breaks if further edges are added - self.hugr.add_link( - src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst)) - ) + self.hugr.add_link(src.out(-1), dst.inp(-1)) def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 21fc1e281..daeb893a4 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -320,19 +320,35 @@ def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> def to_serial(self) -> SerialHugr: node_it = (node for node in self._nodes if node is not None) + + def _serialise_link( + link: tuple[_SO, _SI], + ) -> tuple[tuple[int, int], tuple[int, int]]: + src, dst = link + s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port) + return (src.port.node.idx, s), (dst.port.node.idx, d) + return SerialHugr( version="v1", # non contiguous indices will be erased nodes=[node.to_serial(Node(idx), self) for idx, node in enumerate(node_it)], - edges=[ - ( - (src.port.node.idx, src.port.offset), - (dst.port.node.idx, dst.port.offset), - ) - for src, dst in self._links.items() - ], + edges=[_serialise_link(link) for link in self._links.items()], ) + def _constrain_offset(self, p: P) -> int: + # negative offsets are used to refer to the last port + if p.offset < 0: + match p.direction: + case Direction.INCOMING: + current = self.num_incoming(p.node) + case Direction.OUTGOING: + current = self.num_outgoing(p.node) + offset = current + p.offset + 1 + else: + offset = p.offset + + return offset + @classmethod def from_serial(cls, serial: SerialHugr) -> Hugr: assert serial.nodes, "Empty Hugr is invalid" diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 8ffab4866..db0bbc307 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -226,8 +226,8 @@ def _nested_nop(dfg: Dfg): def test_build_inter_graph(): - h = Dfg.endo([BOOL_T]) - (a,) = h.inputs() + h = Dfg.endo([BOOL_T, BOOL_T]) + (a, b) = h.inputs() nested = h.add_nested([], [BOOL_T]) nt = nested.add(Not(a)) @@ -235,9 +235,9 @@ def test_build_inter_graph(): # TODO a context manager could add this state order edge on # exit by tracking parents of source nodes h.add_state_order(h.input_node, nested.root) - h.set_outputs(nested.root) + h.set_outputs(nested.root, b) - _validate(h.hugr) + _validate(h.hugr, True) def test_ancestral_sibling(): From e36df746c91467c037978ddef7317f7aad7d09d5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 5 Jun 2024 17:36:05 +0100 Subject: [PATCH 3/4] feat: automatically add state order edges for inter-graph edges --- hugr-py/src/hugr/_dfg.py | 6 ++++++ hugr-py/src/hugr/_exceptions.py | 11 +++++++++++ hugr-py/src/hugr/_hugr.py | 1 + hugr-py/tests/test_hugr_build.py | 8 ++++---- 4 files changed, 22 insertions(+), 4 deletions(-) create mode 100644 hugr-py/src/hugr/_exceptions.py diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index bf19c3856..539dfa499 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -4,6 +4,7 @@ from ._hugr import Hugr, Node, Wire, OutPort from ._ops import Op, Command, Input, Output, DFG +from ._exceptions import NoSiblingAncestor from hugr.serialization.tys import FunctionType, Type @@ -72,6 +73,11 @@ def add_state_order(self, src: Node, dst: Node) -> None: def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): src = p.out_port() + node_ancestor = _ancestral_sibling(self.hugr, src.node, node) + if node_ancestor is None: + raise NoSiblingAncestor(src.node.idx, node.idx) + if node_ancestor != src.node: + self.add_state_order(src.node, node_ancestor) self.hugr.add_link(src, node.inp(i)) diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py new file mode 100644 index 000000000..3245af0cc --- /dev/null +++ b/hugr-py/src/hugr/_exceptions.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + + +@dataclass +class NoSiblingAncestor(Exception): + src: int + tgt: int + + @property + def msg(self): + return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index daeb893a4..46ed74d4a 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -269,6 +269,7 @@ def _node_links( return # iterate over known offsets for offset in range(self.num_ports(node, direction)): + # TODO should this also look for -1 state order edges? port = cast(P, node.port(offset, direction)) yield port, list(self._linked_ports(port, links)) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index db0bbc307..d023d0eca 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,7 +3,7 @@ import subprocess import os import pathlib -from hugr._hugr import Hugr, Node, Wire +from hugr._hugr import Hugr, Node, Wire, _SubPort from hugr._dfg import Dfg, _ancestral_sibling from hugr._ops import Custom, Command import hugr._ops as ops @@ -232,11 +232,11 @@ def test_build_inter_graph(): nt = nested.add(Not(a)) nested.set_outputs(nt) - # TODO a context manager could add this state order edge on - # exit by tracking parents of source nodes - h.add_state_order(h.input_node, nested.root) + h.set_outputs(nested.root, b) + assert _SubPort(h.input_node.out(-1)) in h.hugr._links + assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order _validate(h.hugr, True) From 286e4206edcfa172a2f88e567e0885781ac8883d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 6 Jun 2024 14:26:13 +0100 Subject: [PATCH 4/4] feat: add order link iterators and fix detected bug --- hugr-py/src/hugr/_dfg.py | 2 +- hugr-py/src/hugr/_hugr.py | 7 ++++++- hugr-py/tests/test_hugr_build.py | 6 +++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 539dfa499..f083e8ae0 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -76,7 +76,7 @@ def _wire_up(self, node: Node, ports: Iterable[Wire]): node_ancestor = _ancestral_sibling(self.hugr, src.node, node) if node_ancestor is None: raise NoSiblingAncestor(src.node.idx, node.idx) - if node_ancestor != src.node: + if node_ancestor != node: self.add_state_order(src.node, node_ancestor) self.hugr.add_link(src, node.inp(i)) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 46ed74d4a..d42f2edf1 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -260,6 +260,12 @@ def linked_ports(self, port: OutPort | InPort): # TODO: single linked port + def outgoing_order_links(self, node: Node) -> Iterable[Node]: + return (p.node for p in self.linked_ports(node.out(-1))) + + def incoming_order_links(self, node: Node) -> Iterable[Node]: + return (p.node for p in self.linked_ports(node.inp(-1))) + def _node_links( self, node: Node, links: dict[_SubPort[P], _SubPort[K]] ) -> Iterable[tuple[P, list[K]]]: @@ -269,7 +275,6 @@ def _node_links( return # iterate over known offsets for offset in range(self.num_ports(node, direction)): - # TODO should this also look for -1 state order edges? port = cast(P, node.port(offset, direction)) yield port, list(self._linked_ports(port, links)) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index d023d0eca..52f7a2b07 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -235,9 +235,13 @@ def test_build_inter_graph(): h.set_outputs(nested.root, b) + _validate(h.hugr, True) + assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order - _validate(h.hugr, True) + assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1 + assert len(list(h.hugr.incoming_order_links(nested.root))) == 1 + assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0 def test_ancestral_sibling():