Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: hierarchical simple replacement using insert_hugr #1718

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ pub(crate) mod test {
pub(super) const QB: Type = crate::extension::prelude::QB_T;

/// Wire up inputs of a Dataflow container to the outputs.
pub(super) fn n_identity<T: DataflowSubContainer>(
pub(crate) fn n_identity<T: DataflowSubContainer>(
dataflow_builder: T,
) -> Result<T::ContainerHandle, BuildError> {
let w = dataflow_builder.input_wires();
Expand Down
109 changes: 75 additions & 34 deletions hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

use std::collections::{HashMap, HashSet};

use crate::hugr::hugrmut::InsertionResult;
pub use crate::hugr::internal::HugrMutInternals;
use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrMut, HugrView, NodeMetadataMap, Rewrite};
use crate::hugr::{HugrMut, HugrView, Rewrite};
use crate::ops::{OpTag, OpTrait, OpType};
use crate::{Hugr, IncomingPort, Node, OutgoingPort};
use thiserror::Error;

use super::inline_dfg::InlineDFGError;

/// Specification of a simple replacement operation.
#[derive(Debug, Clone)]
pub struct SimpleReplacement {
Expand Down Expand Up @@ -62,7 +66,7 @@ impl Rewrite for SimpleReplacement {
unimplemented!()
}

fn apply(mut self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.subgraph.get_parent(h);
// 1. Check the parent node exists and is a DataflowParent.
if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
Expand All @@ -75,39 +79,23 @@ impl Rewrite for SimpleReplacement {
}
}
// 3. Do the replacement.
// 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes.
// Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self).
let mut index_map: HashMap<Node, Node> = HashMap::new();
let replacement_nodes = self
.replacement
.children(self.replacement.root())
.collect::<Vec<Node>>();
// slice of nodes omitting Input and Output:
let replacement_inner_nodes = &replacement_nodes[2..];
let self_output_node = h.children(parent).nth(1).unwrap();
let replacement_output_node = *replacement_nodes.get(1).unwrap();
for &node in replacement_inner_nodes {
// Add the nodes.
let op: &OpType = self.replacement.get_optype(node);
let new_node = h.add_node_after(self_output_node, op.clone());
index_map.insert(node, new_node);

// Move the metadata
let meta: Option<NodeMetadataMap> = self.replacement.take_node_metadata(node);
h.overwrite_node_metadata(new_node, meta);
// 3.1. Insert the replacement as a whole.
let InsertionResult {
new_root,
node_map: index_map,
} = h.insert_hugr(parent, self.replacement.clone());

// remove the Input and Output nodes from the replacement graph
let replace_children = h.children(new_root).collect::<Vec<Node>>();
for &io in &replace_children[..2] {
h.remove_node(io);
}
// Add edges between all newly added nodes matching those in replacement.
for &node in replacement_inner_nodes {
let new_node = index_map.get(&node).unwrap();
for outport in self.replacement.node_outputs(node) {
for target in self.replacement.linked_inputs(node, outport) {
if self.replacement.get_optype(target.0).tag() != OpTag::Output {
let new_target = index_map.get(&target.0).unwrap();
h.connect(*new_node, outport, *new_target, target.1);
}
}
}
// make all replacement top level children children of the parent
for &child in &replace_children[2..] {
h.set_parent(child, parent);
}
// remove the replacement root (which now has no children and no edges)
h.remove_node(new_root);

// Now we proceed to connect the edges between the newly inserted
// replacement and the rest of the graph.
Expand Down Expand Up @@ -136,6 +124,10 @@ impl Rewrite for SimpleReplacement {
));
}
}
let replacement_output_node = self
.replacement
.get_io(self.replacement.root())
.expect("parent already checked.")[1];
// 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
Expand Down Expand Up @@ -213,6 +205,9 @@ pub enum SimpleReplacementError {
/// Node in replacement graph is invalid.
#[error("A node in the replacement graph is invalid.")]
InvalidReplacementNode(),
/// Inlining replacement failed.
#[error("Inlining replacement failed: {0}")]
InliningFailed(#[from] InlineDFGError),
}

#[cfg(test)]
Expand All @@ -221,11 +216,12 @@ pub(in crate::hugr::rewrite) mod test {
use rstest::{fixture, rstest};
use std::collections::{HashMap, HashSet};

use crate::builder::test::n_identity;
use crate::builder::{
endo_sig, inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::BOOL_T;
use crate::extension::prelude::{BOOL_T, QB_T};
use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::views::{HugrView, SiblingSubgraph};
use crate::hugr::{Hugr, HugrMut, Rewrite};
Expand Down Expand Up @@ -774,6 +770,51 @@ pub(in crate::hugr::rewrite) mod test {
assert_eq!(hugr.node_count(), 4);
}

#[rstest]
fn test_nested_replace(dfg_hugr2: Hugr) {
// replace a node with a hugr with children

let mut h = dfg_hugr2;
let h_node = h
.nodes()
.find(|node: &Node| *h.get_optype(*node) == h_gate().into())
.unwrap();

// build a nested identity dfg
let mut nest_build = DFGBuilder::new(Signature::new_endo(QB_T)).unwrap();
let [input] = nest_build.input_wires_arr();
let inner_build = nest_build.dfg_builder_endo([(QB_T, input)]).unwrap();
let inner_dfg = n_identity(inner_build).unwrap();
let inner_dfg_node = inner_dfg.node();
let replacement = nest_build
.finish_prelude_hugr_with_outputs([inner_dfg.out_wire(0)])
.unwrap();
let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap();
let nu_inp = vec![(
(inner_dfg_node, IncomingPort::from(0)),
(h_node, IncomingPort::from(0)),
)]
.into_iter()
.collect();

let nu_out = vec![(
(h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)),
IncomingPort::from(0),
)]
.into_iter()
.collect();

let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out);

assert_eq!(h.node_count(), 4);

rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}"));
h.update_validate(&PRELUDE_REGISTRY)
.unwrap_or_else(|e| panic!("{e}"));

assert_eq!(h.node_count(), 6);
}

use crate::hugr::rewrite::replace::Replacement;
fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement {
use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec};
Expand Down
Loading