Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 3, 2024
1 parent 6d1acdb commit 16d0fca
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
3 changes: 3 additions & 0 deletions hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,6 @@ impl ExtensionCollectionError {
}
}
}

#[cfg(test)]
mod test;
49 changes: 49 additions & 0 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//! Tests for extension resolution.
use rstest::rstest;

use crate::extension::resolution::{update_op_extensions, update_op_types_extensions};
use crate::extension::ExtensionRegistry;
use crate::ops::{Input, OpType};
use crate::std_extensions::arithmetic::int_ops;
use crate::std_extensions::arithmetic::int_types;
use crate::type_row;

#[rstest]
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
// A type with extra extensions in its instantiated type arguments.
#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4),
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned()]
))]
fn collect_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: ExtensionRegistry) {
let op = op.into();
let resolved = op.used_extensions().unwrap();
assert_eq!(resolved, extensions);
}

#[rstest]
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
// A type with extra extensions in its instantiated type arguments.
#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4),
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned()]
))]
fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: ExtensionRegistry) {
let op = op.into();

// Ensure that all the `Weak` pointers get invalidated by round-tripping via serialization.
let ser = serde_json::to_string(&op).unwrap();
let mut deser_op: OpType = serde_json::from_str(&ser).unwrap();

let dummy_node = portgraph::NodeIndex::new(0).into();

let mut used_exts = ExtensionRegistry::default();
update_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap();
update_op_types_extensions(dummy_node, &mut deser_op, &extensions, &mut used_exts).unwrap();

let deser_extensions = deser_op.used_extensions().unwrap();

assert_eq!(
deser_extensions, extensions,
"{deser_extensions} != {extensions}"
);
}

0 comments on commit 16d0fca

Please sign in to comment.