From 3bba01a7a7d1c2df845281734e1b2d5905bc3f6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:17:36 +0000 Subject: [PATCH] feat!: Replace GATs with `impl Iterator` returns (RPITIT) on `HugrView` (#1660) This will simplify the update to the next portgraph release (which also moved to RPITIT), avoiding the need to use boxes around iterators to be able to name them. - drops the dependency on `context-iterators` - drive-by: Cleanup the implementations of `SiblingGraph` / `SiblingMut` to avoid `collect_vec().into_iter()`s. BREAKING CHANGE: Removed `HugrView` associated iterator types, replaced with `impl Iterator` returns. --- Cargo.toml | 1 - hugr-core/Cargo.toml | 1 - hugr-core/src/export.rs | 22 +- hugr-core/src/extension/op_def.rs | 1 + hugr-core/src/hugr/rewrite/inline_dfg.rs | 12 +- hugr-core/src/hugr/rewrite/outline_cfg.rs | 4 +- hugr-core/src/hugr/views.rs | 143 +++------- hugr-core/src/hugr/views/descendants.rs | 112 ++++---- hugr-core/src/hugr/views/petgraph.rs | 60 +++- hugr-core/src/hugr/views/sibling.rs | 272 +++++++++++-------- hugr-core/src/hugr/views/sibling_subgraph.rs | 27 +- hugr-core/src/hugr/views/tests.rs | 9 +- hugr-core/src/import.rs | 10 +- hugr-passes/src/half_node.rs | 5 +- hugr-passes/src/merge_bbs.rs | 2 +- hugr-passes/src/nest_cfgs.rs | 21 +- 16 files changed, 366 insertions(+), 336 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e046c799f..36ef0526b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ portgraph = { version = "0.12.2" } insta = { version = "1.34.0" } bitvec = "1.0.1" cgmath = "0.18.0" -context-iterators = "0.2.0" cool_asserts = "2.0.3" criterion = "0.5.1" delegate = "0.13.0" diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 3dab5e709..84b83677f 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -47,7 +47,6 @@ bitvec = { workspace = true, features = ["serde"] } enum_dispatch = { workspace = true } lazy_static = { workspace = true } petgraph = { workspace = true } -context-iterators = { workspace = true } serde_json = { workspace = true } delegate = { workspace = true } paste = { workspace = true } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 68f3a15c0..32460ce5d 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -69,7 +69,7 @@ impl<'a> Context<'a> { /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { let hugr_children = self.hugr.children(self.hugr.root()); - let mut children = Vec::with_capacity(hugr_children.len()); + let mut children = Vec::with_capacity(hugr_children.size_hint().0); for child in self.hugr.children(self.hugr.root()) { children.push(self.export_node(child)); @@ -110,7 +110,7 @@ impl<'a> Context<'a> { num_ports: usize, ) -> &'a [model::LinkRef<'a>] { let ports = self.hugr.node_ports(node, direction); - let mut links = BumpVec::with_capacity_in(ports.len(), self.bump); + let mut links = BumpVec::with_capacity_in(ports.size_hint().0, self.bump); for port in ports.take(num_ports) { links.push(model::LinkRef::Id(self.get_link_id(node, port))); @@ -579,7 +579,7 @@ impl<'a> Context<'a> { let targets = self.make_ports(output_node, Direction::Incoming, output_op.types.len()); // Export the remaining children of the node. - let mut region_children = BumpVec::with_capacity_in(children.len(), self.bump); + let mut region_children = BumpVec::with_capacity_in(children.size_hint().0, self.bump); for child in children { region_children.push(self.export_node(child)); @@ -609,7 +609,7 @@ impl<'a> Context<'a> { /// Creates a control flow region from the given node's children. pub fn export_cfg(&mut self, node: Node) -> model::RegionId { let mut children = self.hugr.children(node); - let mut region_children = BumpVec::with_capacity_in(children.len() + 1, self.bump); + let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 + 1, self.bump); // The first child is the entry block. // We create a source port on the control flow region and connect it to the @@ -623,16 +623,16 @@ impl<'a> Context<'a> { let source = model::LinkRef::Id(self.get_link_id(entry_block, IncomingPort::from(0))); region_children.push(self.export_node(entry_block)); - // Export the remaining children of the node, except for the last one. - for _ in 0..children.len() - 1 { - region_children.push(self.export_node(children.next().unwrap())); - } - // The last child is the exit block. // Contrary to the entry block, the exit block does not have a dataflow subgraph. // We therefore do not export the block itself, but simply use its output ports // as the target ports of the control flow region. - let exit_block = children.next().unwrap(); + let exit_block = children.next_back().unwrap(); + + // Export the remaining children of the node, except for the last one. + for child in children { + region_children.push(self.export_node(child)); + } let OpType::ExitBlock(_) = self.hugr.get_optype(exit_block) else { panic!("expected an `ExitBlock` node as the last child node"); @@ -657,7 +657,7 @@ impl<'a> Context<'a> { /// Export the `Case` node children of a `Conditional` node as data flow regions. pub fn export_conditional_regions(&mut self, node: Node) -> &'a [model::RegionId] { let children = self.hugr.children(node); - let mut regions = BumpVec::with_capacity_in(children.len(), self.bump); + let mut regions = BumpVec::with_capacity_in(children.size_hint().0, self.bump); for child in children { let OpType::Case(case_op) = self.hugr.get_optype(child) else { diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 110da4124..b975018c6 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -436,6 +436,7 @@ impl OpDef { } /// Iterate over all miscellaneous data in the [OpDef]. + #[allow(unused)] // Unused when no features are enabled pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator { self.misc.iter().map(|(k, v)| (k.as_str(), v)) } diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index ca3f39cd3..ff400a1d3 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -208,7 +208,7 @@ mod test { // Sanity checks assert_eq!( - outer.children(inner.node()).len(), + outer.children(inner.node()).count(), if nonlocal { 3 } else { 6 } ); // Input, Output, add; + const, load_const, lift assert_eq!(find_dfgs(&outer), vec![outer.root(), inner.node()]); @@ -217,7 +217,7 @@ mod test { outer.get_parent(outer.get_parent(add).unwrap()), outer.get_parent(sub) ); - assert_eq!(outer.nodes().len(), 11); // 6 above + inner DFG + outer (DFG + Input + Output + sub) + assert_eq!(outer.nodes().count(), 11); // 6 above + inner DFG + outer (DFG + Input + Output + sub) { // Check we can't inline the outer DFG let mut h = outer.clone(); @@ -230,7 +230,7 @@ mod test { outer.apply_rewrite(InlineDFG(*inner.handle()))?; outer.validate(®)?; - assert_eq!(outer.nodes().len(), 8); + assert_eq!(outer.nodes().count(), 8); assert_eq!(find_dfgs(&outer), vec![outer.root()]); let [_lift, add, sub] = extension_ops(&outer).try_into().unwrap(); assert_eq!(outer.get_parent(add), Some(outer.root())); @@ -265,8 +265,8 @@ mod test { let mut h = h.finish_hugr_with_outputs(cx.outputs(), ®)?; assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]); - assert_eq!(h.nodes().len(), 8); // Dfg+I+O, H, CX, Dfg+I+O - // No permutation outside the swap DFG: + assert_eq!(h.nodes().count(), 8); // Dfg+I+O, H, CX, Dfg+I+O + // No permutation outside the swap DFG: assert_eq!( h.node_connections(p_h.node(), swap.node()) .collect::>(), @@ -292,7 +292,7 @@ mod test { h.apply_rewrite(InlineDFG(*swap.handle()))?; assert_eq!(find_dfgs(&h), vec![h.root()]); - assert_eq!(h.nodes().len(), 5); // Dfg+I+O + assert_eq!(h.nodes().count(), 5); // Dfg+I+O let mut ops = extension_ops(&h); ops.sort_by_key(|n| h.num_outputs(*n)); // Put H before CX let [h_gate, cx] = ops.try_into().unwrap(); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index c15e3660d..7dd181f92 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -453,8 +453,8 @@ mod test { // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually: let cfg = cfg.node(); let exit_node = h.children(cfg).nth(1).unwrap(); - let tail = h.input_neighbours(exit_node).exactly_one().unwrap(); - let head = h.input_neighbours(tail).exactly_one().unwrap(); + let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap(); + let head = h.input_neighbours(tail).exactly_one().ok().unwrap(); // Just sanity-check we have the correct nodes assert!(h.get_optype(exit_node).is_exit_block()); assert_eq!( diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index d17eaf44f..7d744c150 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -10,8 +10,6 @@ pub mod sibling_subgraph; #[cfg(test)] mod tests; -use std::iter::Map; - pub use self::petgraph::PetgraphWrapper; use self::render::RenderConfig; pub use descendants::DescendantsGraph; @@ -19,10 +17,9 @@ pub use root_checked::RootChecked; pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; +use itertools::Itertools; use portgraph::render::{DotFormat, MermaidFormat}; -use portgraph::{multiportgraph, LinkView, PortView}; +use portgraph::{LinkView, PortView}; use super::internal::HugrInternals; use super::{ @@ -40,36 +37,6 @@ use itertools::Either; /// A trait for inspecting HUGRs. /// For end users we intend this to be superseded by region-specific APIs. pub trait HugrView: HugrInternals { - /// An Iterator over the nodes in a Hugr(View) - type Nodes<'a>: Iterator - where - Self: 'a; - - /// An Iterator over (some or all) ports of a node - type NodePorts<'a>: Iterator - where - Self: 'a; - - /// An Iterator over the children of a node - type Children<'a>: Iterator - where - Self: 'a; - - /// An Iterator over (some or all) the nodes neighbouring a node - type Neighbours<'a>: Iterator - where - Self: 'a; - - /// Iterator over the children of a node - type PortLinks<'a>: Iterator - where - Self: 'a; - - /// Iterator over the links between two nodes. - type NodeConnections<'a>: Iterator - where - Self: 'a; - /// Return the root node of this view. #[inline] fn root(&self) -> Node { @@ -147,16 +114,16 @@ pub trait HugrView: HugrInternals { fn edge_count(&self) -> usize; /// Iterates over the nodes in the port graph. - fn nodes(&self) -> Self::Nodes<'_>; + fn nodes(&self) -> impl Iterator + Clone; /// Iterator over ports of node in a given direction. - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_>; + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone; /// Iterator over output ports of node. /// Like [`node_ports`][HugrView::node_ports]`(node, Direction::Outgoing)` /// but preserves knowledge that the ports are [OutgoingPort]s. #[inline] - fn node_outputs(&self, node: Node) -> OutgoingPorts> { + fn node_outputs(&self, node: Node) -> impl Iterator + Clone { self.node_ports(node, Direction::Outgoing) .map(|p| p.as_outgoing().unwrap()) } @@ -165,16 +132,20 @@ pub trait HugrView: HugrInternals { /// Like [`node_ports`][HugrView::node_ports]`(node, Direction::Incoming)` /// but preserves knowledge that the ports are [IncomingPort]s. #[inline] - fn node_inputs(&self, node: Node) -> IncomingPorts> { + fn node_inputs(&self, node: Node) -> impl Iterator + Clone { self.node_ports(node, Direction::Incoming) .map(|p| p.as_incoming().unwrap()) } /// Iterator over both the input and output ports of node. - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_>; + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone; /// Iterator over the nodes and ports connected to a port. - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_>; + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone; /// Iterator over all the nodes and ports connected to a node in a given direction. fn all_linked_ports( @@ -245,7 +216,7 @@ pub trait HugrView: HugrInternals { &self, node: Node, port: impl Into, - ) -> OutgoingNodePorts> { + ) -> impl Iterator { self.linked_ports(node, port.into()) .map(|(n, p)| (n, p.as_outgoing().unwrap())) } @@ -257,13 +228,13 @@ pub trait HugrView: HugrInternals { &self, node: Node, port: impl Into, - ) -> IncomingNodePorts> { + ) -> impl Iterator { self.linked_ports(node, port.into()) .map(|(n, p)| (n, p.as_incoming().unwrap())) } /// Iterator the links between two nodes. - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_>; + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone; /// Returns whether a port is connected. fn is_linked(&self, node: Node, port: impl Into) -> bool { @@ -288,28 +259,28 @@ pub trait HugrView: HugrInternals { } /// Return iterator over the direct children of node. - fn children(&self, node: Node) -> Self::Children<'_>; + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone; /// Iterates over neighbour nodes in the given direction. /// May contain duplicates if the graph has multiple links between nodes. - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_>; + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone; /// Iterates over the input neighbours of the `node`. /// Shorthand for [`neighbours`][HugrView::neighbours]`(node, Direction::Incoming)`. #[inline] - fn input_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn input_neighbours(&self, node: Node) -> impl Iterator + Clone { self.neighbours(node, Direction::Incoming) } /// Iterates over the output neighbours of the `node`. /// Shorthand for [`neighbours`][HugrView::neighbours]`(node, Direction::Outgoing)`. #[inline] - fn output_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn output_neighbours(&self, node: Node) -> impl Iterator + Clone { self.neighbours(node, Direction::Outgoing) } /// Iterates over the input and output neighbours of the `node` in sequence. - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_>; + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone; /// Get the input and output child nodes of a dataflow parent. /// If the node isn't a dataflow parent, then return None @@ -469,18 +440,6 @@ pub trait HugrView: HugrInternals { } } -/// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s -pub type OutgoingPorts = Map OutgoingPort>; - -/// Wraps an iterator over [Port]s that are known to be [IncomingPort]s -pub type IncomingPorts = Map IncomingPort>; - -/// Wraps an iterator over `(`[`Node`],[`Port`]`)` when the ports are known to be [OutgoingPort]s -pub type OutgoingNodePorts = Map (Node, OutgoingPort)>; - -/// Wraps an iterator over `(`[`Node`],[`Port`]`)` when the ports are known to be [IncomingPort]s -pub type IncomingNodePorts = Map (Node, IncomingPort)>; - /// Trait for views that provides a guaranteed bound on the type of the root node. pub trait RootTagged: HugrView { /// The kind of handle that can be used to refer to the root node. @@ -555,25 +514,6 @@ impl ExtractHugr for &mut Hugr { } impl> HugrView for T { - /// An Iterator over the nodes in a Hugr(View) - type Nodes<'a> = MapInto, Node> where Self: 'a; - - /// An Iterator over (some or all) ports of a node - type NodePorts<'a> = MapInto where Self: 'a; - - /// An Iterator over the children of a node - type Children<'a> = MapInto, Node> where Self: 'a; - - /// An Iterator over (some or all) the nodes neighbouring a node - type Neighbours<'a> = MapInto, Node> where Self: 'a; - - /// Iterator over the children of a node - type PortLinks<'a> = MapWithCtx, &'a Hugr, (Node, Port)> - where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx,&'a Hugr, [Port; 2]> where Self: 'a; - #[inline] fn contains_node(&self, node: Node) -> bool { self.as_ref().graph.contains_node(node.pg_index()) @@ -590,12 +530,12 @@ impl> HugrView for T { } #[inline] - fn nodes(&self) -> Self::Nodes<'_> { + fn nodes(&self) -> impl Iterator + Clone { self.as_ref().graph.nodes_iter().map_into() } #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.as_ref() .graph .port_offsets(node.pg_index(), dir) @@ -603,7 +543,7 @@ impl> HugrView for T { } #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.as_ref() .graph .all_port_offsets(node.pg_index()) @@ -611,36 +551,33 @@ impl> HugrView for T { } #[inline] - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { let port = port.into(); let hugr = self.as_ref(); let port = hugr .graph .port_index(node.pg_index(), port.pg_offset()) .unwrap(); - hugr.graph - .port_links(port) - .with_context(hugr) - .map_with_context(|(_, link), hugr| { - let port = link.port(); - let node = hugr.graph.port_node(port).unwrap(); - let offset = hugr.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) + hugr.graph.port_links(port).map(|(_, link)| { + let port = link.port(); + let node = hugr.graph.port_node(port).unwrap(); + let offset = hugr.graph.port_offset(port).unwrap(); + (node.into(), offset.into()) + }) } #[inline] - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { let hugr = self.as_ref(); hugr.graph .get_connections(node.pg_index(), other.pg_index()) - .with_context(hugr) - .map_with_context(|(p1, p2), hugr| { - [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link.port()).unwrap(); - offset.into() - }) + .map(|(p1, p2)| { + [p1, p2].map(|link| hugr.graph.port_offset(link.port()).unwrap().into()) }) } @@ -650,12 +587,12 @@ impl> HugrView for T { } #[inline] - fn children(&self, node: Node) -> Self::Children<'_> { + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { self.as_ref().hierarchy.children(node.pg_index()).map_into() } #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.as_ref() .graph .neighbours(node.pg_index(), dir) @@ -663,7 +600,7 @@ impl> HugrView for T { } #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.as_ref() .graph .all_neighbours(node.pg_index()) diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 83a7d6687..61db536ef 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -1,8 +1,7 @@ //! DescendantsGraph: view onto the subgraph of the HUGR starting from a root //! (all descendants at all depths). -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; +use itertools::Itertools; use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; use crate::hugr::HugrError; @@ -40,36 +39,6 @@ pub struct DescendantsGraph<'g, Root = Node> { _phantom: std::marker::PhantomData, } impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { - type Nodes<'a> = MapInto< as PortView>::Nodes<'a>, Node> - where - Self: 'a; - - type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> - where - Self: 'a; - - type Children<'a> = MapInto, Node> - where - Self: 'a; - - type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> - where - Self: 'a; - - type PortLinks<'a> = MapWithCtx< - as LinkView>::PortLinks<'a>, - &'a Self, - (Node, Port), - > where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx< - as LinkView>::NodeConnections<'a>, - &'a Self, - [Port; 2], - > where - Self: 'a; - #[inline] fn contains_node(&self, node: Node) -> bool { self.graph.contains_node(node.pg_index()) @@ -86,43 +55,43 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { } #[inline] - fn nodes(&self) -> Self::Nodes<'_> { + fn nodes(&self) -> impl Iterator + Clone { self.graph.nodes_iter().map_into() } #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.port_offsets(node.pg_index(), dir).map_into() } #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph.all_port_offsets(node.pg_index()).map_into() } - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { let port = self .graph .port_index(node.pg_index(), port.into().pg_offset()) .unwrap(); - self.graph - .port_links(port) - .with_context(self) - .map_with_context(|(_, link), region| { - let port: PortIndex = link.into(); - let node = region.graph.port_node(port).unwrap(); - let offset = region.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) + self.graph.port_links(port).map(|(_, link)| { + let port: PortIndex = link.into(); + let node = self.graph.port_node(port).unwrap(); + let offset = self.graph.port_offset(port).unwrap(); + (node.into(), offset.into()) + }) } - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph .get_connections(node.pg_index(), other.pg_index()) - .with_context(self) - .map_with_context(|(p1, p2), hugr| { + .map(|(p1, p2)| { [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link).unwrap(); + let offset = self.graph.port_offset(link).unwrap(); offset.into() }) }) @@ -134,7 +103,7 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { } #[inline] - fn children(&self, node: Node) -> Self::Children<'_> { + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { match self.graph.contains_node(node.pg_index()) { true => self .base_hugr() @@ -146,12 +115,12 @@ impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> { } #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.neighbours(node.pg_index(), dir).map_into() } #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph.all_neighbours(node.pg_index()).map_into() } } @@ -204,6 +173,7 @@ pub(super) mod test { use rstest::rstest; use crate::extension::PRELUDE_REGISTRY; + use crate::IncomingPort; use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, type_row, @@ -253,6 +223,7 @@ pub(super) mod test { let (hugr, def, inner) = make_module_hgr()?; let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; + let def_io = region.get_io(def).unwrap(); assert_eq!(region.node_count(), 7); assert!(region.nodes().all(|n| n == def @@ -268,11 +239,48 @@ pub(super) mod test { .into() ) ); + let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( inner_region.inner_function_type(), Some(Signature::new(type_row![NAT], type_row![NAT])) ); + assert_eq!(inner_region.node_count(), 3); + assert_eq!(inner_region.edge_count(), 2); + assert_eq!(inner_region.children(inner).count(), 2); + assert_eq!(inner_region.children(hugr.root()).count(), 0); + assert_eq!( + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.node_ports(inner, Direction::Outgoing).count() + ); + assert_eq!( + inner_region.num_ports(inner, Direction::Incoming) + + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.all_node_ports(inner).count() + ); + + // The inner region filters out the connections to the main function I/O nodes, + // while the outer region includes them. + assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0); + assert_eq!(region.node_connections(inner, def_io[1]).count(), 1); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); + assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1); + assert_eq!( + inner_region.neighbours(inner, Direction::Outgoing).count(), + 0 + ); + assert_eq!(inner_region.all_neighbours(inner).count(), 0); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); Ok(()) } diff --git a/hugr-core/src/hugr/views/petgraph.rs b/hugr-core/src/hugr/views/petgraph.rs index 9ae2a2331..0f909b332 100644 --- a/hugr-core/src/hugr/views/petgraph.rs +++ b/hugr-core/src/hugr/views/petgraph.rs @@ -6,7 +6,6 @@ use crate::types::EdgeKind; use crate::NodeIndex; use crate::{Node, Port}; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; use petgraph::visit as pv; /// Wrapper for a HugrView that implements petgraph's traits. @@ -99,13 +98,14 @@ where T: HugrView, { type NodeRef = HugrNodeRef<'a>; - type NodeReferences = MapWithCtx<::Nodes<'a>, Self, HugrNodeRef<'a>>; + type NodeReferences = Box> + 'a>; fn node_references(self) -> Self::NodeReferences { - self.hugr - .nodes() - .with_context(self) - .map_with_context(|n, &wrapper| HugrNodeRef::from_node(n, wrapper.hugr)) + Box::new( + self.hugr + .nodes() + .map(|n| HugrNodeRef::from_node(n, self.hugr)), + ) } } @@ -113,10 +113,10 @@ impl<'a, T> pv::IntoNodeIdentifiers for PetgraphWrapper<'a, T> where T: HugrView, { - type NodeIdentifiers = ::Nodes<'a>; + type NodeIdentifiers = Box + 'a>; fn node_identifiers(self) -> Self::NodeIdentifiers { - self.hugr.nodes() + Box::new(self.hugr.nodes()) } } @@ -124,10 +124,10 @@ impl<'a, T> pv::IntoNeighbors for PetgraphWrapper<'a, T> where T: HugrView, { - type Neighbors = ::Neighbours<'a>; + type Neighbors = Box + 'a>; fn neighbors(self, n: Self::NodeId) -> Self::Neighbors { - self.hugr.output_neighbours(n) + Box::new(self.hugr.output_neighbours(n)) } } @@ -135,14 +135,14 @@ impl<'a, T> pv::IntoNeighborsDirected for PetgraphWrapper<'a, T> where T: HugrView, { - type NeighborsDirected = ::Neighbours<'a>; + type NeighborsDirected = Box + 'a>; fn neighbors_directed( self, n: Self::NodeId, d: petgraph::Direction, ) -> Self::NeighborsDirected { - self.hugr.neighbours(n, d.into()) + Box::new(self.hugr.neighbours(n, d.into())) } } @@ -211,3 +211,39 @@ impl pv::NodeRef for HugrNodeRef<'_> { self.op } } + +#[cfg(test)] +mod test { + use petgraph::visit::{ + EdgeCount, GetAdjacencyMatrix, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef, + }; + + use crate::hugr::views::tests::sample_hugr; + use crate::ops::handle::NodeHandle; + use crate::HugrView; + + use super::PetgraphWrapper; + + #[test] + fn test_petgraph_wrapper() { + let (hugr, cx1, cx2) = sample_hugr(); + let wrapper = PetgraphWrapper::from(&hugr); + + assert_eq!(wrapper.node_count(), 5); + assert_eq!(wrapper.node_bound(), 5); + assert_eq!(wrapper.edge_count(), 7); + + let cx1_index = cx1.node().pg_index().index(); + assert_eq!(wrapper.to_index(cx1.node()), cx1_index); + assert_eq!(wrapper.from_index(cx1_index), cx1.node()); + + let cx1_ref = wrapper + .node_references() + .find(|n| n.id() == cx1.node()) + .unwrap(); + assert_eq!(cx1_ref.weight(), hugr.get_optype(cx1.node())); + + let adj = wrapper.adjacency_matrix(); + assert!(wrapper.is_adjacent(&adj, cx1.node(), cx2.node())); + } +} diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 1125bad25..07aaba1da 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -2,9 +2,8 @@ use std::iter; -use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx}; -use itertools::{Itertools, MapInto}; -use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView}; +use itertools::{Either, Itertools}; +use portgraph::{LinkView, MultiPortGraph, PortView}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::{HugrError, HugrMut}; @@ -48,19 +47,6 @@ pub struct SiblingGraph<'g, Root = Node> { /// i.e. that rely only on [HugrInternals::base_hugr] macro_rules! impl_base_members { () => { - - type Nodes<'a> = iter::Chain, MapInto, Node>> - where - Self: 'a; - - type NodePorts<'a> = MapInto< as PortView>::NodePortOffsets<'a>, Port> - where - Self: 'a; - - type Children<'a> = MapInto, Node> - where - Self: 'a; - #[inline] fn node_count(&self) -> usize { self.base_hugr().hierarchy.child_count(self.root.pg_index()) + 1 @@ -75,7 +61,7 @@ macro_rules! impl_base_members { } #[inline] - fn nodes(&self) -> Self::Nodes<'_> { + fn nodes(&self) -> impl Iterator + Clone { // Faster implementation than filtering all the nodes in the internal graph. let children = self .base_hugr() @@ -85,10 +71,14 @@ macro_rules! impl_base_members { iter::once(self.root).chain(children) } - fn children(&self, node: Node) -> Self::Children<'_> { + fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { // Same as SiblingGraph match node == self.root { - true => self.base_hugr().hierarchy.children(node.pg_index()).map_into(), + true => self + .base_hugr() + .hierarchy + .children(node.pg_index()) + .map_into(), false => portgraph::hierarchy::Children::default().map_into(), } } @@ -96,24 +86,6 @@ macro_rules! impl_base_members { } impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { - type Neighbours<'a> = MapInto< as LinkView>::Neighbours<'a>, Node> - where - Self: 'a; - - type PortLinks<'a> = MapWithCtx< - as LinkView>::PortLinks<'a>, - &'a Self, - (Node, Port), - > where - Self: 'a; - - type NodeConnections<'a> = MapWithCtx< - as LinkView>::NodeConnections<'a>, - &'a Self, - [Port; 2], - > where - Self: 'a; - impl_base_members! {} #[inline] @@ -122,41 +94,35 @@ impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { } #[inline] - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.port_offsets(node.pg_index(), dir).map_into() } #[inline] - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph.all_port_offsets(node.pg_index()).map_into() } - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { let port = self .graph .port_index(node.pg_index(), port.into().pg_offset()) .unwrap(); - self.graph - .port_links(port) - .with_context(self) - .map_with_context(|(_, link), region| { - let port: PortIndex = link.into(); - let node = region.graph.port_node(port).unwrap(); - let offset = region.graph.port_offset(port).unwrap(); - (node.into(), offset.into()) - }) + self.graph.port_links(port).map(|(_, link)| { + let node = self.graph.port_node(link).unwrap(); + let offset = self.graph.port_offset(link).unwrap(); + (node.into(), offset.into()) + }) } - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph .get_connections(node.pg_index(), other.pg_index()) - .with_context(self) - .map_with_context(|(p1, p2), hugr| { - [p1, p2].map(|link| { - let offset = hugr.graph.port_offset(link).unwrap(); - offset.into() - }) - }) + .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into())) } #[inline] @@ -165,12 +131,12 @@ impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> { } #[inline] - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph.neighbours(node.pg_index(), dir).map_into() } #[inline] - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph.all_neighbours(node.pg_index()).map_into() } } @@ -293,16 +259,6 @@ impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> { } impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> { - type Neighbours<'a> = as IntoIterator>::IntoIter - where - Self: 'a; - - type PortLinks<'a> = as IntoIterator>::IntoIter - where - Self: 'a; - - type NodeConnections<'a> = as IntoIterator>::IntoIter where Self: 'a; - impl_base_members! {} fn contains_node(&self, node: Node) -> bool { @@ -311,61 +267,54 @@ impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> { node == self.root || self.base_hugr().get_parent(node) == Some(self.root) } - fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> { - match self.contains_node(node) { - true => self.base_hugr().node_ports(node, dir), - false => ::NodePortOffsets::default().map_into(), - } + fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + self.base_hugr().node_ports(node, dir) } - fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> { - match self.contains_node(node) { - true => self.base_hugr().all_node_ports(node), - false => ::NodePortOffsets::default().map_into(), - } + fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { + self.base_hugr().all_node_ports(node) } - fn linked_ports(&self, node: Node, port: impl Into) -> Self::PortLinks<'_> { - // Need to filter only to links inside the sibling graph - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) + fn linked_ports( + &self, + node: Node, + port: impl Into, + ) -> impl Iterator + Clone { + self.hugr .linked_ports(node, port) - .collect::>() - .into_iter() + .filter(|(n, _)| self.contains_node(*n)) } - fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> { - // Need to filter only to connections inside the sibling graph - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) - .node_connections(node, other) - .collect::>() - .into_iter() + fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { + match self.contains_node(node) && self.contains_node(other) { + // The nodes are not in the sibling graph + false => Either::Left(iter::empty()), + // The nodes are in the sibling graph + true => Either::Right(self.hugr.node_connections(node, other)), + } } fn num_ports(&self, node: Node, dir: Direction) -> usize { - match self.contains_node(node) { - true => self.base_hugr().num_ports(node, dir), - false => 0, - } + self.base_hugr().num_ports(node, dir) } - fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> { - // Need to filter to neighbours in the Sibling Graph - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) + fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + self.hugr .neighbours(node, dir) - .collect::>() - .into_iter() + .filter(|n| self.contains_node(*n)) } - fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> { - SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root) + fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { + self.hugr .all_neighbours(node) - .collect::>() - .into_iter() + .filter(|n| self.contains_node(*n)) } } + impl<'g, Root: NodeHandle> RootTagged for SiblingMut<'g, Root> { type RootHandle = Root; } + impl<'g, Root: NodeHandle> HugrMutInternals for SiblingMut<'g, Root> { fn hugr_mut(&mut self) -> &mut Hugr { self.hugr @@ -384,28 +333,116 @@ mod test { use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::ops::{OpTrait, OpType}; - use crate::type_row; use crate::types::{Signature, Type}; + use crate::utils::test_quantum_extension::EXTENSION_ID; + use crate::{type_row, IncomingPort}; + + const NAT: Type = crate::extension::prelude::USIZE_T; + const QB: Type = crate::extension::prelude::QB_T; use super::super::descendants::test::make_module_hgr; use super::*; - #[test] - fn flat_region() -> Result<(), Box> { - let (hugr, def, inner) = make_module_hgr()?; - - let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; + fn test_properties( + hugr: &Hugr, + def: Node, + inner: Node, + region: T, + inner_region: T, + ) -> Result<(), Box> + where + T: HugrView + Sized, + { + let def_io = region.get_io(def).unwrap(); assert_eq!(region.node_count(), 5); - assert!(region - .nodes() - .all(|n| n == def || hugr.get_parent(n) == Some(def))); + assert_eq!(region.portgraph().node_count(), 5); + assert!(region.nodes().all(|n| n == def + || hugr.get_parent(n) == Some(def) + || hugr.get_parent(n) == Some(inner))); assert_eq!(region.children(inner).count(), 0); + assert_eq!( + region.poly_func_type(), + Some( + Signature::new_endo(type_row![NAT, QB]) + .with_extension_delta(EXTENSION_ID) + .into() + ) + ); + + assert_eq!( + inner_region.inner_function_type(), + Some(Signature::new(type_row![NAT], type_row![NAT])) + ); + assert_eq!(inner_region.node_count(), 3); + assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.children(inner).count(), 2); + assert_eq!(inner_region.children(hugr.root()).count(), 0); + assert_eq!( + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.node_ports(inner, Direction::Outgoing).count() + ); + assert_eq!( + inner_region.num_ports(inner, Direction::Incoming) + + inner_region.num_ports(inner, Direction::Outgoing), + inner_region.all_node_ports(inner).count() + ); + + // The inner region filters out the connections to the main function I/O nodes, + // while the outer region includes them. + assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0); + assert_eq!(region.node_connections(inner, def_io[1]).count(), 1); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); + assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1); + assert_eq!( + inner_region.neighbours(inner, Direction::Outgoing).count(), + 0 + ); + assert_eq!(inner_region.all_neighbours(inner).count(), 0); + assert_eq!( + inner_region + .linked_ports(inner, IncomingPort::from(0)) + .count(), + 0 + ); + Ok(()) } - const NAT: Type = crate::extension::prelude::USIZE_T; + #[rstest] + fn sibling_graph_properties() -> Result<(), Box> { + let (hugr, def, inner) = make_module_hgr()?; + + test_properties::( + &hugr, + def, + inner, + SiblingGraph::try_new(&hugr, def).unwrap(), + SiblingGraph::try_new(&hugr, inner).unwrap(), + ) + } + + #[rstest] + fn sibling_mut_properties() -> Result<(), Box> { + let (hugr, def, inner) = make_module_hgr()?; + let mut def_region_hugr = hugr.clone(); + let mut inner_region_hugr = hugr.clone(); + + test_properties::( + &hugr, + def, + inner, + SiblingMut::try_new(&mut def_region_hugr, def).unwrap(), + SiblingMut::try_new(&mut inner_region_hugr, inner).unwrap(), + ) + } + #[test] fn nested_flat() -> Result<(), Box> { let mut module_builder = ModuleBuilder::new(); @@ -417,11 +454,13 @@ mod test { let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?; let sub_dfg = sub_dfg.node(); - // Can create a view from a child or grandchild of a hugr: + + // We can create a view from a child or grandchild of a hugr: let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&h, sub_dfg)?; let fun_view: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&h, fun.node())?; - assert_eq!(fun_view.children(sub_dfg).len(), 0); - // And can create a view from a child of another SiblingGraph + assert_eq!(fun_view.children(sub_dfg).count(), 0); + + // And also create a view from a child of another SiblingGraph let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&fun_view, sub_dfg)?; // Both ways work: @@ -439,6 +478,7 @@ mod test { Ok(()) } + /// Mutate a SiblingMut wrapper #[rstest] fn flat_mut(mut simple_dfg_hugr: Hugr) { simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 2d4c2c2bf..4fcebe179 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -457,15 +457,21 @@ impl SiblingSubgraph { // Connect the inserted nodes in-between the input and output nodes. let [inp, out] = extracted.get_io(extracted.root()).unwrap(); - for (inp_port, repl_ports) in extracted.node_outputs(inp).zip(self.inputs.iter()) { + let inputs = extracted.node_outputs(inp).zip(self.inputs.iter()); + let outputs = extracted.node_inputs(out).zip(self.outputs.iter()); + let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0); + + for (inp_port, repl_ports) in inputs { for (repl_node, repl_port) in repl_ports { - extracted.connect(inp, inp_port, node_map[repl_node], *repl_port); + connections.push((inp, inp_port, node_map[repl_node], *repl_port)); } } - for (out_port, (repl_node, repl_port)) in - extracted.node_inputs(out).zip(self.outputs.iter()) - { - extracted.connect(node_map[repl_node], *repl_port, out, out_port); + for (out_port, (repl_node, repl_port)) in outputs { + connections.push((node_map[repl_node], *repl_port, out, out_port)); + } + + for (src, src_port, dst, dst_port) in connections { + extracted.connect(src, src_port, dst, dst_port); } extracted @@ -1063,9 +1069,9 @@ mod tests { let (hugr, func_root) = build_3not_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let [inp, _out] = hugr.get_io(func_root).unwrap(); - let not1 = hugr.output_neighbours(inp).exactly_one().unwrap(); - let not2 = hugr.output_neighbours(not1).exactly_one().unwrap(); - let not3 = hugr.output_neighbours(not2).exactly_one().unwrap(); + let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap(); + let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap(); + let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap(); let not1_inp = hugr.node_inputs(not1).next().unwrap(); let not1_out = hugr.node_outputs(not1).next().unwrap(); let not3_inp = hugr.node_inputs(not3).next().unwrap(); @@ -1086,11 +1092,12 @@ mod tests { fn convex_multiports() { let (hugr, func_root) = build_multiport_hugr().unwrap(); let [inp, out] = hugr.get_io(func_root).unwrap(); - let not1 = hugr.output_neighbours(inp).exactly_one().unwrap(); + let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap(); let not2 = hugr .output_neighbours(not1) .filter(|&n| n != out) .exactly_one() + .ok() .unwrap(); let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap(); diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index c52957b2c..c958c8b9f 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -17,8 +17,11 @@ use crate::{ Hugr, HugrView, }; +/// A Dataflow graph from two qubits to two qubits that applies two CX operations on them. +/// +/// Returns the Hugr and the two CX node ids. #[fixture] -fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle) { +pub(crate) fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle) { let mut dfg = DFGBuilder::new(endo_sig(type_row![QB_T, QB_T])).unwrap(); let [q1, q2] = dfg.input_wires_arr(); @@ -99,6 +102,7 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle, BuildHandle Context<'a> { // Connect the input node to the tag node let input_outputs = self.hugr.node_outputs(node_input); let tag_inputs = self.hugr.node_inputs(node_tag); + let mut connections = + Vec::with_capacity(input_outputs.size_hint().0 + tag_inputs.size_hint().0); for (a, b) in input_outputs.zip(tag_inputs) { - self.hugr.connect(node_input, a, node_tag, b); + connections.push((node_input, a, node_tag, b)); } // Connect the tag node to the output node @@ -820,7 +822,11 @@ impl<'a> Context<'a> { let output_inputs = self.hugr.node_inputs(node_output); for (a, b) in tag_outputs.zip(output_inputs) { - self.hugr.connect(node_tag, a, node_output, b); + connections.push((node_tag, a, node_output, b)); + } + + for (src, src_port, dst, dst_port) in connections { + self.hugr.connect(src, src_port, dst, dst_port); } } diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index 4bacac548..336d8992c 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -64,7 +64,6 @@ impl> HalfNodeView { } impl> CfgNodeMap for HalfNodeView { - type Iterator<'c> = as IntoIterator>::IntoIter where Self: 'c; fn entry_node(&self) -> HalfNode { HalfNode::N(self.entry) } @@ -72,7 +71,7 @@ impl> CfgNodeMap for HalfNodeView assert!(self.bb_succs(self.exit).count() == 0); HalfNode::N(self.exit) } - fn predecessors(&self, h: HalfNode) -> Self::Iterator<'_> { + fn predecessors(&self, h: HalfNode) -> impl Iterator { let mut ps = Vec::new(); match h { HalfNode::N(ni) => ps.extend(self.bb_preds(ni).map(|n| self.resolve_out(n))), @@ -83,7 +82,7 @@ impl> CfgNodeMap for HalfNodeView } ps.into_iter() } - fn successors(&self, n: HalfNode) -> Self::Iterator<'_> { + fn successors(&self, n: HalfNode) -> impl Iterator { let mut succs = Vec::new(); match n { HalfNode::N(ni) if self.is_multi_node(ni) => succs.push(HalfNode::X(ni)), diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 51fd07d9b..249eed6b4 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -276,7 +276,7 @@ mod test { .nodes() .filter(|n| h.get_optype(*n).cast::().is_some()); let (entry_nop, expected_backedge_target) = if self_loop { - assert_eq!(h.children(r).len(), 2); + assert_eq!(h.children(r).count(), 2); (nops.exactly_one().ok().unwrap(), entry) } else { let [_, _, no_b2] = h.children(r).collect::>().try_into().unwrap(); diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 31f91b9e4..fa7106432 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -67,15 +67,10 @@ pub trait CfgNodeMap { fn entry_node(&self) -> T; /// The unique exit node of the CFG. The only node to have no successors. fn exit_node(&self) -> T; - /// Allows the trait implementor to define a type of iterator it will return from - /// `successors` and `predecessors`. - type Iterator<'c>: Iterator - where - Self: 'c; /// Returns an iterator over the successors of the specified basic block. - fn successors(&self, node: T) -> Self::Iterator<'_>; + fn successors(&self, node: T) -> impl Iterator; /// Returns an iterator over the predecessors of the specified basic block. - fn predecessors(&self, node: T) -> Self::Iterator<'_>; + fn predecessors(&self, node: T) -> impl Iterator; } /// Extension of [CfgNodeMap] to that can perform (mutable/destructive) @@ -242,15 +237,11 @@ impl CfgNodeMap for IdentityCfgMap { self.exit } - type Iterator<'c> = ::Neighbours<'c> - where - Self: 'c; - - fn successors(&self, node: Node) -> Self::Iterator<'_> { + fn successors(&self, node: Node) -> impl Iterator { self.h.neighbours(node, Direction::Outgoing) } - fn predecessors(&self, node: Node) -> Self::Iterator<'_> { + fn predecessors(&self, node: Node) -> impl Iterator { self.h.neighbours(node, Direction::Incoming) } } @@ -731,9 +722,9 @@ pub(crate) mod test { // | \-> right -/ | // \---<---<---<---<---<---<---<---<---/ // split is unique successor of head - let split = h.output_neighbours(head).exactly_one().unwrap(); + let split = h.output_neighbours(head).exactly_one().ok().unwrap(); // merge is unique predecessor of tail - let merge = h.input_neighbours(tail).exactly_one().unwrap(); + let merge = h.input_neighbours(tail).exactly_one().ok().unwrap(); // There's no need to use a view of a region here but we do so just to check // that we *can* (as we'll need to for "real" module Hugr's)