Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 4, 2024
1 parent 16d0fca commit 3be18e9
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 19 deletions.
2 changes: 2 additions & 0 deletions hugr-core/src/extension/resolution/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub(crate) fn collect_op_extensions(
op: &OpType,
) -> Result<Option<Arc<Extension>>, ExtensionCollectionError> {
let OpType::ExtensionOp(ext_op) = op else {
// TODO: Extract the extension when the operation is a `Const`.
// https://github.com/CQCL/hugr/issues/1742
return Ok(None);
};
let ext = ext_op.def().extension();
Expand Down
163 changes: 158 additions & 5 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
//! Tests for extension resolution.
use core::panic;
use std::sync::Arc;

use itertools::Itertools;
use rstest::rstest;

use crate::extension::resolution::{update_op_extensions, update_op_types_extensions};
use crate::extension::ExtensionRegistry;
use crate::ops::{Input, OpType};
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use crate::extension::prelude::{bool_t, ConstUsize};
use crate::extension::resolution::{
collect_op_extensions, collect_op_types_extensions, update_op_extensions,
update_op_types_extensions,
};
use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet};
use crate::ops::{CallIndirect, ExtensionOp, Input, OpType, Tag, Value};
use crate::std_extensions::arithmetic::float_types::float64_type;
use crate::std_extensions::arithmetic::int_ops;
use crate::std_extensions::arithmetic::int_types;
use crate::type_row;
use crate::std_extensions::arithmetic::int_types::{self, int_type};
use crate::types::{Signature, Type};
use crate::{type_row, Extension, HugrView};

#[rstest]
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
Expand Down Expand Up @@ -47,3 +58,145 @@ fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: Ex
"{deser_extensions} != {extensions}"
);
}

/// Create a new test extension with a single operation.
///
/// Returns an instance of the defined op.
fn make_extension(name: &str, op_name: &str) -> (Arc<Extension>, OpType) {
let ext = Extension::new_test_arc(ExtensionId::new_unchecked(name), |ext, extension_ref| {
ext.add_op(
op_name.into(),
"".to_string(),
Signature::new_endo(vec![bool_t()]),
extension_ref,
)
.unwrap();
});
let op_def = ext.get_op(op_name).unwrap();
let op = ExtensionOp::new(op_def.clone(), vec![], &ExtensionRegistry::default()).unwrap();
(ext, op.into())
}

/// Build a hugr with all possible op nodes and resolve the extensions.
#[rstest]
fn resolve_hugr_extensions() {
let (ext_a, op_a) = make_extension("dummy.a", "op_a");
let (ext_b, op_b) = make_extension("dummy.b", "op_b");
let (ext_c, op_c) = make_extension("dummy.c", "op_c");
let (ext_d, op_d) = make_extension("dummy.d", "op_d");

let mut module = ModuleBuilder::new();

// A constant op using the prelude extension.
module.add_constant(Value::extension(ConstUsize::new(42)));

// A function declaration using the floats extension in its signature.
let decl = module
.declare(
"dummy_declaration",
Signature::new_endo(vec![float64_type()]).into(),
)
.unwrap();

// A function definition using the int_types and float_types extension in its body.
let mut func = module
.define_function(
"dummy_fn",
Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta(
[ext_a.name(), ext_b.name(), ext_c.name(), ext_d.name()]
.into_iter()
.cloned()
.collect::<ExtensionSet>(),
),
)
.unwrap();
let [func_i0, func_i1] = func.input_wires_arr();

// Call the function declaration directly, and load & call indirectly.
func.call(&decl, &[], vec![func_i0]).unwrap();
let loaded_func = func.load_func(&decl, &[]).unwrap();
func.add_dataflow_op(
CallIndirect {
signature: Signature::new_endo(vec![float64_type()]),
},
vec![loaded_func, func_i0],
)
.unwrap();

// Add one of the custom ops.
func.add_dataflow_op(op_a, vec![func_i1]).unwrap();

// A nested dataflow region.
let mut dfg = func.dfg_builder_endo([(bool_t(), func_i1)]).unwrap();
let dfg_inputs = dfg.input_wires().collect_vec();
dfg.add_dataflow_op(op_b, dfg_inputs.clone()).unwrap();
dfg.finish_with_outputs(dfg_inputs).unwrap();

// A tag
func.add_dataflow_op(
Tag::new(0, vec![vec![bool_t()].into(), vec![int_type(4)].into()]),
vec![func_i1],
)
.unwrap();

// Dfg control flow.
let mut tail_loop = func
.tail_loop_builder([(bool_t(), func_i1)], [], vec![].into())
.unwrap();
let tl_inputs = tail_loop.input_wires().collect_vec();
tail_loop.add_dataflow_op(op_c, tl_inputs).unwrap();
let tl_tag = tail_loop.add_load_const(Value::true_val());
let tl_tag = tail_loop
.add_dataflow_op(
Tag::new(0, vec![vec![Type::new_unit_sum(2)].into(), vec![].into()]),
vec![tl_tag],
)
.unwrap()
.out_wire(0);
tail_loop.finish_with_outputs(tl_tag, vec![]).unwrap();

// Cfg control flow.
let mut cfg = func
.cfg_builder([(bool_t(), func_i1)], vec![].into())
.unwrap();
let mut cfg_entry = cfg.entry_builder([type_row![]], type_row![]).unwrap();
let [cfg_i0] = cfg_entry.input_wires_arr();
cfg_entry.add_dataflow_op(op_d, [cfg_i0]).unwrap();
let cfg_tag = cfg_entry.add_load_const(Value::unary_unit_sum());
let cfg_entry_wire = cfg_entry.finish_with_outputs(cfg_tag, []).unwrap();
let cfg_exit = cfg.exit_block();
cfg.branch(&cfg_entry_wire, 0, &cfg_exit).unwrap();

// --------------------------------------------------

// Finally, finish the hugr and ensure it's using the right extensions.
func.finish_with_outputs(vec![]).unwrap();
let mut hugr = module.finish_hugr().unwrap_or_else(|e| panic!("{e}"));

let build_extensions = hugr.extensions().clone();
assert!(build_extensions.contains(ext_a.name()));
assert!(build_extensions.contains(ext_b.name()));
assert!(build_extensions.contains(ext_c.name()));
assert!(build_extensions.contains(ext_d.name()));

// Check that the read-only methods collect the same extensions.
let mut collected_exts = ExtensionRegistry::default();
for node in hugr.nodes() {
let op = hugr.get_optype(node);
collected_exts.extend(collect_op_extensions(Some(node), op).unwrap());
collected_exts.extend(collect_op_types_extensions(Some(node), op).unwrap());
}
assert_eq!(
collected_exts, build_extensions,
"{collected_exts} != {build_extensions}"
);

// Check that the mutable methods collect the same extensions.
hugr.resolve_extension_defs(&build_extensions).unwrap();
assert_eq!(
hugr.extensions(),
&build_extensions,
"{} != {build_extensions}",
hugr.extensions()
);
}
7 changes: 4 additions & 3 deletions hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ pub fn collect_op_types_extensions(
}
OpType::FuncDefn(f) => collect_signature_exts(f.signature.body(), &mut used, &mut missing),
OpType::FuncDecl(f) => collect_signature_exts(f.signature.body(), &mut used, &mut missing),
OpType::Const(_c) => {
// TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value
OpType::Const(c) => {
let typ = c.get_type();
collect_type_exts(&typ, &mut used, &mut missing);
}
OpType::Input(inp) => collect_type_row_exts(&inp.types, &mut used, &mut missing),
OpType::Output(out) => collect_type_row_exts(&out.types, &mut used, &mut missing),
Expand Down Expand Up @@ -153,7 +154,7 @@ fn collect_type_row_exts<RV: MaybeRV>(
/// - `used_extensions`: A The registry where to store the used extensions.
/// - `missing_extensions`: A set of `ExtensionId`s of which the
/// `Weak<Extension>` pointer has been invalidated.
fn collect_type_exts<RV: MaybeRV>(
pub(super) fn collect_type_exts<RV: MaybeRV>(
typ: &TypeBase<RV>,
used_extensions: &mut ExtensionRegistry,
missing_extensions: &mut HashSet<ExtensionId>,
Expand Down
14 changes: 12 additions & 2 deletions hugr-core/src/extension/resolution/types_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
//!
//! For a non-mutating option see [`super::collect_op_types_extensions`].
use std::collections::HashSet;
use std::sync::Arc;

use super::types::collect_type_exts;
use super::{ExtensionRegistry, ExtensionResolutionError};
use crate::ops::OpType;
use crate::types::type_row::TypeRowBase;
Expand Down Expand Up @@ -35,8 +37,16 @@ pub fn update_op_types_extensions(
OpType::FuncDecl(f) => {
update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)?
}
OpType::Const(_c) => {
// TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value?
OpType::Const(c) => {
let typ = c.get_type();
let mut missing = HashSet::new();
collect_type_exts(&typ, used_extensions, &mut missing);
// We expect that the `CustomConst::get_type` binary calls always return valid extensions.
// As we cannot update the `CustomConst` type, we ignore the result.
//
// Some exotic consts may need https://github.com/CQCL/hugr/issues/1742 to be implemented
// to pass this test.
//assert!(missing.is_empty());
}
OpType::Input(inp) => {
update_type_row_exts(node, &mut inp.types, extensions, used_extensions)?
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use ident::{IdentList, InvalidIdentifier};
pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError};

use portgraph::multiportgraph::MultiPortGraph;
use portgraph::{Hierarchy, PortMut, UnmanagedDenseMap};
use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap};
use thiserror::Error;

pub use self::views::{HugrView, RootTagged};
Expand Down Expand Up @@ -213,7 +213,7 @@ impl Hugr {
//
// This is not something we want to expose it the API, so we manually
// iterate instead of writing it as a method.
for n in 0..self.node_count() {
for n in 0..self.graph.node_capacity() {
let pg_node = portgraph::NodeIndex::new(n);
let node: Node = pg_node.into();
if !self.contains_node(node) {
Expand Down
23 changes: 16 additions & 7 deletions hugr-core/src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,21 @@ impl ChildrenValidationError {
#[non_exhaustive]
pub enum EdgeValidationError {
/// The dataflow signature of two connected basic blocks does not match.
#[error("The dataflow signature of two connected basic blocks does not match. Output signature: {source_op}, input signature: {target_op}",
source_op = edge.source_op,
target_op = edge.target_op
#[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}",
source_ty = source_types.clone().unwrap_or_default(),
)]
CFGEdgeSignatureMismatch { edge: ChildrenEdgeData },
CFGEdgeSignatureMismatch {
edge: ChildrenEdgeData,
source_types: Option<TypeRow>,
target_types: TypeRow,
},
}

impl EdgeValidationError {
/// Returns information on the edge that caused the error.
pub fn edge(&self) -> &ChildrenEdgeData {
match self {
EdgeValidationError::CFGEdgeSignatureMismatch { edge } => edge,
EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge,
}
}
}
Expand Down Expand Up @@ -342,8 +345,14 @@ fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError>
_ => panic!("CFG sibling graphs can only contain basic block operations."),
};

if source.successor_input(edge.source_port.index()).as_ref() != Some(target_input) {
return Err(EdgeValidationError::CFGEdgeSignatureMismatch { edge });
let source_types = source.successor_input(edge.source_port.index());
if source_types.as_ref() != Some(target_input) {
let target_types = target_input.clone();
return Err(EdgeValidationError::CFGEdgeSignatureMismatch {
edge,
source_types,
target_types,
});
}

Ok(())
Expand Down

0 comments on commit 3be18e9

Please sign in to comment.