From cdcc73381586f20a03d7ca453ff2ba99d782d95f Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 11 Oct 2023 11:26:04 +0100 Subject: [PATCH 1/3] feat: Compute affected nodes for `SimpleReplacement` --- src/hugr/rewrite/simple_replace.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index c40f9776c..b4b477475 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -53,6 +53,19 @@ impl SimpleReplacement { pub fn subgraph(&self) -> &SiblingSubgraph { &self.subgraph } + + /// Returns the nodes affected by the replacement. + /// + /// This includes the nodes in the subgraph and the boundary neighbours that + /// are referenced by the replacement. + /// + /// Two `SimpleReplacement`s can be composed if their affected nodes are + /// disjoint. + pub fn affected_nodes(&self) -> impl Iterator + '_ { + let subcirc = self.subgraph.nodes().iter().copied(); + let out_neighs = self.nu_out.keys().map(|&(n, _)| n); + subcirc.chain(out_neighs) + } } impl Rewrite for SimpleReplacement { @@ -203,7 +216,7 @@ pub(in crate::hugr::rewrite) mod test { use itertools::Itertools; use portgraph::Direction; use rstest::{fixture, rstest}; - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; use crate::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, @@ -384,6 +397,11 @@ pub(in crate::hugr::rewrite) mod test { nu_inp, nu_out, }; + assert_eq!( + HashSet::<_>::from_iter(r.affected_nodes()), + HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]), + ); + h.apply_rewrite(r).unwrap(); // Expect [DFG] to be replaced with: // ┌───┐┌───┐ From a34662f0c0d373f10e18db3b91a698cb011aec97 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 11 Oct 2023 11:30:03 +0100 Subject: [PATCH 2/3] drive-by: add some inlines --- src/hugr/rewrite/simple_replace.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index b4b477475..c2c075c5d 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -30,6 +30,7 @@ pub struct SimpleReplacement { impl SimpleReplacement { /// Create a new [`SimpleReplacement`] specification. + #[inline] pub fn new( subgraph: SiblingSubgraph, replacement: Hugr, @@ -45,11 +46,13 @@ impl SimpleReplacement { } /// The replacement hugr. + #[inline] pub fn replacement(&self) -> &Hugr { &self.replacement } /// Subgraph to be replaced. + #[inline] pub fn subgraph(&self) -> &SiblingSubgraph { &self.subgraph } @@ -61,6 +64,7 @@ impl SimpleReplacement { /// /// Two `SimpleReplacement`s can be composed if their affected nodes are /// disjoint. + #[inline] pub fn affected_nodes(&self) -> impl Iterator + '_ { let subcirc = self.subgraph.nodes().iter().copied(); let out_neighs = self.nu_out.keys().map(|&(n, _)| n); From ffeb85ed29771d2aed02cd83c4dc4c6494975b63 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 11 Oct 2023 14:36:36 +0100 Subject: [PATCH 3/3] Reword docs, s/affected_nodes/invalidation_set/ --- src/hugr/rewrite/simple_replace.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index c2c075c5d..2f5d9f5fd 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -57,15 +57,13 @@ impl SimpleReplacement { &self.subgraph } - /// Returns the nodes affected by the replacement. - /// - /// This includes the nodes in the subgraph and the boundary neighbours that - /// are referenced by the replacement. + /// Returns a set of nodes referenced by the replacement. Modifying any + /// these nodes will invalidate the replacement. /// /// Two `SimpleReplacement`s can be composed if their affected nodes are /// disjoint. #[inline] - pub fn affected_nodes(&self) -> impl Iterator + '_ { + pub fn invalidation_set(&self) -> impl Iterator + '_ { let subcirc = self.subgraph.nodes().iter().copied(); let out_neighs = self.nu_out.keys().map(|&(n, _)| n); subcirc.chain(out_neighs) @@ -402,7 +400,7 @@ pub(in crate::hugr::rewrite) mod test { nu_out, }; assert_eq!( - HashSet::<_>::from_iter(r.affected_nodes()), + HashSet::<_>::from_iter(r.invalidation_set()), HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]), );