diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index c308f656b..9a5d48559 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -398,12 +398,12 @@ impl EdgeClassifier { pub(crate) mod test { use super::*; use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; - use crate::extension::prelude::USIZE_T; + use crate::extension::{prelude::USIZE_T, ExtensionSet}; use crate::hugr::views::{HierarchyView, SiblingGraph}; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; use crate::ops::Const; - use crate::types::Type; + use crate::types::{FunctionType, Type}; use crate::{type_row, Hugr}; const NAT: Type = USIZE_T; @@ -426,13 +426,13 @@ pub(crate) mod test { // /-> left --\ // entry -> split > merge -> head -> tail -> exit // \-> right -/ \-<--<-/ - let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; + let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?; let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1)?, + cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; @@ -611,7 +611,7 @@ pub(crate) mod test { unit_const: &ConstID, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { let split = n_identity( - cfg.simple_block_builder(type_row![NAT], type_row![NAT], 2)?, + cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?, const_pred, )?; let merge = build_then_else_merge_from_if(cfg, unit_const, split)?; @@ -624,15 +624,15 @@ pub(crate) mod test { split: BasicBlockID, ) -> Result { let merge = n_identity( - cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?, + cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, unit_const, )?; let left = n_identity( - cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?, + cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, unit_const, )?; let right = n_identity( - cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?, + cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, unit_const, )?; cfg.branch(&split, 0, &left)?; @@ -649,7 +649,7 @@ pub(crate) mod test { header: BasicBlockID, ) -> Result { let tail = n_identity( - cfg.simple_block_builder(type_row![NAT], type_row![NAT], 2)?, + cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?, const_pred, )?; cfg.branch(&tail, 1, &header)?; @@ -663,7 +663,7 @@ pub(crate) mod test { unit_const: &ConstID, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { let header = n_identity( - cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?, + cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, unit_const, )?; let tail = build_loop_from_header(cfg, const_pred, header)?; @@ -674,18 +674,19 @@ pub(crate) mod test { pub fn build_cond_then_loop_cfg( separate: bool, ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { - let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; + let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?; let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 2)?, + cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, &pred_const, )?; let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?; let head = if separate { let h = n_identity( - cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?, + cfg_builder + .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?, &const_unit, )?; cfg_builder.branch(&merge, 0, &h)?; @@ -708,13 +709,13 @@ pub(crate) mod test { ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { //let sum2_type = Type::new_predicate(2); - let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; + let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?; let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1)?, + cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 22c7267af..80bacadbd 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -315,6 +315,7 @@ pub trait Dataflow: Container { &mut self, inputs: impl IntoIterator, output_types: TypeRow, + extension_delta: ExtensionSet, ) -> Result, BuildError> { let (input_types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); @@ -325,8 +326,7 @@ pub trait Dataflow: Container { NodeType::open_extensions(ops::CFG { inputs: inputs.clone(), outputs: output_types.clone(), - // TODO: Make this a parameter - extension_delta: ExtensionSet::new(), + extension_delta, }), input_wires, )?; diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 027c8cab5..aa4fd1a99 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -57,20 +57,16 @@ impl + AsRef> SubContainer for CFGBuilder { impl CFGBuilder { /// New CFG rooted HUGR builder - pub fn new(input: impl Into, output: impl Into) -> Result { - let input = input.into(); - let output = output.into(); + pub fn new(signature: FunctionType) -> Result { let cfg_op = ops::CFG { - inputs: input.clone(), - outputs: output.clone(), - // TODO: Make this a parameter - extension_delta: ExtensionSet::new(), + inputs: signature.input.clone(), + outputs: signature.output.clone(), + extension_delta: signature.extension_reqs, }; - // TODO: Allow input extensions to be specified let base = Hugr::new(NodeType::open_extensions(cfg_op)); let cfg_node = base.root(); - CFGBuilder::create(base, cfg_node, input, output) + CFGBuilder::create(base, cfg_node, signature.input, signature.output) } } @@ -119,9 +115,16 @@ impl + AsRef> CFGBuilder { &mut self, inputs: TypeRow, predicate_variants: Vec, + extension_delta: ExtensionSet, other_outputs: TypeRow, ) -> Result, BuildError> { - self.any_block_builder(inputs, predicate_variants, other_outputs, false) + self.any_block_builder( + inputs, + predicate_variants, + other_outputs, + extension_delta, + false, + ) } fn any_block_builder( @@ -129,14 +132,14 @@ impl + AsRef> CFGBuilder { inputs: TypeRow, predicate_variants: Vec, other_outputs: TypeRow, + extension_delta: ExtensionSet, entry: bool, ) -> Result, BuildError> { let op = OpType::BasicBlock(BasicBlock::DFB { inputs: inputs.clone(), other_outputs: other_outputs.clone(), predicate_variants: predicate_variants.clone(), - // TODO: Make this a parameter - extension_delta: ExtensionSet::new(), + extension_delta, }); let parent = self.container_node(); let block_n = if entry { @@ -165,11 +168,15 @@ impl + AsRef> CFGBuilder { /// This function will return an error if there is an error adding the node. pub fn simple_block_builder( &mut self, - inputs: TypeRow, - outputs: TypeRow, + signature: FunctionType, n_cases: usize, ) -> Result, BuildError> { - self.block_builder(inputs, vec![type_row![]; n_cases], outputs) + self.block_builder( + signature.input, + vec![type_row![]; n_cases], + signature.extension_reqs, + signature.output, + ) } /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` @@ -183,12 +190,19 @@ impl + AsRef> CFGBuilder { &mut self, predicate_variants: Vec, other_outputs: TypeRow, + extension_delta: ExtensionSet, ) -> Result, BuildError> { let inputs = self .inputs .take() .ok_or(BuildError::EntryBuiltError(self.cfg_node))?; - self.any_block_builder(inputs, predicate_variants, other_outputs, true) + self.any_block_builder( + inputs, + predicate_variants, + other_outputs, + extension_delta, + true, + ) } /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` @@ -201,8 +215,9 @@ impl + AsRef> CFGBuilder { &mut self, outputs: TypeRow, n_cases: usize, + extension_delta: ExtensionSet, ) -> Result, BuildError> { - self.entry_builder(vec![type_row![]; n_cases], outputs) + self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta) } /// Returns the exit block of this [`CFGBuilder`]. @@ -276,6 +291,7 @@ impl BlockBuilder { inputs: impl Into, predicate_variants: impl IntoIterator, other_outputs: impl Into, + extension_delta: ExtensionSet, ) -> Result { let inputs = inputs.into(); let predicate_variants: Vec<_> = predicate_variants.into_iter().collect(); @@ -284,11 +300,9 @@ impl BlockBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), predicate_variants: predicate_variants.clone(), - // TODO: make this a parameter - extension_delta: ExtensionSet::new(), + extension_delta, }; - // TODO: Allow input extensions to be specified let base = Hugr::new(NodeType::open_extensions(op)); let root = base.root(); Self::create(base, root, predicate_variants, other_outputs, inputs) @@ -326,8 +340,11 @@ mod test { let [int] = func_builder.input_wires_arr(); let cfg_id = { - let mut cfg_builder = - func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?; + let mut cfg_builder = func_builder.cfg_builder( + vec![(NAT, int)], + type_row![NAT], + ExtensionSet::new(), + )?; build_basic_cfg(&mut cfg_builder)?; cfg_builder.finish_sub_container()? @@ -344,7 +361,7 @@ mod test { } #[test] fn basic_cfg_hugr() -> Result<(), BuildError> { - let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; + let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; build_basic_cfg(&mut cfg_builder)?; assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_)); @@ -355,14 +372,16 @@ mod test { cfg_builder: &mut CFGBuilder, ) -> Result<(), BuildError> { let sum2_variants = vec![type_row![NAT], type_row![NAT]]; - let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?; + let mut entry_b = + cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?; let entry = { let [inw] = entry_b.input_wires_arr(); let sum = entry_b.make_predicate(1, sum2_variants, [inw])?; entry_b.finish_with_outputs(sum, [])? }; - let mut middle_b = cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?; + let mut middle_b = cfg_builder + .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; let middle = { let c = middle_b.add_load_const(ops::Const::simple_unary_predicate())?; let [inw] = middle_b.input_wires_arr(); diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 44c467bb9..5052f8469 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; -use crate::extension::PRELUDE_REGISTRY; +use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; use crate::hugr::rewrite::Rewrite; use crate::hugr::{HugrMut, HugrView}; use crate::ops; @@ -26,10 +26,13 @@ impl OutlineCfg { } } - fn compute_entry_exit_outside( + /// Compute the entry and exit nodes of the CFG which contains + /// [`self.blocks`], along with the output neighbour its parent graph and + /// the combined extension_deltas of all of the blocks. + fn compute_entry_exit_outside_extensions( &self, h: &impl HugrView, - ) -> Result<(Node, Node, Node), OutlineCfgError> { + ) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> { let cfg_n = match self .blocks .iter() @@ -47,6 +50,7 @@ impl OutlineCfg { let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; let mut exit_succ = None; + let mut extension_delta = ExtensionSet::new(); for &n in self.blocks.iter() { if n == cfg_entry || h.input_neighbours(n) @@ -61,6 +65,7 @@ impl OutlineCfg { } } } + extension_delta = extension_delta.union(&o.signature().extension_reqs); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors @@ -76,7 +81,7 @@ impl OutlineCfg { }; } match (entry, exit_succ) { - (Some(e), Some((x, o))) => Ok((e, x, o)), + (Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)), (None, _) => Err(OutlineCfgError::NoEntryNode), (_, None) => Err(OutlineCfgError::NoExitNode), } @@ -89,11 +94,12 @@ impl Rewrite for OutlineCfg { const UNCHANGED_ON_FAILURE: bool = true; fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { - self.compute_entry_exit_outside(h)?; + self.compute_entry_exit_outside_extensions(h)?; Ok(()) } fn apply(self, h: &mut impl HugrMut) -> Result<(), OutlineCfgError> { - let (entry, exit, outside) = self.compute_entry_exit_outside(h)?; + let (entry, exit, outside, extension_delta) = + self.compute_entry_exit_outside_extensions(h)?; // 1. Compute signature // These panic()s only happen if the Hugr would not have passed validate() let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else { @@ -109,11 +115,19 @@ impl Rewrite for OutlineCfg { // 2. new_block contains input node, sub-cfg, exit node all connected let new_block = { - let mut new_block_bldr = - BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap(); + let mut new_block_bldr = BlockBuilder::new( + inputs.clone(), + vec![type_row![]], + outputs.clone(), + extension_delta.clone(), + ) + .unwrap(); let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires()); - let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap(); - cfg.exit_block(); // Makes inner exit block (but no entry block) + // N.B. By invoking the cfg_builder, we're forgetting any input + // extensions that may have existed on the original CFG. + let cfg = new_block_bldr + .cfg_builder(wires_in, outputs, extension_delta) + .unwrap(); let cfg_outputs = cfg.finish_sub_container().unwrap().outputs(); let predicate = new_block_bldr .add_constant(ops::Const::simple_unary_predicate())