Skip to content

Commit

Permalink
(simple_)block_builder variants, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Jun 25, 2024
1 parent e10a94e commit c0d6432
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 74 deletions.
116 changes: 90 additions & 26 deletions hugr-core/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType};
use crate::{
extension::TO_BE_INFERRED,
ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType},
};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
Expand Down Expand Up @@ -43,7 +46,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// +------------+
/// */
/// use hugr::{
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, ft1, ft2},
/// extension::{prelude, ExtensionSet},
/// ops, type_row,
/// types::{FunctionType, SumType, Type},
Expand All @@ -62,8 +65,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
///
/// // The second argument says what types will be passed through to every
/// // successor, in addition to the appropriate `sum_variants` type.
/// let mut entry_b =
/// cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
///
/// let [inw] = entry_b.input_wires_arr();
/// let entry = {
Expand All @@ -82,7 +84,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// // `NAT` arguments: one from the `sum_variants` type, and another from the
/// // entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// ft2(type_row![NAT, NAT], NAT),
/// 1, // only one successor to this block
/// )?;
/// let successor_a = {
Expand All @@ -96,8 +98,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// };
///
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder
/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let mut successor_builder = cfg_builder.simple_block_builder(ft1(NAT), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
Expand Down Expand Up @@ -197,7 +198,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
/// specified by `sum_rows`. Extension delta will be inferred.
///
/// # Errors
///
Expand All @@ -206,18 +207,40 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
inputs: TypeRow,
sum_rows: impl IntoIterator<Item = TypeRow>,
extension_delta: ExtensionSet,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, false)
self.block_builder_exts(inputs, sum_rows, TO_BE_INFERRED, other_outputs)
}

fn any_block_builder(
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`. Extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn block_builder_exts(
&mut self,
inputs: TypeRow,
sum_rows: impl IntoIterator<Item = TypeRow>,
extension_delta: impl Into<ExtensionSet>,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(
inputs,
extension_delta.into(),
sum_rows,
other_outputs,
false,
)
}

fn any_block_builder(
&mut self,
inputs: TypeRow,
extension_delta: ExtensionSet,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let sum_rows: Vec<_> = sum_rows.into_iter().collect();
Expand All @@ -241,7 +264,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
/// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type
/// (a Sum of `n_cases` unit types) to select the successor.
///
/// # Errors
///
Expand All @@ -251,32 +275,53 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
signature: FunctionType,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.block_builder(
self.block_builder_exts(
signature.input,
vec![type_row![]; n_cases],
signature.extension_reqs,
signature.output,
)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with
/// `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
///
/// # Errors
///
/// This function will return an error if an entry block has already been built.
pub fn entry_builder(
&mut self,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder_exts(TO_BE_INFERRED, sum_rows, other_outputs)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
///
/// # Errors
///
/// This function will return an error if an entry block has already been built.
pub fn entry_builder(
pub fn entry_builder_exts(
&mut self,
extension_delta: impl Into<ExtensionSet>,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
.take()
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, true)
self.any_block_builder(
inputs,
extension_delta.into(),
sum_rows,
other_outputs,
true,
)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
Expand All @@ -289,9 +334,23 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta)
self.entry_builder(vec![type_row![]; n_cases], outputs)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn simple_entry_builder_exts(
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder_exts(extension_delta, vec![type_row![]; n_cases], outputs)
}

/// Returns the exit block of this [`CFGBuilder`].
Expand Down Expand Up @@ -439,8 +498,11 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let mut entry_b =
cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder_exts(
ExtensionSet::new(),
sum2_variants.clone(),
type_row![],
)?;
let entry = {
let [inw] = entry_b.input_wires_arr();

Expand All @@ -466,16 +528,19 @@ pub(crate) mod test {
let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
let sum_variants = vec![type_row![]];

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder_exts(
ExtensionSet::new(),
sum_variants.clone(),
type_row![],
)?;
let [inw] = entry_b.input_wires_arr();
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);

entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b =
cfg_builder.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?;
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.load_const(&sum_tuple_const);
middle_b.finish_with_outputs(c, [inw])?
Expand All @@ -501,8 +566,7 @@ pub(crate) mod test {
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);
// entry block uses wire from middle block even though middle block
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ mod test {
HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::USIZE_T;
use crate::extension::{ExtensionSet, PRELUDE_REGISTRY};
use crate::extension::PRELUDE_REGISTRY;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::HugrMut;
use crate::ops::constant::Value;
Expand Down Expand Up @@ -295,7 +295,7 @@ mod test {
};

let entry = n_identity(
cfg_builder.simple_entry_builder(USIZE_T.into(), 2, ExtensionSet::new())?,
cfg_builder.simple_entry_builder(USIZE_T.into(), 2)?,
&pred_const,
)?;

Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ mod test {
},
op_sig.input()
);
h.simple_entry_builder(op_sig.output, 1, op_sig.extension_reqs.clone())?
h.simple_entry_builder_exts(op_sig.output, 1, op_sig.extension_reqs.clone())?
} else {
h.simple_block_builder(op_sig, 1)?
};
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ fn cfg_connections() -> Result<(), Box<dyn std::error::Error>> {

let mut hugr = CFGBuilder::new(FunctionType::new_endo(USIZE_T))?;
let unary_pred = hugr.add_constant(Value::unary_unit_sum());
let mut entry = hugr.simple_entry_builder(type_row![USIZE_T], 1, ExtensionSet::new())?;
let mut entry = hugr.simple_entry_builder_exts(type_row![USIZE_T], 1, ExtensionSet::new())?;
let p = entry.load_const(&unary_pred);
let ins = entry.input_wires();
let entry = entry.finish_with_outputs(p, ins)?;
Expand Down Expand Up @@ -1068,7 +1068,7 @@ mod extension_tests {
let mut cfg = CFGBuilder::new(
FunctionType::new_endo(USIZE_T).with_extension_delta(parent_extensions.clone()),
)?;
let mut bb = cfg.simple_entry_builder(USIZE_T.into(), 1, XB.into())?;
let mut bb = cfg.simple_entry_builder_exts(USIZE_T.into(), 1, XB)?;
let pred = bb.add_load_value(Value::unary_unit_sum());
let inputs = bb.input_wires();
let blk = bb.finish_with_outputs(pred, inputs)?;
Expand Down
24 changes: 6 additions & 18 deletions hugr-passes/src/merge_bbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ mod test {
use itertools::Itertools;
use rstest::rstest;

use hugr_core::builder::{ft2, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder};
use hugr_core::builder::{ft1, ft2, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder};
use hugr_core::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T};
use hugr_core::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY, TO_BE_INFERRED};
use hugr_core::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY};
use hugr_core::hugr::views::sibling::SiblingMut;
use hugr_core::ops::constant::Value;
use hugr_core::ops::handle::CfgID;
Expand Down Expand Up @@ -224,14 +224,13 @@ mod test {
let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?;
let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e])?;
let mut h = CFGBuilder::new(ft2(loop_variants.clone(), exit_types.clone()))?;
let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, PRELUDE_ID.into())?;
let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?;
let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?;
let br = lifted_unary_unit_sum(&mut no_b1);
let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?;
let mut test_block = h.block_builder(
loop_variants.clone(),
vec![loop_variants.clone(), exit_types],
TO_BE_INFERRED.into(), // TODO infer by default
type_row![],
)?;
let [test_input] = test_block.input_wires_arr();
Expand All @@ -243,10 +242,7 @@ mod test {
let loop_backedge_target = if self_loop {
no_b1
} else {
let mut no_b2 = h.simple_block_builder(
FunctionType::new_endo(loop_variants).with_extension_delta(TO_BE_INFERRED), // TODO infer by default
1,
)?;
let mut no_b2 = h.simple_block_builder(ft1(loop_variants), 1)?;
let n = no_b2.add_dataflow_op(Noop::new(QB_T), no_b2.input_wires())?;
let br = lifted_unary_unit_sum(&mut no_b2);
let nid = no_b2.finish_with_outputs(br, n.outputs())?;
Expand Down Expand Up @@ -322,14 +318,8 @@ mod test {
.into_owned()
.try_into()
.unwrap();
let mut h = CFGBuilder::new(
ft2(QB_T, res_t.clone())
)?;
let mut bb1 = h.simple_entry_builder(
type_row![USIZE_T, QB_T],
1,
TO_BE_INFERRED.into(), // TODO by default
)?;
let mut h = CFGBuilder::new(ft2(QB_T, res_t.clone()))?;
let mut bb1 = h.simple_entry_builder(type_row![USIZE_T, QB_T], 1)?;
let [inw] = bb1.input_wires_arr();
let load_cst = bb1.add_load_value(ConstUsize::new(1));
let pred = lifted_unary_unit_sum(&mut bb1);
Expand All @@ -338,7 +328,6 @@ mod test {
let mut bb2 = h.block_builder(
type_row![USIZE_T, QB_T],
vec![type_row![]],
TO_BE_INFERRED.into(),
type_row![QB_T, USIZE_T],
)?;
let [u, q] = bb2.input_wires_arr();
Expand All @@ -348,7 +337,6 @@ mod test {
let mut bb3 = h.block_builder(
type_row![QB_T, USIZE_T],
vec![type_row![]],
TO_BE_INFERRED.into(),
res_t.clone().into(),
)?;
let [q, u] = bb3.input_wires_arr();
Expand Down
Loading

0 comments on commit c0d6432

Please sign in to comment.