diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 5915c5f02..9cfbf27bc 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -9,6 +9,8 @@ //! while the former provide views for subgraphs within a single level of the //! hierarchy. +use std::collections::HashSet; + use itertools::Itertools; use portgraph::{algorithms::ConvexChecker, view::Subgraph, Direction, PortView}; use thiserror::Error; @@ -18,11 +20,11 @@ use crate::{ handle::{ContainerHandle, DataflowOpID}, OpTag, OpTrait, }, - types::FunctionType, + types::{FunctionType, Type}, Hugr, Node, Port, SimpleReplacement, }; -use super::{sealed::HugrInternals, HugrView}; +use super::HugrView; /// A non-empty convex subgraph of a HUGR sibling graph. /// @@ -51,9 +53,24 @@ pub struct SiblingSubgraph<'g, Base> { base: &'g Base, /// The nodes of the induced subgraph. nodes: Vec, + /// The input ports of the subgraph. + /// + /// Grouped by input parameter. Each port must be unique and belong to a + /// node in `nodes`. + inputs: Vec>, + /// The output ports of the subgraph. + /// + /// Repeated ports are allowed and correspond to copying the output. Every + /// port must belong to a node in `nodes`. + outputs: Vec<(Node, Port)>, } -impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { +/// The type of the incoming boundary of [`SiblingSubgraph`]. +pub type IncomingPorts = Vec>; +/// The type of the outgoing boundary of [`SiblingSubgraph`]. +pub type OutgoingPorts = Vec<(Node, Port)>; + +impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> { /// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR. /// /// The subgraph is given by the nodes between the input and output @@ -62,19 +79,25 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { /// /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the /// subgraph is empty. - pub fn from_dataflow_graph(dfg_graph: &'g Base) -> Result + pub fn try_from_dataflow_graph(dfg_graph: &'g Base) -> Result where Base: HugrView, Root: ContainerHandle, { let parent = dfg_graph.root(); let nodes = dfg_graph.children(parent).skip(2).collect_vec(); + let (inputs, outputs) = get_input_output_ports(dfg_graph); + + validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?; + if nodes.is_empty() { Err(InvalidSubgraph::EmptySubgraph) } else { Ok(Self { base: dfg_graph, nodes, + inputs, + outputs, }) } } @@ -83,14 +106,10 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { /// /// Any sibling subgraph can be defined using two sets of boundary edges /// $B_I$ and $B_O$, the incoming and outgoing boundary edges respectively. - /// Intuitively, the sibling subgraph is all the edges and nodes between + /// Intuitively, the sibling subgraph is all the edges and nodes "between" /// an edge of $B_I$ and an edge of $B_O$. /// - /// The `incoming` and `outgoing` arguments give $B_I$ and $B_O$ respectively. - /// They can be either source or target ports. We currently assume that if - /// the source port of an outgoing boundary edge is linked to multiple - /// target ports, then all edges from the same source port are outgoing - /// boundary edges. + /// ## Definition /// /// More formally, the sibling subgraph of a graph $G = (V, E)$ given /// by sets of incoming and outoing boundary edges $B_I, B_O \subseteq E$ @@ -105,75 +124,96 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { /// - it is in $B_O$ if and only if it has a source inside of the subgraph /// and a target outside of it. /// + /// ## Arguments + /// + /// The `incoming` and `outgoing` arguments give $B_I$ and $B_O$ respectively. + /// Incoming edges must be given by incoming ports and outgoing edges by + /// outgoing ports. The ordering of the incoming and outgoing ports defines + /// the signature of the subgraph. + /// + /// Incoming boundary ports must be unique and partitioned by input + /// parameter: two ports within the same set of the partition must be + /// copyable and will result in the input being copied. Outgoing + /// boundary ports are given in a list and can appear multiple times if + /// they are copyable, in which case the output will be copied. + /// + /// ## Errors + /// /// This function fails if the subgraph is not convex, if the nodes /// do not share a common parent or if the subgraph is empty. - /// - /// The order of the boundary edges is used to determine the order of the - /// signature. - pub fn from_boundary_edges( + pub fn try_from_boundary_ports( base: &'g Base, - incoming: impl IntoIterator, - outgoing: impl IntoIterator, + incoming: IncomingPorts, + outgoing: OutgoingPorts, ) -> Result where Base: HugrView, { let mut checker = ConvexChecker::new(base.portgraph()); - Self::from_boundary_edges_with_checker(base, incoming, outgoing, &mut checker) + Self::try_from_boundary_ports_with_checker(base, incoming, outgoing, &mut checker) } /// Create a new sibling subgraph from some boundary edges. /// /// Provide a [`ConvexChecker`] instance to avoid constructing one for /// faster convexity check. If you do not have one, use - /// [`SiblingSubgraph::from_boundary_edges`]. + /// [`SiblingSubgraph::try_from_boundary_ports`]. /// - /// Refer to [`SiblingSubgraph::from_boundary_edges`] for the full + /// Refer to [`SiblingSubgraph::try_from_boundary_ports`] for the full /// documentation. - pub fn from_boundary_edges_with_checker( + pub fn try_from_boundary_ports_with_checker( base: &'g Base, - incoming: impl IntoIterator, - outgoing: impl IntoIterator, + inputs: IncomingPorts, + outputs: OutgoingPorts, checker: &mut ConvexChecker<&'g Base::Portgraph>, ) -> Result where Base: HugrView, { let pg = base.portgraph(); - let incoming = incoming.into_iter().flat_map(|(n, p)| match p.direction() { - Direction::Outgoing => base.linked_ports(n, p).collect(), - Direction::Incoming => vec![(n, p)], - }); - let outgoing = outgoing.into_iter().flat_map(|(n, p)| match p.direction() { - Direction::Incoming => base.linked_ports(n, p).collect(), - Direction::Outgoing => vec![(n, p)], - }); let to_pg = |(n, p): (Node, Port)| pg.port_index(n.index, p.offset).expect("invalid port"); + // Ordering of the edges here is preserved and becomes ordering of the signature. - let subpg = Subgraph::new_subgraph(pg, incoming.chain(outgoing).map(to_pg)); + let subpg = Subgraph::new_subgraph( + pg, + inputs + .iter() + .flatten() + .copied() + .chain(outputs.iter().copied()) + .map(to_pg), + ); + let nodes = subpg.nodes_iter().map_into().collect_vec(); + + validate_subgraph(base, &nodes, &inputs, &outputs)?; + if !subpg.is_convex_with_checker(checker) { return Err(InvalidSubgraph::NotConvex); } - let nodes = subpg.nodes_iter().map_into().collect_vec(); - if nodes.is_empty() { - return Err(InvalidSubgraph::EmptySubgraph); - } - if !nodes.iter().map(|&n| base.get_parent(n)).all_equal() { - return Err(InvalidSubgraph::NoSharedParent); - } - Ok(Self { base, nodes }) + + Ok(Self { + base, + nodes, + inputs, + outputs, + }) } /// Create a new convex sibling subgraph from a set of nodes. /// /// This fails if the set of nodes is not convex, nodes do not share a /// common parent or the subgraph is empty. - pub fn try_new(base: &'g Base, nodes: Vec) -> Result + pub fn try_new( + base: &'g Base, + nodes: Vec, + inputs: IncomingPorts, + outputs: OutgoingPorts, + ) -> Result where Base: HugrView, { let mut checker = ConvexChecker::new(base.portgraph()); - Self::try_new_with_checker(base, nodes, &mut checker) + Self::try_new_with_checker(base, nodes, inputs, outputs, &mut checker) } /// Create a new convex sibling subgraph from a set of nodes. @@ -186,21 +226,25 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { pub fn try_new_with_checker( base: &'g Base, nodes: Vec, + inputs: IncomingPorts, + outputs: OutgoingPorts, checker: &mut ConvexChecker<&'g Base::Portgraph>, ) -> Result where Base: HugrView, { + validate_subgraph(base, &nodes, &inputs, &outputs)?; + if !checker.is_node_convex(nodes.iter().map(|n| n.index)) { return Err(InvalidSubgraph::NotConvex); } - if nodes.is_empty() { - return Err(InvalidSubgraph::EmptySubgraph); - } - if !nodes.iter().map(|&n| base.get_parent(n)).all_equal() { - return Err(InvalidSubgraph::NoSharedParent); - } - Ok(Self { base, nodes }) + + Ok(Self { + base, + nodes, + inputs, + outputs, + }) } /// An iterator over the nodes in the subgraph. @@ -208,66 +252,24 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { &self.nodes } - /// Whether a port is at the subgraph boundary. - fn is_boundary_port(&self, n: Node, p: Port) -> bool - where - Base: HugrView, - { - // TODO: handle state order edges - if is_order_edge(self.base, n, p) { - unimplemented!("State order edges not supported at boundary") - } - self.nodes.contains(&n) - && self - .base - .linked_ports(n, p) - .any(|(n, _)| !self.nodes.contains(&n)) - } - - /// An iterator of the incoming boundary ports. - pub fn incoming_ports(&self) -> impl Iterator + '_ - where - Base: HugrView, - { - self.boundary_ports(Direction::Incoming) - } - - /// An iterator of the outgoing boundary ports. - pub fn outgoing_ports(&self) -> impl Iterator + '_ - where - Base: HugrView, - { - self.boundary_ports(Direction::Outgoing) - } - - /// An iterator of the boundary ports, either incoming or outgoing. - pub fn boundary_ports(&self, dir: Direction) -> impl Iterator + '_ - where - Base: HugrView, - { - self.nodes.iter().flat_map(move |&n| { - self.base - .node_ports(n, dir) - .filter(move |&p| self.is_boundary_port(n, p)) - .map(move |p| (n, p)) - }) - } - /// The signature of the subgraph. pub fn signature(&self) -> FunctionType where Base: HugrView, { let input = self - .incoming_ports() - .map(|(n, p)| { + .inputs + .iter() + .map(|part| { + let &(n, p) = part.iter().next().expect("is non-empty"); let sig = self.base.get_optype(n).signature(); sig.get(p).cloned().expect("must be dataflow edge") }) .collect_vec(); let output = self - .outgoing_ports() - .map(|(n, p)| { + .outputs + .iter() + .map(|&(n, p)| { let sig = self.base.get_optype(n).signature(); sig.get(p).cloned().expect("must be dataflow edge") }) @@ -337,20 +339,24 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { unimplemented!("Found state order edges in replacement graph"); } - let self_inputs = self.incoming_ports(); - let self_outputs = self.outgoing_ports(); let nu_inp = rep_inputs .into_iter() - .zip_eq(self_inputs) - .flat_map(|((rep_source_n, rep_source_p), self_target)| { + .zip_eq(&self.inputs) + .flat_map(|((rep_source_n, rep_source_p), self_targets)| { replacement .linked_ports(rep_source_n, rep_source_p) - .map(move |rep_target| (rep_target, self_target)) + .flat_map(move |rep_target| { + self_targets + .iter() + .map(move |&self_target| (rep_target, self_target)) + }) }) .collect(); - let nu_out = self_outputs + let nu_out = self + .outputs + .iter() .zip_eq(rep_outputs) - .flat_map(|((self_source_n, self_source_p), (_, rep_target_p))| { + .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| { self.base .linked_ports(self_source_n, self_source_p) .map(move |self_target| (self_target, rep_target_p)) @@ -367,10 +373,139 @@ impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { } } +/// The type of all ports in the iterator. +/// +/// If the array is empty or a port does not exist, returns `None`. +fn get_edge_type(hugr: &H, ports: &[(Node, Port)]) -> Option { + let &(n, p) = ports.first()?; + let edge_t = hugr.get_optype(n).signature().get(p)?.clone(); + ports + .iter() + .all(|&(n, p)| hugr.get_optype(n).signature().get(p) == Some(&edge_t)) + .then_some(edge_t) +} + +/// Whether a subgraph is valid. +/// +/// Does NOT check for convexity. +fn validate_subgraph( + hugr: &H, + nodes: &[Node], + inputs: &IncomingPorts, + outputs: &OutgoingPorts, +) -> Result<(), InvalidSubgraph> { + // Check nodes is not empty + if nodes.is_empty() { + return Err(InvalidSubgraph::EmptySubgraph); + } + // Check all nodes share parent + if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() { + return Err(InvalidSubgraph::NoSharedParent); + } + + // Check there are no linked "other" ports + if inputs + .iter() + .flatten() + .chain(outputs) + .any(|&(n, p)| is_order_edge(hugr, n, p)) + { + unimplemented!("Linked other ports not supported at boundary") + } + + // Check inputs are incoming ports and outputs are outgoing ports + if inputs + .iter() + .flatten() + .any(|(_, p)| p.direction() == Direction::Outgoing) + { + return Err(InvalidSubgraph::InvalidBoundary); + } + if outputs + .iter() + .any(|(_, p)| p.direction() == Direction::Incoming) + { + return Err(InvalidSubgraph::InvalidBoundary); + } + + let mut ports_inside = inputs.iter().flatten().chain(outputs).copied(); + let mut ports_outside = ports_inside + .clone() + .flat_map(|(n, p)| hugr.linked_ports(n, p)); + // Check incoming & outgoing ports have target resp. source inside + let nodes = nodes.iter().copied().collect::>(); + if ports_inside.any(|(n, _)| !nodes.contains(&n)) { + return Err(InvalidSubgraph::InvalidBoundary); + } + // Check incoming & outgoing ports have source resp. target outside + if ports_outside.any(|(n, _)| nodes.contains(&n)) { + return Err(InvalidSubgraph::NotConvex); + } + + // Check inputs are unique + if !inputs.iter().flatten().all_unique() { + return Err(InvalidSubgraph::InvalidBoundary); + } + + // Check no incoming partition is empty + if inputs.iter().any(|p| p.is_empty()) { + return Err(InvalidSubgraph::InvalidBoundary); + } + + // Check edge types are equal within partition and copyable if partition size > 1 + if !inputs.iter().all(|ports| { + let Some(edge_t) = get_edge_type(hugr, ports) else { + return false; + }; + let require_copy = ports.len() > 1; + !require_copy || edge_t.copyable() + }) { + return Err(InvalidSubgraph::InvalidBoundary); + } + + Ok(()) +} + +fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPorts) { + let (inp, out) = hugr + .children(hugr.root()) + .take(2) + .collect_tuple() + .expect("invalid DFG"); + if has_other_edge(hugr, inp, Direction::Outgoing) { + unimplemented!("Non-dataflow output not supported at input node") + } + let dfg_inputs = hugr.get_optype(inp).signature().output_ports(); + if has_other_edge(hugr, out, Direction::Incoming) { + unimplemented!("Non-dataflow input not supported at output node") + } + let dfg_outputs = hugr.get_optype(out).signature().input_ports(); + let inputs = dfg_inputs + .into_iter() + .map(|p| hugr.linked_ports(inp, p).collect()) + .collect(); + let outputs = dfg_outputs + .into_iter() + .map(|p| { + hugr.linked_ports(out, p) + .exactly_one() + .ok() + .expect("invalid DFG") + }) + .collect(); + (inputs, outputs) +} + /// Whether a port is linked to a state order edge. fn is_order_edge(hugr: &H, node: Node, port: Port) -> bool { - hugr.get_optype(node).signature().get(port).is_none() - && hugr.linked_ports(node, port).count() > 0 + let op = hugr.get_optype(node); + op.other_port_index(port.direction()) == Some(port) && hugr.is_linked(node, port) +} + +/// Whether node has a non-df linked port in the given direction. +fn has_other_edge(hugr: &H, node: Node, dir: Direction) -> bool { + let op = hugr.get_optype(node); + op.other_port(dir).is_some() && hugr.is_linked(node, op.other_port_index(dir).unwrap()) } /// Errors that can occur while constructing a [`SimpleReplacement`]. @@ -402,6 +537,9 @@ pub enum InvalidSubgraph { /// Empty subgraphs are not supported. #[error("Empty subgraphs are not supported.")] EmptySubgraph, + /// An invalid boundary port was found. + #[error("Invalid boundary port.")] + InvalidBoundary, } #[cfg(test)] @@ -411,16 +549,19 @@ mod tests { BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }, - extension::prelude::QB_T, + extension::prelude::{BOOL_T, QB_T}, hugr::views::{HierarchyView, SiblingGraph}, - ops::handle::{FuncID, NodeHandle}, - std_extensions::quantum::test::cx_gate, + ops::{ + handle::{FuncID, NodeHandle}, + OpType, + }, + std_extensions::{logic::test::and_op, quantum::test::cx_gate}, type_row, }; use super::*; - impl<'g, Base: HugrInternals> SiblingSubgraph<'g, Base> { + impl<'g, Base: HugrView> SiblingSubgraph<'g, Base> { /// A sibling subgraph from a HUGR. /// /// The subgraph is given by the sibling graph of the root. If you wish to @@ -441,6 +582,8 @@ mod tests { Ok(Self { base: sibling_graph, nodes, + inputs: Vec::new(), + outputs: Vec::new(), }) } } @@ -463,6 +606,25 @@ mod tests { Ok((hugr, func_id.node())) } + /// A HUGR with a copy + fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { + let mut mod_builder = ModuleBuilder::new(); + let func = mod_builder.declare( + "test", + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).pure(), + )?; + let func_id = { + let mut dfg = mod_builder.define_declaration(&func)?; + let in_wire = dfg.input_wires().exactly_one().unwrap(); + let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?; + dfg.finish_with_outputs(outs.outputs())? + }; + let hugr = mod_builder + .finish_hugr() + .map_err(|e| -> BuildError { e.into() })?; + Ok((hugr, func_id.node())) + } + #[test] fn construct_subgraph() -> Result<(), InvalidSubgraph> { let (hugr, func_root) = build_hugr().unwrap(); @@ -479,7 +641,7 @@ mod tests { fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { let (mut hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); - let sub = SiblingSubgraph::from_dataflow_graph(&func)?; + let sub = SiblingSubgraph::try_from_dataflow_graph(&func)?; let empty_dfg = { let builder = DFGBuilder::new(FunctionType::new_linear(type_row![QB_T, QB_T])).unwrap(); @@ -502,7 +664,7 @@ mod tests { fn test_signature() -> Result<(), InvalidSubgraph> { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, dfg); - let sub = SiblingSubgraph::from_dataflow_graph(&func)?; + let sub = SiblingSubgraph::try_from_dataflow_graph(&func)?; assert_eq!( sub.signature(), FunctionType::new_linear(type_row![QB_T, QB_T]) @@ -534,7 +696,7 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); assert_eq!( - SiblingSubgraph::from_dataflow_graph(&func) + SiblingSubgraph::try_from_dataflow_graph(&func) .unwrap() .nodes() .len(), @@ -548,28 +710,62 @@ mod tests { let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); // All graph except input/output nodes - SiblingSubgraph::from_boundary_edges( + SiblingSubgraph::try_from_boundary_ports( &func, - hugr.node_outputs(inp).map(|p| (inp, p)), - hugr.node_inputs(out).map(|p| (out, p)), + hugr.node_outputs(inp) + .map(|p| hugr.linked_ports(inp, p).collect_vec()) + .filter(|ps| !ps.is_empty()) + .collect(), + hugr.node_inputs(out) + .filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok()) + .collect(), ) .unwrap(); } #[test] - fn non_convex_subgraph() { + fn degen_boundary() { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap(); let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); // All graph but one edge assert!(matches!( - SiblingSubgraph::from_boundary_edges( + SiblingSubgraph::try_from_boundary_ports( &func, - [(inp, first_cx_edge)], - [(inp, first_cx_edge)], + vec![hugr.linked_ports(inp, first_cx_edge).collect()], + vec![(inp, first_cx_edge)], ), Err(InvalidSubgraph::NotConvex) )); } + + #[test] + fn non_convex_subgraph() { + let (hugr, func_root) = build_hugr().unwrap(); + let func: SiblingGraph<'_> = SiblingGraph::new(&hugr, func_root); + let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); + let snd_cx_edge = hugr.node_inputs(out).next().unwrap(); + // All graph but one edge + assert!(matches!( + SiblingSubgraph::try_from_boundary_ports( + &func, + vec![vec![(out, snd_cx_edge)]], + vec![(inp, first_cx_edge)], + ), + Err(InvalidSubgraph::NotConvex) + )); + } + + #[test] + fn preserve_signature() { + let (hugr, func_root) = build_hugr_classical().unwrap(); + let func: SiblingGraph<'_, FuncID> = SiblingGraph::new(&hugr, func_root); + let func = SiblingSubgraph::try_from_dataflow_graph(&func).unwrap(); + let OpType::FuncDefn(func_defn) = hugr.get_optype(func_root) else { + panic!() + }; + assert_eq!(func_defn.signature, func.signature()) + } }