Skip to content

Commit

Permalink
test test test
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 18, 2024
1 parent 7fc6306 commit ec182dc
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 39 deletions.
1 change: 1 addition & 0 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ impl OpDef {
}

/// Iterate over all miscellaneous data in the [OpDef].
#[allow(unused)] // Unused when no features are enabled
pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
self.misc.iter().map(|(k, v)| (k.as_str(), v))
}
Expand Down
27 changes: 27 additions & 0 deletions hugr-core/src/hugr/views/descendants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,35 @@ pub(super) mod test {
Some(Signature::new(type_row![NAT], type_row![NAT]))
);
assert_eq!(inner_region.node_count(), 3);
assert_eq!(inner_region.edge_count(), 2);
assert_eq!(inner_region.children(inner).count(), 2);
assert_eq!(inner_region.children(hugr.root()).count(), 0);
assert_eq!(
inner_region.num_ports(inner, Direction::Outgoing),
inner_region.node_ports(inner, Direction::Outgoing).count()
);
assert_eq!(
inner_region.num_ports(inner, Direction::Incoming)
+ inner_region.num_ports(inner, Direction::Outgoing),
inner_region.all_node_ports(inner).count()
);

// The inner region filters out the connections to the main function I/O nodes,
// while the outer region includes them.
assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0);
assert_eq!(region.node_connections(inner, def_io[1]).count(), 1);
assert_eq!(
inner_region
.linked_ports(inner, IncomingPort::from(0))
.count(),
0
);
assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1);
assert_eq!(
inner_region.neighbours(inner, Direction::Outgoing).count(),
0
);
assert_eq!(inner_region.all_neighbours(inner).count(), 0);
assert_eq!(
inner_region
.linked_ports(inner, IncomingPort::from(0))
Expand Down
153 changes: 114 additions & 39 deletions hugr-core/src/hugr/views/sibling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::iter;

use itertools::Itertools;
use itertools::{Either, Itertools};
use portgraph::{LinkView, MultiPortGraph, PortView};

use crate::hugr::internal::HugrMutInternals;
Expand Down Expand Up @@ -280,49 +280,33 @@ impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> {
node: Node,
port: impl Into<Port>,
) -> impl Iterator<Item = (Node, Port)> + Clone {
let port = self
.hugr
.graph
.port_index(node.pg_index(), port.into().pg_offset())
.unwrap();
self.hugr
.graph
.port_links(port)
.map(|(_, link)| {
let node: Node = self.hugr.graph.port_node(link).unwrap().into();
let offset: Port = self.hugr.graph.port_offset(link).unwrap().into();
(node, offset)
})
.linked_ports(node, port)
.filter(|(n, _)| self.contains_node(*n))
}

fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
self.hugr
.graph
.get_connections(node.pg_index(), other.pg_index())
.map(|(p1, p2)| [p1, p2].map(|link| self.hugr.graph.port_offset(link).unwrap().into()))
match self.contains_node(node) && self.contains_node(other) {
// The nodes are not in the sibling graph
false => Either::Left(iter::empty()),
// The nodes are in the sibling graph
true => Either::Right(self.hugr.node_connections(node, other)),
}
}

fn num_ports(&self, node: Node, dir: Direction) -> usize {
match self.contains_node(node) {
true => self.base_hugr().num_ports(node, dir),
false => 0,
}
self.base_hugr().num_ports(node, dir)
}

fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
self.hugr
.graph
.neighbours(node.pg_index(), dir)
.map_into()
.neighbours(node, dir)
.filter(|n| self.contains_node(*n))
}

fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
self.hugr
.graph
.all_neighbours(node.pg_index())
.map_into()
.all_neighbours(node)
.filter(|n| self.contains_node(*n))
}
}
Expand All @@ -349,28 +333,116 @@ mod test {
use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID};
use crate::ops::{dataflow::IOTrait, Input, OpTag, Output};
use crate::ops::{OpTrait, OpType};
use crate::type_row;
use crate::types::{Signature, Type};
use crate::utils::test_quantum_extension::EXTENSION_ID;
use crate::{type_row, IncomingPort};

const NAT: Type = crate::extension::prelude::USIZE_T;
const QB: Type = crate::extension::prelude::QB_T;

use super::super::descendants::test::make_module_hgr;
use super::*;

#[test]
fn flat_region() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, inner) = make_module_hgr()?;

let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?;
fn test_properties<T>(
hugr: &Hugr,
def: Node,
inner: Node,
region: T,
inner_region: T,
) -> Result<(), Box<dyn std::error::Error>>
where
T: HugrView + Sized,
{
let def_io = region.get_io(def).unwrap();

assert_eq!(region.node_count(), 5);
assert!(region
.nodes()
.all(|n| n == def || hugr.get_parent(n) == Some(def)));
assert_eq!(region.portgraph().node_count(), 5);
assert!(region.nodes().all(|n| n == def
|| hugr.get_parent(n) == Some(def)
|| hugr.get_parent(n) == Some(inner)));
assert_eq!(region.children(inner).count(), 0);

assert_eq!(
region.poly_func_type(),
Some(
Signature::new_endo(type_row![NAT, QB])
.with_extension_delta(EXTENSION_ID)
.into()
)
);

assert_eq!(
inner_region.inner_function_type(),
Some(Signature::new(type_row![NAT], type_row![NAT]))
);
assert_eq!(inner_region.node_count(), 3);
assert_eq!(inner_region.edge_count(), 1);
assert_eq!(inner_region.children(inner).count(), 2);
assert_eq!(inner_region.children(hugr.root()).count(), 0);
assert_eq!(
inner_region.num_ports(inner, Direction::Outgoing),
inner_region.node_ports(inner, Direction::Outgoing).count()
);
assert_eq!(
inner_region.num_ports(inner, Direction::Incoming)
+ inner_region.num_ports(inner, Direction::Outgoing),
inner_region.all_node_ports(inner).count()
);

// The inner region filters out the connections to the main function I/O nodes,
// while the outer region includes them.
assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0);
assert_eq!(region.node_connections(inner, def_io[1]).count(), 1);
assert_eq!(
inner_region
.linked_ports(inner, IncomingPort::from(0))
.count(),
0
);
assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1);
assert_eq!(
inner_region.neighbours(inner, Direction::Outgoing).count(),
0
);
assert_eq!(inner_region.all_neighbours(inner).count(), 0);
assert_eq!(
inner_region
.linked_ports(inner, IncomingPort::from(0))
.count(),
0
);

Ok(())
}

const NAT: Type = crate::extension::prelude::USIZE_T;
#[rstest]
fn sibling_graph_properties() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, inner) = make_module_hgr()?;

test_properties::<SiblingGraph>(
&hugr,
def,
inner,
SiblingGraph::try_new(&hugr, def).unwrap(),
SiblingGraph::try_new(&hugr, inner).unwrap(),
)
}

#[rstest]
fn sibling_mut_properties() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, inner) = make_module_hgr()?;
let mut def_region_hugr = hugr.clone();
let mut inner_region_hugr = hugr.clone();

test_properties::<SiblingMut>(
&hugr,
def,
inner,
SiblingMut::try_new(&mut def_region_hugr, def).unwrap(),
SiblingMut::try_new(&mut inner_region_hugr, inner).unwrap(),
)
}

#[test]
fn nested_flat() -> Result<(), Box<dyn std::error::Error>> {
let mut module_builder = ModuleBuilder::new();
Expand All @@ -382,11 +454,13 @@ mod test {
let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?;
let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?;
let sub_dfg = sub_dfg.node();
// Can create a view from a child or grandchild of a hugr:

// We can create a view from a child or grandchild of a hugr:
let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&h, sub_dfg)?;
let fun_view: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&h, fun.node())?;
assert_eq!(fun_view.children(sub_dfg).count(), 0);
// And can create a view from a child of another SiblingGraph

// And also create a view from a child of another SiblingGraph
let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&fun_view, sub_dfg)?;

// Both ways work:
Expand All @@ -404,6 +478,7 @@ mod test {
Ok(())
}

/// Mutate a SiblingMut wrapper
#[rstest]
fn flat_mut(mut simple_dfg_hugr: Hugr) {
simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap();
Expand Down
4 changes: 4 additions & 0 deletions hugr-core/src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<Dataflow
let (h, n1, n2) = sample_hugr;

let all_output_ports = h.all_linked_outputs(n2.node()).collect_vec();
let all_ports = h.all_node_ports(n2.node()).collect_vec();

assert_eq!(
&all_output_ports[..],
Expand All @@ -108,6 +109,9 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<Dataflow
(n1.node(), 2.into()),
]
);
assert!(all_output_ports
.iter()
.all(|&(_, p)| all_ports.contains(&p.into())));

let all_linked_inputs = h.all_linked_inputs(n1.node()).collect_vec();
assert_eq!(
Expand Down

0 comments on commit ec182dc

Please sign in to comment.