diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index a1cd8d57a..272933389 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -97,7 +97,6 @@ impl Rewrite for SimpleReplacement { h.overwrite_node_metadata(new_node, meta); } // Add edges between all newly added nodes matching those in replacement. - // TODO This will probably change when implicit copies are implemented. for &node in replacement_inner_nodes { let new_node = index_map.get(&node).unwrap(); for outport in self.replacement.node_outputs(node) { @@ -109,6 +108,17 @@ impl Rewrite for SimpleReplacement { } } } + + // Now we proceed to connect the edges between the newly inserted + // replacement and the rest of the graph. + // + // We delay creating these connections to avoid them getting mixed with + // the pre-existing ones in the following logic. + // + // Existing connections to the removed subgraph will be automatically + // removed when the nodes are removed. + let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new(); + // 3.2. For each p = self.nu_inp[q] such that q is not an Output port, add an edge from the // predecessor of p to (the new copy of) q. for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp { @@ -117,14 +127,13 @@ impl Rewrite for SimpleReplacement { let (rem_inp_pred_node, rem_inp_pred_port) = h .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); - h.disconnect(*rem_inp_node, *rem_inp_port); let new_inp_node = index_map.get(rep_inp_node).unwrap(); - h.connect( + connect.insert(( rem_inp_pred_node, rem_inp_pred_port, *new_inp_node, *rep_inp_port, - ); + )); } } // 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an @@ -136,21 +145,18 @@ impl Rewrite for SimpleReplacement { .unwrap(); if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input { let new_out_node = index_map.get(&rep_out_pred_node).unwrap(); - h.disconnect(*rem_out_node, *rem_out_port); - h.connect( + connect.insert(( *new_out_node, rep_out_pred_port, *rem_out_node, *rem_out_port, - ); + )); } } // 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 // to p1. // // i.e. the replacement graph has direct edges between the input and output nodes. - let mut disconnect: HashSet<(Node, IncomingPort)> = HashSet::new(); - let mut connect: HashSet<(Node, OutgoingPort, Node, IncomingPort)> = HashSet::new(); for ((rem_out_node, rem_out_port), &rep_out_port) in &self.nu_out { let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port)); if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport { @@ -158,13 +164,11 @@ impl Rewrite for SimpleReplacement { let (rem_inp_pred_node, rem_inp_pred_port) = h .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); - // Delay connecting/disconnecting the nodes until after - // processing all nu_out entries. + // Delay connecting the nodes until after processing all nu_out + // entries. // // Otherwise, we might disconnect other wires in `rem_inp_node` // that are needed for the following iterations. - disconnect.insert((*rem_inp_node, *rem_inp_port)); - disconnect.insert((*rem_out_node, *rem_out_port)); connect.insert(( rem_inp_pred_node, rem_inp_pred_port, @@ -173,9 +177,6 @@ impl Rewrite for SimpleReplacement { )); } } - disconnect.into_iter().for_each(|(node, port)| { - h.disconnect(node, port); - }); connect .into_iter() .for_each(|(src_node, src_port, tgt_node, tgt_port)| { @@ -339,26 +340,56 @@ pub(in crate::hugr::rewrite) mod test { /// Returns the hugr and the nodes of the NOT gates, in order. #[fixture] pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { - fn build() -> Result<(Hugr, Vec), BuildError> { - let mut dfg_builder = - DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]))?; - let [b] = dfg_builder.input_wires_arr(); - - let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b])?; - let [b] = not_inp.outputs_arr(); - - let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b])?; - let [b0] = not_0.outputs_arr(); - let not_1 = dfg_builder.add_dataflow_op(NotOp, vec![b])?; - let [b1] = not_1.outputs_arr(); - - Ok(( - dfg_builder.finish_prelude_hugr_with_outputs([b0, b1])?, - vec![not_inp.node(), not_0.node(), not_1.node()], - )) - } + let mut dfg_builder = + DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); + let [b] = dfg_builder.input_wires_arr(); + + let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let [b] = not_inp.outputs_arr(); + + let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let [b0] = not_0.outputs_arr(); + let not_1 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let [b1] = not_1.outputs_arr(); + + ( + dfg_builder + .finish_prelude_hugr_with_outputs([b0, b1]) + .unwrap(), + vec![not_inp.node(), not_0.node(), not_1.node()], + ) + } - build().unwrap() + /// A hugr with a DFG root mapping BOOL_T to (BOOL_T, BOOL_T) + /// ┌─────────┐ + /// ┌────┤ (1) NOT ├── + /// ┌─────────┐ │ └─────────┘ + /// ─┤ (0) NOT ├───┤ + /// └─────────┘ │ + /// └───────────────── + /// + /// This can be replaced with a single NOT op, coping the input to the first output. + /// + /// Returns the hugr and the nodes of the NOT ops, in order. + #[fixture] + pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { + let mut dfg_builder = + DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); + let [b] = dfg_builder.input_wires_arr(); + + let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let [b] = not_inp.outputs_arr(); + + let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let [b0] = not_0.outputs_arr(); + let b1 = b; + + ( + dfg_builder + .finish_prelude_hugr_with_outputs([b0, b1]) + .unwrap(), + vec![not_inp.node(), not_0.node()], + ) } #[rstest] @@ -623,24 +654,28 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), orig.node_count()); } + /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input + /// directly to the outputs. + /// + /// https://github.com/CQCL/hugr/issues/1190 #[rstest] - fn test_copy_inputs( - dfg_hugr_copy_bools: (Hugr, Vec), - ) -> Result<(), Box> { + fn test_copy_inputs(dfg_hugr_copy_bools: (Hugr, Vec)) { let (mut hugr, nodes) = dfg_hugr_copy_bools; let (input_not, output_not_0, output_not_1) = nodes.into_iter().collect_tuple().unwrap(); let [_input, output] = hugr.get_io(hugr.root()).unwrap(); let replacement = { - let b = DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]))?; + let b = DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])) + .unwrap(); let [w] = b.input_wires_arr(); - b.finish_prelude_hugr_with_outputs([w, w])? + b.finish_prelude_hugr_with_outputs([w, w]).unwrap() }; let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap(); let subgraph = - SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr)?; + SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr) + .unwrap(); // A map from (target ports of edges from the Input node of `replacement`) to (target ports of // edges from nodes not in `removal` to nodes in `removal`). let nu_inp = [ @@ -674,8 +709,67 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(())); assert_eq!(hugr.node_count(), 3); + } - Ok(()) + /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input + /// directly to the output. + /// + /// https://github.com/CQCL/hugr/issues/1323 + #[rstest] + fn test_half_nots(dfg_hugr_half_not_bools: (Hugr, Vec)) { + let (mut hugr, nodes) = dfg_hugr_half_not_bools; + let (input_not, output_not_0) = nodes.into_iter().collect_tuple().unwrap(); + + let [_input, output] = hugr.get_io(hugr.root()).unwrap(); + + let (replacement, repl_not) = { + let mut b = + DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); + let [w] = b.input_wires_arr(); + let not = b.add_dataflow_op(NotOp, vec![w]).unwrap(); + let [w_not] = not.outputs_arr(); + ( + b.finish_prelude_hugr_with_outputs([w, w_not]).unwrap(), + not.node(), + ) + }; + let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap(); + + let subgraph = + SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap(); + // A map from (target ports of edges from the Input node of `replacement`) to (target ports of + // edges from nodes not in `removal` to nodes in `removal`). + let nu_inp = [ + ( + (repl_output, IncomingPort::from(0)), + (input_not, IncomingPort::from(0)), + ), + ( + (repl_not, IncomingPort::from(0)), + (input_not, IncomingPort::from(0)), + ), + ] + .into_iter() + .collect(); + // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to + // (input ports of the Output node of `replacement`). + let nu_out = [ + ((output, IncomingPort::from(0)), IncomingPort::from(0)), + ((output, IncomingPort::from(1)), IncomingPort::from(1)), + ] + .into_iter() + .collect(); + + let rewrite = SimpleReplacement { + subgraph, + replacement, + nu_inp, + nu_out, + }; + rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); + + assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(hugr.node_count(), 4); } use crate::hugr::rewrite::replace::Replacement;