diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index f40703243..8bde38cc6 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -256,7 +256,7 @@ pub(crate) mod test { pub(super) const QB: Type = crate::extension::prelude::QB_T; /// Wire up inputs of a Dataflow container to the outputs. - pub(super) fn n_identity( + pub(crate) fn n_identity( dataflow_builder: T, ) -> Result { let w = dataflow_builder.input_wires(); diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 3018adce0..026a996b8 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -2,12 +2,16 @@ use std::collections::{HashMap, HashSet}; +use crate::hugr::hugrmut::InsertionResult; +pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite}; +use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node, OutgoingPort}; use thiserror::Error; +use super::inline_dfg::InlineDFGError; + /// Specification of a simple replacement operation. #[derive(Debug, Clone)] pub struct SimpleReplacement { @@ -62,7 +66,7 @@ impl Rewrite for SimpleReplacement { unimplemented!() } - fn apply(mut self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { let parent = self.subgraph.get_parent(h); // 1. Check the parent node exists and is a DataflowParent. if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) { @@ -75,39 +79,23 @@ impl Rewrite for SimpleReplacement { } } // 3. Do the replacement. - // 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes. - // Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self). - let mut index_map: HashMap = HashMap::new(); - let replacement_nodes = self - .replacement - .children(self.replacement.root()) - .collect::>(); - // slice of nodes omitting Input and Output: - let replacement_inner_nodes = &replacement_nodes[2..]; - let self_output_node = h.children(parent).nth(1).unwrap(); - let replacement_output_node = *replacement_nodes.get(1).unwrap(); - for &node in replacement_inner_nodes { - // Add the nodes. - let op: &OpType = self.replacement.get_optype(node); - let new_node = h.add_node_after(self_output_node, op.clone()); - index_map.insert(node, new_node); - - // Move the metadata - let meta: Option = self.replacement.take_node_metadata(node); - h.overwrite_node_metadata(new_node, meta); + // 3.1. Insert the replacement as a whole. + let InsertionResult { + new_root, + node_map: index_map, + } = h.insert_hugr(parent, self.replacement.clone()); + + // remove the Input and Output nodes from the replacement graph + let replace_children = h.children(new_root).collect::>(); + for &io in &replace_children[..2] { + h.remove_node(io); } - // Add edges between all newly added nodes matching those in replacement. - for &node in replacement_inner_nodes { - let new_node = index_map.get(&node).unwrap(); - for outport in self.replacement.node_outputs(node) { - for target in self.replacement.linked_inputs(node, outport) { - if self.replacement.get_optype(target.0).tag() != OpTag::Output { - let new_target = index_map.get(&target.0).unwrap(); - h.connect(*new_node, outport, *new_target, target.1); - } - } - } + // make all replacement top level children children of the parent + for &child in &replace_children[2..] { + h.set_parent(child, parent); } + // remove the replacement root (which now has no children and no edges) + h.remove_node(new_root); // Now we proceed to connect the edges between the newly inserted // replacement and the rest of the graph. @@ -136,6 +124,10 @@ impl Rewrite for SimpleReplacement { )); } } + let replacement_output_node = self + .replacement + .get_io(self.replacement.root()) + .expect("parent already checked.")[1]; // 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an // edge from (the new copy of) the predecessor of q to p. for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out { @@ -213,6 +205,9 @@ pub enum SimpleReplacementError { /// Node in replacement graph is invalid. #[error("A node in the replacement graph is invalid.")] InvalidReplacementNode(), + /// Inlining replacement failed. + #[error("Inlining replacement failed: {0}")] + InliningFailed(#[from] InlineDFGError), } #[cfg(test)] @@ -221,11 +216,12 @@ pub(in crate::hugr::rewrite) mod test { use rstest::{fixture, rstest}; use std::collections::{HashMap, HashSet}; + use crate::builder::test::n_identity; use crate::builder::{ endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }; - use crate::extension::prelude::BOOL_T; + use crate::extension::prelude::{BOOL_T, QB_T}; use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; @@ -774,6 +770,51 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(hugr.node_count(), 4); } + #[rstest] + fn test_nested_replace(dfg_hugr2: Hugr) { + // replace a node with a hugr with children + + let mut h = dfg_hugr2; + let h_node = h + .nodes() + .find(|node: &Node| *h.get_optype(*node) == h_gate().into()) + .unwrap(); + + // build a nested identity dfg + let mut nest_build = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap(); + let [input] = nest_build.input_wires_arr(); + let inner_build = nest_build.dfg_builder_endo([(QB_T, input)]).unwrap(); + let inner_dfg = n_identity(inner_build).unwrap(); + let inner_dfg_node = inner_dfg.node(); + let replacement = nest_build + .finish_prelude_hugr_with_outputs([inner_dfg.out_wire(0)]) + .unwrap(); + let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap(); + let nu_inp = vec![( + (inner_dfg_node, IncomingPort::from(0)), + (h_node, IncomingPort::from(0)), + )] + .into_iter() + .collect(); + + let nu_out = vec![( + (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)), + IncomingPort::from(0), + )] + .into_iter() + .collect(); + + let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out); + + assert_eq!(h.node_count(), 4); + + rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); + h.update_validate(&PRELUDE_REGISTRY) + .unwrap_or_else(|e| panic!("{e}")); + + assert_eq!(h.node_count(), 6); + } + use crate::hugr::rewrite::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};