Skip to content

Commit

Permalink
feat!: HugrView API improvements (#680)
Browse files Browse the repository at this point in the history
- [x] Closes #495
- [x] Closes #521
- [x] Closes #655
- [x] Closes #499
- [x] Closes #506 
- [x] Closes #653
- [x] `OpType::signature` returns option (non-dataflow ops don't return
signature)
- [x] Implement `try_into` from OpType references in to inner
references.

Doesn't necessarily do exactly as those issues specify - instead
considers them holistically for a more unified interface. Easiest to
review commit by commit.

Uses rust_version crate to use return position impl for the new
`HugrView` methods that return iterators. This will be stable with 1.75
(which enters beta in a few days).

BREAKING_CHANGES: `OpType` and `FunctionType` methods renamed for
clarity; `OpType::signature` returns `Option<FuncType>`.
  • Loading branch information
ss2165 authored Nov 15, 2023
1 parent 3938883 commit e943fdc
Show file tree
Hide file tree
Showing 28 changed files with 664 additions and 251 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ petgraph = { version = "0.6.3", default-features = false }
context-iterators = "0.2.0"
serde_json = "1.0.97"
delegate = "0.10.0"
rustversion = "1.0.14"
paste = "1.0"

[features]
pyo3 = ["dep:pyo3"]
Expand All @@ -61,7 +63,6 @@ rmp-serde = "1.1.1"
webbrowser = "0.8.10"
urlencoding = "2.1.2"
cool_asserts = "2.0.3"
paste = "1.0"
insta = { version = "1.34.0", features = ["yaml"] }

[[bench]]
Expand All @@ -71,4 +72,4 @@ harness = false

[profile.dev.package]
insta.opt-level = 3
similar.opt-level = 3
similar.opt-level = 3
15 changes: 8 additions & 7 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ pub trait Dataflow: Container {
hugr: Hugr,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let num_outputs = hugr.get_optype(hugr.root()).value_output_count();
let node = self.add_hugr(hugr)?.new_root;

let inputs = input_wires.into_iter().collect();
Expand All @@ -257,8 +257,8 @@ pub trait Dataflow: Container {
hugr: &impl HugrView,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr_view(hugr)?.new_root;
let num_outputs = hugr.get_optype(hugr.root()).value_output_count();

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand Down Expand Up @@ -612,8 +612,9 @@ pub trait Dataflow: Container {
})
}
};
let const_in_port = signature.output.len();
let op_id = self.add_dataflow_op(ops::Call { signature }, input_wires)?;
let op: OpType = ops::Call { signature }.into();
let const_in_port = op.static_input_port().unwrap();
let op_id = self.add_dataflow_op(op, input_wires)?;
let src_port = self.hugr_mut().num_outputs(function.node()) - 1;

self.hugr_mut()
Expand All @@ -633,13 +634,13 @@ fn add_node_with_wires<T: Dataflow + ?Sized>(
nodetype: impl Into<NodeType>,
inputs: Vec<Wire>,
) -> Result<(Node, usize), BuildError> {
let nodetype = nodetype.into();
let sig = nodetype.op_signature();
let nodetype: NodeType = nodetype.into();
let num_outputs = nodetype.op().value_output_count();
let op_node = data_builder.add_child_node(nodetype)?;

wire_up_inputs(inputs, op_node, data_builder)?;

Ok((op_node, sig.output().len()))
Ok((op_node, num_outputs))
}

fn wire_up_inputs<T: Dataflow + ?Sized>(
Expand Down
5 changes: 3 additions & 2 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::extension::ExtensionRegistry;
use crate::hugr::views::HugrView;
use crate::ops::dataflow::DataflowOpTrait;
use crate::types::{FunctionType, TypeRow};

use crate::ops;
use crate::ops::handle::CaseID;
use crate::ops::{self, OpTrait};

use super::build_traits::SubContainer;
use super::handle::BuildHandle;
Expand Down Expand Up @@ -104,12 +105,12 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
pub fn case_builder(&mut self, case: usize) -> Result<CaseBuilder<&mut Hugr>, BuildError> {
let conditional = self.conditional_node;
let control_op = self.hugr().get_optype(self.conditional_node);
let extension_delta = control_op.signature().extension_reqs;

let cond: ops::Conditional = control_op
.clone()
.try_into()
.expect("Parent node does not have Conditional optype.");
let extension_delta = cond.signature().extension_reqs;
let inputs = cond
.case_input_row(case)
.ok_or(ConditionalBuildError::NotCase { conditional, case })?;
Expand Down
37 changes: 18 additions & 19 deletions src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::{
};

use crate::ops::handle::{AliasID, FuncID, NodeHandle};
use crate::ops::OpType;

use crate::types::Signature;

Expand Down Expand Up @@ -72,22 +71,22 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
/// # Errors
///
/// This function will return an error if there is an error in adding the
/// [`OpType::FuncDefn`] node.
/// [`crate::ops::OpType::FuncDefn`] node.
pub fn define_declaration(
&mut self,
f_id: &FuncID<false>,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
let f_node = f_id.node();
let (signature, name) = if let OpType::FuncDecl(ops::FuncDecl { signature, name }) =
self.hugr().get_optype(f_node)
{
(signature.clone(), name.clone())
} else {
return Err(BuildError::UnexpectedType {
let ops::FuncDecl { signature, name } = self
.hugr()
.get_optype(f_node)
.as_func_decl()
.ok_or(BuildError::UnexpectedType {
node: f_node,
op_desc: "OpType::FuncDecl",
});
};
op_desc: "crate::ops::OpType::FuncDecl",
})?
.clone();

self.hugr_mut().replace_op(
f_node,
NodeType::new_pure(ops::FuncDefn {
Expand All @@ -105,7 +104,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
/// # Errors
///
/// This function will return an error if there is an error in adding the
/// [`OpType::FuncDecl`] node.
/// [`crate::ops::OpType::FuncDecl`] node.
pub fn declare(
&mut self,
name: impl Into<String>,
Expand All @@ -124,11 +123,11 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
Ok(declare_n.into())
}

/// Add a [`OpType::AliasDefn`] node and return a handle to the Alias.
/// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias.
///
/// # Errors
///
/// Error in adding [`OpType::AliasDefn`] child node.
/// Error in adding [`crate::ops::OpType::AliasDefn`] child node.
pub fn add_alias_def(
&mut self,
name: impl Into<SmolStr>,
Expand All @@ -149,10 +148,10 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
Ok(AliasID::new(node, name, bound))
}

/// Add a [`OpType::AliasDecl`] node and return a handle to the Alias.
/// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias.
/// # Errors
///
/// Error in adding [`OpType::AliasDecl`] child node.
/// Error in adding [`crate::ops::OpType::AliasDecl`] child node.
pub fn add_alias_declare(
&mut self,
name: impl Into<SmolStr>,
Expand Down Expand Up @@ -233,14 +232,14 @@ mod test {

let mut f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(),
)?;
let local_build = f_build.define_function(
"local",
FunctionType::new(type_row![NAT], type_row![NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(),
)?;
let [wire] = local_build.input_wires_arr();
let f_id = local_build.finish_with_outputs([wire])?;
let f_id = local_build.finish_with_outputs([wire, wire])?;

let call = f_build.call(f_id.handle(), f_build.input_wires())?;

Expand Down
11 changes: 5 additions & 6 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ops::{self, OpType};
use crate::ops;

use crate::hugr::{views::HugrView, NodeType};
use crate::types::{FunctionType, TypeRow};
Expand Down Expand Up @@ -38,14 +38,13 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
/// Get a reference to the [`ops::TailLoop`]
/// that defines the signature of the [`ops::TailLoop`]
pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> {
if let OpType::TailLoop(tail_loop) = self.hugr().get_optype(self.container_node()) {
Ok(tail_loop)
} else {
Err(BuildError::UnexpectedType {
self.hugr()
.get_optype(self.container_node())
.as_tail_loop()
.ok_or(BuildError::UnexpectedType {
node: self.container_node(),
op_desc: "crate::ops::TailLoop",
})
}
}

/// The output types of the child graph, including the TupleSum as the first.
Expand Down
12 changes: 8 additions & 4 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,15 @@ impl UnificationContext {
match node_type.signature() {
// Input extensions are open
None => {
let delta = node_type.op_signature().extension_reqs;
let c = if delta.is_empty() {
Constraint::Equal(m_input)
let c = if let Some(sig) = node_type.op_signature() {
let delta = sig.extension_reqs;
if delta.is_empty() {
Constraint::Equal(m_input)
} else {
Constraint::Plus(delta, m_input)
}
} else {
Constraint::Plus(delta, m_input)
Constraint::Equal(m_input)
};
self.add_constraint(m_output, c);
}
Expand Down
2 changes: 1 addition & 1 deletion src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
hugr,
conditional_node,
op.clone(),
Into::<OpType>::into(op).signature(),
Into::<OpType>::into(op).dataflow_signature().unwrap(),
)?;

let lift1 = hugr.add_node_with_parent(
Expand Down
9 changes: 7 additions & 2 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ impl ExtensionValidator {
pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self {
let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new();
for (node, incoming_sol) in closure.into_iter() {
let op_signature = hugr.get_nodetype(node).op_signature();
let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol);
let extension_reqs = hugr
.get_nodetype(node)
.op_signature()
.map(|s| s.extension_reqs)
.unwrap_or_default();

let outgoing_sol = extension_reqs.union(&incoming_sol);

extensions.insert((node, Direction::Incoming), incoming_sol);
extensions.insert((node, Direction::Outgoing), outgoing_sol);
Expand Down
13 changes: 8 additions & 5 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,17 @@ impl NodeType {

/// Use the input extensions to calculate the concrete signature of the node
pub fn signature(&self) -> Option<Signature> {
self.input_extensions
.as_ref()
.map(|rs| self.op.signature().with_input_extensions(rs.clone()))
self.input_extensions.as_ref().map(|rs| {
self.op
.dataflow_signature()
.unwrap_or_default()
.with_input_extensions(rs.clone())
})
}

/// Get the function type from the embedded op
pub fn op_signature(&self) -> FunctionType {
self.op.signature()
pub fn op_signature(&self) -> Option<FunctionType> {
self.op.dataflow_signature()
}

/// The input extensions defined for this node.
Expand Down
4 changes: 2 additions & 2 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,12 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
) -> Result<(OutgoingPort, IncomingPort), HugrError> {
let src_port = self
.get_optype(src)
.other_port_index(Direction::Outgoing)
.other_output_port()
.expect("Source operation has no non-dataflow outgoing edges")
.as_outgoing()?;
let dst_port = self
.get_optype(dst)
.other_port_index(Direction::Incoming)
.other_input_port()
.expect("Destination operation has no non-dataflow incoming edges")
.as_incoming()?;
self.connect(src, src_port, dst, dst_port)?;
Expand Down
5 changes: 1 addition & 4 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::{HugrView, IncomingPort};

use super::Rewrite;

use itertools::Itertools;
use thiserror::Error;

/// Specification of a identity-insertion operation.
Expand Down Expand Up @@ -73,9 +72,7 @@ impl Rewrite for IdentityInsertion {
};

let (pre_node, pre_port) = h
.linked_outputs(self.post_node, self.post_port)
.exactly_one()
.ok()
.single_linked_output(self.post_node, self.post_port)
.expect("Value kind input can only have one connection.");

h.disconnect(self.post_node, self.post_port).unwrap();
Expand Down
7 changes: 4 additions & 3 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use crate::hugr::rewrite::Rewrite;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::{HugrMut, HugrView};
use crate::ops;
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpTrait, OpType};
use crate::ops::{BasicBlock, OpType};
use crate::PortIndex;
use crate::{type_row, Node};

Expand Down Expand Up @@ -49,7 +50,7 @@ impl OutlineCfg {
_ => return Err(OutlineCfgError::NotSiblings),
};
let o = h.get_optype(cfg_n);
if !matches!(o, OpType::CFG(_)) {
let OpType::CFG(o) = o else {
return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone()));
};
let cfg_entry = h.children(cfg_n).next().unwrap();
Expand Down Expand Up @@ -177,7 +178,7 @@ impl Rewrite for OutlineCfg {
let exit_port = h
.node_outputs(exit)
.filter(|p| {
let (t, p2) = h.linked_ports(exit, *p).exactly_one().ok().unwrap();
let (t, p2) = h.single_linked_input(exit, *p).unwrap();
assert!(p2.index() == 0);
t == outside
})
Expand Down
Loading

0 comments on commit e943fdc

Please sign in to comment.