diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index 8068748dac..275bb55e67 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -5,7 +5,10 @@ use super::{ BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, }; -use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}; +use crate::{ + extension::TO_BE_INFERRED, + ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}, +}; use crate::{ extension::{ExtensionRegistry, ExtensionSet}, types::FunctionType, @@ -43,7 +46,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// +------------+ /// */ /// use hugr::{ -/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder}, +/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, ft1, ft2}, /// extension::{prelude, ExtensionSet}, /// ops, type_row, /// types::{FunctionType, SumType, Type}, @@ -62,8 +65,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// /// // The second argument says what types will be passed through to every /// // successor, in addition to the appropriate `sum_variants` type. -/// let mut entry_b = -/// cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?; +/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?; /// /// let [inw] = entry_b.input_wires_arr(); /// let entry = { @@ -82,7 +84,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// // `NAT` arguments: one from the `sum_variants` type, and another from the /// // entry node's `other_outputs`. /// let mut successor_builder = cfg_builder.simple_block_builder( -/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]), +/// ft2(type_row![NAT, NAT], NAT), /// 1, // only one successor to this block /// )?; /// let successor_a = { @@ -96,8 +98,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// }; /// /// // The only argument to this block is the entry node's `other_outputs`. -/// let mut successor_builder = cfg_builder -/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; +/// let mut successor_builder = cfg_builder.simple_block_builder(ft1(NAT), 1)?; /// let successor_b = { /// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum()); /// let [in_wire] = successor_builder.input_wires_arr(); @@ -197,7 +198,7 @@ impl + AsRef> CFGBuilder { /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. + /// specified by `sum_rows`. Extension delta will be inferred. /// /// # Errors /// @@ -206,18 +207,40 @@ impl + AsRef> CFGBuilder { &mut self, inputs: TypeRow, sum_rows: impl IntoIterator, - extension_delta: ExtensionSet, other_outputs: TypeRow, ) -> Result, BuildError> { - self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, false) + self.block_builder_exts(inputs, sum_rows, TO_BE_INFERRED, other_outputs) } - fn any_block_builder( + /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` + /// and `outputs` and the variants of the branching Sum value + /// specified by `sum_rows`. Extension delta will be inferred. + /// + /// # Errors + /// + /// This function will return an error if there is an error adding the node. + pub fn block_builder_exts( &mut self, inputs: TypeRow, sum_rows: impl IntoIterator, + extension_delta: impl Into, other_outputs: TypeRow, + ) -> Result, BuildError> { + self.any_block_builder( + inputs, + extension_delta.into(), + sum_rows, + other_outputs, + false, + ) + } + + fn any_block_builder( + &mut self, + inputs: TypeRow, extension_delta: ExtensionSet, + sum_rows: impl IntoIterator, + other_outputs: TypeRow, entry: bool, ) -> Result, BuildError> { let sum_rows: Vec<_> = sum_rows.into_iter().collect(); @@ -241,7 +264,8 @@ impl + AsRef> CFGBuilder { } /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. + /// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type + /// (a Sum of `n_cases` unit types) to select the successor. /// /// # Errors /// @@ -251,7 +275,7 @@ impl + AsRef> CFGBuilder { signature: FunctionType, n_cases: usize, ) -> Result, BuildError> { - self.block_builder( + self.block_builder_exts( signature.input, vec![type_row![]; n_cases], signature.extension_reqs, @@ -259,6 +283,21 @@ impl + AsRef> CFGBuilder { ) } + /// Return a builder for the entry [`DataflowBlock`] child graph with + /// `outputs` and the variants of the branching Sum value + /// specified by `sum_rows`. + /// + /// # Errors + /// + /// This function will return an error if an entry block has already been built. + pub fn entry_builder( + &mut self, + sum_rows: impl IntoIterator, + other_outputs: TypeRow, + ) -> Result, BuildError> { + self.entry_builder_exts(TO_BE_INFERRED, sum_rows, other_outputs) + } + /// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching Sum value /// specified by `sum_rows`. @@ -266,17 +305,23 @@ impl + AsRef> CFGBuilder { /// # Errors /// /// This function will return an error if an entry block has already been built. - pub fn entry_builder( + pub fn entry_builder_exts( &mut self, + extension_delta: impl Into, sum_rows: impl IntoIterator, other_outputs: TypeRow, - extension_delta: ExtensionSet, ) -> Result, BuildError> { let inputs = self .inputs .take() .ok_or(BuildError::EntryBuiltError(self.cfg_node))?; - self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, true) + self.any_block_builder( + inputs, + extension_delta.into(), + sum_rows, + other_outputs, + true, + ) } /// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` @@ -289,9 +334,23 @@ impl + AsRef> CFGBuilder { &mut self, outputs: TypeRow, n_cases: usize, - extension_delta: ExtensionSet, ) -> Result, BuildError> { - self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta) + self.entry_builder(vec![type_row![]; n_cases], outputs) + } + + /// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` + /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. + /// + /// # Errors + /// + /// This function will return an error if there is an error adding the node. + pub fn simple_entry_builder_exts( + &mut self, + outputs: TypeRow, + n_cases: usize, + extension_delta: impl Into, + ) -> Result, BuildError> { + self.entry_builder_exts(extension_delta, vec![type_row![]; n_cases], outputs) } /// Returns the exit block of this [`CFGBuilder`]. @@ -439,8 +498,11 @@ pub(crate) mod test { cfg_builder: &mut CFGBuilder, ) -> Result<(), BuildError> { let sum2_variants = vec![type_row![NAT], type_row![NAT]]; - let mut entry_b = - cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?; + let mut entry_b = cfg_builder.entry_builder_exts( + ExtensionSet::new(), + sum2_variants.clone(), + type_row![], + )?; let entry = { let [inw] = entry_b.input_wires_arr(); @@ -466,16 +528,19 @@ pub(crate) mod test { let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); let sum_variants = vec![type_row![]]; - let mut entry_b = - cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?; + let mut entry_b = cfg_builder.entry_builder_exts( + ExtensionSet::new(), + sum_variants.clone(), + type_row![], + )?; let [inw] = entry_b.input_wires_arr(); let entry = { let sum = entry_b.load_const(&sum_tuple_const); entry_b.finish_with_outputs(sum, [])? }; - let mut middle_b = - cfg_builder.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?; + let mut middle_b = cfg_builder + .simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?; let middle = { let c = middle_b.load_const(&sum_tuple_const); middle_b.finish_with_outputs(c, [inw])? @@ -501,8 +566,7 @@ pub(crate) mod test { middle_b.finish_with_outputs(c, [inw])? }; - let mut entry_b = - cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?; + let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?; let entry = { let sum = entry_b.load_const(&sum_tuple_const); // entry block uses wire from middle block even though middle block diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index 0f206a2877..1d2ce2ee57 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -252,7 +252,7 @@ mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::USIZE_T; - use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use crate::extension::PRELUDE_REGISTRY; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::constant::Value; @@ -295,7 +295,7 @@ mod test { }; let entry = n_identity( - cfg_builder.simple_entry_builder(USIZE_T.into(), 2, ExtensionSet::new())?, + cfg_builder.simple_entry_builder(USIZE_T.into(), 2)?, &pred_const, )?; diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 09467720f7..fe40bd0cb2 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -628,7 +628,7 @@ mod test { }, op_sig.input() ); - h.simple_entry_builder(op_sig.output, 1, op_sig.extension_reqs.clone())? + h.simple_entry_builder_exts(op_sig.output, 1, op_sig.extension_reqs.clone())? } else { h.simple_block_builder(op_sig, 1)? }; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index a13e591463..f00aabe1a6 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -963,7 +963,7 @@ fn cfg_connections() -> Result<(), Box> { let mut hugr = CFGBuilder::new(FunctionType::new_endo(USIZE_T))?; let unary_pred = hugr.add_constant(Value::unary_unit_sum()); - let mut entry = hugr.simple_entry_builder(type_row![USIZE_T], 1, ExtensionSet::new())?; + let mut entry = hugr.simple_entry_builder_exts(type_row![USIZE_T], 1, ExtensionSet::new())?; let p = entry.load_const(&unary_pred); let ins = entry.input_wires(); let entry = entry.finish_with_outputs(p, ins)?; @@ -1068,7 +1068,7 @@ mod extension_tests { let mut cfg = CFGBuilder::new( FunctionType::new_endo(USIZE_T).with_extension_delta(parent_extensions.clone()), )?; - let mut bb = cfg.simple_entry_builder(USIZE_T.into(), 1, XB.into())?; + let mut bb = cfg.simple_entry_builder_exts(USIZE_T.into(), 1, XB)?; let pred = bb.add_load_value(Value::unary_unit_sum()); let inputs = bb.input_wires(); let blk = bb.finish_with_outputs(pred, inputs)?; diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index ba0f4fd97e..e9b715f805 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -161,9 +161,9 @@ mod test { use itertools::Itertools; use rstest::rstest; - use hugr_core::builder::{ft2, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; + use hugr_core::builder::{ft1, ft2, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; - use hugr_core::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY, TO_BE_INFERRED}; + use hugr_core::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}; use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; use hugr_core::ops::handle::CfgID; @@ -224,14 +224,13 @@ mod test { let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e])?; let mut h = CFGBuilder::new(ft2(loop_variants.clone(), exit_types.clone()))?; - let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, PRELUDE_ID.into())?; + let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; let mut test_block = h.block_builder( loop_variants.clone(), vec![loop_variants.clone(), exit_types], - TO_BE_INFERRED.into(), // TODO infer by default type_row![], )?; let [test_input] = test_block.input_wires_arr(); @@ -243,10 +242,7 @@ mod test { let loop_backedge_target = if self_loop { no_b1 } else { - let mut no_b2 = h.simple_block_builder( - FunctionType::new_endo(loop_variants).with_extension_delta(TO_BE_INFERRED), // TODO infer by default - 1, - )?; + let mut no_b2 = h.simple_block_builder(ft1(loop_variants), 1)?; let n = no_b2.add_dataflow_op(Noop::new(QB_T), no_b2.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b2); let nid = no_b2.finish_with_outputs(br, n.outputs())?; @@ -322,14 +318,8 @@ mod test { .into_owned() .try_into() .unwrap(); - let mut h = CFGBuilder::new( - ft2(QB_T, res_t.clone()) - )?; - let mut bb1 = h.simple_entry_builder( - type_row![USIZE_T, QB_T], - 1, - TO_BE_INFERRED.into(), // TODO by default - )?; + let mut h = CFGBuilder::new(ft2(QB_T, res_t.clone()))?; + let mut bb1 = h.simple_entry_builder(type_row![USIZE_T, QB_T], 1)?; let [inw] = bb1.input_wires_arr(); let load_cst = bb1.add_load_value(ConstUsize::new(1)); let pred = lifted_unary_unit_sum(&mut bb1); @@ -338,7 +328,6 @@ mod test { let mut bb2 = h.block_builder( type_row![USIZE_T, QB_T], vec![type_row![]], - TO_BE_INFERRED.into(), type_row![QB_T, USIZE_T], )?; let [u, q] = bb2.input_wires_arr(); @@ -348,7 +337,6 @@ mod test { let mut bb3 = h.block_builder( type_row![QB_T, USIZE_T], vec![type_row![]], - TO_BE_INFERRED.into(), res_t.clone().into(), )?; let [q, u] = bb3.input_wires_arr(); diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 8753464d0f..8dc0e25a3c 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -575,7 +575,7 @@ impl EdgeClassifier { pub(crate) mod test { use super::*; use hugr_core::builder::{ - BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, + ft1, BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::extension::{prelude::USIZE_T, ExtensionSet}; @@ -614,17 +614,17 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, + cfg_builder.simple_entry_builder_exts(type_row![NAT], 1, ExtensionSet::new())?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; cfg_builder.branch(&entry, 0, &split)?; let head = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + cfg_builder.simple_block_builder(ft1(NAT), 1)?, &const_unit, )?; let tail = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + cfg_builder.simple_block_builder(ft1(NAT), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &head)?; @@ -851,10 +851,7 @@ pub(crate) mod test { const_pred: &ConstID, unit_const: &ConstID, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { - let split = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 2)?, - const_pred, - )?; + let split = n_identity(cfg.simple_block_builder(ft1(NAT), 2)?, const_pred)?; let merge = build_then_else_merge_from_if(cfg, unit_const, split)?; Ok((split, merge)) } @@ -864,18 +861,9 @@ pub(crate) mod test { unit_const: &ConstID, split: BasicBlockID, ) -> Result { - let merge = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - unit_const, - )?; - let left = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - unit_const, - )?; - let right = n_identity( - cfg.simple_block_builder(FunctionType::new_endo(NAT), 1)?, - unit_const, - )?; + let merge = n_identity(cfg.simple_block_builder(ft1(NAT), 1)?, unit_const)?; + let left = n_identity(cfg.simple_block_builder(ft1(NAT), 1)?, unit_const)?; + let right = n_identity(cfg.simple_block_builder(ft1(NAT), 1)?, unit_const)?; cfg.branch(&split, 0, &left)?; cfg.branch(&split, 1, &right)?; cfg.branch(&left, 0, &merge)?; @@ -893,13 +881,13 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?, + cfg_builder.simple_entry_builder(type_row![NAT], 2)?, &pred_const, )?; let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?; // The merge block is also the loop header (so it merges three incoming control-flow edges) let tail = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + cfg_builder.simple_block_builder(ft1(NAT), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &merge)?; @@ -929,14 +917,14 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?, + cfg_builder.simple_entry_builder(type_row![NAT], 1)?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(cfg_builder, &pred_const, &const_unit)?; let head = if separate_headers { let head = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 1)?, + cfg_builder.simple_block_builder(ft1(NAT), 1)?, &const_unit, )?; cfg_builder.branch(&head, 0, &split)?; @@ -946,7 +934,7 @@ pub(crate) mod test { split }; let tail = n_identity( - cfg_builder.simple_block_builder(FunctionType::new_endo(NAT), 2)?, + cfg_builder.simple_block_builder(ft1(NAT), 2)?, &pred_const, )?; cfg_builder.branch(&tail, 1, &head)?;