Skip to content

Commit

Permalink
Add RootChecked to guarantee root node type (#548)
Browse files Browse the repository at this point in the history
* Factor out RootTagged as a subtrait of HugrView to allow separate
implementation. Most things can still just use HugrView (as it happens
everything that implements HugrView also implements RootTagged, but this
is not required).
    * HugrMut a subtrait of RootTagged not just HugrView
* Then add a new RootChecked struct storing any `AsRef<Hugr>` with a
fixed RootHandle.
* It allows `.as_ref()`, throwing away the extra information in the
RootHandle
     * If the underlying view is a HugrMut then it is too
* Use trait-default implementations of Hugr(View,(Mut)Internals), as the
latter check that `replace_op` does not break the bound on root-type
* But do not provide `.as_mut()` as that would allow bypassing and
invalidating the extra info in the RootHandle
* Like SiblingMut, check that nested instances only narrow the bound.
* Change the overridden-for-&(mut) Hugr impls of Hugr(Mut)Internals to
only work for `RootHandle=Node` to ensure the lack of checking in
replace_op there is safe.
  • Loading branch information
acl-cqc authored Oct 10, 2023
1 parent 23754fb commit 31736d3
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 53 deletions.
5 changes: 2 additions & 3 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr};

pub use self::views::HugrView;
pub use self::views::{HugrView, RootTagged};
use crate::extension::{
infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError,
};
Expand Down Expand Up @@ -570,7 +570,7 @@ pub enum HugrError {
#[error("Invalid node {0:?}.")]
InvalidNode(Node),
/// The node was not of the required [OpTag]
/// (e.g. to conform to a [HugrView::RootHandle])
/// (e.g. to conform to the [RootTagged::RootHandle] of a [HugrView])
#[error("Invalid tag: required a tag in {required} but found {actual}")]
#[allow(missing_docs)]
InvalidTag { required: OpTag, actual: OpTag },
Expand Down Expand Up @@ -617,7 +617,6 @@ mod test {
#[test]
fn io_node() {
use crate::builder::test::simple_dfg_hugr;
use crate::hugr::views::HugrView;
use cool_asserts::assert_matches;

let hugr = simple_dfg_hugr();
Expand Down
19 changes: 6 additions & 13 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::ops::Range;
use portgraph::view::{NodeFilter, NodeFiltered};
use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap};

use crate::hugr::{Direction, HugrError, HugrView, Node, NodeType};
use crate::hugr::{Direction, HugrError, HugrView, Node, NodeType, RootTagged};
use crate::ops::OpType;

use crate::{Hugr, Port};
Expand All @@ -17,7 +17,7 @@ use super::views::SiblingSubgraph;
use super::{IncomingPort, NodeMetadata, OutgoingPort, PortIndex, Rewrite};

/// Functions for low-level building of a HUGR.
pub trait HugrMut: HugrView + HugrMutInternals {
pub trait HugrMut: HugrMutInternals {
/// Returns the metadata associated with a node.
fn get_metadata_mut(&mut self, node: Node) -> Result<&mut NodeMetadata, HugrError> {
self.valid_node(node)?;
Expand Down Expand Up @@ -216,10 +216,7 @@ impl InsertionResult {
}

/// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr.
impl<T> HugrMut for T
where
T: HugrView + AsMut<Hugr>,
{
impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
fn add_node_with_parent(&mut self, parent: Node, node: NodeType) -> Result<Node, HugrError> {
let node = self.as_mut().add_node(node);
self.as_mut()
Expand Down Expand Up @@ -445,7 +442,7 @@ pub(crate) mod sealed {
///
/// Specifically, this trait lets you apply arbitrary modifications that may
/// invalidate the HUGR.
pub trait HugrMutInternals: HugrView {
pub trait HugrMutInternals: RootTagged {
/// Returns the Hugr at the base of a chain of views.
fn hugr_mut(&mut self) -> &mut Hugr;

Expand Down Expand Up @@ -518,10 +515,7 @@ pub(crate) mod sealed {
}

/// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr.
impl<T> HugrMutInternals for T
where
T: HugrView + AsMut<Hugr>,
{
impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMutInternals for T {
fn hugr_mut(&mut self) -> &mut Hugr {
self.as_mut()
}
Expand Down Expand Up @@ -577,7 +571,7 @@ pub(crate) mod sealed {
}

fn replace_op(&mut self, node: Node, op: NodeType) -> Result<NodeType, HugrError> {
// No possibility of failure here since Self::RootHandle == Any
// We know RootHandle=Node here so no need to check
let cur = self.hugr_mut().op_types.get_mut(node.index);
Ok(std::mem::replace(cur, op))
}
Expand All @@ -589,7 +583,6 @@ mod test {
use crate::{
extension::prelude::USIZE_T,
extension::PRELUDE_REGISTRY,
hugr::HugrView,
macros::type_row,
ops::{self, dataflow::IOTrait, LeafOp},
types::{FunctionType, Type},
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ mod tests {
algorithm::nest_cfgs::test::build_conditional_in_loop_cfg,
extension::{prelude::QB_T, PRELUDE_REGISTRY},
ops::handle::NodeHandle,
Hugr,
Hugr, HugrView,
};

#[rstest]
Expand Down
42 changes: 26 additions & 16 deletions src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

pub mod descendants;
pub mod petgraph;
mod root_checked;
pub mod sibling;
pub mod sibling_subgraph;

Expand All @@ -10,6 +11,7 @@ mod tests;

pub use self::petgraph::PetgraphWrapper;
pub use descendants::DescendantsGraph;
pub use root_checked::RootChecked;
pub use sibling::SiblingGraph;
pub use sibling_subgraph::SiblingSubgraph;

Expand All @@ -27,12 +29,6 @@ use crate::{Direction, Node, Port};
/// A trait for inspecting HUGRs.
/// For end users we intend this to be superseded by region-specific APIs.
pub trait HugrView: sealed::HugrInternals {
/// The kind of handle that can be used to refer to the root node.
///
/// The handle is guaranteed to be able to contain the operation returned by
/// [`HugrView::root_type`].
type RootHandle: NodeHandle;

/// An Iterator over the nodes in a Hugr(View)
type Nodes<'a>: Iterator<Item = Node>
where
Expand Down Expand Up @@ -73,7 +69,8 @@ pub trait HugrView: sealed::HugrInternals {
#[inline]
fn root_type(&self) -> &NodeType {
let node_type = self.get_nodetype(self.root());
debug_assert!(Self::RootHandle::can_hold(node_type.tag()));
// Sadly no way to do this at present
// debug_assert!(Self::RootHandle::can_hold(node_type.tag()));
node_type
}

Expand Down Expand Up @@ -303,8 +300,17 @@ pub trait HugrView: sealed::HugrInternals {
}
}

/// Trait for views that provides a guaranteed bound on the type of the root node.
pub trait RootTagged: HugrView {
/// The kind of handle that can be used to refer to the root node.
///
/// The handle is guaranteed to be able to contain the operation returned by
/// [`HugrView::root_type`].
type RootHandle: NodeHandle;
}

/// A common trait for views of a HUGR hierarchical subgraph.
pub trait HierarchyView<'a>: HugrView + Sized {
pub trait HierarchyView<'a>: RootTagged + Sized {
/// Create a hierarchical view of a HUGR given a root node.
///
/// # Errors
Expand All @@ -322,12 +328,19 @@ fn check_tag<Required: NodeHandle>(hugr: &impl HugrView, node: Node) -> Result<(
Ok(())
}

impl<T> HugrView for T
where
T: AsRef<Hugr>,
{
impl RootTagged for Hugr {
type RootHandle = Node;
}

impl RootTagged for &Hugr {
type RootHandle = Node;
}

impl RootTagged for &mut Hugr {
type RootHandle = Node;
}

impl<T: AsRef<Hugr>> HugrView for T {
/// An Iterator over the nodes in a Hugr(View)
type Nodes<'a> = MapInto<multiportgraph::Nodes<'a>, Node> where Self: 'a;

Expand Down Expand Up @@ -451,10 +464,7 @@ pub(crate) mod sealed {
fn root_node(&self) -> Node;
}

impl<T> HugrInternals for T
where
T: AsRef<super::Hugr>,
{
impl<T: AsRef<Hugr>> HugrInternals for T {
type Portgraph<'p> = &'p MultiPortGraph where Self: 'p;

#[inline]
Expand Down
13 changes: 5 additions & 8 deletions src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::hugr::HugrError;
use crate::ops::handle::NodeHandle;
use crate::{Direction, Hugr, Node, Port};

use super::{check_tag, sealed::HugrInternals, HierarchyView, HugrView};
use super::{check_tag, sealed::HugrInternals, HierarchyView, HugrView, RootTagged};

type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;

Expand Down Expand Up @@ -39,13 +39,7 @@ pub struct DescendantsGraph<'g, Root = Node> {
/// The operation handle of the root node.
_phantom: std::marker::PhantomData<Root>,
}

impl<'g, Root> HugrView for DescendantsGraph<'g, Root>
where
Root: NodeHandle,
{
type RootHandle = Root;

impl<'g, Root: NodeHandle> HugrView for DescendantsGraph<'g, Root> {
type Nodes<'a> = MapInto<<RegionGraph<'g> as PortView>::Nodes<'a>, Node>
where
Self: 'a;
Expand Down Expand Up @@ -154,6 +148,9 @@ where
self.graph.all_neighbours(node.index).map_into()
}
}
impl<'g, Root: NodeHandle> RootTagged for DescendantsGraph<'g, Root> {
type RootHandle = Root;
}

impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root>
where
Expand Down
127 changes: 127 additions & 0 deletions src/hugr/views/root_checked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::marker::PhantomData;

use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut};
use crate::ops::handle::NodeHandle;
use crate::{Hugr, Node};

use super::{check_tag, RootTagged};

/// A view of the whole Hugr.
/// (Just provides static checking of the type of the root node)
pub struct RootChecked<H, Root = Node>(H, PhantomData<Root>);

impl<H: RootTagged + AsRef<Hugr>, Root: NodeHandle> RootChecked<H, Root> {
/// Create a hierarchical view of a whole HUGR
///
/// # Errors
/// Returns [`HugrError::InvalidTag`] if the root isn't a node of the required [`OpTag`]
///
/// [`OpTag`]: crate::ops::OpTag
pub fn try_new(hugr: H) -> Result<Self, HugrError> {
if !H::RootHandle::TAG.is_superset(Root::TAG) {
return Err(HugrError::InvalidTag {
required: H::RootHandle::TAG,
actual: Root::TAG,
});
}
check_tag::<Root>(&hugr, hugr.root())?;
Ok(Self(hugr, PhantomData))
}
}

impl<Root> RootChecked<Hugr, Root> {
/// Extracts the underlying (owned) Hugr
pub fn into_hugr(self) -> Hugr {
self.0
}
}

impl<H: AsRef<Hugr>, Root: NodeHandle> RootTagged for RootChecked<H, Root> {
type RootHandle = Root;
}

impl<H: AsRef<Hugr>, Root> AsRef<Hugr> for RootChecked<H, Root> {
fn as_ref(&self) -> &Hugr {
self.0.as_ref()
}
}

impl<H: HugrMutInternals + AsRef<Hugr>, Root> HugrMutInternals for RootChecked<H, Root>
where
Root: NodeHandle,
{
#[inline(always)]
fn hugr_mut(&mut self) -> &mut Hugr {
self.0.hugr_mut()
}
}

impl<H: HugrMutInternals + AsRef<Hugr>, Root: NodeHandle> HugrMut for RootChecked<H, Root> {}

#[cfg(test)]
mod test {
use super::RootChecked;
use crate::extension::ExtensionSet;
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut, NodeType};
use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID};
use crate::ops::{BasicBlock, LeafOp, OpTag};
use crate::{ops, type_row, types::FunctionType, Hugr, HugrView};

#[test]
fn root_checked() {
let root_type = NodeType::pure(ops::DFG {
signature: FunctionType::new(vec![], vec![]),
});
let mut h = Hugr::new(root_type.clone());
let cfg_v = RootChecked::<&Hugr, CfgID>::try_new(&h);
assert_eq!(
cfg_v.err(),
Some(HugrError::InvalidTag {
required: OpTag::Cfg,
actual: OpTag::Dfg
})
);
let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
// That is a HugrMutInternal, so we can try:
let root = dfg_v.root();
let bb = NodeType::pure(BasicBlock::DFB {
inputs: type_row![],
other_outputs: type_row![],
predicate_variants: vec![type_row![]],
extension_delta: ExtensionSet::new(),
});
let r = dfg_v.replace_op(root, bb.clone());
assert_eq!(
r,
Err(HugrError::InvalidTag {
required: OpTag::Dfg,
actual: ops::OpTag::BasicBlock
})
);
// That didn't do anything:
assert_eq!(dfg_v.get_nodetype(root), &root_type);

// Make a RootChecked that allows any DataflowParent
// We won't be able to do this by widening the bound:
assert_eq!(
RootChecked::<_, DataflowParentID>::try_new(dfg_v).err(),
Some(HugrError::InvalidTag {
required: OpTag::Dfg,
actual: OpTag::DataflowParent
})
);

let mut dfp_v = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut h).unwrap();
let r = dfp_v.replace_op(root, bb.clone());
assert_eq!(r, Ok(root_type));
assert_eq!(dfp_v.get_nodetype(root), &bb);
// Just check we can create a nested instance (narrowing the bound)
let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap();

// And it's a HugrMut:
let nodetype = NodeType::pure(LeafOp::MakeTuple { tys: type_row![] });
bb_v.add_node_with_parent(bb_v.root(), nodetype).unwrap();
}
}
Loading

0 comments on commit 31736d3

Please sign in to comment.