From 64b91997e7cfee61264ab93179cdff9150c768fa Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 20 Dec 2023 17:03:39 +0000 Subject: [PATCH] refactor: move hugr equality check out for reuse --- src/hugr.rs | 44 +++++++++++++++++++++++++++++++++++++++++-- src/hugr/serialize.rs | 37 ++---------------------------------- 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/src/hugr.rs b/src/hugr.rs index 9672f3dbb..2efb8e867 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -343,12 +343,16 @@ pub enum HugrError { } #[cfg(test)] -mod test { +pub(crate) mod test { + use itertools::Itertools; + use portgraph::{LinkView, PortView}; + use super::{Hugr, HugrView}; use crate::builder::test::closed_dfg_root_hugr; use crate::extension::ExtensionSet; use crate::hugr::HugrMut; - use crate::ops; + use crate::ops::LeafOp; + use crate::ops::{self, OpType}; use crate::type_row; use crate::types::{FunctionType, Type}; @@ -398,4 +402,40 @@ mod test { assert_eq!(hugr.get_nodetype(output).input_extensions().unwrap(), &r); Ok(()) } + + pub(crate) fn assert_hugr_equality(hugr: &Hugr, other: &Hugr) { + assert_eq!(other.root, hugr.root); + assert_eq!(other.hierarchy, hugr.hierarchy); + assert_eq!(other.metadata, hugr.metadata); + + // Extension operations may have been downgraded to opaque operations. + for node in other.nodes() { + let new_op = other.get_optype(node); + let old_op = hugr.get_optype(node); + if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op { + if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op { + assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque()); + } else { + panic!("Expected old_op to be a custom op"); + } + } else { + assert_eq!(new_op, old_op); + } + } + + // Check that the graphs are equivalent up to port renumbering. + let new_graph = &other.graph; + let old_graph = &hugr.graph; + assert_eq!(new_graph.node_count(), old_graph.node_count()); + assert_eq!(new_graph.port_count(), old_graph.port_count()); + assert_eq!(new_graph.link_count(), old_graph.link_count()); + for n in old_graph.nodes_iter() { + assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n)); + assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n)); + assert_eq!( + new_graph.output_neighbours(n).collect_vec(), + old_graph.output_neighbours(n).collect_vec() + ); + } + } } diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index b49549b0f..5f9b236f8 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -260,6 +260,7 @@ pub mod test { use crate::extension::simple_op::MakeRegisteredOp; use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; + use crate::hugr::test::assert_hugr_equality; use crate::hugr::NodeType; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{dataflow::IOTrait, Input, LeafOp, Module, Output, DFG}; @@ -267,7 +268,6 @@ pub mod test { use crate::types::{FunctionType, Type}; use crate::OutgoingPort; use itertools::Itertools; - use portgraph::LinkView; use portgraph::{ multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap, }; @@ -298,40 +298,7 @@ pub mod test { // The internal port indices may still be different. let mut h_canon = hugr.clone(); h_canon.canonicalize_nodes(|_, _| {}); - - assert_eq!(new_hugr.root, h_canon.root); - assert_eq!(new_hugr.hierarchy, h_canon.hierarchy); - assert_eq!(new_hugr.metadata, h_canon.metadata); - - // Extension operations may have been downgraded to opaque operations. - for node in new_hugr.nodes() { - let new_op = new_hugr.get_optype(node); - let old_op = h_canon.get_optype(node); - if let OpType::LeafOp(LeafOp::CustomOp(new_ext)) = new_op { - if let OpType::LeafOp(LeafOp::CustomOp(old_ext)) = old_op { - assert_eq!(new_ext.clone().as_opaque(), old_ext.clone().as_opaque()); - } else { - panic!("Expected old_op to be a custom op"); - } - } else { - assert_eq!(new_op, old_op); - } - } - - // Check that the graphs are equivalent up to port renumbering. - let new_graph = &new_hugr.graph; - let old_graph = &h_canon.graph; - assert_eq!(new_graph.node_count(), old_graph.node_count()); - assert_eq!(new_graph.port_count(), old_graph.port_count()); - assert_eq!(new_graph.link_count(), old_graph.link_count()); - for n in old_graph.nodes_iter() { - assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n)); - assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n)); - assert_eq!( - new_graph.output_neighbours(n).collect_vec(), - old_graph.output_neighbours(n).collect_vec() - ); - } + assert_hugr_equality(&h_canon, &new_hugr); new_hugr }