Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ignore input->output links in SiblingSubgraph::try_new_dataflow_subgraph #589

Merged
merged 1 commit into from
Oct 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,16 @@ pub type IncomingPorts = Vec<Vec<(Node, Port)>>;
pub type OutgoingPorts = Vec<(Node, Port)>;

impl SiblingSubgraph {
/// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted HUGR.
/// A sibling subgraph from a [`crate::ops::OpTag::DataflowParent`]-rooted
/// HUGR.
///
/// The subgraph is given by the nodes between the input and output
/// children nodes of the root node. If you wish to create a subgraph
/// from another root, wrap the `region` argument in a [`super::SiblingGraph`].
/// The subgraph is given by the nodes between the input and output children
/// nodes of the root node. If you wish to create a subgraph from another
/// root, wrap the `region` argument in a [`super::SiblingGraph`].
///
/// Wires connecting the input and output nodes are ignored. Note that due
/// to this the resulting subgraph's signature may not match the signature
/// of the dataflow parent.
///
/// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the
/// subgraph is empty.
Expand Down Expand Up @@ -350,8 +355,7 @@ impl SiblingSubgraph {
if !OpTag::Dfg.is_superset(dfg_optype.tag()) {
return Err(InvalidReplacement::InvalidDataflowGraph);
}
let Some((rep_input, rep_output)) = replacement.children(rep_root).take(2).collect_tuple()
else {
let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else {
return Err(InvalidReplacement::InvalidDataflowParent);
};
if dfg_optype.signature() != self.signature(hugr) {
Expand Down Expand Up @@ -587,11 +591,7 @@ fn validate_subgraph<H: HugrView>(
}

fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPorts) {
let (inp, out) = hugr
.children(hugr.root())
.take(2)
.collect_tuple()
.expect("invalid DFG");
let [inp, out] = hugr.get_io(hugr.root()).expect("invalid DFG");
if has_other_edge(hugr, inp, Direction::Outgoing) {
unimplemented!("Non-dataflow output not supported at input node")
}
Expand All @@ -600,18 +600,23 @@ fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPort
unimplemented!("Non-dataflow input not supported at output node")
}
let dfg_outputs = hugr.get_optype(out).signature().input_ports();

// Collect for each port in the input the set of target ports, filtering
// direct wires to the output.
let inputs = dfg_inputs
.into_iter()
.map(|p| hugr.linked_ports(inp, p).collect())
.map(|p| {
hugr.linked_ports(inp, p)
.filter(|&(n, _)| n != out)
.collect_vec()
})
.filter(|v| !v.is_empty())
.collect();
// Collect for each port in the output the set of source ports, filtering
// direct wires to the input.
let outputs = dfg_outputs
.into_iter()
.map(|p| {
hugr.linked_ports(out, p)
.exactly_one()
.ok()
.expect("invalid DFG")
})
.filter_map(|p| hugr.linked_ports(out, p).find(|&(n, _)| n != inp))
.collect();
(inputs, outputs)
}
Expand Down Expand Up @@ -762,12 +767,13 @@ mod tests {
let mut mod_builder = ModuleBuilder::new();
let func = mod_builder.declare(
"test",
FunctionType::new_linear(type_row![QB_T, QB_T]).pure(),
FunctionType::new_linear(type_row![QB_T, QB_T, QB_T]).pure(),
)?;
let func_id = {
let mut dfg = mod_builder.define_declaration(&func)?;
let outs = dfg.add_dataflow_op(cx_gate(), dfg.input_wires())?;
dfg.finish_with_outputs(outs.outputs())?
let [w0, w1, w2] = dfg.input_wires_arr();
let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
dfg.finish_with_outputs([w0, w1, w2])?
};
let hugr = mod_builder
.finish_prelude_hugr()
Expand Down Expand Up @@ -857,6 +863,8 @@ mod tests {
let (hugr, dfg) = build_hugr().unwrap();
let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, dfg).unwrap();
let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
// The identity wire on the third qubit is ignored, so the subgraph's signature only contains
// the first two qubits.
assert_eq!(
sub.signature(&func),
FunctionType::new_linear(type_row![QB_T, QB_T])
Expand Down Expand Up @@ -899,15 +907,17 @@ mod tests {
#[test]
fn convex_subgraph_2() {
let (hugr, func_root) = build_hugr().unwrap();
let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, out] = hugr.get_io(func_root).unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
// All graph except input/output nodes
SiblingSubgraph::try_new(
hugr.node_outputs(inp)
.take(2)
.map(|p| hugr.linked_ports(inp, p).collect_vec())
.filter(|ps| !ps.is_empty())
.collect(),
hugr.node_inputs(out)
.take(2)
.filter_map(|p| hugr.linked_ports(out, p).exactly_one().ok())
.collect(),
&func,
Expand All @@ -919,7 +929,7 @@ mod tests {
fn degen_boundary() {
let (hugr, func_root) = build_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, _] = hugr.get_io(func_root).unwrap();
let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
// All graph but one edge
assert_matches!(
Expand All @@ -938,7 +948,7 @@ mod tests {
fn non_convex_subgraph() {
let (hugr, func_root) = build_3not_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, _out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, _out] = hugr.get_io(func_root).unwrap();
let not1 = hugr.output_neighbours(inp).exactly_one().unwrap();
let not2 = hugr.output_neighbours(not1).exactly_one().unwrap();
let not3 = hugr.output_neighbours(not2).exactly_one().unwrap();
Expand All @@ -960,7 +970,7 @@ mod tests {
fn invalid_boundary() {
let (hugr, func_root) = build_hugr().unwrap();
let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
let (inp, out) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let [inp, out] = hugr.get_io(func_root).unwrap();
let cx_edges_in = hugr.node_outputs(inp);
let cx_edges_out = hugr.node_inputs(out);
// All graph but the CX
Expand Down