diff --git a/src/hugr/views.rs b/src/hugr/views.rs index b5b300619..d0b8d7af4 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -590,6 +590,19 @@ impl> HugrView for T { } } +/// Filter an iterator of node-ports to only dataflow dependency specifying +/// ports (Value and StateOrder) +pub fn dataflow_ports_only<'i, 'a: 'i, P: Into + Copy>( + hugr: &'a impl HugrView, + it: impl Iterator + 'i, +) -> impl Iterator + 'i { + it.filter(move |(n, p)| { + matches!( + hugr.get_optype(*n).port_kind(*p), + Some(EdgeKind::Value(_) | EdgeKind::StateOrder) + ) + }) +} pub(crate) mod sealed { use super::*; diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index c51dd6333..3b544d30b 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -1,4 +1,3 @@ -use itertools::Itertools; use portgraph::PortOffset; use rstest::{fixture, rstest}; @@ -153,6 +152,8 @@ fn value_types() { #[test] fn static_targets() { use crate::extension::prelude::{ConstUsize, USIZE_T}; + use itertools::Itertools; + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![], type_row![USIZE_T])).unwrap(); let c = dfg.add_constant(ConstUsize::new(1).into(), None).unwrap(); @@ -170,3 +171,44 @@ fn static_targets() { &[(load.node(), 0.into())] ) } + +#[rustversion::since(1.75)] // uses impl in return position +#[test] +fn test_dataflow_ports_only() { + use crate::builder::DataflowSubContainer; + use crate::extension::prelude::BOOL_T; + use crate::hugr::views::dataflow_ports_only; + use crate::std_extensions::logic::test::not_op; + use itertools::Itertools; + let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); + let local_and = { + let local_and = dfg + .define_function( + "and", + FunctionType::new(type_row![BOOL_T; 2], type_row![BOOL_T]).pure(), + ) + .unwrap(); + let first_input = local_and.input().out_wire(0); + local_and.finish_with_outputs([first_input]).unwrap() + }; + let [in_bool] = dfg.input_wires_arr(); + + let not = dfg.add_dataflow_op(not_op(), [in_bool]).unwrap(); + let call = dfg.call(local_and.handle(), [not.out_wire(0); 2]).unwrap(); + dfg.add_other_wire(not.node(), call.node()).unwrap(); + let h = dfg + .finish_hugr_with_outputs(not.outputs(), &crate::extension::PRELUDE_REGISTRY) + .unwrap(); + let filtered_ports = dataflow_ports_only(&h, h.all_linked_outputs(call.node())).collect_vec(); + + // should ignore the static input in to call, but report the two value ports + // and the order port. + assert_eq!( + &filtered_ports[..], + &[ + (not.node(), 0.into()), + (not.node(), 0.into()), + (not.node(), 1.into()) + ] + ) +}