diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index dcb773efb..9237046b8 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -246,7 +246,13 @@ pub struct SiblingMut<'g, Root = Node> { impl<'g, Root: NodeHandle> SiblingMut<'g, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. - pub fn try_new(hugr: &'g mut impl HugrMut, root: Node) -> Result { + pub fn try_new(hugr: &'g mut Base, root: Node) -> Result { + if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) { + return Err(HugrError::InvalidTag { + required: Base::RootHandle::TAG, + actual: Root::TAG, + }); + } check_tag::(hugr, root)?; Ok(Self { hugr: hugr.hugr_mut(), @@ -366,7 +372,7 @@ mod test { use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use crate::extension::PRELUDE_REGISTRY; use crate::hugr::NodeType; - use crate::ops::handle::{CfgID, DfgID, FuncID, ModuleRootID}; + use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID, ModuleRootID}; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::type_row; use crate::types::{FunctionType, Type}; @@ -462,4 +468,24 @@ mod test { simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); assert!(simple_dfg_hugr.validate(&PRELUDE_REGISTRY).is_err()); } + + #[rstest] + fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { + let root = simple_dfg_hugr.root(); + let case_nodetype = NodeType::open_extensions(crate::ops::Case { + signature: simple_dfg_hugr.root_type().op_signature(), + }); + let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); + // As expected, we cannot replace the root with a Case + assert_eq!( + sib_mut.replace_op(root, case_nodetype.clone()), + Err(HugrError::InvalidTag { + required: OpTag::Dfg, + actual: OpTag::Case + }) + ); + + let nested_sib_mut = SiblingMut::::try_new(&mut sib_mut, root); + assert!(nested_sib_mut.is_err()); + } }