Skip to content

Commit

Permalink
fix: add a MutableIO trait on Alan's suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
qartik committed Jan 10, 2025
1 parent 577971c commit bdc709b
Showing 1 changed file with 118 additions and 103 deletions.
221 changes: 118 additions & 103 deletions hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use delegate::delegate;
use itertools::Itertools;

use super::build_traits::{HugrBuilder, SubContainer};
use super::handle::BuildHandle;
use super::{BuildError, Container, Dataflow, DfgID, FuncID};

use itertools::Itertools;
use std::marker::PhantomData;

use crate::hugr::internal::HugrMutInternals;
Expand All @@ -18,77 +16,21 @@ use crate::types::{PolyFuncType, Signature, Type};
use crate::Node;
use crate::{hugr::HugrMut, Hugr};

/// Builder for a [`ops::DFG`] node.
#[derive(Debug, Clone, PartialEq)]
pub struct DFGBuilder<T> {
pub(crate) base: T,
pub(crate) dfg_node: Node,
pub(crate) num_in_wires: usize,
pub(crate) num_out_wires: usize,
}

impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
pub(super) fn create_with_io(
mut base: T,
parent: Node,
signature: Signature,
) -> Result<Self, BuildError> {
let num_in_wires = signature.input().len();
let num_out_wires = signature.output().len();
/* For a given dataflow graph with extension requirements IR -> IR + dR,
- The output node's extension requirements are IR + dR -> IR + dR
(but we expect no output wires)
- The input node's extension requirements are IR -> IR, though we
expect no input wires. We must avoid the case where the difference
in extensions is an open variable, as it would be if the requirements
were 0 -> IR.
N.B. This means that for input nodes, we can't infer the extensions
from the input wires as we normally expect, but have to infer the
output wires and make use of the equality between the two.
*/
let input = ops::Input {
types: signature.input().clone(),
};
let output = ops::Output {
types: signature.output().clone(),
};
base.as_mut().add_node_with_parent(parent, input);
base.as_mut().add_node_with_parent(parent, output);

Ok(Self {
base,
dfg_node: parent,
num_in_wires,
num_out_wires,
})
}
}
pub trait MutableIO: Dataflow {
fn num_in_wires_mut(&mut self) -> &mut usize;
fn num_out_wires_mut(&mut self) -> &mut usize;

impl DFGBuilder<Hugr> {
/// Begin building a new DFG-rooted HUGR given its inputs, outputs,
/// and extension delta.
///
/// # Errors
///
/// Error in adding DFG child nodes.
pub fn new(signature: Signature) -> Result<DFGBuilder<Hugr>, BuildError> {
let dfg_op = ops::DFG {
signature: signature.clone(),
};
let base = Hugr::new(dfg_op);
let root = base.root();
DFGBuilder::create_with_io(base, root, signature)
}
// TODO: define a generic signature for this function
fn update_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn;

/// Add a new input to the DFG being constructed.
///
/// Returns the new wire from the input node.
pub fn add_input(&mut self, input_type: Type) -> Wire {
fn add_input(&mut self, input_type: Type) -> Wire {
let [inp_node, _] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
let new_optype = self.update_signature(|mut s| {
s.input.to_mut().push(input_type);
s
});
Expand All @@ -115,17 +57,17 @@ impl DFGBuilder<Hugr> {
}

// Update the builder metadata
self.num_in_wires += 1;
*self.num_in_wires_mut() += 1;

self.input_wires().last().unwrap()
}

/// Add a new output to the DFG being constructed.
pub fn add_output(&mut self, output_type: Type) {
fn add_output(&mut self, output_type: Type) {
let [_, out_node] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
let new_optype = self.update_signature(|mut s| {
s.output.to_mut().push(output_type);
s
});
Expand All @@ -152,37 +94,84 @@ impl DFGBuilder<Hugr> {
}

// Update the builder metadata
self.num_out_wires += 1;
*self.num_out_wires_mut() += 1;
}
}

/// Update the DFG builder's parent signature.
///
/// Internal function used in [add_input] and [add_output].
/// Builder for a [`ops::DFG`] node.
#[derive(Debug, Clone, PartialEq)]
pub struct DFGBuilder<T> {
pub(crate) base: T,
pub(crate) dfg_node: Node,
pub(crate) num_in_wires: usize,
pub(crate) num_out_wires: usize,
}

impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
pub(super) fn create_with_io(
mut base: T,
parent: Node,
signature: Signature,
) -> Result<Self, BuildError> {
let num_in_wires = signature.input().len();
let num_out_wires = signature.output().len();
/* For a given dataflow graph with extension requirements IR -> IR + dR,
- The output node's extension requirements are IR + dR -> IR + dR
(but we expect no output wires)
- The input node's extension requirements are IR -> IR, though we
expect no input wires. We must avoid the case where the difference
in extensions is an open variable, as it would be if the requirements
were 0 -> IR.
N.B. This means that for input nodes, we can't infer the extensions
from the input wires as we normally expect, but have to infer the
output wires and make use of the equality between the two.
*/
let input = ops::Input {
types: signature.input().clone(),
};
let output = ops::Output {
types: signature.output().clone(),
};
base.as_mut().add_node_with_parent(parent, input);
base.as_mut().add_node_with_parent(parent, output);

Ok(Self {
base,
dfg_node: parent,
num_in_wires,
num_out_wires,
})
}
}

impl DFGBuilder<Hugr> {
/// Begin building a new DFG-rooted HUGR given its inputs, outputs,
/// and extension delta.
///
/// Does not update the input and output nodes.
/// # Errors
///
/// Returns a reference to the new optype.
fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
let parent = self.container_node();
let old_optype = self
.hugr()
.get_optype(parent)
.as_func_defn()
.expect("FunctionBuilder node must be a FuncDefn");
let signature = old_optype.inner_signature().into_owned();
let name = old_optype.name.clone();
self.hugr_mut()
.replace_op(
parent,
ops::FuncDefn {
signature: f(signature).into(),
name,
},
)
.expect("Could not replace FunctionBuilder operation");
/// Error in adding DFG child nodes.
pub fn new(signature: Signature) -> Result<DFGBuilder<Hugr>, BuildError> {
let dfg_op = ops::DFG {
signature: signature.clone(),
};
let base = Hugr::new(dfg_op);
let root = base.root();
DFGBuilder::create_with_io(base, root, signature)
}
}

self.hugr().get_optype(parent).as_func_defn().unwrap()
impl MutableIO for DFGBuilder<Hugr> {
fn num_in_wires_mut(&mut self) -> &mut usize {
&mut self.num_in_wires
}

fn num_out_wires_mut(&mut self) -> &mut usize {
&mut self.num_out_wires
}

// fn update_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {}
}

impl HugrBuilder for DFGBuilder<Hugr> {
Expand Down Expand Up @@ -263,17 +252,43 @@ impl FunctionBuilder<Hugr> {
let db = DFGBuilder::create_with_io(base, root, body)?;
Ok(Self::from_dfg_builder(db))
}
}

delegate! {
to self.0 {
/// Add a new input to the function being constructed.
///
/// Returns the new wire from the input node.
pub fn add_input(&mut self, input_type: Type) -> Wire;
impl MutableIO for FunctionBuilder<Hugr> {
fn num_in_wires_mut(&mut self) -> &mut usize {
&mut self.0.num_in_wires
}
fn num_out_wires_mut(&mut self) -> &mut usize {
&mut self.0.num_out_wires
}

/// Add a new input to the function being constructed.
pub fn add_output(&mut self, output_type: Type);
}
/// Update the function builder's parent signature.
///
/// Internal function used in [add_input] and [add_output].
///
/// Does not update the input and output nodes.
///
/// Returns a reference to the new optype.
fn update_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
let parent = self.container_node();
let old_optype = self
.hugr()
.get_optype(parent)
.as_func_defn()
.expect("FunctionBuilder node must be a FuncDefn");
let signature = old_optype.inner_signature().into_owned();
let name = old_optype.name.clone();
self.hugr_mut()
.replace_op(
parent,
ops::FuncDefn {
signature: f(signature).into(),
name,
},
)
.expect("Could not replace FunctionBuilder operation");

self.hugr().get_optype(parent).as_func_defn().unwrap()
}
}

Expand Down

0 comments on commit bdc709b

Please sign in to comment.