Skip to content

Commit

Permalink
pre-extensions in hugrs lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 4, 2024
1 parent 3be18e9 commit c0caf52
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
48 changes: 32 additions & 16 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use crate::extension::resolution::{
collect_op_extensions, collect_op_types_extensions, update_op_extensions,
update_op_types_extensions,
};
use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet};
use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::ops::{CallIndirect, ExtensionOp, Input, OpType, Tag, Value};
use crate::std_extensions::arithmetic::float_types::float64_type;
use crate::std_extensions::arithmetic::float_types::{self, float64_type};
use crate::std_extensions::arithmetic::int_ops;
use crate::std_extensions::arithmetic::int_types::{self, int_type};
use crate::types::{Signature, Type};
Expand Down Expand Up @@ -85,6 +85,16 @@ fn resolve_hugr_extensions() {
let (ext_c, op_c) = make_extension("dummy.c", "op_c");
let (ext_d, op_d) = make_extension("dummy.d", "op_d");

let build_extensions = ExtensionRegistry::new([
PRELUDE.to_owned(),
ext_a.clone(),
ext_b.clone(),
ext_c.clone(),
ext_d.clone(),
float_types::EXTENSION.to_owned(),
int_types::EXTENSION.to_owned(),
]);

let mut module = ModuleBuilder::new();

// A constant op using the prelude extension.
Expand Down Expand Up @@ -113,8 +123,20 @@ fn resolve_hugr_extensions() {
let [func_i0, func_i1] = func.input_wires_arr();

// Call the function declaration directly, and load & call indirectly.
func.call(&decl, &[], vec![func_i0]).unwrap();
let loaded_func = func.load_func(&decl, &[]).unwrap();
func.call(
&decl,
&[],
vec![func_i0],
&ExtensionRegistry::new([float_types::EXTENSION.to_owned()]),
)
.unwrap();
let loaded_func = func
.load_func(
&decl,
&[],
&ExtensionRegistry::new([float_types::EXTENSION.to_owned()]),
)
.unwrap();
func.add_dataflow_op(
CallIndirect {
signature: Signature::new_endo(vec![float64_type()]),
Expand Down Expand Up @@ -171,13 +193,9 @@ fn resolve_hugr_extensions() {

// Finally, finish the hugr and ensure it's using the right extensions.
func.finish_with_outputs(vec![]).unwrap();
let mut hugr = module.finish_hugr().unwrap_or_else(|e| panic!("{e}"));

let build_extensions = hugr.extensions().clone();
assert!(build_extensions.contains(ext_a.name()));
assert!(build_extensions.contains(ext_b.name()));
assert!(build_extensions.contains(ext_c.name()));
assert!(build_extensions.contains(ext_d.name()));
let mut hugr = module
.finish_hugr(&build_extensions)
.unwrap_or_else(|e| panic!("{e}"));

// Check that the read-only methods collect the same extensions.
let mut collected_exts = ExtensionRegistry::default();
Expand All @@ -192,11 +210,9 @@ fn resolve_hugr_extensions() {
);

// Check that the mutable methods collect the same extensions.
hugr.resolve_extension_defs(&build_extensions).unwrap();
let resolved = hugr.resolve_extension_defs(&build_extensions).unwrap();
assert_eq!(
hugr.extensions(),
&build_extensions,
"{} != {build_extensions}",
hugr.extensions()
&resolved, &build_extensions,
"{resolved} != {build_extensions}"
);
}
2 changes: 1 addition & 1 deletion hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ pub(in crate::hugr::rewrite) mod test {
};
rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}"));

assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(()));
assert_eq!(hugr.update_validate(&test_quantum_extension::REG), Ok(()));
assert_eq!(hugr.node_count(), 4);
}

Expand Down

0 comments on commit c0caf52

Please sign in to comment.