Skip to content

Commit

Permalink
fix: Panic on SimpleReplace with multiports (#1324)
Browse files Browse the repository at this point in the history
Fixes #1323

This fix generalises the one from #1191.
SimpleReplace was too eager in disconnecting/connecting edges to the
replacement graph, and that caused issues when querying the neighbours
of multiports.

This gets resolved by delaying all new connections to the replacement
until after we have computed all of them.
We don't need to do explicit disconnections to the replaced subgraph,
since the nodes get removed anyway.

---------

Co-authored-by: Douglas Wilson <141026920+doug-q@users.noreply.github.com>
  • Loading branch information
aborgna-q and doug-q authored Jul 19, 2024
1 parent 01da7ba commit dae7e67
Showing 1 changed file with 136 additions and 42 deletions.
178 changes: 136 additions & 42 deletions hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ impl Rewrite for SimpleReplacement {
h.overwrite_node_metadata(new_node, meta);
}
// Add edges between all newly added nodes matching those in replacement.
// TODO This will probably change when implicit copies are implemented.
for &node in replacement_inner_nodes {
let new_node = index_map.get(&node).unwrap();
for outport in self.replacement.node_outputs(node) {
Expand All @@ -109,6 +108,17 @@ impl Rewrite for SimpleReplacement {
}
}
}

// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
//
// We delay creating these connections to avoid them getting mixed with
// the pre-existing ones in the following logic.
//
// Existing connections to the removed subgraph will be automatically
// removed when the nodes are removed.
let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new();

// 3.2. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the
// predecessor of p to (the new copy of) q.
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp {
Expand All @@ -117,14 +127,13 @@ impl Rewrite for SimpleReplacement {
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port);
let new_inp_node = index_map.get(rep_inp_node).unwrap();
h.connect(
connect.insert((
rem_inp_pred_node,
rem_inp_pred_port,
*new_inp_node,
*rep_inp_port,
);
));
}
}
// 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
Expand All @@ -136,35 +145,30 @@ impl Rewrite for SimpleReplacement {
.unwrap();
if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input {
let new_out_node = index_map.get(&rep_out_pred_node).unwrap();
h.disconnect(*rem_out_node, *rem_out_port);
h.connect(
connect.insert((
*new_out_node,
rep_out_pred_port,
*rem_out_node,
*rem_out_port,
);
));
}
}
// 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
// to p1.
//
// i.e. the replacement graph has direct edges between the input and output nodes.
let mut disconnect: HashSet<(Node, IncomingPort)> = HashSet::new();
let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new();
for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out {
let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let (rem_inp_pred_node, rem_inp_pred_port) = h
.single_linked_output(*rem_inp_node, *rem_inp_port)
.unwrap();
// Delay connecting/disconnecting the nodes until after
// processing all nu_out entries.
// Delay connecting the nodes until after processing all nu_out
// entries.
//
// Otherwise, we might disconnect other wires in `rem_inp_node`
// that are needed for the following iterations.
disconnect.insert((*rem_inp_node, *rem_inp_port));
disconnect.insert((*rem_out_node, *rem_out_port));
connect.insert((
rem_inp_pred_node,
rem_inp_pred_port,
Expand All @@ -173,9 +177,6 @@ impl Rewrite for SimpleReplacement {
));
}
}
disconnect.into_iter().for_each(|(node, port)| {
h.disconnect(node, port);
});
connect
.into_iter()
.for_each(|(src_node, src_port, tgt_node, tgt_port)| {
Expand Down Expand Up @@ -339,26 +340,56 @@ pub(in crate::hugr::rewrite) mod test {
/// Returns the hugr and the nodes of the NOT gates, in order.
#[fixture]
pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec<Node>) {
fn build() -> Result<(Hugr, Vec<Node>), BuildError> {
let mut dfg_builder =
DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]))?;
let [b] = dfg_builder.input_wires_arr();

let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b])?;
let [b] = not_inp.outputs_arr();

let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b])?;
let [b0] = not_0.outputs_arr();
let not_1 = dfg_builder.add_dataflow_op(NotOp, vec![b])?;
let [b1] = not_1.outputs_arr();

Ok((
dfg_builder.finish_prelude_hugr_with_outputs([b0, b1])?,
vec![not_inp.node(), not_0.node(), not_1.node()],
))
}
let mut dfg_builder =
DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap();
let [b] = dfg_builder.input_wires_arr();

let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap();
let [b] = not_inp.outputs_arr();

let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap();
let [b0] = not_0.outputs_arr();
let not_1 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap();
let [b1] = not_1.outputs_arr();

(
dfg_builder
.finish_prelude_hugr_with_outputs([b0, b1])
.unwrap(),
vec![not_inp.node(), not_0.node(), not_1.node()],
)
}

build().unwrap()
/// A hugr with a DFG root mapping BOOL_T to (BOOL_T, BOOL_T)
/// ┌─────────┐
/// ┌────┤ (1) NOT ├──
/// ┌─────────┐ │ └─────────┘
/// ─┤ (0) NOT ├───┤
/// └─────────┘ │
/// └─────────────────
///
/// This can be replaced with a single NOT op, coping the input to the first output.
///
/// Returns the hugr and the nodes of the NOT ops, in order.
#[fixture]
pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec<Node>) {
let mut dfg_builder =
DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap();
let [b] = dfg_builder.input_wires_arr();

let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap();
let [b] = not_inp.outputs_arr();

let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap();
let [b0] = not_0.outputs_arr();
let b1 = b;

(
dfg_builder
.finish_prelude_hugr_with_outputs([b0, b1])
.unwrap(),
vec![not_inp.node(), not_0.node()],
)
}

#[rstest]
Expand Down Expand Up @@ -623,24 +654,28 @@ pub(in crate::hugr::rewrite) mod test {
assert_eq!(h.node_count(), orig.node_count());
}

/// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input
/// directly to the outputs.
///
/// https://github.com/CQCL/hugr/issues/1190
#[rstest]
fn test_copy_inputs(
dfg_hugr_copy_bools: (Hugr, Vec<Node>),
) -> Result<(), Box<dyn std::error::Error>> {
fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec<Node>)) {
let (mut hugr, nodes) = dfg_hugr_copy_bools;
let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap();

let [_input, output] = hugr.get_io(hugr.root()).unwrap();

let replacement = {
let b = DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]))?;
let b = DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]))
.unwrap();
let [w] = b.input_wires_arr();
b.finish_prelude_hugr_with_outputs([w, w])?
b.finish_prelude_hugr_with_outputs([w, w]).unwrap()
};
let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();

let subgraph =
SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)?;
SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)
.unwrap();
// A map from (target ports of edges from the Input node of `replacement`) to (target ports of
// edges from nodes not in `removal` to nodes in `removal`).
let nu_inp = [
Expand Down Expand Up @@ -674,8 +709,67 @@ pub(in crate::hugr::rewrite) mod test {

assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(()));
assert_eq!(hugr.node_count(), 3);
}

Ok(())
/// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input
/// directly to the output.
///
/// https://github.com/CQCL/hugr/issues/1323
#[rstest]
fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec<Node>)) {
let (mut hugr, nodes) = dfg_hugr_half_not_bools;
let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap();

let [_input, output] = hugr.get_io(hugr.root()).unwrap();

let (replacement, repl_not) = {
let mut b =
DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap();
let [w] = b.input_wires_arr();
let not = b.add_dataflow_op(NotOp, vec![w]).unwrap();
let [w_not] = not.outputs_arr();
(
b.finish_prelude_hugr_with_outputs([w, w_not]).unwrap(),
not.node(),
)
};
let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap();

let subgraph =
SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap();
// A map from (target ports of edges from the Input node of `replacement`) to (target ports of
// edges from nodes not in `removal` to nodes in `removal`).
let nu_inp = [
(
(repl_output, IncomingPort::from(0)),
(input_not, IncomingPort::from(0)),
),
(
(repl_not, IncomingPort::from(0)),
(input_not, IncomingPort::from(0)),
),
]
.into_iter()
.collect();
// A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to
// (input ports of the Output node of `replacement`).
let nu_out = [
((output, IncomingPort::from(0)), IncomingPort::from(0)),
((output, IncomingPort::from(1)), IncomingPort::from(1)),
]
.into_iter()
.collect();

let rewrite = SimpleReplacement {
subgraph,
replacement,
nu_inp,
nu_out,
};
rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));

assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(()));
assert_eq!(hugr.node_count(), 4);
}

use crate::hugr::rewrite::replace::Replacement;
Expand Down

0 comments on commit dae7e67

Please sign in to comment.