Skip to content

Commit

Permalink
fix: ignore input->output links in SiblingSubgraph::try_new_dataflow_…
Browse files Browse the repository at this point in the history
…subgraph (#589)

And some drive-bys to modernise the code.
  • Loading branch information
aborgna-q authored Oct 5, 2023
1 parent 416a665 commit e119a99
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,16 @@ pub type IncomingPorts = Vec<Vec<(Node, Port)>>;
pub type OutgoingPorts = Vec<(Node, Port)>;

impl SiblingSubgraph {
/// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR.
/// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted
/// HUGR.
///
/// The subgraph is given by the nodes between the input and output
/// children nodes of the root node. If you wish to create a subgraph
/// from another root, wrap the `region` argument in a [`super::SiblingGraph`].
/// The subgraph is given by the nodes between the input and output children
/// nodes of the root node. If you wish to create a subgraph from another
/// root, wrap the `region` argument in a [`super::SiblingGraph`].
///
/// Wires connecting the input and output nodes are ignored. Note that due
/// to this the resulting subgraph's signature may not match the signature
/// of the dataflow parent.
///
/// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the
/// subgraph is empty.
Expand Down Expand Up @@ -350,8 +355,7 @@ impl SiblingSubgraph {
if !OpTag::Dfg.is_superset(dfg_optype.tag()) {
return Err(InvalidReplacement::InvalidDataflowGraph);
}
let Some((rep_input, rep_output)) = replacement.children(rep_root).take(2).collect_tuple()
else {
let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else {
return Err(InvalidReplacement::InvalidDataflowParent);
};
if dfg_optype.signature() != self.signature(hugr) {
Expand Down Expand Up @@ -587,11 +591,7 @@ fn validate_subgraph<H: HugrView>(
}

fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPorts) {
let (inp, out) = hugr
.children(hugr.root())
.take(2)
.collect_tuple()
.expect("invalid DFG");
let [inp, out] = hugr.get_io(hugr.root()).expect("invalid DFG");
if has_other_edge(hugr, inp, Direction::Outgoing) {
unimplemented!("Non-dataflow output not supported at input node")
}
Expand All @@ -600,18 +600,23 @@ fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPort
unimplemented!("Non-dataflow input not supported at output node")
}
let dfg_outputs = hugr.get_optype(out).signature().input_ports();

// Collect for each port in the input the set of target ports, filtering
// direct wires to the output.
let inputs = dfg_inputs
.into_iter()
.map(|p| hugr.linked_ports(inp, p).collect())
.map(|p| {
hugr.linked_ports(inp, p)
.filter(|&(n, _)| n != out)
.collect_vec()
})
.filter(|v| !v.is_empty())
.collect();
// Collect for each port in the output the set of source ports, filtering
// direct wires to the input.
let outputs = dfg_outputs
.into_iter()
.map(|p| {
hugr.linked_ports(out, p)
.exactly_one()
.ok()
.expect("invalid DFG")
})
.filter_map(|p| hugr.linked_ports(out, p).find(|&(n, _)| n != inp))
.collect();
(inputs, outputs)
}
Expand Down Expand Up @@ -762,12 +767,13 @@ mod tests {
let mut mod_builder = ModuleBuilder::new();
let func = mod_builder.declare(
"test",
FunctionType::new_linear(type_row![QB_T, QB_T]).pure(),
FunctionType::new_linear(type_row![QB_T, QB_T, QB_T]).pure(),
)?;
let func_id = {
let mut dfg = mod_builder.define_declaration(&func)?;
let outs = dfg.add_dataflow_op(cx_gate(), dfg.input_wires())?;
dfg.finish_with_outputs(outs.outputs())?
let [w0, w1, w2] = dfg.input_wires_arr();
let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
dfg.finish_with_outputs([w0, w1, w2])?
};
let hugr = mod_builder
.finish_prelude_hugr()
Expand Down Expand Up @@ -857,6 +863,8 @@ mod tests {
let (hugr, dfg) = build_hugr().unwrap();
let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, dfg).unwrap();
let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
// The identity wire on the third qubit is ignored, so the subgraph's signature only contains
// the first two qubits.
assert_eq!(
sub.signature(&func),
FunctionType::new_linear(type_row![QB_T, QB_T])
Expand Down Expand Up @@ -899,15 +907,17 @@ mod tests {
#[test]
fn convex_subgraph_2() {
let (hugr, func_root) = build_hugr().unwrap();
let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, out] = hugr.get_io(func_root).unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
// All graph except input/output nodes
SiblingSubgraph::try_new(
hugr.node_outputs(inp)
.take(2)
.map(|p| hugr.linked_ports(inp, p).collect_vec())
.filter(|ps| !ps.is_empty())
.collect(),
hugr.node_inputs(out)
.take(2)
.filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok())
.collect(),
&func,
Expand All @@ -919,7 +929,7 @@ mod tests {
fn degen_boundary() {
let (hugr, func_root) = build_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, _] = hugr.get_io(func_root).unwrap();
let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
// All graph but one edge
assert_matches!(
Expand All @@ -938,7 +948,7 @@ mod tests {
fn non_convex_subgraph() {
let (hugr, func_root) = build_3not_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, _out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, _out] = hugr.get_io(func_root).unwrap();
let not1 = hugr.output_neighbours(inp).exactly_one().unwrap();
let not2 = hugr.output_neighbours(not1).exactly_one().unwrap();
let not3 = hugr.output_neighbours(not2).exactly_one().unwrap();
Expand All @@ -960,7 +970,7 @@ mod tests {
fn invalid_boundary() {
let (hugr, func_root) = build_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, out] = hugr.get_io(func_root).unwrap();
let cx_edges_in = hugr.node_outputs(inp);
let cx_edges_out = hugr.node_inputs(out);
// All graph but the CX
Expand Down

0 comments on commit e119a99

Please sign in to comment.