diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs index f3e8ffda3..4f7ddc5a4 100644 --- a/hugr-core/src/extension/resolution/ops.rs +++ b/hugr-core/src/extension/resolution/ops.rs @@ -30,6 +30,8 @@ pub(crate) fn collect_op_extensions( op: &OpType, ) -> Result>, ExtensionCollectionError> { let OpType::ExtensionOp(ext_op) = op else { + // TODO: Extract the extension when the operation is a `Const`. + // https://github.com/CQCL/hugr/issues/1742 return Ok(None); }; let ext = ext_op.def().extension(); diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 73ebf994e..71285f3d9 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -1,13 +1,24 @@ //! Tests for extension resolution. +use core::panic; +use std::sync::Arc; + +use itertools::Itertools; use rstest::rstest; -use crate::extension::resolution::{update_op_extensions, update_op_types_extensions}; -use crate::extension::ExtensionRegistry; -use crate::ops::{Input, OpType}; +use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; +use crate::extension::prelude::{bool_t, ConstUsize}; +use crate::extension::resolution::{ + collect_op_extensions, collect_op_types_extensions, update_op_extensions, + update_op_types_extensions, +}; +use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; +use crate::ops::{CallIndirect, ExtensionOp, Input, OpType, Tag, Value}; +use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::arithmetic::int_ops; -use crate::std_extensions::arithmetic::int_types; -use crate::type_row; +use crate::std_extensions::arithmetic::int_types::{self, int_type}; +use crate::types::{Signature, Type}; +use crate::{type_row, Extension, HugrView}; #[rstest] #[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] @@ -47,3 +58,145 @@ fn resolve_type_extensions(#[case] op: impl Into, #[case] extensions: Ex "{deser_extensions} != {extensions}" ); } + +/// Create a new test extension with a single operation. +/// +/// Returns an instance of the defined op. +fn make_extension(name: &str, op_name: &str) -> (Arc, OpType) { + let ext = Extension::new_test_arc(ExtensionId::new_unchecked(name), |ext, extension_ref| { + ext.add_op( + op_name.into(), + "".to_string(), + Signature::new_endo(vec![bool_t()]), + extension_ref, + ) + .unwrap(); + }); + let op_def = ext.get_op(op_name).unwrap(); + let op = ExtensionOp::new(op_def.clone(), vec![], &ExtensionRegistry::default()).unwrap(); + (ext, op.into()) +} + +/// Build a hugr with all possible op nodes and resolve the extensions. +#[rstest] +fn resolve_hugr_extensions() { + let (ext_a, op_a) = make_extension("dummy.a", "op_a"); + let (ext_b, op_b) = make_extension("dummy.b", "op_b"); + let (ext_c, op_c) = make_extension("dummy.c", "op_c"); + let (ext_d, op_d) = make_extension("dummy.d", "op_d"); + + let mut module = ModuleBuilder::new(); + + // A constant op using the prelude extension. + module.add_constant(Value::extension(ConstUsize::new(42))); + + // A function declaration using the floats extension in its signature. + let decl = module + .declare( + "dummy_declaration", + Signature::new_endo(vec![float64_type()]).into(), + ) + .unwrap(); + + // A function definition using the int_types and float_types extension in its body. + let mut func = module + .define_function( + "dummy_fn", + Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( + [ext_a.name(), ext_b.name(), ext_c.name(), ext_d.name()] + .into_iter() + .cloned() + .collect::(), + ), + ) + .unwrap(); + let [func_i0, func_i1] = func.input_wires_arr(); + + // Call the function declaration directly, and load & call indirectly. + func.call(&decl, &[], vec![func_i0]).unwrap(); + let loaded_func = func.load_func(&decl, &[]).unwrap(); + func.add_dataflow_op( + CallIndirect { + signature: Signature::new_endo(vec![float64_type()]), + }, + vec![loaded_func, func_i0], + ) + .unwrap(); + + // Add one of the custom ops. + func.add_dataflow_op(op_a, vec![func_i1]).unwrap(); + + // A nested dataflow region. + let mut dfg = func.dfg_builder_endo([(bool_t(), func_i1)]).unwrap(); + let dfg_inputs = dfg.input_wires().collect_vec(); + dfg.add_dataflow_op(op_b, dfg_inputs.clone()).unwrap(); + dfg.finish_with_outputs(dfg_inputs).unwrap(); + + // A tag + func.add_dataflow_op( + Tag::new(0, vec![vec![bool_t()].into(), vec![int_type(4)].into()]), + vec![func_i1], + ) + .unwrap(); + + // Dfg control flow. + let mut tail_loop = func + .tail_loop_builder([(bool_t(), func_i1)], [], vec![].into()) + .unwrap(); + let tl_inputs = tail_loop.input_wires().collect_vec(); + tail_loop.add_dataflow_op(op_c, tl_inputs).unwrap(); + let tl_tag = tail_loop.add_load_const(Value::true_val()); + let tl_tag = tail_loop + .add_dataflow_op( + Tag::new(0, vec![vec![Type::new_unit_sum(2)].into(), vec![].into()]), + vec![tl_tag], + ) + .unwrap() + .out_wire(0); + tail_loop.finish_with_outputs(tl_tag, vec![]).unwrap(); + + // Cfg control flow. + let mut cfg = func + .cfg_builder([(bool_t(), func_i1)], vec![].into()) + .unwrap(); + let mut cfg_entry = cfg.entry_builder([type_row![]], type_row![]).unwrap(); + let [cfg_i0] = cfg_entry.input_wires_arr(); + cfg_entry.add_dataflow_op(op_d, [cfg_i0]).unwrap(); + let cfg_tag = cfg_entry.add_load_const(Value::unary_unit_sum()); + let cfg_entry_wire = cfg_entry.finish_with_outputs(cfg_tag, []).unwrap(); + let cfg_exit = cfg.exit_block(); + cfg.branch(&cfg_entry_wire, 0, &cfg_exit).unwrap(); + + // -------------------------------------------------- + + // Finally, finish the hugr and ensure it's using the right extensions. + func.finish_with_outputs(vec![]).unwrap(); + let mut hugr = module.finish_hugr().unwrap_or_else(|e| panic!("{e}")); + + let build_extensions = hugr.extensions().clone(); + assert!(build_extensions.contains(ext_a.name())); + assert!(build_extensions.contains(ext_b.name())); + assert!(build_extensions.contains(ext_c.name())); + assert!(build_extensions.contains(ext_d.name())); + + // Check that the read-only methods collect the same extensions. + let mut collected_exts = ExtensionRegistry::default(); + for node in hugr.nodes() { + let op = hugr.get_optype(node); + collected_exts.extend(collect_op_extensions(Some(node), op).unwrap()); + collected_exts.extend(collect_op_types_extensions(Some(node), op).unwrap()); + } + assert_eq!( + collected_exts, build_extensions, + "{collected_exts} != {build_extensions}" + ); + + // Check that the mutable methods collect the same extensions. + hugr.resolve_extension_defs(&build_extensions).unwrap(); + assert_eq!( + hugr.extensions(), + &build_extensions, + "{} != {build_extensions}", + hugr.extensions() + ); +} diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index a970447f8..56f87850a 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -43,8 +43,9 @@ pub fn collect_op_types_extensions( } OpType::FuncDefn(f) => collect_signature_exts(f.signature.body(), &mut used, &mut missing), OpType::FuncDecl(f) => collect_signature_exts(f.signature.body(), &mut used, &mut missing), - OpType::Const(_c) => { - // TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value + OpType::Const(c) => { + let typ = c.get_type(); + collect_type_exts(&typ, &mut used, &mut missing); } OpType::Input(inp) => collect_type_row_exts(&inp.types, &mut used, &mut missing), OpType::Output(out) => collect_type_row_exts(&out.types, &mut used, &mut missing), @@ -153,7 +154,7 @@ fn collect_type_row_exts( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -fn collect_type_exts( +pub(super) fn collect_type_exts( typ: &TypeBase, used_extensions: &mut ExtensionRegistry, missing_extensions: &mut HashSet, diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index c038c1eab..ad0b009c8 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -3,8 +3,10 @@ //! //! For a non-mutating option see [`super::collect_op_types_extensions`]. +use std::collections::HashSet; use std::sync::Arc; +use super::types::collect_type_exts; use super::{ExtensionRegistry, ExtensionResolutionError}; use crate::ops::OpType; use crate::types::type_row::TypeRowBase; @@ -35,8 +37,16 @@ pub fn update_op_types_extensions( OpType::FuncDecl(f) => { update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? } - OpType::Const(_c) => { - // TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value? + OpType::Const(c) => { + let typ = c.get_type(); + let mut missing = HashSet::new(); + collect_type_exts(&typ, used_extensions, &mut missing); + // We expect that the `CustomConst::get_type` binary calls always return valid extensions. + // As we cannot update the `CustomConst` type, we ignore the result. + // + // Some exotic consts may need https://github.com/CQCL/hugr/issues/1742 to be implemented + // to pass this test. + //assert!(missing.is_empty()); } OpType::Input(inp) => { update_type_row_exts(node, &mut inp.types, extensions, used_extensions)? diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index b42622745..6fdbce98c 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -19,7 +19,7 @@ pub use ident::{IdentList, InvalidIdentifier}; pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError}; use portgraph::multiportgraph::MultiPortGraph; -use portgraph::{Hierarchy, PortMut, UnmanagedDenseMap}; +use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; use thiserror::Error; pub use self::views::{HugrView, RootTagged}; @@ -213,7 +213,7 @@ impl Hugr { // // This is not something we want to expose it the API, so we manually // iterate instead of writing it as a method. - for n in 0..self.node_count() { + for n in 0..self.graph.node_capacity() { let pg_node = portgraph::NodeIndex::new(n); let node: Node = pg_node.into(); if !self.contains_node(node) { diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index 16eed59e6..e12e29445 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -215,18 +215,21 @@ impl ChildrenValidationError { #[non_exhaustive] pub enum EdgeValidationError { /// The dataflow signature of two connected basic blocks does not match. - #[error("The dataflow signature of two connected basic blocks does not match. Output signature: {source_op}, input signature: {target_op}", - source_op = edge.source_op, - target_op = edge.target_op + #[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}", + source_ty = source_types.clone().unwrap_or_default(), )] - CFGEdgeSignatureMismatch { edge: ChildrenEdgeData }, + CFGEdgeSignatureMismatch { + edge: ChildrenEdgeData, + source_types: Option, + target_types: TypeRow, + }, } impl EdgeValidationError { /// Returns information on the edge that caused the error. pub fn edge(&self) -> &ChildrenEdgeData { match self { - EdgeValidationError::CFGEdgeSignatureMismatch { edge } => edge, + EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge, } } } @@ -342,8 +345,14 @@ fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> _ => panic!("CFG sibling graphs can only contain basic block operations."), }; - if source.successor_input(edge.source_port.index()).as_ref() != Some(target_input) { - return Err(EdgeValidationError::CFGEdgeSignatureMismatch { edge }); + let source_types = source.successor_input(edge.source_port.index()); + if source_types.as_ref() != Some(target_input) { + let target_types = target_input.clone(); + return Err(EdgeValidationError::CFGEdgeSignatureMismatch { + edge, + source_types, + target_types, + }); } Ok(())