Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make extension delta a parameter of CFG builders #514

Merged
merged 4 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,12 @@ impl<T: Copy + Clone + PartialEq + Eq + Hash> EdgeClassifier<T> {
pub(crate) mod test {
use super::*;
use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder};
use crate::extension::prelude::USIZE_T;
use crate::extension::{prelude::USIZE_T, ExtensionSet};

use crate::hugr::views::{HierarchyView, SiblingGraph};
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::Const;
use crate::types::Type;
use crate::types::{FunctionType, Type};
use crate::{type_row, Hugr};
const NAT: Type = USIZE_T;

Expand All @@ -426,13 +426,13 @@ pub(crate) mod test {
// /-> left --\
// entry -> split > merge -> head -> tail -> exit
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1)?,
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?;
Expand Down Expand Up @@ -611,7 +611,7 @@ pub(crate) mod test {
unit_const: &ConstID,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let split = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 2)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?,
const_pred,
)?;
let merge = build_then_else_merge_from_if(cfg, unit_const, split)?;
Expand All @@ -624,15 +624,15 @@ pub(crate) mod test {
split: BasicBlockID,
) -> Result<BasicBlockID, BuildError> {
let merge = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
let left = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
let right = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
cfg.branch(&split, 0, &left)?;
Expand All @@ -649,7 +649,7 @@ pub(crate) mod test {
header: BasicBlockID,
) -> Result<BasicBlockID, BuildError> {
let tail = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 2)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 2)?,
const_pred,
)?;
cfg.branch(&tail, 1, &header)?;
Expand All @@ -663,7 +663,7 @@ pub(crate) mod test {
unit_const: &ConstID,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let header = n_identity(
cfg.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
unit_const,
)?;
let tail = build_loop_from_header(cfg, const_pred, header)?;
Expand All @@ -674,18 +674,19 @@ pub(crate) mod test {
pub fn build_cond_then_loop_cfg(
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2)?,
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
&pred_const,
)?;
let merge = build_then_else_merge_from_if(&mut cfg_builder, &const_unit, entry)?;
let head = if separate {
let h = n_identity(
cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?,
cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?,
&const_unit,
)?;
cfg_builder.branch(&merge, 0, &h)?;
Expand All @@ -708,13 +709,13 @@ pub(crate) mod test {
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
//let sum2_type = Type::new_predicate(2);

let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const = cfg_builder.add_constant(Const::simple_predicate(0, 2))?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1)?,
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
&const_unit,
)?;
let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?;
Expand Down
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ pub trait Dataflow: Container {
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();

Expand All @@ -325,8 +326,7 @@ pub trait Dataflow: Container {
NodeType::open_extensions(ops::CFG {
inputs: inputs.clone(),
outputs: output_types.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
extension_delta,
}),
input_wires,
)?;
Expand Down
69 changes: 44 additions & 25 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,16 @@ impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {

impl CFGBuilder<Hugr> {
/// New CFG rooted HUGR builder
pub fn new(input: impl Into<TypeRow>, output: impl Into<TypeRow>) -> Result<Self, BuildError> {
let input = input.into();
let output = output.into();
pub fn new(signature: FunctionType) -> Result<Self, BuildError> {
let cfg_op = ops::CFG {
inputs: input.clone(),
outputs: output.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
inputs: signature.input.clone(),
outputs: signature.output.clone(),
extension_delta: signature.extension_reqs,
};

// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(cfg_op));
let cfg_node = base.root();
CFGBuilder::create(base, cfg_node, input, output)
CFGBuilder::create(base, cfg_node, signature.input, signature.output)
}
}

Expand Down Expand Up @@ -119,24 +115,31 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
extension_delta: ExtensionSet,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(inputs, predicate_variants, other_outputs, false)
self.any_block_builder(
inputs,
predicate_variants,
other_outputs,
extension_delta,
false,
)
}

fn any_block_builder(
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let op = OpType::BasicBlock(BasicBlock::DFB {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
extension_delta,
});
let parent = self.container_node();
let block_n = if entry {
Expand Down Expand Up @@ -165,11 +168,15 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
/// This function will return an error if there is an error adding the node.
pub fn simple_block_builder(
&mut self,
inputs: TypeRow,
outputs: TypeRow,
signature: FunctionType,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.block_builder(inputs, vec![type_row![]; n_cases], outputs)
self.block_builder(
signature.input,
vec![type_row![]; n_cases],
signature.extension_reqs,
signature.output,
)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
Expand All @@ -183,12 +190,19 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
predicate_variants: Vec<TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
.take()
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
self.any_block_builder(inputs, predicate_variants, other_outputs, true)
self.any_block_builder(
inputs,
predicate_variants,
other_outputs,
extension_delta,
true,
)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
Expand All @@ -201,8 +215,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder(vec![type_row![]; n_cases], outputs)
self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta)
}

/// Returns the exit block of this [`CFGBuilder`].
Expand Down Expand Up @@ -276,6 +291,7 @@ impl BlockBuilder<Hugr> {
inputs: impl Into<TypeRow>,
predicate_variants: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
let inputs = inputs.into();
let predicate_variants: Vec<_> = predicate_variants.into_iter().collect();
Expand All @@ -284,11 +300,9 @@ impl BlockBuilder<Hugr> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: make this a parameter
extension_delta: ExtensionSet::new(),
extension_delta,
};

// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::open_extensions(op));
let root = base.root();
Self::create(base, root, predicate_variants, other_outputs, inputs)
Expand Down Expand Up @@ -326,8 +340,11 @@ mod test {
let [int] = func_builder.input_wires_arr();

let cfg_id = {
let mut cfg_builder =
func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?;
let mut cfg_builder = func_builder.cfg_builder(
vec![(NAT, int)],
type_row![NAT],
ExtensionSet::new(),
)?;
build_basic_cfg(&mut cfg_builder)?;

cfg_builder.finish_sub_container()?
Expand All @@ -344,7 +361,7 @@ mod test {
}
#[test]
fn basic_cfg_hugr() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
build_basic_cfg(&mut cfg_builder)?;
assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_));

Expand All @@ -355,14 +372,16 @@ mod test {
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?;
let mut entry_b =
cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?;
let entry = {
let [inw] = entry_b.input_wires_arr();

let sum = entry_b.make_predicate(1, sum2_variants, [inw])?;
entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b = cfg_builder.simple_block_builder(type_row![NAT], type_row![NAT], 1)?;
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Const::simple_unary_predicate())?;
let [inw] = middle_b.input_wires_arr();
Expand Down
34 changes: 24 additions & 10 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use itertools::Itertools;
use thiserror::Error;

use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::extension::PRELUDE_REGISTRY;
use crate::extension::{ExtensionSet, PRELUDE_REGISTRY};
use crate::hugr::rewrite::Rewrite;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
Expand All @@ -26,10 +26,13 @@ impl OutlineCfg {
}
}

fn compute_entry_exit_outside(
/// Compute the entry and exit nodes of the CFG which contains
/// [`self.blocks`], along with the output neighbour its parent graph and
/// the combined extension_deltas of all of the blocks.
fn compute_entry_exit_outside_extensions(
&self,
h: &impl HugrView,
) -> Result<(Node, Node, Node), OutlineCfgError> {
) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> {
let cfg_n = match self
.blocks
.iter()
Expand All @@ -47,6 +50,7 @@ impl OutlineCfg {
let cfg_entry = h.children(cfg_n).next().unwrap();
let mut entry = None;
let mut exit_succ = None;
let mut extension_delta = ExtensionSet::new();
for &n in self.blocks.iter() {
if n == cfg_entry
|| h.input_neighbours(n)
Expand All @@ -61,6 +65,7 @@ impl OutlineCfg {
}
}
}
extension_delta = extension_delta.union(&o.signature().extension_reqs);
let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
match external_succs.at_most_one() {
Ok(None) => (), // No external successors
Expand All @@ -76,7 +81,7 @@ impl OutlineCfg {
};
}
match (entry, exit_succ) {
(Some(e), Some((x, o))) => Ok((e, x, o)),
(Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)),
(None, _) => Err(OutlineCfgError::NoEntryNode),
(_, None) => Err(OutlineCfgError::NoExitNode),
}
Expand All @@ -89,11 +94,12 @@ impl Rewrite for OutlineCfg {

const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
self.compute_entry_exit_outside(h)?;
self.compute_entry_exit_outside_extensions(h)?;
Ok(())
}
fn apply(self, h: &mut impl HugrMut) -> Result<(), OutlineCfgError> {
let (entry, exit, outside) = self.compute_entry_exit_outside(h)?;
let (entry, exit, outside, extension_delta) =
self.compute_entry_exit_outside_extensions(h)?;
// 1. Compute signature
// These panic()s only happen if the Hugr would not have passed validate()
let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else {
Expand All @@ -109,11 +115,19 @@ impl Rewrite for OutlineCfg {

// 2. new_block contains input node, sub-cfg, exit node all connected
let new_block = {
let mut new_block_bldr =
BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
let mut new_block_bldr = BlockBuilder::new(
inputs.clone(),
vec![type_row![]],
outputs.clone(),
extension_delta.clone(),
)
.unwrap();
let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
cfg.exit_block(); // Makes inner exit block (but no entry block)
// N.B. By invoking the cfg_builder, we're forgetting any input
// extensions that may have existed on the original CFG.
let cfg = new_block_bldr
.cfg_builder(wires_in, outputs, extension_delta)
.unwrap();
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let predicate = new_block_bldr
.add_constant(ops::Const::simple_unary_predicate())
Expand Down