diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 7d100d1ab..bfa6146a2 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -179,8 +179,13 @@ where Root: NodeHandle, { fn new(hugr: &'a impl HugrView, root: Node) -> Result { + println!("ALAN checking node {:?} against {}", root, Root::TAG); hugr.valid_node(root)?; if !Root::TAG.is_superset(hugr.get_optype(root).tag()) { + println!( + "ALAN check failed, optype was {}", + hugr.get_optype(root).tag() + ); return Err(HugrError::InvalidNode(root)); } let hugr = hugr.base_hugr(); @@ -217,6 +222,13 @@ where #[cfg(test)] mod test { + use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; + use crate::extension::PRELUDE_REGISTRY; + use crate::ops::handle::{DfgID, FuncID, ModuleRootID}; + use crate::ops::{dataflow::IOTrait, Input, Output}; + use crate::type_row; + use crate::types::{FunctionType, Type}; + use super::super::descendants::test::make_module_hgr; use super::*; @@ -234,4 +246,45 @@ mod test { Ok(()) } + + const NAT: Type = crate::extension::prelude::USIZE_T; + #[test] + fn nested_flat() -> Result<(), Box> { + let mut module_builder = ModuleBuilder::new(); + let fty = FunctionType::new(type_row![NAT], type_row![NAT]); + let mut fbuild = module_builder.define_function("main", fty.clone().pure())?; + let dfg = fbuild.dfg_builder(fty, None, fbuild.input_wires())?; + let ins = dfg.input_wires(); + let sub_dfg = dfg.finish_with_outputs(ins)?; + let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; + let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?; + let sub_dfg = sub_dfg.node(); + // Can create a view from a child or grandchild of a hugr: + let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::new(&h, sub_dfg)?; + let fun_view: SiblingGraph<'_, FuncID> = SiblingGraph::new(&h, fun.node())?; + assert_eq!(fun_view.children(sub_dfg).len(), 0); + // And can create a view from a child of another SiblingGraph + let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::new(&fun_view, sub_dfg)?; + + // Both ways work: + let just_io = vec![ + Input::new(type_row![NAT]).into(), + Output::new(type_row![NAT]).into(), + ]; + for d in [dfg_view, nested_dfg_view] { + assert_eq!( + d.children(sub_dfg).map(|n| d.get_optype(n)).collect_vec(), + just_io.iter().collect_vec() + ); + } + + // But cannot create a view directly as a grandchild of another SiblingGraph + let root_view: SiblingGraph<'_, ModuleRootID> = SiblingGraph::new(&h, h.root()).unwrap(); + assert_eq!( + SiblingGraph::<'_, DfgID>::new(&root_view, sub_dfg.node()).err(), + Some(HugrError::InvalidNode(sub_dfg.node())) + ); + + Ok(()) + } }