Skip to content

Commit

Permalink
feat: static_targets and OpType::static_output_port
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Nov 13, 2023
1 parent 110ce33 commit 310de5b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ pub trait HugrView: sealed::HugrInternals {
.map(|(n, _)| n)
}

#[rustversion::since(1.75)] // uses impl in return position
/// If a node has a static output, return the targets.
fn static_targets(&self, node: Node) -> Option<impl Iterator<Item = (Node, IncomingPort)>> {
Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?))
}

/// Get the "signature" (incoming and outgoing types) of a node, non-Value
/// kind edges will be missing.
fn signature(&self, node: Node) -> FunctionType {
Expand Down
23 changes: 23 additions & 0 deletions src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use itertools::Itertools;
use portgraph::PortOffset;
use rstest::{fixture, rstest};

Expand Down Expand Up @@ -147,3 +148,25 @@ fn value_types() {

assert_eq!(&out_types[..], &[(0.into(), BOOL_T), (1.into(), QB_T)]);
}

#[rustversion::since(1.75)] // uses impl in return position
#[test]
fn static_targets() {
use crate::extension::prelude::{ConstUsize, USIZE_T};
let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap();

let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap();

let load = dfg.load_const(&c).unwrap();

let h = dfg
.finish_hugr_with_outputs([load], &crate::extension::PRELUDE_REGISTRY)
.unwrap();

assert_eq!(h.static_source(load.node()), Some(c.node()));

assert_eq!(
&h.static_targets(c.node()).unwrap().collect_vec()[..],
&[(load.node(), 0.into())]
)
}
9 changes: 8 additions & 1 deletion src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod module;
pub mod tag;
pub mod validate;
use crate::types::{EdgeKind, FunctionType, Type};
use crate::{Direction, Port};
use crate::{Direction, OutgoingPort, Port};
use crate::{IncomingPort, PortIndex};

use portgraph::NodeIndex;
Expand Down Expand Up @@ -142,6 +142,13 @@ impl OpType {
}
}

/// If the op has a static output (Const, FuncDefn, FuncDecl), the port of that output.
pub fn static_output_port(&self) -> Option<OutgoingPort> {
OpTag::StaticOutput
.is_superset(self.tag())
.then_some(0.into())
}

/// Returns the number of ports for the given direction.
pub fn port_count(&self, dir: Direction) -> usize {
let signature = self.signature();
Expand Down
8 changes: 6 additions & 2 deletions src/ops/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub enum OpTag {
Output,
/// Dataflow node that has a static input
StaticInput,
/// Node that has a static output
StaticOutput,
/// A function call.
FnCall,
/// A constant load operation.
Expand Down Expand Up @@ -106,14 +108,14 @@ impl OpTag {
OpTag::DataflowChild => &[OpTag::Any],
OpTag::Input => &[OpTag::DataflowChild],
OpTag::Output => &[OpTag::DataflowChild],
OpTag::Function => &[OpTag::ModuleOp],
OpTag::Function => &[OpTag::ModuleOp, OpTag::StaticOutput],
OpTag::Alias => &[OpTag::ScopedDefn],
OpTag::FuncDefn => &[OpTag::Function, OpTag::ScopedDefn, OpTag::DataflowParent],
OpTag::BasicBlock => &[OpTag::ControlFlowChild, OpTag::DataflowParent],
OpTag::BasicBlockExit => &[OpTag::BasicBlock],
OpTag::Case => &[OpTag::Any, OpTag::DataflowParent],
OpTag::ModuleRoot => &[OpTag::Any],
OpTag::Const => &[OpTag::ScopedDefn],
OpTag::Const => &[OpTag::ScopedDefn, OpTag::StaticOutput],
OpTag::Dfg => &[OpTag::DataflowChild, OpTag::DataflowParent],
OpTag::Cfg => &[OpTag::DataflowChild],
OpTag::ScopedDefn => &[
Expand All @@ -124,6 +126,7 @@ impl OpTag {
OpTag::TailLoop => &[OpTag::DataflowChild, OpTag::DataflowParent],
OpTag::Conditional => &[OpTag::DataflowChild],
OpTag::StaticInput => &[OpTag::DataflowChild],
OpTag::StaticOutput => &[OpTag::ModuleOp],
OpTag::FnCall => &[OpTag::StaticInput],
OpTag::LoadConst => &[OpTag::StaticInput],
OpTag::Leaf => &[OpTag::DataflowChild],
Expand Down Expand Up @@ -154,6 +157,7 @@ impl OpTag {
OpTag::TailLoop => "Tail-recursive loop",
OpTag::Conditional => "Conditional operation",
OpTag::StaticInput => "Dataflow child with static input (LoadConst or FnCall)",
OpTag::StaticOutput => "Node with static input (FuncDefn, FuncDecl, Const)",
OpTag::FnCall => "Function call",
OpTag::LoadConst => "Constant load operation",
OpTag::Leaf => "Leaf operation",
Expand Down

0 comments on commit 310de5b

Please sign in to comment.