Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow add_input/add_output on all DFGs #1824

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 102 additions & 74 deletions hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use itertools::Itertools;

use super::build_traits::{HugrBuilder, SubContainer};
use super::handle::BuildHandle;
use super::{BuildError, Container, Dataflow, DfgID, FuncID};

use itertools::Itertools;
use std::marker::PhantomData;

use crate::hugr::internal::HugrMutInternals;
Expand All @@ -17,6 +16,88 @@ use crate::types::{PolyFuncType, Signature, Type};
use crate::Node;
use crate::{hugr::HugrMut, Hugr};

pub trait MutableIO: Dataflow {
fn num_in_wires_mut(&mut self) -> &mut usize;
fn num_out_wires_mut(&mut self) -> &mut usize;

// TODO: define a generic signature for this function
fn update_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn;

/// Add a new input to the DFG being constructed.
///
/// Returns the new wire from the input node.
fn add_input(&mut self, input_type: Type) -> Wire {
let [inp_node, _] = self.io();

// Update the parent's root type
let new_optype = self.update_signature(|mut s| {
s.input.to_mut().push(input_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().input.clone();
self.hugr_mut()
.replace_op(inp_node, Input { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: OutgoingPort = (new_port - 1).into();
let new_order_port: OutgoingPort = new_port.into();
let order_edge_targets = self
.hugr()
.linked_inputs(inp_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(inp_node, new_value_port);
for (tgt_node, tgt_port) in order_edge_targets {
self.hugr_mut()
.connect(inp_node, new_order_port, tgt_node, tgt_port);
}

// Update the builder metadata
*self.num_in_wires_mut() += 1;

self.input_wires().last().unwrap()
}

/// Add a new output to the DFG being constructed.
fn add_output(&mut self, output_type: Type) {
let [_, out_node] = self.io();

// Update the parent's root type
let new_optype = self.update_signature(|mut s| {
s.output.to_mut().push(output_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().output.clone();
self.hugr_mut()
.replace_op(out_node, Output { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: IncomingPort = (new_port - 1).into();
let new_order_port: IncomingPort = new_port.into();
let order_edge_sources = self
.hugr()
.linked_outputs(out_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(out_node, new_value_port);
for (src_node, src_port) in order_edge_sources {
self.hugr_mut()
.connect(src_node, src_port, out_node, new_order_port);
}

// Update the builder metadata
*self.num_out_wires_mut() += 1;
}
}

/// Builder for a [`ops::DFG`] node.
#[derive(Debug, Clone, PartialEq)]
pub struct DFGBuilder<T> {
Expand Down Expand Up @@ -81,6 +162,18 @@ impl DFGBuilder<Hugr> {
}
}

impl MutableIO for DFGBuilder<Hugr> {
fn num_in_wires_mut(&mut self) -> &mut usize {
&mut self.num_in_wires
}

fn num_out_wires_mut(&mut self) -> &mut usize {
&mut self.num_out_wires
}

// fn update_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {}
}

impl HugrBuilder for DFGBuilder<Hugr> {
fn finish_hugr(mut self) -> Result<Hugr, ValidationError> {
if cfg!(feature = "extension_inference") {
Expand Down Expand Up @@ -159,79 +252,14 @@ impl FunctionBuilder<Hugr> {
let db = DFGBuilder::create_with_io(base, root, body)?;
Ok(Self::from_dfg_builder(db))
}
}

/// Add a new input to the function being constructed.
///
/// Returns the new wire from the input node.
pub fn add_input(&mut self, input_type: Type) -> Wire {
let [inp_node, _] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
s.input.to_mut().push(input_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().input.clone();
self.hugr_mut()
.replace_op(inp_node, Input { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: OutgoingPort = (new_port - 1).into();
let new_order_port: OutgoingPort = new_port.into();
let order_edge_targets = self
.hugr()
.linked_inputs(inp_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(inp_node, new_value_port);
for (tgt_node, tgt_port) in order_edge_targets {
self.hugr_mut()
.connect(inp_node, new_order_port, tgt_node, tgt_port);
}

// Update the builder metadata
self.0.num_in_wires += 1;

self.input_wires().last().unwrap()
impl MutableIO for FunctionBuilder<Hugr> {
fn num_in_wires_mut(&mut self) -> &mut usize {
&mut self.0.num_in_wires
}

/// Add a new output to the function being constructed.
pub fn add_output(&mut self, output_type: Type) {
let [_, out_node] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
s.output.to_mut().push(output_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().output.clone();
self.hugr_mut()
.replace_op(out_node, Output { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: IncomingPort = (new_port - 1).into();
let new_order_port: IncomingPort = new_port.into();
let order_edge_sources = self
.hugr()
.linked_outputs(out_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(out_node, new_value_port);
for (src_node, src_port) in order_edge_sources {
self.hugr_mut()
.connect(src_node, src_port, out_node, new_order_port);
}

// Update the builder metadata
self.0.num_out_wires += 1;
fn num_out_wires_mut(&mut self) -> &mut usize {
&mut self.0.num_out_wires
}

/// Update the function builder's parent signature.
Expand All @@ -241,7 +269,7 @@ impl FunctionBuilder<Hugr> {
/// Does not update the input and output nodes.
///
/// Returns a reference to the new optype.
fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
fn update_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
let parent = self.container_node();
let old_optype = self
.hugr()
Expand Down
Loading