diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index ed291b614..66ffb09e6 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -87,11 +87,16 @@ pub type IncomingPorts = Vec>; pub type OutgoingPorts = Vec<(Node, Port)>; impl SiblingSubgraph { - /// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR. + /// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted + /// HUGR. /// - /// The subgraph is given by the nodes between the input and output - /// children nodes of the root node. If you wish to create a subgraph - /// from another root, wrap the `region` argument in a [`super::SiblingGraph`]. + /// The subgraph is given by the nodes between the input and output children + /// nodes of the root node. If you wish to create a subgraph from another + /// root, wrap the `region` argument in a [`super::SiblingGraph`]. + /// + /// Wires connecting the input and output nodes are ignored. Note that due + /// to this the resulting subgraph's signature may not match the signature + /// of the dataflow parent. /// /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the /// subgraph is empty. @@ -350,8 +355,7 @@ impl SiblingSubgraph { if !OpTag::Dfg.is_superset(dfg_optype.tag()) { return Err(InvalidReplacement::InvalidDataflowGraph); } - let Some((rep_input, rep_output)) = replacement.children(rep_root).take(2).collect_tuple() - else { + let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; if dfg_optype.signature() != self.signature(hugr) { @@ -587,11 +591,7 @@ fn validate_subgraph( } fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPorts) { - let (inp, out) = hugr - .children(hugr.root()) - .take(2) - .collect_tuple() - .expect("invalid DFG"); + let [inp, out] = hugr.get_io(hugr.root()).expect("invalid DFG"); if has_other_edge(hugr, inp, Direction::Outgoing) { unimplemented!("Non-dataflow output not supported at input node") } @@ -600,18 +600,23 @@ fn get_input_output_ports(hugr: &H) -> (IncomingPorts, OutgoingPort unimplemented!("Non-dataflow input not supported at output node") } let dfg_outputs = hugr.get_optype(out).signature().input_ports(); + + // Collect for each port in the input the set of target ports, filtering + // direct wires to the output. let inputs = dfg_inputs .into_iter() - .map(|p| hugr.linked_ports(inp, p).collect()) + .map(|p| { + hugr.linked_ports(inp, p) + .filter(|&(n, _)| n != out) + .collect_vec() + }) + .filter(|v| !v.is_empty()) .collect(); + // Collect for each port in the output the set of source ports, filtering + // direct wires to the input. let outputs = dfg_outputs .into_iter() - .map(|p| { - hugr.linked_ports(out, p) - .exactly_one() - .ok() - .expect("invalid DFG") - }) + .filter_map(|p| hugr.linked_ports(out, p).find(|&(n, _)| n != inp)) .collect(); (inputs, outputs) } @@ -762,12 +767,13 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new_linear(type_row![QB_T, QB_T]).pure(), + FunctionType::new_linear(type_row![QB_T, QB_T, QB_T]).pure(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; - let outs = dfg.add_dataflow_op(cx_gate(), dfg.input_wires())?; - dfg.finish_with_outputs(outs.outputs())? + let [w0, w1, w2] = dfg.input_wires_arr(); + let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr(); + dfg.finish_with_outputs([w0, w1, w2])? }; let hugr = mod_builder .finish_prelude_hugr() @@ -857,6 +863,8 @@ mod tests { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + // The identity wire on the third qubit is ignored, so the subgraph's signature only contains + // the first two qubits. assert_eq!( sub.signature(&func), FunctionType::new_linear(type_row![QB_T, QB_T]) @@ -899,15 +907,17 @@ mod tests { #[test] fn convex_subgraph_2() { let (hugr, func_root) = build_hugr().unwrap(); - let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let [inp, out] = hugr.get_io(func_root).unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); // All graph except input/output nodes SiblingSubgraph::try_new( hugr.node_outputs(inp) + .take(2) .map(|p| hugr.linked_ports(inp, p).collect_vec()) .filter(|ps| !ps.is_empty()) .collect(), hugr.node_inputs(out) + .take(2) .filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok()) .collect(), &func, @@ -919,7 +929,7 @@ mod tests { fn degen_boundary() { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let [inp, _] = hugr.get_io(func_root).unwrap(); let first_cx_edge = hugr.node_outputs(inp).next().unwrap(); // All graph but one edge assert_matches!( @@ -938,7 +948,7 @@ mod tests { fn non_convex_subgraph() { let (hugr, func_root) = build_3not_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let (inp, _out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let [inp, _out] = hugr.get_io(func_root).unwrap(); let not1 = hugr.output_neighbours(inp).exactly_one().unwrap(); let not2 = hugr.output_neighbours(not1).exactly_one().unwrap(); let not3 = hugr.output_neighbours(not2).exactly_one().unwrap(); @@ -960,7 +970,7 @@ mod tests { fn invalid_boundary() { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap(); + let [inp, out] = hugr.get_io(func_root).unwrap(); let cx_edges_in = hugr.node_outputs(inp); let cx_edges_out = hugr.node_inputs(out); // All graph but the CX