From f2215e930d111856217789c811d4bad12cc4e5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:56:16 +0000 Subject: [PATCH] feat!: Hugrs now keep a `ExtensionRegistry` with their requirements (#1738) Closes #1613. Depends on #1739. Hugrs now keep an `Extensions` registry that is automatically computed when using the builder or deserializing from json (using the new `Hugr::load_json`). This set can contain unneeded extensions, but it should always be sufficient to validate the HUGR definition. Note that this is **not** runtime extensions (see #426, #1734). A big chunk of the diff is removing the extension registry when finishing building a hugr. The extension tracking is now done automatically while adding operations. drive-by: Remove unneeded `set_num_ports` call in `insert_hugr_internal`. BREAKING CHANGE: Removed `update_validate`. The hugr extensions should be resolved at load time, so we can use `validate` instead. BREAKING CHANGE: The builder `finish_hugr` function family no longer takes a registry as parameter, and the `_prelude` variants have been removed. --- hugr-cli/src/lib.rs | 52 ++- hugr-cli/src/main.rs | 16 +- hugr-cli/src/mermaid.rs | 7 +- hugr-cli/src/validate.rs | 27 +- hugr-cli/tests/validate.rs | 12 +- hugr-core/src/builder.rs | 15 +- hugr-core/src/builder/build_traits.rs | 80 +++-- hugr-core/src/builder/cfg.rs | 37 +- hugr-core/src/builder/circuit.rs | 2 +- hugr-core/src/builder/conditional.rs | 14 +- hugr-core/src/builder/dataflow.rs | 51 ++- hugr-core/src/builder/module.rs | 31 +- hugr-core/src/builder/tail_loop.rs | 4 +- hugr-core/src/extension/op_def.rs | 2 +- hugr-core/src/extension/prelude.rs | 16 +- hugr-core/src/extension/prelude/array.rs | 2 +- .../src/extension/prelude/unwrap_builder.rs | 2 +- hugr-core/src/extension/resolution/test.rs | 61 +--- hugr-core/src/hugr.rs | 113 +++--- hugr-core/src/hugr/hugrmut.rs | 68 +++- hugr-core/src/hugr/rewrite/consts.rs | 6 +- hugr-core/src/hugr/rewrite/inline_dfg.rs | 30 +- hugr-core/src/hugr/rewrite/insert_identity.rs | 3 +- hugr-core/src/hugr/rewrite/outline_cfg.rs | 9 +- hugr-core/src/hugr/rewrite/replace.rs | 8 +- hugr-core/src/hugr/rewrite/simple_replace.rs | 51 +-- hugr-core/src/hugr/serialize/test.rs | 45 +-- hugr-core/src/hugr/serialize/upgrade/test.rs | 5 +- hugr-core/src/hugr/validate.rs | 51 +-- hugr-core/src/hugr/validate/test.rs | 164 ++++----- hugr-core/src/hugr/views.rs | 22 +- hugr-core/src/hugr/views/descendants.rs | 5 +- hugr-core/src/hugr/views/sibling.rs | 14 +- hugr-core/src/hugr/views/sibling_subgraph.rs | 29 +- hugr-core/src/hugr/views/tests.rs | 22 +- hugr-core/src/ops/constant.rs | 13 +- hugr-core/src/ops/custom.rs | 3 + hugr-core/src/ops/validate.rs | 23 +- hugr-core/src/package.rs | 246 ++++++------- hugr-core/src/std_extensions/ptr.rs | 9 +- hugr-llvm/src/emit/test.rs | 30 +- hugr-llvm/src/extension/prelude/array.rs | 12 +- hugr-llvm/src/utils/array_op_builder.rs | 2 +- .../src/utils/inline_constant_functions.rs | 8 +- hugr-passes/src/const_fold.rs | 74 ++-- hugr-passes/src/const_fold/test.rs | 334 ++++++------------ hugr-passes/src/dataflow/test.rs | 41 +-- hugr-passes/src/force_order.rs | 15 +- hugr-passes/src/lower.rs | 4 +- hugr-passes/src/merge_bbs.rs | 18 +- hugr-passes/src/nest_cfgs.rs | 11 +- hugr-passes/src/non_local.rs | 13 +- hugr-passes/src/validation.rs | 23 +- hugr/benches/benchmarks/hugr/examples.rs | 10 +- hugr/src/hugr.rs | 2 +- hugr/src/lib.rs | 4 +- 56 files changed, 871 insertions(+), 1100 deletions(-) diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index 24cc40e2b..9ce393bb3 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -5,8 +5,9 @@ use clap_verbosity_flag::{InfoLevel, Verbosity}; use clio::Input; use derive_more::{Display, Error, From}; use hugr::extension::ExtensionRegistry; -use hugr::package::PackageValidationError; +use hugr::package::{PackageEncodingError, PackageValidationError}; use hugr::Hugr; +use std::io::{Cursor, Read, Seek, SeekFrom}; use std::{ffi::OsString, path::PathBuf}; pub mod extensions; @@ -46,6 +47,9 @@ pub enum CliError { /// Error parsing input. #[display("Error parsing package: {_0}")] Parse(serde_json::Error), + /// Hugr load error. + #[display("Error parsing package: {_0}")] + HUGRLoad(PackageEncodingError), #[display("Error validating HUGR: {_0}")] /// Errors produced by the `validate` subcommand. Validate(PackageValidationError), @@ -96,15 +100,10 @@ impl PackageOrHugr { } /// Validates the package or hugr. - /// - /// Updates the extension registry with any new extensions defined in the package. - pub fn update_validate( - &mut self, - reg: &mut ExtensionRegistry, - ) -> Result<(), PackageValidationError> { + pub fn validate(&self) -> Result<(), PackageValidationError> { match self { - PackageOrHugr::Package(pkg) => pkg.update_validate(reg), - PackageOrHugr::Hugr(hugr) => hugr.update_validate(reg).map_err(Into::into), + PackageOrHugr::Package(pkg) => pkg.validate(), + PackageOrHugr::Hugr(hugr) => Ok(hugr.validate()?), } } } @@ -120,13 +119,21 @@ impl AsRef<[Hugr]> for PackageOrHugr { impl HugrArgs { /// Read either a package or a single hugr from the input. - pub fn get_package_or_hugr(&mut self) -> Result { - let val: serde_json::Value = serde_json::from_reader(&mut self.input)?; - if let Ok(hugr) = serde_json::from_value::(val.clone()) { - return Ok(PackageOrHugr::Hugr(hugr)); + pub fn get_package_or_hugr( + &mut self, + extensions: &ExtensionRegistry, + ) -> Result { + // We need to read the input twice; once to try to load it as a HUGR, and if that fails, as a package. + // If `input` is a file, we can reuse the reader by seeking back to the start. + // Else, we need to read the file into a buffer. + match self.input.can_seek() { + true => get_package_or_hugr_seek(&mut self.input, extensions), + false => { + let mut buffer = Vec::new(); + self.input.read_to_end(&mut buffer)?; + get_package_or_hugr_seek(Cursor::new(buffer), extensions) + } } - let pkg = serde_json::from_value::(val.clone())?; - Ok(PackageOrHugr::Package(pkg)) } /// Read either a package from the input. @@ -142,3 +149,18 @@ impl HugrArgs { Ok(pkg) } } + +/// Load a package or hugr from a seekable input. +fn get_package_or_hugr_seek( + mut input: I, + extensions: &ExtensionRegistry, +) -> Result { + match Hugr::load_json(&mut input, extensions) { + Ok(hugr) => Ok(PackageOrHugr::Hugr(hugr)), + Err(_) => { + input.seek(SeekFrom::Start(0))?; + let pkg = Package::from_json_reader(input, extensions)?; + Ok(PackageOrHugr::Package(pkg)) + } + } +} diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 445930b24..4b489a952 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -2,7 +2,7 @@ use clap::Parser as _; -use hugr_cli::{validate, CliArgs}; +use hugr_cli::{mermaid, validate, CliArgs}; use clap_verbosity_flag::log::Level; @@ -10,7 +10,7 @@ fn main() { match CliArgs::parse() { CliArgs::Validate(args) => run_validate(args), CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), - CliArgs::Mermaid(mut args) => args.run_print().unwrap(), + CliArgs::Mermaid(args) => run_mermaid(args), CliArgs::External(_) => { // TODO: Implement support for external commands. // Running `hugr COMMAND` would look for `hugr-COMMAND` in the path @@ -36,3 +36,15 @@ fn run_validate(mut args: validate::ValArgs) { std::process::exit(1); } } + +/// Run the `mermaid` subcommand. +fn run_mermaid(mut args: mermaid::MermaidArgs) { + let result = args.run_print(); + + if let Err(e) = result { + if args.hugr_args.verbosity(Level::Error) { + eprintln!("{}", e); + } + std::process::exit(1); + } +} diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs index ee3bbabd1..5974dbd16 100644 --- a/hugr-cli/src/mermaid.rs +++ b/hugr-cli/src/mermaid.rs @@ -30,9 +30,12 @@ impl MermaidArgs { /// Write the mermaid diagram to the output. pub fn run_print(&mut self) -> Result<(), crate::CliError> { let hugrs = if self.validate { - self.hugr_args.validate()?.0 + self.hugr_args.validate()? } else { - self.hugr_args.get_package_or_hugr()?.into_hugrs() + let extensions = self.hugr_args.extensions()?; + self.hugr_args + .get_package_or_hugr(&extensions)? + .into_hugrs() }; for hugr in hugrs { diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index 5205006d7..a47f73191 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -6,14 +6,6 @@ use hugr::{extension::ExtensionRegistry, Extension, Hugr}; use crate::{CliError, HugrArgs}; -// TODO: Deprecated re-export. Remove on a breaking release. -#[doc(inline)] -#[deprecated( - since = "0.13.2", - note = "Use `hugr::package::PackageValidationError` instead." -)] -pub use hugr::package::PackageValidationError as ValError; - /// Validate and visualise a HUGR file. #[derive(Parser, Debug)] #[clap(version = "1.0", long_about = None)] @@ -31,7 +23,7 @@ pub const VALID_PRINT: &str = "HUGR valid!"; impl ValArgs { /// Run the HUGR cli and validate against an extension registry. - pub fn run(&mut self) -> Result<(Vec, ExtensionRegistry), CliError> { + pub fn run(&mut self) -> Result, CliError> { let result = self.hugr_args.validate()?; if self.verbosity(Level::Info) { eprintln!("{}", VALID_PRINT); @@ -50,24 +42,29 @@ impl HugrArgs { /// /// Returns the validated modules and the extension registry the modules /// were validated against. - pub fn validate(&mut self) -> Result<(Vec, ExtensionRegistry), CliError> { - let mut package = self.get_package_or_hugr()?; + pub fn validate(&mut self) -> Result, CliError> { + let reg = self.extensions()?; + let package = self.get_package_or_hugr(®)?; + + package.validate()?; + Ok(package.into_hugrs()) + } - let mut reg: ExtensionRegistry = if self.no_std { + /// Return a register with the selected extensions. + pub fn extensions(&self) -> Result { + let mut reg = if self.no_std { hugr::extension::PRELUDE_REGISTRY.to_owned() } else { hugr::std_extensions::STD_REG.to_owned() }; - // register external extensions for ext in &self.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; reg.register_updated(ext); } - package.update_validate(&mut reg)?; - Ok((package.into_hugrs(), reg)) + Ok(reg) } /// Test whether a `level` message should be output. diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index f3fc35f83..be85f67bd 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -4,8 +4,6 @@ //! calling the CLI binary, which Miri doesn't support. #![cfg(all(test, not(miri)))] -use std::sync::Arc; - use assert_cmd::Command; use assert_fs::{fixture::FileWriteStr, NamedTempFile}; use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder}; @@ -49,9 +47,7 @@ fn test_package(#[default(bool_t())] id_type: Type) -> Package { df.finish_with_outputs([i]).unwrap(); let hugr = module.hugr().clone(); // unvalidated - let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap(); - let float_ext: Arc = serde_json::from_reader(rdr).unwrap(); - Package::new(vec![hugr], vec![float_ext]).unwrap() + Package::new(vec![hugr]).unwrap() } /// A DFG-rooted HUGR. @@ -130,7 +126,9 @@ fn test_mermaid_invalid(bad_hugr_string: String, mut cmd: Command) { cmd.arg("mermaid"); cmd.arg("--validate"); cmd.write_stdin(bad_hugr_string); - cmd.assert().failure().stderr(contains("UnconnectedPort")); + cmd.assert() + .failure() + .stderr(contains("has an unconnected port")); } #[rstest] @@ -141,7 +139,7 @@ fn test_bad_hugr(bad_hugr_string: String, mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains("Error validating HUGR").and(contains("unconnected port"))); + .stderr(contains("Node(1)").and(contains("unconnected port"))); } #[rstest] diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index a02c38816..ae04d00b4 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -75,11 +75,11 @@ //! // Finish building the HUGR, consuming the builder. //! // //! // Requires a registry with all the extensions used in the module. -//! module_builder.finish_hugr(&LOGIC_REG) +//! module_builder.finish_hugr() //! }?; //! //! // The built HUGR is always valid. -//! hugr.validate(&LOGIC_REG).unwrap_or_else(|e| { +//! hugr.validate().unwrap_or_else(|e| { //! panic!("HUGR validation failed: {e}"); //! }); //! # Ok(()) @@ -242,7 +242,6 @@ pub(crate) mod test { use crate::hugr::{views::HugrView, HugrMut}; use crate::ops; use crate::types::{PolyFuncType, Signature}; - use crate::utils::test_quantum_extension; use crate::Hugr; use super::handle::BuildHandle; @@ -269,14 +268,14 @@ pub(crate) mod test { f(f_builder)?; - Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?) + Ok(module_builder.finish_hugr()?) } #[fixture] pub(crate) fn simple_dfg_hugr() -> Hugr { let dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); - dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() + dfg_builder.finish_hugr_with_outputs([i1]).unwrap() } #[fixture] @@ -284,7 +283,7 @@ pub(crate) mod test { let fn_builder = FunctionBuilder::new("test", Signature::new(vec![bool_t()], vec![bool_t()])).unwrap(); let [i1] = fn_builder.input_wires_arr(); - fn_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() + fn_builder.finish_hugr_with_outputs([i1]).unwrap() } #[fixture] @@ -292,7 +291,7 @@ pub(crate) mod test { let mut builder = ModuleBuilder::new(); let sig = Signature::new(vec![bool_t()], vec![bool_t()]); builder.declare("test", sig.into()).unwrap(); - builder.finish_prelude_hugr().unwrap() + builder.finish_hugr().unwrap() } #[fixture] @@ -300,7 +299,7 @@ pub(crate) mod test { let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); super::cfg::test::build_basic_cfg(&mut cfg_builder).unwrap(); - cfg_builder.finish_prelude_hugr().unwrap() + cfg_builder.finish_hugr().unwrap() } /// A helper method which creates a DFG rooted hugr with Input and Output node diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index cb05f2552..44f5a00ac 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -4,9 +4,10 @@ use crate::hugr::views::HugrView; use crate::hugr::{NodeMetadata, ValidationError}; use crate::ops::{self, OpTag, OpTrait, OpType, Tag, TailLoop}; use crate::utils::collect_array; -use crate::{IncomingPort, Node, OutgoingPort}; +use crate::{Extension, IncomingPort, Node, OutgoingPort}; use std::iter; +use std::sync::Arc; use super::{ handle::{BuildHandle, Outputs}, @@ -19,7 +20,7 @@ use crate::{ types::EdgeKind, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED}; +use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -45,7 +46,17 @@ pub trait Container { /// Immutable reference to HUGR being built fn hugr(&self) -> &Hugr; /// Add an [`OpType`] as the final child of the container. + /// + /// Adds the extensions required by the op to the HUGR, if they are not already present. fn add_child_node(&mut self, node: impl Into) -> Node { + let node: OpType = node.into(); + + // Add the extension the operation is defined in to the HUGR. + let used_extensions = node + .used_extensions() + .unwrap_or_else(|e| panic!("Build-time signatures should have valid extensions. {e}")); + self.use_extensions(used_extensions); + let parent = self.container_node(); self.hugr_mut().add_node_with_parent(parent, node) } @@ -61,6 +72,8 @@ pub trait Container { /// Add a constant value to the container and return a handle to it. /// + /// Adds the extensions required by the op to the HUGR, if they are not already present. + /// /// # Errors /// /// This function will return an error if there is an error in adding the @@ -88,6 +101,13 @@ pub trait Container { signature, }); + // Add the extensions used by the function types. + self.use_extensions( + body.used_extensions().unwrap_or_else(|e| { + panic!("Build-time signatures should have valid extensions. {e}") + }), + ); + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) } @@ -122,24 +142,26 @@ pub trait Container { ) { self.hugr_mut().set_metadata(child, key, meta); } + + /// Add an extension to the set of extensions used by the hugr. + fn use_extension(&mut self, ext: impl Into>) { + self.hugr_mut().use_extension(ext); + } + + /// Extend the set of extensions used by the hugr with the extensions in the registry. + fn use_extensions(&mut self, registry: impl IntoIterator) + where + ExtensionRegistry: Extend, + { + self.hugr_mut().extensions_mut().extend(registry); + } } /// Types implementing this trait can be used to build complete HUGRs /// (with varying root node types) pub trait HugrBuilder: Container { /// Finish building the HUGR, perform any validation checks and return it. - fn finish_hugr(self, extension_registry: &ExtensionRegistry) -> Result; - - /// Finish building the HUGR (as [HugrBuilder::finish_hugr]), - /// validating against the [prelude] extension only - /// - /// [prelude]: crate::extension::prelude - fn finish_prelude_hugr(self) -> Result - where - Self: Sized, - { - self.finish_hugr(&PRELUDE_REGISTRY) - } + fn finish_hugr(self) -> Result; } /// Types implementing this trait build a container graph region by borrowing a HUGR @@ -179,6 +201,8 @@ pub trait Dataflow: Container { /// Add a dataflow [`OpType`] to the sibling graph, wiring up the `input_wires` to the /// incoming ports of the resulting node. /// + /// Adds the extensions required by the op to the HUGR, if they are not already present. + /// /// # Errors /// /// Returns a [`BuildError::OperationWiring`] error if the `input_wires` cannot be connected. @@ -398,8 +422,6 @@ pub trait Dataflow: Container { &mut self, fid: &FuncID, type_args: &[TypeArg], - // Sadly required as we substituting in type_args may result in recomputing bounds of types: - exts: &ExtensionRegistry, ) -> Result { let func_node = fid.node(); let func_op = self.hugr().get_optype(func_node); @@ -415,7 +437,7 @@ pub trait Dataflow: Container { }; let load_n = self.add_dataflow_op( - ops::LoadFunction::try_new(func_sig, type_args, exts)?, + ops::LoadFunction::try_new(func_sig, type_args, self.hugr().extensions())?, // Static wire from the function node vec![Wire::new(func_node, func_op.static_output_port().unwrap())], )?; @@ -664,8 +686,6 @@ pub trait Dataflow: Container { function: &FuncID, type_args: &[TypeArg], input_wires: impl IntoIterator, - // Sadly required as we substituting in type_args may result in recomputing bounds of types: - exts: &ExtensionRegistry, ) -> Result, BuildError> { let hugr = self.hugr(); let def_op = hugr.get_optype(function.node()); @@ -679,7 +699,8 @@ pub trait Dataflow: Container { }) } }; - let op: OpType = ops::Call::try_new(type_scheme, type_args, exts)?.into(); + let op: OpType = + ops::Call::try_new(type_scheme, type_args, self.hugr().extensions())?.into(); let const_in_port = op.static_input_port().unwrap(); let op_id = self.add_dataflow_op(op, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; @@ -698,6 +719,8 @@ pub trait Dataflow: Container { /// Add a node to the graph, wiring up the `inputs` to the input ports of the resulting node. /// +/// Adds the extensions required by the op to the HUGR, if they are not already present. +/// /// # Errors /// /// Returns a [`BuildError::OperationWiring`] if any of the connections produces an @@ -826,27 +849,12 @@ pub trait DataflowHugr: HugrBuilder + Dataflow { fn finish_hugr_with_outputs( mut self, outputs: impl IntoIterator, - extension_registry: &ExtensionRegistry, ) -> Result where Self: Sized, { self.set_outputs(outputs)?; - Ok(self.finish_hugr(extension_registry)?) - } - - /// Sets the outputs of a dataflow Hugr, validates against - /// the [prelude] extension only, and return the Hugr - /// - /// [prelude]: crate::extension::prelude - fn finish_prelude_hugr_with_outputs( - self, - outputs: impl IntoIterator, - ) -> Result - where - Self: Sized, - { - self.finish_hugr_with_outputs(outputs, &PRELUDE_REGISTRY) + Ok(self.finish_hugr()?) } } diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index fb9199df5..b25b6ef33 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -5,14 +5,9 @@ use super::{ BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, }; -use crate::{ - extension::TO_BE_INFERRED, - ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}, -}; -use crate::{ - extension::{ExtensionRegistry, ExtensionSet}, - types::Signature, -}; +use crate::extension::TO_BE_INFERRED; +use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}; +use crate::{extension::ExtensionSet, types::Signature}; use crate::{hugr::views::HugrView, types::TypeRow}; use crate::Node; @@ -108,7 +103,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// cfg_builder.branch(&entry, 1, &successor_b)?; // branch 1 goes to successor_b /// cfg_builder.branch(&successor_a, 0, &exit)?; /// cfg_builder.branch(&successor_b, 0, &exit)?; -/// let hugr = cfg_builder.finish_prelude_hugr()?; +/// let hugr = cfg_builder.finish_hugr()?; /// Ok(hugr) /// }; /// #[cfg(not(feature = "extension_inference"))] @@ -162,11 +157,11 @@ impl CFGBuilder { } impl HugrBuilder for CFGBuilder { - fn finish_hugr( - mut self, - extension_registry: &ExtensionRegistry, - ) -> Result { - self.base.update_validate(extension_registry)?; + fn finish_hugr(mut self) -> Result { + if cfg!(feature = "extension_inference") { + self.base.infer_extensions(false)?; + } + self.base.validate()?; Ok(self.base) } } @@ -455,11 +450,9 @@ impl BlockBuilder { mut self, branch_wire: Wire, outputs: impl IntoIterator, - extension_registry: &ExtensionRegistry, ) -> Result { self.set_outputs(branch_wire, outputs)?; - self.finish_hugr(extension_registry) - .map_err(BuildError::InvalidHUGR) + self.finish_hugr().map_err(BuildError::InvalidHUGR) } } @@ -493,10 +486,10 @@ pub(crate) mod test { func_builder.finish_with_outputs(cfg_id.outputs())? }; - module_builder.finish_prelude_hugr() + module_builder.finish_hugr() }; - assert_eq!(build_result.err(), None); + assert!(build_result.is_ok(), "{}", build_result.unwrap_err()); Ok(()) } @@ -504,7 +497,7 @@ pub(crate) mod test { fn basic_cfg_hugr() -> Result<(), BuildError> { let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?; build_basic_cfg(&mut cfg_builder)?; - assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_)); + assert_matches!(cfg_builder.finish_hugr(), Ok(_)); Ok(()) } @@ -564,7 +557,7 @@ pub(crate) mod test { let exit = cfg_builder.exit_block(); cfg_builder.branch(&entry, 0, &middle)?; cfg_builder.branch(&middle, 0, &exit)?; - assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_)); + assert_matches!(cfg_builder.finish_hugr(), Ok(_)); Ok(()) } @@ -594,7 +587,7 @@ pub(crate) mod test { cfg_builder.branch(&entry, 0, &middle)?; cfg_builder.branch(&middle, 0, &exit)?; assert_matches!( - cfg_builder.finish_prelude_hugr(), + cfg_builder.finish_hugr(), Err(ValidationError::InterGraphEdgeError( InterGraphEdgeError::NonDominatedAncestor { .. } )) diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 5a2b18a04..338c1b260 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -346,7 +346,7 @@ mod test { let mut registry = test_quantum_extension::REG.clone(); registry.register(my_ext).unwrap(); - let build_res = module_builder.finish_hugr(®istry); + let build_res = module_builder.finish_hugr(); assert_matches!(build_res, Ok(_)); } diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index 4e8039992..4aec60c2a 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -1,4 +1,4 @@ -use crate::extension::{ExtensionRegistry, TO_BE_INFERRED}; +use crate::extension::TO_BE_INFERRED; use crate::hugr::views::HugrView; use crate::ops::dataflow::DataflowOpTrait; use crate::types::{Signature, TypeRow}; @@ -142,11 +142,11 @@ impl + AsRef> ConditionalBuilder { } impl HugrBuilder for ConditionalBuilder { - fn finish_hugr( - mut self, - extension_registry: &ExtensionRegistry, - ) -> Result { - self.base.update_validate(extension_registry)?; + fn finish_hugr(mut self) -> Result { + if cfg!(feature = "extension_inference") { + self.base.infer_extensions(false)?; + } + self.base.validate()?; Ok(self.base) } } @@ -264,7 +264,7 @@ mod test { let [int] = conditional_id.outputs_arr(); fbuild.finish_with_outputs([int])? }; - Ok(module_builder.finish_prelude_hugr()?) + Ok(module_builder.finish_hugr()?) }; assert_matches!(build_result, Ok(_)); diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 5d0c7e0c2..ce38650c0 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -14,7 +14,6 @@ use crate::{Direction, IncomingPort, OutgoingPort, Wire}; use crate::types::{PolyFuncType, Signature, Type}; -use crate::extension::ExtensionRegistry; use crate::Node; use crate::{hugr::HugrMut, Hugr}; @@ -83,11 +82,11 @@ impl DFGBuilder { } impl HugrBuilder for DFGBuilder { - fn finish_hugr( - mut self, - extension_registry: &ExtensionRegistry, - ) -> Result { - self.base.update_validate(extension_registry)?; + fn finish_hugr(mut self) -> Result { + if cfg!(feature = "extension_inference") { + self.base.infer_extensions(false)?; + } + self.base.validate()?; Ok(self.base) } } @@ -299,8 +298,8 @@ impl + AsRef, T: From>> SubContainer for } impl HugrBuilder for DFGWrapper { - fn finish_hugr(self, extension_registry: &ExtensionRegistry) -> Result { - self.0.finish_hugr(extension_registry) + fn finish_hugr(self) -> Result { + self.0.finish_hugr() } } @@ -317,7 +316,7 @@ pub(crate) mod test { }; use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::prelude::{Lift, Noop}; - use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionId, SignatureError, PRELUDE_REGISTRY}; use crate::hugr::validate::InterGraphEdgeError; use crate::ops::{handle::NodeHandle, OpTag}; use crate::ops::{OpTrait, Value}; @@ -325,7 +324,7 @@ pub(crate) mod test { use crate::std_extensions::logic::test::and_op; use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, FuncValueType, RowVariable, Signature, Type, TypeBound, TypeRV}; - use crate::utils::test_quantum_extension::{self, h_gate}; + use crate::utils::test_quantum_extension::h_gate; use crate::{builder::test::n_identity, type_row, Wire}; use super::super::test::simple_dfg_hugr; @@ -342,10 +341,7 @@ pub(crate) mod test { let inner_builder = outer_builder.dfg_builder_endo([(usize_t(), int)])?; let inner_id = n_identity(inner_builder)?; - outer_builder.finish_hugr_with_outputs( - inner_id.outputs().chain(q_out.outputs()), - &test_quantum_extension::REG, - ) + outer_builder.finish_hugr_with_outputs(inner_id.outputs().chain(q_out.outputs())) }; assert_eq!(build_result.err(), None); @@ -363,7 +359,7 @@ pub(crate) mod test { f(&mut builder)?; - builder.finish_hugr(&test_quantum_extension::REG) + builder.finish_hugr() }; assert_matches!(build_result, Ok(_), "Failed on example: {}", msg); @@ -412,7 +408,7 @@ pub(crate) mod test { let [q1] = f_build.input_wires_arr(); f_build.finish_with_outputs([q1, q1])?; - Ok(module_builder.finish_prelude_hugr()?) + Ok(module_builder.finish_hugr()?) }; assert_matches!( @@ -446,7 +442,7 @@ pub(crate) mod test { let nested = nested.finish_with_outputs([id.out_wire(0)])?; - f_build.finish_prelude_hugr_with_outputs([nested.out_wire(0)]) + f_build.finish_hugr_with_outputs([nested.out_wire(0)]) }; assert_matches!(builder(), Ok(_)); @@ -473,8 +469,7 @@ pub(crate) mod test { let i1 = f_build.add_input(qb_t()); let noop1 = f_build.add_dataflow_op(Noop(qb_t()), [i1])?; - let hugr = - f_build.finish_prelude_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?; + let hugr = f_build.finish_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?; Ok((hugr, f_node)) }; @@ -527,7 +522,7 @@ pub(crate) mod test { let mut dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()]))?; let [i1] = dfg_builder.input_wires_arr(); dfg_builder.set_metadata("x", 42); - let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1], &EMPTY_REG)?; + let dfg_hugr = dfg_builder.finish_hugr_with_outputs([i1])?; // Create a module, and insert the DFG into it let mut module_builder = ModuleBuilder::new(); @@ -543,7 +538,7 @@ pub(crate) mod test { (dfg.node(), f.node()) }; - let hugr = module_builder.finish_hugr(&EMPTY_REG)?; + let hugr = module_builder.finish_hugr()?; assert_eq!(hugr.node_count(), 7); assert_eq!(hugr.get_metadata(hugr.root(), "x"), None); @@ -585,7 +580,7 @@ pub(crate) mod test { let add_c = add_c.finish_with_outputs(wires)?; let [w] = add_c.outputs_arr(); - parent.finish_hugr_with_outputs([w], &test_quantum_extension::REG)?; + parent.finish_hugr_with_outputs([w])?; Ok(()) } @@ -603,7 +598,7 @@ pub(crate) mod test { // CFGs let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?; - let res = b.finish_prelude_hugr_with_outputs([b_child_2_handle.out_wire(0)]); + let res = b.finish_hugr_with_outputs([b_child_2_handle.out_wire(0)]); assert_matches!( res, @@ -685,17 +680,13 @@ pub(crate) mod test { .unwrap(); let load_constant = builder.add_load_value(Value::true_val()); let [r] = builder - .call(&func, &[], [load_constant], &EMPTY_REG) + .call(&func, &[], [load_constant]) .unwrap() .outputs_arr(); builder.finish_with_outputs([r]).unwrap(); (load_constant.node(), r.node()) }; - ( - builder.finish_hugr(&EMPTY_REG).unwrap(), - load_constant, - call, - ) + (builder.finish_hugr().unwrap(), load_constant, call) }; let lc_optype = hugr.get_optype(load_constant); @@ -712,6 +703,6 @@ pub(crate) mod test { call_optype.other_input_port().unwrap(), ); - hugr.validate(&EMPTY_REG).unwrap(); + hugr.validate().unwrap(); } } diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 1df328d83..18390926e 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -4,7 +4,6 @@ use super::{ BuildError, Container, }; -use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::HugrView; use crate::hugr::ValidationError; @@ -51,11 +50,11 @@ impl Default for ModuleBuilder { } impl HugrBuilder for ModuleBuilder { - fn finish_hugr( - mut self, - extension_registry: &ExtensionRegistry, - ) -> Result { - self.0.update_validate(extension_registry)?; + fn finish_hugr(mut self) -> Result { + if cfg!(feature = "extension_inference") { + self.0.infer_extensions(false)?; + } + self.0.validate()?; Ok(self.0) } } @@ -102,12 +101,20 @@ impl + AsRef> ModuleBuilder { name: impl Into, signature: PolyFuncType, ) -> Result, BuildError> { + let body = signature.body().clone(); // TODO add param names to metadata let declare_n = self.add_child_node(ops::FuncDecl { signature, name: name.into(), }); + // Add the extensions used by the function types. + self.use_extensions( + body.used_extensions().unwrap_or_else(|e| { + panic!("Build-time signatures should have valid extensions. {e}") + }), + ); + Ok(declare_n.into()) } @@ -162,7 +169,6 @@ mod test { use crate::extension::prelude::usize_t; use crate::{ builder::{test::n_identity, Dataflow, DataflowSubContainer}, - extension::{EMPTY_REG, PRELUDE_REGISTRY}, types::Signature, }; @@ -178,10 +184,10 @@ mod test { )?; let mut f_build = module_builder.define_declaration(&f_id)?; - let call = f_build.call(&f_id, &[], f_build.input_wires(), &PRELUDE_REGISTRY)?; + let call = f_build.call(&f_id, &[], f_build.input_wires())?; f_build.finish_with_outputs(call.outputs())?; - module_builder.finish_prelude_hugr() + module_builder.finish_hugr() }; assert_matches!(build_result, Ok(_)); Ok(()) @@ -203,7 +209,7 @@ mod test { ), )?; n_identity(f_build)?; - module_builder.finish_hugr(&EMPTY_REG) + module_builder.finish_hugr() }; assert_matches!(build_result, Ok(_)); Ok(()) @@ -225,11 +231,10 @@ mod test { let [wire] = local_build.input_wires_arr(); let f_id = local_build.finish_with_outputs([wire, wire])?; - let call = - f_build.call(f_id.handle(), &[], f_build.input_wires(), &PRELUDE_REGISTRY)?; + let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?; f_build.finish_with_outputs(call.outputs())?; - module_builder.finish_prelude_hugr() + module_builder.finish_hugr() }; assert_matches!(build_result, Ok(_)); Ok(()) diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index b3051c72e..b83dcdfb1 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -127,7 +127,7 @@ mod test { let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; - loop_b.finish_prelude_hugr() + loop_b.finish_hugr() }; assert_matches!(build_result, Ok(_)); @@ -191,7 +191,7 @@ mod test { }; fbuild.finish_with_outputs(loop_id.outputs())? }; - module_builder.finish_prelude_hugr() + module_builder.finish_hugr() }; assert_matches!(build_result, Ok(_)); diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index b060c7ae6..276bd7025 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -670,7 +670,7 @@ pub(super) mod test { .unwrap(), dfg.input_wires(), )?; - dfg.finish_hugr_with_outputs(rev.outputs(), ®)?; + dfg.finish_hugr_with_outputs(rev.outputs())?; Ok(()) } diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 180f2dfc7..86cae4e1c 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -958,9 +958,7 @@ impl MakeRegisteredOp for Lift { #[cfg(test)] mod test { use crate::builder::inout_sig; - use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; - use crate::utils::test_quantum_extension; use crate::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}, utils::test_quantum_extension::cx_gate, @@ -1049,7 +1047,7 @@ mod test { let out = b.add_dataflow_op(op, [q1, q2]).unwrap(); - b.finish_prelude_hugr_with_outputs(out.outputs()).unwrap(); + b.finish_hugr_with_outputs(out.outputs()).unwrap(); } #[test] @@ -1063,7 +1061,7 @@ mod test { let some = b.add_load_value(const_val1); let none = b.add_load_value(const_val2); - b.finish_prelude_hugr_with_outputs([some, none]).unwrap(); + b.finish_hugr_with_outputs([some, none]).unwrap(); } #[test] @@ -1077,8 +1075,7 @@ mod test { let bool = b.add_load_value(const_bool); let float = b.add_load_value(const_float); - b.finish_hugr_with_outputs([bool, float], &FLOAT_OPS_REGISTRY) - .unwrap(); + b.finish_hugr_with_outputs([bool, float]).unwrap(); } #[test] @@ -1121,7 +1118,7 @@ mod test { b.add_dataflow_op(op, [err]).unwrap(); - b.finish_prelude_hugr_with_outputs([]).unwrap(); + b.finish_hugr_with_outputs([]).unwrap(); } #[test] @@ -1151,8 +1148,7 @@ mod test { .add_dataflow_op(panic_op, [err, q0, q1]) .unwrap() .outputs_arr(); - b.finish_hugr_with_outputs([q0, q1], &test_quantum_extension::REG) - .unwrap(); + b.finish_hugr_with_outputs([q0, q1]).unwrap(); } #[test] @@ -1186,7 +1182,7 @@ mod test { .instantiate_extension_op(&PRINT_OP_ID, [], &PRELUDE_REGISTRY) .unwrap(); b.add_dataflow_op(print_op, [greeting_out]).unwrap(); - b.finish_prelude_hugr_with_outputs([]).unwrap(); + b.finish_hugr_with_outputs([]).unwrap(); } #[test] diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index 4ae3023a8..eaa2a95a3 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -764,7 +764,7 @@ mod tests { let out = b.add_dataflow_op(op, [q1, q2]).unwrap(); - b.finish_prelude_hugr_with_outputs(out.outputs()).unwrap(); + b.finish_hugr_with_outputs(out.outputs()).unwrap(); } #[test] diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index 753533f1f..eba1aa0c7 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -104,6 +104,6 @@ mod tests { let [res] = builder .build_unwrap_sum(&PRELUDE_REGISTRY, 1, option_type(bool_t()), opt) .unwrap(); - builder.finish_prelude_hugr_with_outputs([res]).unwrap(); + builder.finish_hugr_with_outputs([res]).unwrap(); } } diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 7d02d1e0c..f435633dc 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -14,9 +14,9 @@ use crate::extension::prelude::{bool_t, ConstUsize}; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, }; -use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE}; +use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; use crate::ops::{CallIndirect, ExtensionOp, Input, OpTrait, OpType, Tag, Value}; -use crate::std_extensions::arithmetic::float_types::{self, float64_type}; +use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::arithmetic::int_ops; use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::types::{Signature, Type}; @@ -88,17 +88,6 @@ fn resolve_hugr_extensions() { let (ext_d, op_d) = make_extension("dummy.d", "op_d"); let (ext_e, op_e) = make_extension("dummy.e", "op_e"); - let build_extensions = ExtensionRegistry::new([ - PRELUDE.to_owned(), - ext_a.clone(), - ext_b.clone(), - ext_c.clone(), - ext_d.clone(), - ext_e.clone(), - float_types::EXTENSION.to_owned(), - int_types::EXTENSION.to_owned(), - ]); - let mut module = ModuleBuilder::new(); // A constant op using the prelude extension. @@ -133,20 +122,8 @@ fn resolve_hugr_extensions() { let [func_i0, func_i1] = func.input_wires_arr(); // Call the function declaration directly, and load & call indirectly. - func.call( - &decl, - &[], - vec![func_i0], - &ExtensionRegistry::new([float_types::EXTENSION.to_owned()]), - ) - .unwrap(); - let loaded_func = func - .load_func( - &decl, - &[], - &ExtensionRegistry::new([float_types::EXTENSION.to_owned()]), - ) - .unwrap(); + func.call(&decl, &[], vec![func_i0]).unwrap(); + let loaded_func = func.load_func(&decl, &[]).unwrap(); func.add_dataflow_op( CallIndirect { signature: Signature::new_endo(vec![float64_type()]), @@ -212,9 +189,13 @@ fn resolve_hugr_extensions() { // Finally, finish the hugr and ensure it's using the right extensions. func.finish_with_outputs(vec![]).unwrap(); - let mut hugr = module - .finish_hugr(&build_extensions) - .unwrap_or_else(|e| panic!("{e}")); + let mut hugr = module.finish_hugr().unwrap_or_else(|e| panic!("{e}")); + + let build_extensions = hugr.extensions().clone(); + assert!(build_extensions.contains(ext_a.name())); + assert!(build_extensions.contains(ext_b.name())); + assert!(build_extensions.contains(ext_c.name())); + assert!(build_extensions.contains(ext_d.name())); // Check that the read-only methods collect the same extensions. let mut collected_exts = ExtensionRegistry::default(); @@ -228,14 +209,12 @@ fn resolve_hugr_extensions() { ); // Check that the mutable methods collect the same extensions. - assert_matches!( - hugr.resolve_extension_defs(&ExtensionRegistry::default()), - Err(_) - ); - let resolved = hugr.resolve_extension_defs(&build_extensions).unwrap(); + hugr.resolve_extension_defs(&build_extensions).unwrap(); assert_eq!( - &resolved, &build_extensions, - "{resolved} != {build_extensions}" + hugr.extensions(), + &build_extensions, + "{} != {build_extensions}", + hugr.extensions() ); } @@ -243,12 +222,6 @@ fn resolve_hugr_extensions() { #[rstest] fn dropped_weak_extensions() { let (ext_a, op_a) = make_extension("dummy.a", "op_a"); - let build_extensions = ExtensionRegistry::new([ - PRELUDE.to_owned(), - ext_a.clone(), - float_types::EXTENSION.to_owned(), - ]); - let mut func = FunctionBuilder::new( "dummy_fn", Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( @@ -262,7 +235,7 @@ fn dropped_weak_extensions() { let [_func_i0, func_i1] = func.input_wires_arr(); func.add_dataflow_op(op_a, vec![func_i1]).unwrap(); - let hugr = func.finish_hugr(&build_extensions).unwrap(); + let hugr = func.finish_hugr().unwrap(); // Do a serialization roundtrip to drop the references. let ser = serde_json::to_string(&hugr).unwrap(); diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index c3412a33b..99accb9ac 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -10,6 +10,7 @@ pub mod validate; pub mod views; use std::collections::VecDeque; +use std::io::Read; use std::iter; pub(crate) use self::hugrmut::HugrMut; @@ -49,6 +50,9 @@ pub struct Hugr { /// Node metadata metadata: UnmanagedDenseMap>, + + /// Extensions used by the operations in the Hugr. + extensions: ExtensionRegistry, } impl Default for Hugr { @@ -84,19 +88,28 @@ impl Hugr { Self::with_capacity(root_node.into(), 0, 0) } - /// Resolve extension ops, infer extensions used, and pass the closure into validation - pub fn update_validate( - &mut self, + /// Load a Hugr from a json reader. + /// + /// Validates the Hugr against the provided extension registry, ensuring all + /// operations are resolved. + /// + /// If the feature `extension_inference` is enabled, we will ensure every function + /// correctly specifies the extensions required by its contained ops. + pub fn load_json( + reader: impl Read, extension_registry: &ExtensionRegistry, - ) -> Result<(), ValidationError> { - self.resolve_extension_defs(extension_registry)?; - self.validate_no_extensions(extension_registry)?; - #[cfg(feature = "extension_inference")] - { - self.infer_extensions(false)?; - self.validate_extensions()?; + ) -> Result { + let mut hugr: Hugr = serde_json::from_reader(reader)?; + + hugr.resolve_extension_defs(extension_registry)?; + hugr.validate_no_extensions()?; + + if cfg!(feature = "extension_inference") { + hugr.infer_extensions(false)?; + hugr.validate_extensions()?; } - Ok(()) + + Ok(hugr) } /// Infers an extension-delta for any non-function container node @@ -186,7 +199,8 @@ impl Hugr { /// function signature by the `required_extensions` field and define the set /// of capabilities required by the runtime to execute each function. /// - /// Returns a new extension registry with the extensions used in the Hugr. + /// Updates the internal extension registry with the extensions used in the + /// definition. /// /// # Parameters /// @@ -204,7 +218,7 @@ impl Hugr { pub fn resolve_extension_defs( &mut self, extensions: &ExtensionRegistry, - ) -> Result { + ) -> Result<(), ExtensionResolutionError> { let mut used_extensions = ExtensionRegistry::default(); // Here we need to iterate the optypes in the hugr mutably, to avoid @@ -232,7 +246,8 @@ impl Hugr { resolve_op_types_extensions(node, op, extensions, &mut used_extensions)?; } - Ok(used_extensions) + self.extensions = used_extensions; + Ok(()) } } @@ -244,6 +259,7 @@ impl Hugr { let hierarchy = Hierarchy::new(); let mut op_types = UnmanagedDenseMap::with_capacity(nodes); let root = graph.add_node(root_node.input_count(), root_node.output_count()); + let extensions = root_node.used_extensions(); op_types[root] = root_node; Self { @@ -252,6 +268,7 @@ impl Hugr { root, op_types, metadata: UnmanagedDenseMap::with_capacity(nodes), + extensions: extensions.unwrap_or_default(), } } @@ -358,6 +375,24 @@ pub enum HugrError { InvalidPortDirection(Direction), } +/// Errors that can occur while loading and validating a Hugr json. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum LoadHugrError { + /// Error while loading the Hugr from JSON. + #[error("Error while loading the Hugr from JSON: {0}")] + Load(#[from] serde_json::Error), + /// Validation of the loaded Hugr failed. + #[error(transparent)] + Validation(#[from] ValidationError), + /// Error when resolving extension operations and types. + #[error(transparent)] + Extension(#[from] ExtensionResolutionError), + /// Error when inferring runtime extensions. + #[error(transparent)] + RuntimeInference(#[from] ExtensionError), +} + #[cfg(test)] mod test { use std::{fs::File, io::BufReader}; @@ -368,11 +403,10 @@ mod test { use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; use crate::extension::prelude::Lift; use crate::extension::prelude::PRELUDE_ID; - use crate::extension::{ - ExtensionId, ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY, TO_BE_INFERRED, - }; + use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED}; use crate::types::{Signature, Type}; use crate::{const_extension_ids, ops, test_file, type_row}; + use cool_asserts::assert_matches; use rstest::rstest; #[test] @@ -395,53 +429,46 @@ mod test { #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri - #[should_panic] // issue 1225: In serialization we do not distinguish between unknown CustomConst serialized value invalid but known CustomConst serialized values" fn hugr_validation_0() { // https://github.com/CQCL/hugr/issues/1091 bad case - let mut hugr: Hugr = serde_json::from_reader(BufReader::new( - File::open(test_file!("hugr-0.json")).unwrap(), - )) - .unwrap(); - assert!( - hugr.update_validate(&PRELUDE_REGISTRY).is_err(), - "HUGR should not validate." + let hugr = Hugr::load_json( + BufReader::new(File::open(test_file!("hugr-0.json")).unwrap()), + &PRELUDE_REGISTRY, ); + assert_matches!(hugr, Err(_)); } #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn hugr_validation_1() { // https://github.com/CQCL/hugr/issues/1091 good case - let mut hugr: Hugr = serde_json::from_reader(BufReader::new( - File::open(test_file!("hugr-1.json")).unwrap(), - )) - .unwrap(); - assert!(hugr.update_validate(&PRELUDE_REGISTRY).is_ok()); + let hugr = Hugr::load_json( + BufReader::new(File::open(test_file!("hugr-1.json")).unwrap()), + &PRELUDE_REGISTRY, + ); + assert_matches!(&hugr, Ok(_)); } #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn hugr_validation_2() { // https://github.com/CQCL/hugr/issues/1185 bad case - let mut hugr: Hugr = serde_json::from_reader(BufReader::new( - File::open(test_file!("hugr-2.json")).unwrap(), - )) - .unwrap(); - assert!( - hugr.update_validate(&PRELUDE_REGISTRY).is_err(), - "HUGR should not validate." + let hugr = Hugr::load_json( + BufReader::new(File::open(test_file!("hugr-2.json")).unwrap()), + &PRELUDE_REGISTRY, ); + assert_matches!(hugr, Err(_)); } #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn hugr_validation_3() { // https://github.com/CQCL/hugr/issues/1185 good case - let mut hugr: Hugr = serde_json::from_reader(BufReader::new( - File::open(test_file!("hugr-3.json")).unwrap(), - )) - .unwrap(); - assert!(hugr.update_validate(&PRELUDE_REGISTRY).is_ok()); + let hugr = Hugr::load_json( + BufReader::new(File::open(test_file!("hugr-3.json")).unwrap()), + &PRELUDE_REGISTRY, + ); + assert_matches!(&hugr, Ok(_)); } const_extension_ids! { @@ -485,7 +512,7 @@ mod test { let backup = h.clone(); h.infer_extensions(false).unwrap(); assert_eq!(h, backup); // did nothing - let val_res = h.validate(&EMPTY_REG); + let val_res = h.validate(); let expected_err = ExtensionError { parent: h.root(), parent_extensions: XB.into(), diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 4753e1dec..4056f36e6 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -2,14 +2,16 @@ use core::panic; use std::collections::HashMap; +use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap}; +use crate::extension::ExtensionRegistry; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; -use crate::{Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; +use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; use super::NodeMetadataMap; @@ -245,6 +247,37 @@ pub trait HugrMut: HugrMutInternals { { rw.apply(self) } + + /// Registers a new extension in the set used by the hugr, keeping the one + /// most recent one if the extension already exists. + /// + /// These can be queried using [`HugrView::extensions`]. + /// + /// See [`ExtensionRegistry::register_updated`] for more information. + fn use_extension(&mut self, extension: impl Into>) { + self.hugr_mut().extensions.register_updated(extension); + } + + /// Extend the set of extensions used by the hugr with the extensions in the + /// registry. + /// + /// For each extension, keeps the most recent version if the id already + /// exists. + /// + /// These can be queried using [`HugrView::extensions`]. + /// + /// See [`ExtensionRegistry::register_updated`] for more information. + fn use_extensions(&mut self, registry: impl IntoIterator) + where + ExtensionRegistry: Extend, + { + self.hugr_mut().extensions.extend(registry); + } + + /// Returns a mutable reference to the extension registry for this hugr. + fn extensions_mut(&mut self) -> &mut ExtensionRegistry { + &mut self.hugr_mut().extensions + } } /// Records the result of inserting a Hugr or view @@ -349,6 +382,8 @@ impl + AsMut> HugrMut for T { fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); // Update the optypes and metadata, taking them from the other graph. + // + // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let optype = other.op_types.take(node); self.as_mut().op_types.set(new_node, optype); @@ -368,6 +403,8 @@ impl + AsMut> HugrMut for T { fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); // Update the optypes and metadata, copying them from the other graph. + // + // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let nodetype = other.get_optype(node.into()); self.as_mut().op_types.set(new_node, nodetype.clone()); @@ -404,6 +441,10 @@ impl + AsMut> HugrMut for T { self.as_mut().op_types.set(new_node, nodetype.clone()); let meta = other.base_hugr().metadata.get(node); self.as_mut().metadata.set(new_node, meta.clone()); + // Add the required extensions to the registry. + if let Ok(exts) = nodetype.used_extensions() { + self.use_extensions(exts); + } } translate_indices(node_map) } @@ -440,13 +481,8 @@ fn insert_hugr_internal( }); } - // The root node didn't have any ports. - let root_optype = other.get_optype(other.root()); - hugr.set_num_ports( - other_root.into(), - root_optype.input_count(), - root_optype.output_count(), - ); + // Merge the extension sets. + hugr.extensions.extend(other.extensions()); (other_root.into(), node_map) } @@ -533,11 +569,9 @@ pub(super) fn panic_invalid_port( #[cfg(test)] mod test { + use crate::extension::PRELUDE; use crate::{ - extension::{ - prelude::{usize_t, Noop}, - PRELUDE_REGISTRY, - }, + extension::prelude::{usize_t, Noop}, ops::{self, dataflow::IOTrait, FuncDefn, Input, Output}, types::Signature, }; @@ -547,6 +581,7 @@ mod test { #[test] fn simple_function() -> Result<(), Box> { let mut hugr = Hugr::default(); + hugr.use_extension(PRELUDE.to_owned()); // Create the root module definition let module: Node = hugr.root(); @@ -572,7 +607,7 @@ mod test { hugr.connect(noop, 0, f_out, 1); } - hugr.update_validate(&PRELUDE_REGISTRY)?; + hugr.validate()?; Ok(()) } @@ -599,6 +634,7 @@ mod test { #[test] fn remove_subtree() { let mut hugr = Hugr::default(); + hugr.use_extension(PRELUDE.to_owned()); let root = hugr.root(); let [foo, bar] = ["foo", "bar"].map(|name| { let fd = hugr.add_node_with_parent( @@ -613,15 +649,15 @@ mod test { hugr.connect(inp, 0, out, 0); fd }); - hugr.validate(&PRELUDE_REGISTRY).unwrap(); + hugr.validate().unwrap(); assert_eq!(hugr.node_count(), 7); hugr.remove_subtree(foo); - hugr.validate(&PRELUDE_REGISTRY).unwrap(); + hugr.validate().unwrap(); assert_eq!(hugr.node_count(), 4); hugr.remove_subtree(bar); - hugr.validate(&PRELUDE_REGISTRY).unwrap(); + hugr.validate().unwrap(); assert_eq!(hugr.node_count(), 1); } } diff --git a/hugr-core/src/hugr/rewrite/consts.rs b/hugr-core/src/hugr/rewrite/consts.rs index 7980fee1f..855a61f48 100644 --- a/hugr-core/src/hugr/rewrite/consts.rs +++ b/hugr-core/src/hugr/rewrite/consts.rs @@ -117,7 +117,7 @@ mod test { use crate::extension::prelude::PRELUDE_ID; use crate::{ builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, - extension::{prelude::ConstUsize, PRELUDE_REGISTRY}, + extension::prelude::ConstUsize, ops::{handle::NodeHandle, Value}, type_row, types::Signature, @@ -136,7 +136,7 @@ mod test { let tup = dfg_build.make_tuple([load_1, load_2])?; dfg_build.finish_sub_container()?; - let mut h = build.finish_prelude_hugr()?; + let mut h = build.finish_hugr()?; // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple assert_eq!(h.node_count(), 8); let tup_node = tup.node(); @@ -194,7 +194,7 @@ mod test { assert_eq!(h.apply_rewrite(remove_con)?, h.root()); assert_eq!(h.node_count(), 4); - assert!(h.validate(&PRELUDE_REGISTRY).is_ok()); + assert!(h.validate().is_ok()); Ok(()) } } diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index 74327f970..c8b0c7448 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -137,7 +137,7 @@ mod test { SubContainer, }; use crate::extension::prelude::qb_t; - use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; + use crate::extension::ExtensionSet; use crate::hugr::rewrite::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; use crate::ops::handle::{DfgID, NodeHandle}; @@ -169,12 +169,6 @@ mod test { fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box> { use crate::extension::prelude::Lift; - let reg = ExtensionRegistry::new([ - PRELUDE.to_owned(), - int_ops::EXTENSION.to_owned(), - int_types::EXTENSION.to_owned(), - ]); - reg.validate()?; let int_ty = &int_types::INT_TYPES[6]; let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?; @@ -204,7 +198,7 @@ mod test { let [a1] = inner.outputs_arr(); let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?; - let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs(), ®)?; + let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs())?; // Sanity checks assert_eq!( @@ -229,7 +223,7 @@ mod test { } outer.apply_rewrite(InlineDFG(*inner.handle()))?; - outer.validate(®)?; + outer.validate()?; assert_eq!(outer.nodes().count(), 8); assert_eq!(find_dfgs(&outer), vec![outer.root()]); let [_lift, add, sub] = extension_ops(&outer).try_into().unwrap(); @@ -256,14 +250,8 @@ mod test { }; let [q, p] = swap.outputs_arr(); let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?; - let reg = ExtensionRegistry::new([ - test_quantum_extension::EXTENSION.clone(), - PRELUDE.clone(), - float_types::EXTENSION.clone(), - ]); - reg.validate()?; - let mut h = h.finish_hugr_with_outputs(cx.outputs(), ®)?; + let mut h = h.finish_hugr_with_outputs(cx.outputs())?; assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]); assert_eq!(h.nodes().count(), 8); // Dfg+I+O, H, CX, Dfg+I+O // No permutation outside the swap DFG: @@ -333,12 +321,6 @@ mod test { * CX */ // Extension inference here relies on quantum ops not requiring their own test_quantum_extension - let reg = ExtensionRegistry::new([ - test_quantum_extension::EXTENSION.to_owned(), - float_types::EXTENSION.to_owned(), - PRELUDE.to_owned(), - ]); - reg.validate()?; let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [a, b] = outer.input_wires_arr(); let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?; @@ -370,10 +352,10 @@ mod test { test_quantum_extension::cx_gate(), h_a2.outputs().chain(inner.outputs()), )?; - let mut outer = outer.finish_hugr_with_outputs(cx.outputs(), ®)?; + let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?; outer.apply_rewrite(InlineDFG(*inner.handle()))?; - outer.validate(®)?; + outer.validate()?; let order_neighbours = |n, d| { let p = outer.get_optype(n).other_port(d).unwrap(); outer diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/rewrite/insert_identity.rs index 0c1dc872a..2114be8fd 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/rewrite/insert_identity.rs @@ -101,7 +101,6 @@ mod tests { use super::super::simple_replace::test::dfg_hugr; use super::*; - use crate::utils::test_quantum_extension; use crate::{extension::prelude::qb_t, Hugr}; #[rstest] @@ -127,6 +126,6 @@ mod tests { assert_eq!(noop, Noop(qb_t())); - h.update_validate(&test_quantum_extension::REG).unwrap(); + h.validate().unwrap(); } } diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index 1b9a47a1a..b7b1da3bc 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -252,7 +252,6 @@ mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::extension::PRELUDE_REGISTRY; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::constant::Value; @@ -319,7 +318,7 @@ mod test { let exit = cfg_builder.exit_block(); cfg_builder.branch(&tail, 0, &exit)?; - let h = cfg_builder.finish_prelude_hugr()?; + let h = cfg_builder.finish_hugr()?; let (left, right) = (left.node(), right.node()); let (merge, head, tail) = (merge.node(), head.node(), tail.node()); Ok(Self { @@ -446,7 +445,7 @@ mod test { .add_hugr_with_wires(cond_then_loop_cfg.h, [i1]) .unwrap(); fbuild.finish_with_outputs(cfg.outputs()).unwrap(); - let mut h = module_builder.finish_prelude_hugr().unwrap(); + let mut h = module_builder.finish_hugr().unwrap(); // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually: let cfg = cfg.node(); let exit_node = h.children(cfg).nth(1).unwrap(); @@ -463,7 +462,7 @@ mod test { cfg, vec![head, tail], ); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.validate().unwrap(); } #[rstest] @@ -486,7 +485,7 @@ mod test { let root = h.root(); let (new_block, _, _) = outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.validate().unwrap(); assert_eq!(new_block, h.children(h.root()).next().unwrap()); assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]); } diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 43de652a3..a201e6809 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -491,7 +491,7 @@ mod test { cfg.branch(&entry, 0, &bb2)?; cfg.branch(&bb2, 0, &exit)?; - let mut h = cfg.finish_hugr(®).unwrap(); + let mut h = cfg.finish_hugr().unwrap(); { let pop = find_node(&h, "pop"); let push = find_node(&h, "push"); @@ -571,7 +571,7 @@ mod test { }], mu_new: vec![], })?; - h.update_validate(®)?; + h.validate()?; { let pop = find_node(&h, "pop"); let push = find_node(&h, "push"); @@ -684,9 +684,7 @@ mod test { let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); let cond = cond.finish_sub_container().unwrap(); - let h = h - .finish_hugr_with_outputs(cond.outputs(), ®istry) - .unwrap(); + let h = h.finish_hugr_with_outputs(cond.outputs()).unwrap(); let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone()); let r1 = r_hugr.add_node_with_parent( diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index a67ec6f2f..8424acbeb 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -243,7 +243,7 @@ pub(in crate::hugr::rewrite) mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use crate::extension::ExtensionSet; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::dataflow::DataflowOpTrait; @@ -253,7 +253,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::std_extensions::logic::test::and_op; use crate::std_extensions::logic::LogicOp; use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::{self, cx_gate, h_gate, EXTENSION_ID}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -297,7 +297,7 @@ pub(in crate::hugr::rewrite) mod test { func_builder.finish_with_outputs(inner_graph.outputs().chain(q_out.outputs()))? }; - Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?) + Ok(module_builder.finish_hugr()?) } #[fixture] @@ -317,7 +317,7 @@ pub(in crate::hugr::rewrite) mod test { let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire45 = dfg_builder.add_dataflow_op(cx_gate(), wire2.outputs().chain(wire3.outputs()))?; - dfg_builder.finish_hugr_with_outputs(wire45.outputs(), &test_quantum_extension::REG) + dfg_builder.finish_hugr_with_outputs(wire45.outputs()) } #[fixture] @@ -337,7 +337,7 @@ pub(in crate::hugr::rewrite) mod test { let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; let wire2out = wire2.outputs().exactly_one().unwrap(); let wireoutvec = vec![wire0, wire2out]; - dfg_builder.finish_hugr_with_outputs(wireoutvec, &test_quantum_extension::REG) + dfg_builder.finish_hugr_with_outputs(wireoutvec) } #[fixture] @@ -371,9 +371,7 @@ pub(in crate::hugr::rewrite) mod test { let [b1] = not_1.outputs_arr(); ( - dfg_builder - .finish_hugr_with_outputs([b0, b1], &test_quantum_extension::REG) - .unwrap(), + dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(), vec![not_inp.node(), not_0.node(), not_1.node()], ) } @@ -403,9 +401,7 @@ pub(in crate::hugr::rewrite) mod test { let b1 = b; ( - dfg_builder - .finish_hugr_with_outputs([b0, b1], &test_quantum_extension::REG) - .unwrap(), + dfg_builder.finish_hugr_with_outputs([b0, b1]).unwrap(), vec![not_inp.node(), not_0.node()], ) } @@ -489,7 +485,7 @@ pub(in crate::hugr::rewrite) mod test { // ├───┤├───┤┌─┴─┐ // ┤ H ├┤ H ├┤ X ├ // └───┘└───┘└───┘ - assert_eq!(h.update_validate(&test_quantum_extension::REG), Ok(())); + assert_eq!(h.validate(), Ok(())); } #[rstest] @@ -561,7 +557,7 @@ pub(in crate::hugr::rewrite) mod test { // ├───┤├───┤┌───┐ // ┤ H ├┤ H ├┤ H ├ // └───┘└───┘└───┘ - assert_eq!(h.update_validate(&test_quantum_extension::REG), Ok(())); + assert_eq!(h.validate(), Ok(())); } #[test] @@ -573,9 +569,7 @@ pub(in crate::hugr::rewrite) mod test { circ.append(cx_gate(), [1, 0]).unwrap(); let wires = circ.finish(); let [input, output] = builder.io(); - let mut h = builder - .finish_hugr_with_outputs(wires, &test_quantum_extension::REG) - .unwrap(); + let mut h = builder.finish_hugr_with_outputs(wires).unwrap(); let replacement = h.clone(); let orig = h.clone(); @@ -634,17 +628,13 @@ pub(in crate::hugr::rewrite) mod test { .unwrap() .outputs(); let [input, _] = builder.io(); - let mut h = builder - .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) - .unwrap(); + let mut h = builder.finish_hugr_with_outputs(outw).unwrap(); let mut builder = DFGBuilder::new(inout_sig(two_bit, one_bit)).unwrap(); let inw = builder.input_wires(); let outw = builder.add_dataflow_op(and_op(), inw).unwrap().outputs(); let [repl_input, repl_output] = builder.io(); - let repl = builder - .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) - .unwrap(); + let repl = builder.finish_hugr_with_outputs(outw).unwrap(); let orig = h.clone(); @@ -693,7 +683,7 @@ pub(in crate::hugr::rewrite) mod test { let b = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [w] = b.input_wires_arr(); - b.finish_prelude_hugr_with_outputs([w, w]).unwrap() + b.finish_hugr_with_outputs([w, w]).unwrap() }; let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap(); @@ -731,7 +721,7 @@ pub(in crate::hugr::rewrite) mod test { }; rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); - assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(hugr.validate(), Ok(())); assert_eq!(hugr.node_count(), 3); } @@ -752,11 +742,7 @@ pub(in crate::hugr::rewrite) mod test { let [w] = b.input_wires_arr(); let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); let [w_not] = not.outputs_arr(); - ( - b.finish_hugr_with_outputs([w, w_not], &test_quantum_extension::REG) - .unwrap(), - not.node(), - ) + (b.finish_hugr_with_outputs([w, w_not]).unwrap(), not.node()) }; let [_repl_input, repl_output] = replacement.get_io(replacement.root()).unwrap(); @@ -793,7 +779,7 @@ pub(in crate::hugr::rewrite) mod test { }; rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); - assert_eq!(hugr.update_validate(&test_quantum_extension::REG), Ok(())); + assert_eq!(hugr.validate(), Ok(())); assert_eq!(hugr.node_count(), 4); } @@ -814,7 +800,7 @@ pub(in crate::hugr::rewrite) mod test { let inner_dfg = n_identity(inner_build).unwrap(); let inner_dfg_node = inner_dfg.node(); let replacement = nest_build - .finish_prelude_hugr_with_outputs([inner_dfg.out_wire(0)]) + .finish_hugr_with_outputs([inner_dfg.out_wire(0)]) .unwrap(); let subgraph = SiblingSubgraph::try_from_nodes(vec![h_node], &h).unwrap(); let nu_inp = vec![( @@ -836,8 +822,7 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), 4); rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); - h.update_validate(&PRELUDE_REGISTRY) - .unwrap_or_else(|e| panic!("{e}")); + h.validate().unwrap_or_else(|e| panic!("{e}")); assert_eq!(h.node_count(), 6); } diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 63be8c109..452f15812 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -8,15 +8,13 @@ use crate::builder::{ use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; use crate::extension::simple_op::MakeRegisteredOp; -use crate::extension::ExtensionId; -use crate::extension::{test::SimpleOpDef, ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; +use crate::extension::ExtensionRegistry; +use crate::extension::{test::SimpleOpDef, ExtensionSet, EMPTY_REG}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; -use crate::hugr::ExtensionResolutionError; -use crate::ops::custom::{ExtensionOp, OpaqueOp}; +use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Output, Value, DFG}; use crate::std_extensions::arithmetic::float_types::float64_type; -use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use crate::std_extensions::logic::LogicOp; use crate::types::type_param::TypeParam; @@ -24,7 +22,6 @@ use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, TypeRV, }; -use crate::utils::test_quantum_extension; use crate::{type_row, OutgoingPort}; use itertools::Itertools; @@ -289,6 +286,7 @@ fn simpleser() { root, op_types, metadata: Default::default(), + extensions: ExtensionRegistry::default(), }; check_hugr_roundtrip(&hugr, true); @@ -317,7 +315,7 @@ fn weighted_hugr_ser() { f_build.set_metadata("val", 42); f_build.finish_with_outputs(outputs).unwrap(); - module_builder.finish_prelude_hugr().unwrap() + module_builder.finish_hugr().unwrap() }; check_hugr_roundtrip(&hugr, true); @@ -334,7 +332,7 @@ fn dfg_roundtrip() -> Result<(), Box> { .unwrap() .out_wire(0); } - let hugr = dfg.finish_hugr_with_outputs(params, &test_quantum_extension::REG)?; + let hugr = dfg.finish_hugr_with_outputs(params)?; check_hugr_roundtrip(&hugr, true); Ok(()) @@ -353,7 +351,7 @@ fn extension_ops() -> Result<(), Box> { .unwrap() .out_wire(0); - let hugr = dfg.finish_hugr_with_outputs([wire], &test_quantum_extension::REG)?; + let hugr = dfg.finish_hugr_with_outputs([wire])?; check_hugr_roundtrip(&hugr, true); Ok(()) @@ -371,7 +369,6 @@ fn opaque_ops() -> Result<(), Box> { .add_dataflow_op(extension_op.clone(), [wire]) .unwrap() .out_wire(0); - let not_node = wire.node(); // Add an unresolved opaque operation let opaque_op: OpaqueOp = extension_op.into(); @@ -379,15 +376,12 @@ fn opaque_ops() -> Result<(), Box> { let wire = dfg.add_dataflow_op(opaque_op, [wire]).unwrap().out_wire(0); assert_eq!( - dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY), - Err(ValidationError::ExtensionResolutionError( - ExtensionResolutionError::MissingOpExtension { - node: not_node, - op: "logic.Not".into(), - missing_extension: ext_name, - available_extensions: vec![ExtensionId::new("prelude").unwrap()] - } - ) + dfg.finish_hugr_with_outputs([wire]), + Err(ValidationError::OpaqueOpError(OpaqueOpError::UnresolvedOp( + wire.node(), + "Not".into(), + ext_name, + )) .into()) ); @@ -399,7 +393,7 @@ fn function_type() -> Result<(), Box> { let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()]).with_prelude()); let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]).with_prelude())?; let op = bldr.add_dataflow_op(Noop(fn_ty), bldr.input_wires())?; - let h = bldr.finish_prelude_hugr_with_outputs(op.outputs())?; + let h = bldr.finish_hugr_with_outputs(op.outputs())?; check_hugr_roundtrip(&h, true); Ok(()) @@ -417,11 +411,10 @@ fn hierarchy_order() -> Result<(), Box> { hugr.connect(new_in, 0, out, 0); hugr.move_before_sibling(new_in, old_in); hugr.remove_node(old_in); - hugr.update_validate(&PRELUDE_REGISTRY)?; + hugr.validate()?; let rhs: Hugr = check_hugr_roundtrip(&hugr, true); - rhs.validate(&EMPTY_REG).unwrap_err(); - rhs.validate(&PRELUDE_REGISTRY)?; + rhs.validate()?; Ok(()) } @@ -429,10 +422,10 @@ fn hierarchy_order() -> Result<(), Box> { fn constants_roundtrip() -> Result<(), Box> { let mut builder = DFGBuilder::new(inout_sig(vec![], INT_TYPES[4].clone())).unwrap(); let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap()); - let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?; + let hugr = builder.finish_hugr_with_outputs([w])?; - let ser = serde_json::to_string(&hugr)?; - let deser = serde_json::from_str(&ser)?; + let ser = serde_json::to_vec(&hugr)?; + let deser = Hugr::load_json(ser.as_slice(), hugr.extensions())?; assert_eq!(hugr, deser); diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index e9837ddc0..5e1d3ee51 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -4,7 +4,6 @@ use crate::{ hugr::serialize::test::check_hugr_deserialize, std_extensions::logic::LogicOp, types::Signature, - utils::test_quantum_extension, }; use lazy_static::lazy_static; use std::{ @@ -50,9 +49,7 @@ pub fn hugr_with_named_op() -> Hugr { DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![bool_t()])).unwrap(); let [a, b] = builder.input_wires_arr(); let x = builder.add_dataflow_op(LogicOp::And, [a, b]).unwrap(); - builder - .finish_hugr_with_outputs(x.outputs(), &test_quantum_extension::REG) - .unwrap() + builder.finish_hugr_with_outputs(x.outputs()).unwrap() } #[rstest] diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 6c9c92753..663253e1a 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,8 +9,7 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::resolution::ExtensionResolutionError; -use crate::extension::{ExtensionRegistry, SignatureError, TO_BE_INFERRED}; +use crate::extension::{SignatureError, TO_BE_INFERRED}; use crate::ops::constant::ConstTypeError; use crate::ops::custom::{ExtensionOp, OpaqueOpError}; @@ -28,12 +27,10 @@ use super::ExtensionError; /// /// TODO: Consider implementing updatable dominator trees and storing it in the /// Hugr to avoid recomputing it every time. -struct ValidationContext<'a, 'b> { +struct ValidationContext<'a> { hugr: &'a Hugr, /// Dominator tree for each CFG region, using the container node as index. dominators: HashMap>, - /// Registry of available Extensions - extension_registry: &'b ExtensionRegistry, } impl Hugr { @@ -41,20 +38,18 @@ impl Hugr { /// variables. /// TODO: Add a version of validation which allows for open extension /// variables (see github issue #457) - pub fn validate(&self, extension_registry: &ExtensionRegistry) -> Result<(), ValidationError> { - self.validate_no_extensions(extension_registry)?; - #[cfg(feature = "extension_inference")] - self.validate_extensions()?; + pub fn validate(&self) -> Result<(), ValidationError> { + self.validate_no_extensions()?; + if cfg!(feature = "extension_inference") { + self.validate_extensions()?; + } Ok(()) } /// Check the validity of the HUGR, but don't check consistency of extension /// requirements between connected nodes or between parents and children. - pub fn validate_no_extensions( - &self, - extension_registry: &ExtensionRegistry, - ) -> Result<(), ValidationError> { - let mut validator = ValidationContext::new(self, extension_registry); + pub fn validate_no_extensions(&self) -> Result<(), ValidationError> { + let mut validator = ValidationContext::new(self); validator.validate() } @@ -96,17 +91,14 @@ impl Hugr { } } -impl<'a, 'b> ValidationContext<'a, 'b> { +impl<'a> ValidationContext<'a> { /// Create a new validation context. // Allow unused "extension_closure" variable for when // the "extension_inference" feature is disabled. #[allow(unused_variables)] - pub fn new(hugr: &'a Hugr, extension_registry: &'b ExtensionRegistry) -> Self { - Self { - hugr, - dominators: HashMap::new(), - extension_registry, - } + pub fn new(hugr: &'a Hugr) -> Self { + let dominators = HashMap::new(); + Self { hugr, dominators } } /// Check the validity of the HUGR. @@ -308,11 +300,11 @@ impl<'a, 'b> ValidationContext<'a, 'b> { var_decls: &[TypeParam], ) -> Result<(), SignatureError> { match &port_kind { - EdgeKind::Value(ty) => ty.validate(self.extension_registry, var_decls), + EdgeKind::Value(ty) => ty.validate(self.hugr.extensions(), var_decls), // Static edges must *not* refer to type variables declared by enclosing FuncDefns // as these are only types at runtime. - EdgeKind::Const(ty) => ty.validate(self.extension_registry, &[]), - EdgeKind::Function(pf) => pf.validate(self.extension_registry), + EdgeKind::Const(ty) => ty.validate(self.hugr.extensions(), &[]), + EdgeKind::Function(pf) => pf.validate(self.hugr.extensions()), _ => Ok(()), } } @@ -583,7 +575,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Check TypeArgs are valid, and if we can, fit the declared TypeParams ext_op .def() - .validate_args(ext_op.args(), self.extension_registry, var_decls) + .validate_args(ext_op.args(), self.hugr.extensions(), var_decls) .map_err(|cause| ValidationError::SignatureError { node, op: op_type.name(), @@ -600,7 +592,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { ))?; } OpType::Call(c) => { - c.validate(self.extension_registry).map_err(|cause| { + c.validate(self.hugr.extensions()).map_err(|cause| { ValidationError::SignatureError { node, op: op_type.name(), @@ -609,7 +601,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { })?; } OpType::LoadFunction(c) => { - c.validate(self.extension_registry).map_err(|cause| { + c.validate(self.hugr.extensions()).map_err(|cause| { ValidationError::SignatureError { node, op: op_type.name(), @@ -777,11 +769,6 @@ pub enum ValidationError { /// [Type]: crate::types::Type #[error(transparent)] ConstTypeError(#[from] ConstTypeError), - /// Some operations or types in the HUGR reference invalid extensions. - // - // TODO: Remove once `hugr::update_validate` is removed. - #[error(transparent)] - ExtensionResolutionError(#[from] ExtensionResolutionError), } /// Errors related to the inter-graph edge validations. diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index d4309c92d..155449af5 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -11,8 +11,10 @@ use crate::builder::{ FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE, PRELUDE_ID}; -use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; +use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; +use crate::extension::{ + Extension, ExtensionRegistry, ExtensionSet, TypeDefBound, PRELUDE, PRELUDE_REGISTRY, +}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; use crate::ops::dataflow::IOTrait; @@ -72,12 +74,12 @@ fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, No fn invalid_root() { let mut b = Hugr::new(LogicOp::Not); let root = b.root(); - assert_eq!(b.validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(b.validate(), Ok(())); // Change the number of ports in the root b.set_num_ports(root, 1, 0); assert_matches!( - b.validate(&PRELUDE_REGISTRY), + b.validate(), Err(ValidationError::WrongNumberOfPorts { node, .. }) => assert_eq!(node, root) ); b.set_num_ports(root, 2, 2); @@ -85,7 +87,7 @@ fn invalid_root() { // Connect it to itself b.connect(root, 0, root, 0); assert_matches!( - b.validate(&PRELUDE_REGISTRY), + b.validate(), Err(ValidationError::RootWithEdges { node, .. }) => assert_eq!(node, root) ); b.disconnect(root, OutgoingPort::from(0)); @@ -93,21 +95,21 @@ fn invalid_root() { // Add another hierarchy root let module = b.add_node(ops::Module::new().into()); assert_matches!( - b.validate(&PRELUDE_REGISTRY), + b.validate(), Err(ValidationError::NoParent { node }) => assert_eq!(node, module) ); // Make the hugr root not a hierarchy root b.set_parent(root, module); assert_matches!( - b.validate(&PRELUDE_REGISTRY), + b.validate(), Err(ValidationError::RootNotRoot { node }) => assert_eq!(node, root) ); // Fix the root b.root = module.pg_index(); b.remove_node(root); - assert_eq!(b.validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(b.validate(), Ok(())); } #[test] @@ -115,7 +117,7 @@ fn leaf_root() { let leaf_op: OpType = Noop(usize_t()).into(); let b = Hugr::new(leaf_op); - assert_eq!(b.validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(b.validate(), Ok(())); } #[test] @@ -128,13 +130,13 @@ fn dfg_root() { let mut b = Hugr::new(dfg_op); let root = b.root(); add_df_children(&mut b, root, 1); - assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); + assert_eq!(b.validate(), Ok(())); } #[test] fn simple_hugr() { - let mut b = make_simple_hugr(2).0; - assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); + let b = make_simple_hugr(2).0; + assert_eq!(b.validate(), Ok(())); } #[test] @@ -159,7 +161,7 @@ fn children_restrictions() { }, ); assert_matches!( - b.update_validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) ); @@ -167,7 +169,7 @@ fn children_restrictions() { add_df_children(&mut b, new_def, 2); b.set_parent(new_def, copy); assert_matches!( - b.update_validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) ); b.set_parent(new_def, root); @@ -176,7 +178,7 @@ fn children_restrictions() { // add an input node to the module subgraph let new_input = b.add_node_with_parent(root, ops::Input::new(type_row![])); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} ); } @@ -195,7 +197,7 @@ fn df_children_restrictions() { // Replace the output operation of the df subgraph with a copy b.replace_op(output, Noop(usize_t())).unwrap(); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); @@ -203,7 +205,7 @@ fn df_children_restrictions() { b.replace_op(output, ops::Output::new(vec![bool_t()])) .unwrap(); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} ); @@ -214,7 +216,7 @@ fn df_children_restrictions() { b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])) .unwrap(); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} ); @@ -248,21 +250,19 @@ fn test_ext_edge() { h.connect(input, 0, sub_dfg, 0); h.connect(sub_dfg, 0, output, 0); - assert_matches!( - h.update_validate(&test_quantum_extension::REG), - Err(ValidationError::UnconnectedPort { .. }) - ); + assert_matches!(h.validate(), Err(ValidationError::UnconnectedPort { .. })); h.connect(input, 1, sub_op, 1); assert_matches!( - h.update_validate(&test_quantum_extension::REG), + h.validate(), Err(ValidationError::InterGraphEdgeError( InterGraphEdgeError::MissingOrderEdge { .. } )) ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg); - h.update_validate(&test_quantum_extension::REG).unwrap(); + h.infer_extensions(false).unwrap(); + h.validate().unwrap(); } #[test] @@ -276,9 +276,9 @@ fn no_ext_edge_into_func() -> Result<(), Box> { let [fn_input] = func.input_wires_arr(); let and_op = func.add_dataflow_op(and_op(), [fn_input, input])?; // 'ext' edge let func = func.finish_with_outputs(and_op.outputs())?; - let loadfn = dfg.load_func(func.handle(), &[], &EMPTY_REG)?; + let loadfn = dfg.load_func(func.handle(), &[])?; let dfg = dfg.finish_with_outputs([loadfn])?; - let res = h.finish_hugr_with_outputs(dfg.outputs(), &test_quantum_extension::REG); + let res = h.finish_hugr_with_outputs(dfg.outputs()); assert_eq!( res, Err(BuildError::InvalidHUGR( @@ -303,7 +303,7 @@ fn test_local_const() { h.connect(input, 0, and, 0); h.connect(and, 0, output, 0); assert_eq!( - h.update_validate(&test_quantum_extension::REG), + h.validate(), Err(ValidationError::UnconnectedPort { node: and, port: IncomingPort::from(1).into(), @@ -324,7 +324,8 @@ fn test_local_const() { h.connect(lcst, 0, and, 1); assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: - h.update_validate(&test_quantum_extension::REG).unwrap(); + h.infer_extensions(false).unwrap(); + h.validate().unwrap(); } #[test] @@ -340,12 +341,11 @@ fn dfg_with_cycles() { h.connect(input, 1, not2, 0); h.connect(not2, 0, output, 0); // The graph contains a cycle: - assert_matches!( - h.validate(&test_quantum_extension::REG), - Err(ValidationError::NotADag { .. }) - ); + assert_matches!(h.validate(), Err(ValidationError::NotADag { .. })); } +/// An identity hugr. Note that extensions must be updated before validation, +/// as `hugr.extensions` is empty. fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { let mut b = Hugr::default(); let row: TypeRow = vec![t].into(); @@ -366,11 +366,10 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { #[test] fn unregistered_extension() { let (mut h, _def) = identity_hugr_with_type(usize_t()); - assert_matches!( - h.validate(&EMPTY_REG), - Err(ValidationError::SignatureError { .. }) - ); - h.update_validate(&test_quantum_extension::REG).unwrap(); + assert!(h.validate().is_err(),); + h.resolve_extension_defs(&test_quantum_extension::REG) + .unwrap(); + h.validate().unwrap(); } const_extension_ids! { @@ -388,21 +387,20 @@ fn invalid_types() { ) .unwrap(); }); - let reg = ExtensionRegistry::new([ext.clone(), PRELUDE.clone()]); + let reg = ExtensionRegistry::new([ext.clone(), PRELUDE.to_owned()]); reg.validate().unwrap(); - let validate_to_sig_error = |t: CustomType| { - let (h, def) = identity_hugr_with_type(Type::new_extension(t)); - match h.validate(®) { - Err(ValidationError::SignatureError { node, cause, .. }) if node == def => cause, - e => panic!( - "Expected SignatureError at def node, got {}", - match e { - Ok(()) => "Ok".to_owned(), - Err(e) => format!("{}", e), - } - ), - } + let validate_to_sig_error = |t: CustomType| -> SignatureError { + let (mut h, def) = identity_hugr_with_type(Type::new_extension(t)); + h.resolve_extension_defs(®).unwrap(); + + let e = h.validate().unwrap_err(); + let (node, cause) = assert_matches!( + e, + ValidationError::SignatureError{ node, cause, .. } => (node, cause) + ); + assert_eq!(node, def); + cause }; let valid = Type::new_extension(CustomType::new( @@ -412,12 +410,9 @@ fn invalid_types() { TypeBound::Any, &Arc::downgrade(&ext), )); - assert_eq!( - identity_hugr_with_type(valid.clone()) - .0 - .update_validate(®), - Ok(()) - ); + let mut hugr = identity_hugr_with_type(valid.clone()).0; + hugr.resolve_extension_defs(®).unwrap(); + assert_eq!(hugr.validate(), Ok(())); // valid is Any, so is not allowed as an element of an outer MyContainer. let element_outside_bound = CustomType::new( @@ -495,7 +490,7 @@ fn typevars_declared() -> Result<(), Box> { ), )?; let [w] = f.input_wires_arr(); - f.finish_prelude_hugr_with_outputs([w])?; + f.finish_hugr_with_outputs([w])?; // Type refers to undeclared variable let f = FunctionBuilder::new( "myfunc", @@ -505,7 +500,7 @@ fn typevars_declared() -> Result<(), Box> { ), )?; let [w] = f.input_wires_arr(); - assert!(f.finish_prelude_hugr_with_outputs([w]).is_err()); + assert!(f.finish_hugr_with_outputs([w]).is_err()); // Variable declaration incorrectly copied to use site let f = FunctionBuilder::new( "myfunc", @@ -515,7 +510,7 @@ fn typevars_declared() -> Result<(), Box> { ), )?; let [w] = f.input_wires_arr(); - assert!(f.finish_prelude_hugr_with_outputs([w]).is_err()); + assert!(f.finish_hugr_with_outputs([w]).is_err()); Ok(()) } @@ -539,7 +534,7 @@ fn nested_typevars() -> Result<(), Box> { let [w] = inner.input_wires_arr(); inner.finish_with_outputs([w])?; let [w] = outer.input_wires_arr(); - outer.finish_prelude_hugr_with_outputs([w]) + outer.finish_hugr_with_outputs([w]) } assert!(build(Type::new_var_use(0, INNER_BOUND)).is_ok()); assert_matches!( @@ -585,7 +580,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { TypeBound::Copyable, ))); let cst = def.add_load_const(empty_list); - let res = def.finish_hugr_with_outputs([cst], ®); + let res = def.finish_hugr_with_outputs([cst]); assert_matches!( res.unwrap_err(), BuildError::InvalidHUGR(ValidationError::SignatureError { @@ -653,10 +648,7 @@ fn instantiate_row_variables() -> Result<(), Box> { let eval2 = e.instantiate_extension_op("eval", [uint_seq(2), uint_seq(4)], &PRELUDE_REGISTRY)?; let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; - dfb.finish_hugr_with_outputs( - eval2.outputs(), - &ExtensionRegistry::new([PRELUDE.clone(), e]), - )?; + dfb.finish_hugr_with_outputs(eval2.outputs())?; Ok(()) } @@ -685,7 +677,7 @@ fn row_variables() -> Result<(), Box> { let bldr = fb.define_function("id_usz", Signature::new_endo(usize_t()))?; let vals = bldr.input_wires(); let inner_def = bldr.finish_with_outputs(vals)?; - fb.load_func(inner_def.handle(), &[], &PRELUDE_REGISTRY)? + fb.load_func(inner_def.handle(), &[])? }; let par = e.instantiate_extension_op( "parallel", @@ -693,10 +685,7 @@ fn row_variables() -> Result<(), Box> { &PRELUDE_REGISTRY, )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; - fb.finish_hugr_with_outputs( - par_func.outputs(), - &ExtensionRegistry::new([PRELUDE.clone(), e]), - )?; + fb.finish_hugr_with_outputs(par_func.outputs())?; Ok(()) } @@ -782,8 +771,6 @@ fn test_polymorphic_call() -> Result<(), Box> { f.finish_with_outputs([tup])? }; - let reg = ExtensionRegistry::new([e, PRELUDE.clone()]); - reg.validate()?; let [func, tup] = d.input_wires_arr(); let call = d.call( f.handle(), @@ -791,9 +778,8 @@ fn test_polymorphic_call() -> Result<(), Box> { es: ExtensionSet::singleton(PRELUDE_ID), }], [func, tup], - ®, )?; - let h = d.finish_hugr_with_outputs(call.outputs(), ®)?; + let h = d.finish_hugr_with_outputs(call.outputs())?; let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); let exp_fun_ty = Signature::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair) .with_extension_delta(EXT_ID) @@ -817,9 +803,9 @@ fn test_polymorphic_load() -> Result<(), Box> { vec![Type::new_function(Signature::new_endo(vec![usize_t()]))], ); let mut f = m.define_function("main", sig)?; - let l = f.load_func(&id, &[usize_t().into()], &PRELUDE_REGISTRY)?; + let l = f.load_func(&id, &[usize_t().into()])?; f.finish_with_outputs([l])?; - let _ = m.finish_prelude_hugr()?; + let _ = m.finish_hugr()?; Ok(()) } @@ -835,7 +821,7 @@ fn cfg_children_restrictions() { .unwrap(); // Write Extension annotations into the Hugr while it's still well-formed // enough for us to compute them - b.validate(&test_quantum_extension::REG).unwrap(); + b.validate().unwrap(); b.replace_op( copy, ops::CFG { @@ -844,7 +830,7 @@ fn cfg_children_restrictions() { ) .unwrap(); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::ContainerWithoutChildren { .. }) ); let cfg = copy; @@ -880,7 +866,7 @@ fn cfg_children_restrictions() { }, ); b.add_other_edge(block, exit); - assert_eq!(b.update_validate(&test_quantum_extension::REG), Ok(())); + assert_eq!(b.validate(), Ok(())); // Test malformed errors @@ -892,7 +878,7 @@ fn cfg_children_restrictions() { }, ); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} ); @@ -927,7 +913,7 @@ fn cfg_children_restrictions() { ) .unwrap(); assert_matches!( - b.validate(&test_quantum_extension::REG), + b.validate(), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) => assert_eq!(parent, cfg) ); @@ -956,11 +942,11 @@ fn cfg_connections() -> Result<(), Box> { let exit = hugr.exit_block(); hugr.branch(&entry, 0, &middle)?; hugr.branch(&middle, 0, &exit)?; - let mut h = hugr.finish_hugr(&PRELUDE_REGISTRY)?; + let mut h = hugr.finish_hugr()?; h.connect(middle.node(), 0, middle.node(), 0); assert_eq!( - h.validate(&PRELUDE_REGISTRY), + h.validate(), Err(ValidationError::TooManyConnections { node: middle.node(), port: Port::new(Direction::Outgoing, 0), @@ -975,12 +961,12 @@ fn cfg_connections() -> Result<(), Box> { fn cfg_entry_io_bug() -> Result<(), Box> { // load test file where input node of entry block has types in reversed // order compared to parent CFG node. - let mut hugr: Hugr = serde_json::from_reader(BufReader::new( + let hugr: Hugr = serde_json::from_reader(BufReader::new( File::open(test_file!("issue-1189.json")).unwrap(), )) .unwrap(); assert_matches!( - hugr.update_validate(&PRELUDE_REGISTRY), + hugr.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch{..}, .. }) => assert_eq!(parent, hugr.root()) ); @@ -1039,7 +1025,7 @@ mod extension_tests { hugr.connect(input, 0, lift, 0); hugr.connect(lift, 0, output, 0); - let result = hugr.validate(&PRELUDE_REGISTRY); + let result = hugr.validate(); assert_eq!( result, Err(ValidationError::ExtensionError(ExtensionError { @@ -1069,7 +1055,7 @@ mod extension_tests { let exit = cfg.exit_block(); cfg.branch(&blk, 0, &exit)?; let root = cfg.hugr().root(); - let res = cfg.finish_prelude_hugr(); + let res = cfg.finish_hugr(); if success { assert!(res.is_ok()) } else { @@ -1143,7 +1129,7 @@ mod extension_tests { case }); // case is the last-assigned child, i.e. the one that requires 'XB' - let result = hugr.validate(&PRELUDE_REGISTRY); + let result = hugr.validate(); let expected = if success { Ok(()) } else { @@ -1171,7 +1157,7 @@ mod extension_tests { let lift = dfg.add_dataflow_op(Lift::new(usize_t().into(), XB), dfg.input_wires())?; let pred = make_pred(&mut dfg, lift.outputs())?; let root = dfg.hugr().root(); - let res = dfg.finish_prelude_hugr_with_outputs([pred]); + let res = dfg.finish_hugr_with_outputs([pred]); if success { assert!(res.is_ok()) } else { diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index e0b68ffaa..34396ec68 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -434,16 +434,30 @@ pub trait HugrView: HugrInternals { .map(|(p, t)| (p.as_outgoing().unwrap(), t)) } + /// Returns the set of extensions used by the HUGR. + /// + /// This set may contain extensions that are no longer required by the HUGR. + fn extensions(&self) -> &ExtensionRegistry { + &self.base_hugr().extensions + } + /// Check the validity of the underlying HUGR. - fn validate(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> { - self.base_hugr().validate(reg) + /// + /// This includes checking consistency of extension requirements between + /// connected nodes and between parents and children. + /// See [`HugrView::validate_no_extensions`] for a version that doesn't check + /// extension requirements. + fn validate(&self) -> Result<(), ValidationError> { + self.base_hugr().validate() } /// Check the validity of the underlying HUGR, but don't check consistency /// of extension requirements between connected nodes or between parents and /// children. - fn validate_no_extensions(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> { - self.base_hugr().validate_no_extensions(reg) + /// + /// For a more thorough check, use [`HugrView::validate`]. + fn validate_no_extensions(&self) -> Result<(), ValidationError> { + self.base_hugr().validate_no_extensions() } } diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index f7b893ddf..92a09037d 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -176,7 +176,6 @@ pub(super) mod test { use rstest::rstest; use crate::extension::prelude::{qb_t, usize_t}; - use crate::utils::test_quantum_extension; use crate::IncomingPort; use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, @@ -214,7 +213,7 @@ pub(super) mod test { func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?; (f_id, inner_id) }; - let hugr = module_builder.finish_hugr(&test_quantum_extension::REG)?; + let hugr = module_builder.finish_hugr()?; Ok((hugr, f_id.handle().node(), inner_id.handle().node())) } @@ -291,7 +290,7 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; let extracted = region.extract_hugr(); - extracted.validate(&test_quantum_extension::REG)?; + extracted.validate()?; let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index b9710f00a..0f0794831 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -341,7 +341,7 @@ mod test { use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::ops::{OpTrait, OpType}; use crate::types::Signature; - use crate::utils::test_quantum_extension::{self, EXTENSION_ID}; + use crate::utils::test_quantum_extension::EXTENSION_ID; use crate::IncomingPort; use super::super::descendants::test::make_module_hgr; @@ -456,7 +456,7 @@ mod test { let ins = dfg.input_wires(); let sub_dfg = dfg.finish_with_outputs(ins)?; let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?; - let h = module_builder.finish_hugr(&test_quantum_extension::REG)?; + let h = module_builder.finish_hugr()?; let sub_dfg = sub_dfg.node(); // We can create a view from a child or grandchild of a hugr: @@ -485,9 +485,7 @@ mod test { /// Mutate a SiblingMut wrapper #[rstest] fn flat_mut(mut simple_dfg_hugr: Hugr) { - simple_dfg_hugr - .update_validate(&test_quantum_extension::REG) - .unwrap(); + simple_dfg_hugr.validate().unwrap(); let root = simple_dfg_hugr.root(); let signature = simple_dfg_hugr.inner_function_type().unwrap().clone(); @@ -512,9 +510,7 @@ mod test { // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); - assert!(simple_dfg_hugr - .validate(&test_quantum_extension::REG) - .is_err()); + assert!(simple_dfg_hugr.validate().is_err()); } #[rstest] @@ -543,7 +539,7 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; let extracted = region.extract_hugr(); - extracted.validate(&test_quantum_extension::REG)?; + extracted.validate()?; let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index f6494231a..82033db42 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -785,7 +785,6 @@ mod tests { use cool_asserts::assert_matches; use crate::builder::inout_sig; - use crate::extension::{prelude, ExtensionRegistry}; use crate::ops::Const; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::std_extensions::logic::{self, LogicOp}; @@ -849,11 +848,7 @@ mod tests { dfg.finish_with_outputs([w0, w1, w2])? }; let hugr = mod_builder - .finish_hugr(&ExtensionRegistry::new([ - prelude::PRELUDE.to_owned(), - test_quantum_extension::EXTENSION.to_owned(), - float_types::EXTENSION.to_owned(), - ])) + .finish_hugr() .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -875,7 +870,7 @@ mod tests { dfg.finish_with_outputs(outs3.outputs())? }; let hugr = mod_builder - .finish_hugr(&test_quantum_extension::REG) + .finish_hugr() .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -897,7 +892,7 @@ mod tests { dfg.finish_with_outputs([b1, b2])? }; let hugr = mod_builder - .finish_hugr(&test_quantum_extension::REG) + .finish_hugr() .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -918,7 +913,7 @@ mod tests { dfg.finish_with_outputs(outs.outputs())? }; let hugr = mod_builder - .finish_hugr(&test_quantum_extension::REG) + .finish_hugr() .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -951,7 +946,7 @@ mod tests { let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap(); let inputs = builder.input_wires(); - builder.finish_prelude_hugr_with_outputs(inputs).unwrap() + builder.finish_hugr_with_outputs(inputs).unwrap() }; let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap(); @@ -991,7 +986,7 @@ mod tests { let empty_dfg = { let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap(); let inputs = builder.input_wires(); - builder.finish_prelude_hugr_with_outputs(inputs).unwrap() + builder.finish_hugr_with_outputs(inputs).unwrap() }; assert_matches!( @@ -1135,13 +1130,7 @@ mod tests { let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); let extracted = subgraph.extract_subgraph(&hugr, "region"); - extracted - .validate(&ExtensionRegistry::new([ - prelude::PRELUDE.to_owned(), - test_quantum_extension::EXTENSION.to_owned(), - float_types::EXTENSION.to_owned(), - ])) - .unwrap(); + extracted.validate().unwrap(); } #[test] @@ -1161,9 +1150,7 @@ mod tests { .unwrap() .outputs(); let outw = [outw1].into_iter().chain(outw2); - let h = builder - .finish_hugr_with_outputs(outw, &test_quantum_extension::REG) - .unwrap(); + let h = builder.finish_hugr_with_outputs(outw).unwrap(); let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); assert_eq!(subg.nodes().len(), 2); diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index f69779082..692b14994 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -1,8 +1,6 @@ use portgraph::PortOffset; use rstest::{fixture, rstest}; -use crate::std_extensions::logic::LOGIC_REG; -use crate::utils::test_quantum_extension; use crate::{ builder::{ endo_sig, inout_sig, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, @@ -33,12 +31,7 @@ pub(crate) fn sample_hugr() -> (Hugr, BuildHandle, BuildHandle ExtensionRegistry { - ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) - } - /// Constructs a DFG hugr defining a sum constant, and returning the loaded value. #[test] fn test_sum() -> Result<(), BuildError> { @@ -638,7 +635,7 @@ mod test { pred_ty.clone(), )?); let w = b.load_const(&c); - b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); + b.finish_hugr_with_outputs([w]).unwrap(); let mut b = DFGBuilder::new(Signature::new( type_row![], @@ -646,7 +643,7 @@ mod test { ))?; let c = b.add_constant(Value::sum(1, [], pred_ty.clone())?); let w = b.load_const(&c); - b.finish_hugr_with_outputs([w], &test_registry()).unwrap(); + b.finish_hugr_with_outputs([w]).unwrap(); Ok(()) } diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index e6d5793f8..b0312cd9d 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -302,6 +302,9 @@ pub enum OpaqueOpError { /// Unresolved operation encountered during validation. #[error("Unexpected unresolved opaque operation '{1}' in {0}, from Extension {2}.")] UnresolvedOp(Node, OpName, ExtensionId), + /// Error updating the extension registry in the Hugr while resolving opaque ops. + #[error("Error updating extension registry: {0}")] + ExtensionRegistryError(#[from] crate::extension::ExtensionRegistryError), } #[cfg(test)] diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index 16eed59e6..e12e29445 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -215,18 +215,21 @@ impl ChildrenValidationError { #[non_exhaustive] pub enum EdgeValidationError { /// The dataflow signature of two connected basic blocks does not match. - #[error("The dataflow signature of two connected basic blocks does not match. Output signature: {source_op}, input signature: {target_op}", - source_op = edge.source_op, - target_op = edge.target_op + #[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}", + source_ty = source_types.clone().unwrap_or_default(), )] - CFGEdgeSignatureMismatch { edge: ChildrenEdgeData }, + CFGEdgeSignatureMismatch { + edge: ChildrenEdgeData, + source_types: Option, + target_types: TypeRow, + }, } impl EdgeValidationError { /// Returns information on the edge that caused the error. pub fn edge(&self) -> &ChildrenEdgeData { match self { - EdgeValidationError::CFGEdgeSignatureMismatch { edge } => edge, + EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge, } } } @@ -342,8 +345,14 @@ fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> _ => panic!("CFG sibling graphs can only contain basic block operations."), }; - if source.successor_input(edge.source_port.index()).as_ref() != Some(target_input) { - return Err(EdgeValidationError::CFGEdgeSignatureMismatch { edge }); + let source_types = source.successor_input(edge.source_port.index()); + if source_types.as_ref() != Some(target_input) { + let target_types = target_input.clone(); + return Err(EdgeValidationError::CFGEdgeSignatureMismatch { + edge, + source_types, + target_types, + }); } Ok(()) diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index aa32f24d5..7461203f4 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -1,41 +1,43 @@ //! Bundles of hugr modules along with the extension required to load them. use derive_more::{Display, Error, From}; -use std::collections::HashMap; +use itertools::Itertools; use std::path::Path; -use std::sync::Arc; use std::{fs, io, mem}; use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder}; -use crate::extension::{ExtensionRegistry, ExtensionRegistryError}; +use crate::extension::resolution::ExtensionResolutionError; +use crate::extension::{ExtensionId, ExtensionRegistry}; use crate::hugr::internal::HugrMutInternals; -use crate::hugr::{HugrView, ValidationError}; +use crate::hugr::{ExtensionError, HugrView, ValidationError}; use crate::ops::{FuncDefn, Module, NamedOp, OpTag, OpTrait, OpType}; -use crate::{Extension, Hugr}; +use crate::Hugr; -#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] -/// Package of module HUGRs and extensions. -/// The HUGRs are validated against the extensions. +#[derive(Debug, Default, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +/// Package of module HUGRs. pub struct Package { /// Module HUGRs included in the package. pub modules: Vec, - /// Extensions to validate against. - pub extensions: Vec>, + /// Extensions used in the modules. + /// + /// This is a superset of the extensions used in the modules. + pub extensions: ExtensionRegistry, } impl Package { - /// Create a new package from a list of hugrs and extensions. + /// Create a new package from a list of hugrs. /// /// All the HUGRs must have a `Module` operation at the root. /// + /// Collects the extensions used in the modules and stores them in top-level + /// `extensions` attribute. + /// /// # Errors /// /// Returns an error if any of the HUGRs does not have a `Module` root. - pub fn new( - modules: impl IntoIterator, - extensions: impl IntoIterator>, - ) -> Result { + pub fn new(modules: impl IntoIterator) -> Result { let modules: Vec = modules.into_iter().collect(); + let mut extensions = ExtensionRegistry::default(); for (idx, module) in modules.iter().enumerate() { let root_op = module.get_optype(module.root()); if !root_op.is_module() { @@ -44,14 +46,15 @@ impl Package { root_op: root_op.clone(), }); } + extensions.extend(module.extensions()); } Ok(Self { modules, - extensions: extensions.into_iter().collect(), + extensions, }) } - /// Create a new package from a list of hugrs and extensions. + /// Create a new package from a list of hugrs. /// /// HUGRs that do not have a `Module` root will be wrapped in a new `Module` root, /// depending on the root optype. @@ -61,21 +64,24 @@ impl Package { /// # Errors /// /// Returns an error if any of the HUGRs cannot be wrapped in a module. - pub fn from_hugrs( - modules: impl IntoIterator, - extensions: impl IntoIterator>, - ) -> Result { + pub fn from_hugrs(modules: impl IntoIterator) -> Result { let modules: Vec = modules .into_iter() .map(to_module_hugr) .collect::>()?; + + let mut extensions = ExtensionRegistry::default(); + for module in &modules { + extensions.extend(module.extensions()); + } + Ok(Self { modules, - extensions: extensions.into_iter().collect(), + extensions, }) } - /// Create a new package containing a single HUGR, and no extension definitions. + /// Create a new package containing a single HUGR. /// /// If the Hugr is not a module, a new [OpType::Module] root will be added. /// This behaviours depends on the root optype. @@ -88,52 +94,64 @@ impl Package { pub fn from_hugr(hugr: Hugr) -> Result { let mut package = Self::default(); let module = to_module_hugr(hugr)?; + package.extensions = module.extensions().clone(); package.modules.push(module); Ok(package) } - /// Validate the package against an extension registry. + + /// Validate the modules of the package. /// - /// `reg` is updated with any new extensions. - pub fn update_validate( - &mut self, - reg: &mut ExtensionRegistry, - ) -> Result<(), PackageValidationError> { - for ext in &self.extensions { - reg.register_updated_ref(ext); - } - for hugr in self.modules.iter_mut() { - hugr.update_validate(reg)?; + /// Ensures that the top-level extension list is a superset of the extensions used in the modules. + pub fn validate(&self) -> Result<(), PackageValidationError> { + for hugr in self.modules.iter() { + hugr.validate()?; + + let missing_exts = hugr + .extensions() + .ids() + .filter(|id| !self.extensions.contains(id)) + .cloned() + .collect_vec(); + if !missing_exts.is_empty() { + return Err(PackageValidationError::MissingExtension { + missing: missing_exts, + available: self.extensions.ids().cloned().collect(), + }); + } } Ok(()) } - /// Validate the package against an extension registry. - /// - /// `reg` is updated with any new extensions. - /// - /// Returns the validated modules. - /// - /// deprecated: use [Package::update_validate] instead. - #[deprecated(since = "0.13.2", note = "Replaced by `Package::update_validate`")] - pub fn validate( - mut self, - reg: &mut ExtensionRegistry, - ) -> Result, PackageValidationError> { - self.update_validate(reg)?; - Ok(self.modules) - } - /// Read a Package in json format from an io reader. /// /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. - pub fn from_json_reader(reader: impl io::Read) -> Result { + pub fn from_json_reader( + reader: impl io::Read, + extension_registry: &ExtensionRegistry, + ) -> Result { let val: serde_json::Value = serde_json::from_reader(reader)?; - let pkg_load_err = match serde_json::from_value::(val.clone()) { - Ok(p) => return Ok(p), - Err(e) => e, + let loaded_pkg = serde_json::from_value::(val.clone()); + + if let Ok(mut pkg) = loaded_pkg { + // Resolve the operations in the modules using the defined registries. + let mut combined_registry = extension_registry.clone(); + combined_registry.extend(&pkg.extensions); + + for module in &mut pkg.modules { + module.resolve_extension_defs(&combined_registry)?; + pkg.extensions.extend(module.extensions()); + } + + return Ok(pkg); }; + let pkg_load_err = loaded_pkg.unwrap_err(); - if let Ok(hugr) = serde_json::from_value::(val) { + // As a fallback, try to load a hugr json. + if let Ok(mut hugr) = serde_json::from_value::(val) { + hugr.resolve_extension_defs(extension_registry)?; + if cfg!(feature = "extension_inference") { + hugr.infer_extensions(false)?; + } return Ok(Package::from_hugr(hugr)?); } @@ -144,17 +162,23 @@ impl Package { /// Read a Package from a json string. /// /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. - pub fn from_json(json: impl AsRef) -> Result { - Self::from_json_reader(json.as_ref().as_bytes()) + pub fn from_json( + json: impl AsRef, + extension_registry: &ExtensionRegistry, + ) -> Result { + Self::from_json_reader(json.as_ref().as_bytes(), extension_registry) } /// Read a Package from a json file. /// /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. - pub fn from_json_file(path: impl AsRef) -> Result { + pub fn from_json_file( + path: impl AsRef, + extension_registry: &ExtensionRegistry, + ) -> Result { let file = fs::File::open(path)?; let reader = io::BufReader::new(file); - Self::from_json_reader(reader) + Self::from_json_reader(reader, extension_registry) } /// Write the Package in json format into an io writer. @@ -179,24 +203,6 @@ impl Package { } } -impl PartialEq for Package { - fn eq(&self, other: &Self) -> bool { - if self.modules != other.modules || self.extensions.len() != other.extensions.len() { - return false; - } - // Extensions may be in different orders, so we compare them as sets. - let exts = self - .extensions - .iter() - .map(|e| (&e.name, e)) - .collect::>(); - other - .extensions - .iter() - .all(|e| exts.get(&e.name).map_or(false, |&e2| e == e2)) - } -} - impl AsRef<[Hugr]> for Package { fn as_ref(&self) -> &[Hugr] { &self.modules @@ -296,60 +302,30 @@ pub enum PackageEncodingError { IOError(io::Error), /// Improper package definition. Package(PackageError), + /// Could not resolve the extension needed to encode the hugr. + ExtensionResolution(ExtensionResolutionError), + /// Could not resolve the runtime extensions for the hugr. + RuntimeExtensionResolution(ExtensionError), } /// Error raised while validating a package. -#[derive(Debug, From)] +#[derive(Debug, Display, From, Error)] #[non_exhaustive] pub enum PackageValidationError { /// Error raised while processing the package extensions. - Extension(ExtensionRegistryError), + #[display("The package modules use the extension{} {} not present in the defined set. The declared extensions are {}", + if missing.len() > 1 {"s"} else {""}, + missing.iter().map(|id| id.to_string()).collect::>().join(", "), + available.iter().map(|id| id.to_string()).collect::>().join(", "), + )] + MissingExtension { + /// The missing extensions. + missing: Vec, + /// The available extensions. + available: Vec, + }, /// Error raised while validating the package hugrs. Validation(ValidationError), - /// Error validating HUGR. - // TODO: Remove manual Display and Error impls when removing deprecated variants. - #[from(ignore)] - #[deprecated( - since = "0.13.2", - note = "Replaced by `PackageValidationError::Validation`" - )] - Validate(ValidationError), - /// Error registering extension. - // TODO: Remove manual Display and Error impls when removing deprecated variants. - #[from(ignore)] - #[deprecated( - since = "0.13.2", - note = "Replaced by `PackageValidationError::Extension`" - )] - ExtReg(ExtensionRegistryError), -} - -// Note: We cannot use the `derive_more::Error` derive due to a bug with deprecated elements. -// See https://github.com/JelteF/derive_more/issues/419 -#[allow(deprecated)] -impl std::error::Error for PackageValidationError { - fn source(&self) -> Option<&(dyn derive_more::Error + 'static)> { - match self { - PackageValidationError::Extension(source) => Some(source), - PackageValidationError::Validation(source) => Some(source), - PackageValidationError::Validate(source) => Some(source), - PackageValidationError::ExtReg(source) => Some(source), - } - } -} - -// Note: We cannot use the `derive_more::Display` derive due to a bug with deprecated elements. -// See https://github.com/JelteF/derive_more/issues/419 -impl Display for PackageValidationError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - #[allow(deprecated)] - match self { - PackageValidationError::Extension(e) => write!(f, "Error processing extensions: {}", e), - PackageValidationError::Validation(e) => write!(f, "Error validating HUGR: {}", e), - PackageValidationError::Validate(e) => write!(f, "Error validating HUGR: {}", e), - PackageValidationError::ExtReg(e) => write!(f, "Error registering extension: {}", e), - } - } } #[cfg(test)] @@ -359,28 +335,17 @@ mod test { use crate::builder::test::{ simple_cfg_hugr, simple_dfg_hugr, simple_funcdef_hugr, simple_module_hugr, }; - use crate::extension::{ExtensionId, EMPTY_REG}; use crate::ops::dataflow::IOTrait; use crate::ops::Input; use super::*; use rstest::{fixture, rstest}; - use semver::Version; #[fixture] fn simple_package() -> Package { let hugr0 = simple_module_hugr(); let hugr1 = simple_module_hugr(); - - let ext_1_id = ExtensionId::new("ext1").unwrap(); - let ext_2_id = ExtensionId::new("ext2").unwrap(); - let ext1 = Extension::new(ext_1_id.clone(), Version::new(2, 4, 8)); - let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0)); - - Package { - modules: vec![hugr0, hugr1], - extensions: vec![ext1.into(), ext2.into()], - } + Package::new([hugr0, hugr1]).unwrap() } #[fixture] @@ -392,8 +357,10 @@ mod test { #[case::empty(Package::default())] #[case::simple(simple_package())] fn package_roundtrip(#[case] package: Package) { + use crate::extension::PRELUDE_REGISTRY; + let json = package.to_json().unwrap(); - let new_package = Package::from_json(&json).unwrap(); + let new_package = Package::from_json(&json, &PRELUDE_REGISTRY).unwrap(); assert_eq!(package, new_package); } @@ -425,16 +392,15 @@ mod test { let dfg = simple_dfg_hugr(); assert_matches!( - Package::new([module.clone(), dfg.clone()], []), + Package::new([module.clone(), dfg.clone()]), Err(PackageError::NonModuleHugr { module_index: 1, root_op: OpType::DFG(_), }) ); - let mut pkg = Package::from_hugrs([module, dfg], []).unwrap(); - let mut reg = EMPTY_REG.clone(); - pkg.update_validate(&mut reg).unwrap(); + let pkg = Package::from_hugrs([module, dfg]).unwrap(); + pkg.validate().unwrap(); assert_eq!(pkg.modules.len(), 2); } diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 22587ec3e..d263f7a7b 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -235,9 +235,7 @@ pub(crate) mod test { use strum::IntoEnumIterator; use super::*; - use crate::std_extensions::arithmetic::float_types::{ - float64_type, EXTENSION as FLOAT_EXTENSION, - }; + use crate::std_extensions::arithmetic::float_types::float64_type; fn get_opdef(op: impl NamedOp) -> Option<&'static Arc> { EXTENSION.get_op(&op.name()) } @@ -271,7 +269,6 @@ pub(crate) mod test { fn test_build() { let in_row = vec![bool_t(), float64_type()]; - let reg = ExtensionRegistry::new([EXTENSION.to_owned(), FLOAT_EXTENSION.to_owned()]); let hugr = { let mut builder = DFGBuilder::new( Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID), @@ -285,8 +282,8 @@ pub(crate) mod test { builder.add_write_ptr(new_ptr, read).unwrap(); } - builder.finish_hugr_with_outputs([], ®).unwrap() + builder.finish_hugr_with_outputs([]).unwrap() }; - assert_matches!(hugr.validate(®), Ok(_)); + assert_matches!(hugr.validate(), Ok(_)); } } diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index c8a53902c..8bf8856c1 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -165,7 +165,7 @@ impl SimpleHugrConfig { // unvalidated. // println!("{}", mod_b.hugr().mermaid_string()); - mod_b.finish_hugr(&self.extensions).unwrap() + mod_b.finish_hugr().unwrap() } } @@ -251,7 +251,7 @@ mod test_fns { use hugr_core::builder::DataflowSubContainer; use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}; use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize}; - use hugr_core::extension::{EMPTY_REG, PRELUDE_REGISTRY}; + use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::ops::constant::CustomConst; use hugr_core::ops::{CallIndirect, Tag, Value}; @@ -370,9 +370,7 @@ mod test_fns { .declare(name, HugrFuncType::new_endo(io).into()) .unwrap(); let mut func_b = mod_b.define_declaration(&f_id).unwrap(); - let call = func_b - .call(&f_id, &[], func_b.input_wires(), &EMPTY_REG) - .unwrap(); + let call = func_b.call(&f_id, &[], func_b.input_wires()).unwrap(); func_b.finish_with_outputs(call.outputs()).unwrap(); } @@ -380,7 +378,7 @@ mod test_fns { build_recursive(&mut mod_b, "main_void", type_row![]); build_recursive(&mut mod_b, "main_unary", vec![bool_t()].into()); build_recursive(&mut mod_b, "main_binary", vec![bool_t(), bool_t()].into()); - let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = mod_b.finish_hugr().unwrap(); check_emission!(hugr, llvm_ctx); } @@ -390,7 +388,7 @@ mod test_fns { let signature = HugrFuncType::new_endo(io); let f_id = mod_b.declare(name, signature.clone().into()).unwrap(); let mut func_b = mod_b.define_declaration(&f_id).unwrap(); - let func = func_b.load_func(&f_id, &[], &EMPTY_REG).unwrap(); + let func = func_b.load_func(&f_id, &[]).unwrap(); let inputs = iter::once(func).chain(func_b.input_wires()); let call_indirect = func_b .add_dataflow_op(CallIndirect { signature }, inputs) @@ -402,7 +400,7 @@ mod test_fns { build_recursive(&mut mod_b, "main_void", type_row![]); build_recursive(&mut mod_b, "main_unary", vec![bool_t()].into()); build_recursive(&mut mod_b, "main_binary", vec![bool_t(), bool_t()].into()); - let hugr = mod_b.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = mod_b.finish_hugr().unwrap(); check_emission!(hugr, llvm_ctx); } @@ -459,7 +457,7 @@ mod test_fns { let _ = builder .declare("decl", HugrFuncType::new_endo(type_row![]).into()) .unwrap(); - builder.finish_hugr(&EMPTY_REG).unwrap() + builder.finish_hugr().unwrap() }; check_emission!(hugr, llvm_ctx); } @@ -484,10 +482,7 @@ mod test_fns { let w = builder.load_const(&konst); builder.finish_with_outputs([w]).unwrap() }; - let [r] = builder - .call(func.handle(), &[], [], &EMPTY_REG) - .unwrap() - .outputs_arr(); + let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); builder.finish_with_outputs([r]).unwrap().outputs_arr() }; builder.finish_with_outputs([r]).unwrap() @@ -518,10 +513,7 @@ mod test_fns { .entry_builder([type_row![]], vec![bool_t()].into()) .unwrap(); let control = builder.add_load_value(Value::unary_unit_sum()); - let [r] = builder - .call(func.handle(), &[], [], &EMPTY_REG) - .unwrap() - .outputs_arr(); + let [r] = builder.call(func.handle(), &[], []).unwrap().outputs_arr(); builder.finish_with_outputs(control, [r]).unwrap() }; let exit = builder.exit_block(); @@ -548,10 +540,10 @@ mod test_fns { Signature::new(type_row![], Type::new_function(target_sig)), ) .unwrap(); - let r = builder.load_func(&target_func, &[], &EMPTY_REG).unwrap(); + let r = builder.load_func(&target_func, &[]).unwrap(); builder.finish_with_outputs([r]).unwrap() }; - builder.finish_hugr(&EMPTY_REG).unwrap() + builder.finish_hugr().unwrap() }; check_emission!(hugr, llvm_ctx); diff --git a/hugr-llvm/src/extension/prelude/array.rs b/hugr-llvm/src/extension/prelude/array.rs index 787b0a60b..4333dde72 100644 --- a/hugr-llvm/src/extension/prelude/array.rs +++ b/hugr-llvm/src/extension/prelude/array.rs @@ -969,9 +969,7 @@ mod test { .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); let func_id = func.finish_with_outputs(vec![v]).unwrap(); - let func_v = builder - .load_func(func_id.handle(), &[], &exec_registry()) - .unwrap(); + let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); let arr = builder .add_dataflow_op(repeat, vec![func_v]) @@ -1023,9 +1021,7 @@ mod test { let delta = func.add_load_value(ConstInt::new_u(6, inc).unwrap()); let out = func.add_iadd(6, elem, delta).unwrap(); let func_id = func.finish_with_outputs(vec![out]).unwrap(); - let func_v = builder - .load_func(func_id.handle(), &[], &exec_registry()) - .unwrap(); + let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); let scan = ArrayScan::new( int_ty.clone(), int_ty.clone(), @@ -1102,9 +1098,7 @@ mod test { .unwrap() .out_wire(0); let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap(); - let func_v = builder - .load_func(func_id.handle(), &[], &exec_registry()) - .unwrap(); + let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); let scan = ArrayScan::new( int_ty.clone(), Type::UNIT, diff --git a/hugr-llvm/src/utils/array_op_builder.rs b/hugr-llvm/src/utils/array_op_builder.rs index 167584806..1c2492515 100644 --- a/hugr-llvm/src/utils/array_op_builder.rs +++ b/hugr-llvm/src/utils/array_op_builder.rs @@ -200,6 +200,6 @@ pub mod test { #[rstest] fn build_all_ops(all_array_ops: DFGBuilder) { - all_array_ops.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + all_array_ops.finish_hugr().unwrap(); } } diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index aef52cd4d..f946da615 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -106,9 +106,7 @@ mod test { Value::function({ let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap(); let r = go(&mut builder); - builder - .finish_hugr_with_outputs([r], &PRELUDE_REGISTRY) - .unwrap() + builder.finish_hugr_with_outputs([r]).unwrap() }) .unwrap() .into() @@ -138,7 +136,7 @@ mod test { .outputs_arr(); builder.finish_with_outputs([r]).unwrap(); }; - builder.finish_hugr(&PRELUDE_REGISTRY).unwrap() + builder.finish_hugr().unwrap() }; inline_constant_functions(&mut hugr, &PRELUDE_REGISTRY).unwrap(); @@ -187,7 +185,7 @@ mod test { .outputs_arr(); builder.finish_with_outputs([r]).unwrap(); }; - builder.finish_hugr(&PRELUDE_REGISTRY).unwrap() + builder.finish_hugr().unwrap() }; inline_constant_functions(&mut hugr, &PRELUDE_REGISTRY).unwrap(); diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 351f4a19e..7bd181f5f 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -11,7 +11,7 @@ use hugr_core::types::SumType; use hugr_core::Direction; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{fold_out_row, ConstFoldResult, ExtensionRegistry}, + extension::{fold_out_row, ConstFoldResult}, hugr::{ hugrmut::HugrMut, rewrite::consts::{RemoveConst, RemoveLoadConstant}, @@ -51,40 +51,34 @@ impl ConstantFoldPass { } /// Run the Constant Folding pass. - pub fn run( - &self, - hugr: &mut H, - reg: &ExtensionRegistry, - ) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, reg, |hugr: &mut H, _| { - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(hugr, hugr.nodes(), reg).next() - else { - break Ok(()); - }; - hugr.apply_rewrite(replace)?; - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = hugr.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = hugr.apply_rewrite(RemoveConst(const_node)); - } + pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { + self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { + loop { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(hugr, hugr.nodes()).next() else { + break Ok(()); + }; + hugr.apply_rewrite(replace)?; + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = hugr.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = hugr.apply_rewrite(RemoveConst(const_node)); } } - }) + } + }) } } @@ -107,7 +101,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR /// Generate a graph that loads and outputs `consts` in order, validating /// against `reg`. -fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { +fn const_graph(consts: Vec) -> Hugr { let const_types = consts.iter().map(Value::get_type).collect_vec(); let mut b = DFGBuilder::new(inout_sig(type_row![], const_types)).unwrap(); @@ -116,7 +110,7 @@ fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { .map(|c| b.add_load_const(c)) .collect_vec(); - b.finish_hugr_with_outputs(outputs, reg).unwrap() + b.finish_hugr_with_outputs(outputs).unwrap() } /// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, @@ -130,7 +124,6 @@ fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { pub fn find_consts<'a, 'r: 'a>( hugr: &'a impl HugrView, candidate_nodes: impl IntoIterator + 'a, - reg: &'r ExtensionRegistry, ) -> impl Iterator)> + 'a { // track nodes for operations that have already been considered for folding let mut used_neighbours = BTreeSet::new(); @@ -152,7 +145,7 @@ pub fn find_consts<'a, 'r: 'a>( } let fold_iter = neighbours .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); + .filter_map(|(neighbour, _)| fold_op(hugr, neighbour)); Some(fold_iter) }) .flatten() @@ -162,7 +155,6 @@ pub fn find_consts<'a, 'r: 'a>( fn fold_op( hugr: &impl HugrView, op_node: Node, - reg: &ExtensionRegistry, ) -> Option<(SimpleReplacement, Vec)> { // only support leaf folding for now. let neighbour_op = hugr.get_optype(op_node); @@ -184,7 +176,7 @@ fn fold_op( .map(|np| ((np, i.into()), konst)) }) .unzip(); - let replacement = const_graph(consts, reg); + let replacement = const_graph(consts); let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) .expect("Operation should form valid subgraph."); @@ -213,8 +205,8 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option< } /// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - ConstantFoldPass::default().run(h, reg).unwrap() +pub fn constant_fold_pass(h: &mut H) { + ConstantFoldPass::default().run(h).unwrap() } #[cfg(test)] diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 8b54f7e93..109154657 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -104,12 +104,10 @@ fn test_big() { .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), &TEST_REG) - .unwrap(); + let mut h = build.finish_hugr_with_outputs(to_int.outputs()).unwrap(); assert_eq!(h.node_count(), 8); - constant_fold_pass(&mut h, &TEST_REG); + constant_fold_pass(&mut h); let expected = const_ok(i2c(2).clone(), error_type()); assert_fully_folded(&h, &expected); @@ -153,9 +151,9 @@ fn test_list_ops() -> Result<(), Box> { )? .outputs_arr(); - let mut h = build.finish_hugr_with_outputs([list], &TEST_REG)?; + let mut h = build.finish_hugr_with_outputs([list])?; - constant_fold_pass(&mut h, &TEST_REG); + constant_fold_pass(&mut h); assert_fully_folded(&h, &base_list); Ok(()) @@ -171,10 +169,8 @@ fn test_fold_and() { let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::true_val()); let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -189,10 +185,8 @@ fn test_fold_or() { let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::false_val()); let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -206,10 +200,8 @@ fn test_fold_not() { let mut build = DFGBuilder::new(noargfn(bool_t())).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -235,9 +227,7 @@ fn orphan_output() { .unwrap(); let or_node = r.node(); let parent = build.container_node(); - let mut h = build - .finish_hugr_with_outputs(r.outputs(), &TEST_REG) - .unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs()).unwrap(); // we delete the original Not and create a new One. This means it will be // traversed by `constant_fold_pass` after the Or. @@ -246,7 +236,7 @@ fn orphan_output() { h.disconnect(or_node, IncomingPort::from(1)); h.connect(new_not, 0, or_node, 1); h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, &TEST_REG); + constant_fold_pass(&mut h); assert_fully_folded(&h, &Value::true_val()) } @@ -276,10 +266,8 @@ fn test_folding_pass_issue_996() { let x7 = build .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x7.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x7.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -291,10 +279,8 @@ fn test_const_fold_to_nonfinite() { let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h0, &TEST_REG); + let mut h0 = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h0); assert_fully_folded_with(&h0, |v| { v.get_custom_value::().unwrap().value() == 1.0 }); @@ -305,10 +291,8 @@ fn test_const_fold_to_nonfinite() { let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h1, &TEST_REG); + let mut h1 = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h1); assert_eq!(h1.node_count(), 8); } @@ -324,10 +308,8 @@ fn test_fold_iwiden_u() { let x1 = build .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 13).unwrap()); assert_fully_folded(&h, &expected); } @@ -344,10 +326,8 @@ fn test_fold_iwiden_s() { let x1 = build .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); assert_fully_folded(&h, &expected); } @@ -392,10 +372,8 @@ fn test_fold_inarrow, E: std::fmt::Debug>( [x0], ) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); lazy_static! { static ref INARROW_ERROR_VALUE: ConstError = ConstError { signal: 0, @@ -422,10 +400,8 @@ fn test_fold_itobool() { let x1 = build .add_dataflow_op(ConvertOpDef::itobool.without_log_width(), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -442,10 +418,8 @@ fn test_fold_ifrombool() { let x1 = build .add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(0, 0).unwrap()); assert_fully_folded(&h, &expected); } @@ -462,10 +436,8 @@ fn test_fold_ieq() { let x2 = build .add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -482,10 +454,8 @@ fn test_fold_ine() { let x2 = build .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -502,10 +472,8 @@ fn test_fold_ilt_u() { let x2 = build .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -522,10 +490,8 @@ fn test_fold_ilt_s() { let x2 = build .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -542,10 +508,8 @@ fn test_fold_igt_u() { let x2 = build .add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -562,10 +526,8 @@ fn test_fold_igt_s() { let x2 = build .add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -582,10 +544,8 @@ fn test_fold_ile_u() { let x2 = build .add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -602,10 +562,8 @@ fn test_fold_ile_s() { let x2 = build .add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -622,10 +580,8 @@ fn test_fold_ige_u() { let x2 = build .add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::false_val(); assert_fully_folded(&h, &expected); } @@ -642,10 +598,8 @@ fn test_fold_ige_s() { let x2 = build .add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } @@ -662,10 +616,8 @@ fn test_fold_imax_u() { let x2 = build .add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 11).unwrap()); assert_fully_folded(&h, &expected); } @@ -682,10 +634,8 @@ fn test_fold_imax_s() { let x2 = build .add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -702,10 +652,8 @@ fn test_fold_imin_u() { let x2 = build .add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 7).unwrap()); assert_fully_folded(&h, &expected); } @@ -722,10 +670,8 @@ fn test_fold_imin_s() { let x2 = build .add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, -2).unwrap()); assert_fully_folded(&h, &expected); } @@ -742,10 +688,8 @@ fn test_fold_iadd() { let x2 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, -1).unwrap()); assert_fully_folded(&h, &expected); } @@ -762,10 +706,8 @@ fn test_fold_isub() { let x2 = build .add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); assert_fully_folded(&h, &expected); } @@ -781,10 +723,8 @@ fn test_fold_ineg() { let x2 = build .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -801,10 +741,8 @@ fn test_fold_imul() { let x2 = build .add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, -14).unwrap()); assert_fully_folded(&h, &expected); } @@ -824,10 +762,8 @@ fn test_fold_idivmod_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::idivmod_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -853,10 +789,8 @@ fn test_fold_idivmod_u() { let x4 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(3), [x2, x3]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x4.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x4.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(3, 8).unwrap()); assert_fully_folded(&h, &expected); } @@ -876,10 +810,8 @@ fn test_fold_idivmod_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::idivmod_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -907,10 +839,8 @@ fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { let x4 = build .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x4.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x4.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(6, c).unwrap()); assert_fully_folded(&h, &expected); } @@ -928,10 +858,8 @@ fn test_fold_idiv_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::idiv_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -952,10 +880,8 @@ fn test_fold_idiv_u() { let x2 = build .add_dataflow_op(IntOpDef::idiv_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 6).unwrap()); assert_fully_folded(&h, &expected); } @@ -973,10 +899,8 @@ fn test_fold_imod_checked_u() { let x2 = build .add_dataflow_op(IntOpDef::imod_checked_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -997,10 +921,8 @@ fn test_fold_imod_u() { let x2 = build .add_dataflow_op(IntOpDef::imod_u.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -1018,10 +940,8 @@ fn test_fold_idiv_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::idiv_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1042,10 +962,8 @@ fn test_fold_idiv_s() { let x2 = build .add_dataflow_op(IntOpDef::idiv_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_s(5, -7).unwrap()); assert_fully_folded(&h, &expected); } @@ -1063,10 +981,8 @@ fn test_fold_imod_checked_s() { let x2 = build .add_dataflow_op(IntOpDef::imod_checked_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = ConstError { signal: 0, message: "Division by zero".to_string(), @@ -1087,10 +1003,8 @@ fn test_fold_imod_s() { let x2 = build .add_dataflow_op(IntOpDef::imod_s.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1106,10 +1020,8 @@ fn test_fold_iabs() { let x2 = build .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); assert_fully_folded(&h, &expected); } @@ -1126,10 +1038,8 @@ fn test_fold_iand() { let x2 = build .add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 4).unwrap()); assert_fully_folded(&h, &expected); } @@ -1146,10 +1056,8 @@ fn test_fold_ior() { let x2 = build .add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 30).unwrap()); assert_fully_folded(&h, &expected); } @@ -1166,10 +1074,8 @@ fn test_fold_ixor() { let x2 = build .add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 26).unwrap()); assert_fully_folded(&h, &expected); } @@ -1185,10 +1091,8 @@ fn test_fold_inot() { let x2 = build .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap()); assert_fully_folded(&h, &expected); } @@ -1205,10 +1109,8 @@ fn test_fold_ishl() { let x2 = build .add_dataflow_op(IntOpDef::ishl.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 112).unwrap()); assert_fully_folded(&h, &expected); } @@ -1225,10 +1127,8 @@ fn test_fold_ishr() { let x2 = build .add_dataflow_op(IntOpDef::ishr.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1245,10 +1145,8 @@ fn test_fold_irotl() { let x2 = build .add_dataflow_op(IntOpDef::irotl.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1265,10 +1163,8 @@ fn test_fold_irotr() { let x2 = build .add_dataflow_op(IntOpDef::irotr.with_log_width(5), [x0, x1]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x2.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); assert_fully_folded(&h, &expected); } @@ -1284,10 +1180,8 @@ fn test_fold_itostring_u() { let x1 = build .add_dataflow_op(ConvertOpDef::itostring_u.with_log_width(5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstString::new("17".into())); assert_fully_folded(&h, &expected); } @@ -1303,10 +1197,8 @@ fn test_fold_itostring_s() { let x1 = build .add_dataflow_op(ConvertOpDef::itostring_s.with_log_width(5), [x0]) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x1.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x1.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::extension(ConstString::new("-17".into())); assert_fully_folded(&h, &expected); } @@ -1343,10 +1235,8 @@ fn test_fold_int_ops() { let x7 = build .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); - let mut h = build - .finish_hugr_with_outputs(x7.outputs(), &TEST_REG) - .unwrap(); - constant_fold_pass(&mut h, &TEST_REG); + let mut h = build.finish_hugr_with_outputs(x7.outputs()).unwrap(); + constant_fold_pass(&mut h); let expected = Value::true_val(); assert_fully_folded(&h, &expected); } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 13815d186..9e745a698 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,7 +1,6 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; -use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; use hugr_core::ops::TailLoop; @@ -10,7 +9,7 @@ use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ prelude::{bool_t, UnpackTuple}, - ExtensionSet, EMPTY_REG, + ExtensionSet, }, ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, type_row, @@ -58,7 +57,7 @@ fn test_make_tuple() { let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); - let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let results = Machine::new(&hugr).run(TestContext, []); @@ -74,7 +73,7 @@ fn test_unpack_tuple_const() { .add_dataflow_op(UnpackTuple::new(vec![bool_t(); 2].into()), [v]) .unwrap() .outputs_arr(); - let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let results = Machine::new(&hugr).run(TestContext, []); @@ -100,7 +99,7 @@ fn test_tail_loop_never_iterates() { .unwrap(); let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap(); let [tl_o] = tail_loop.outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let results = Machine::new(&hugr).run(TestContext, []); @@ -135,7 +134,7 @@ fn test_tail_loop_always_iterates() { let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); let [tl_o1, tl_o2] = tail_loop.outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let results = Machine::new(&hugr).run(TestContext, []); @@ -172,7 +171,7 @@ fn test_tail_loop_two_iters() { let [in_w1, in_w2] = tlb.input_wires_arr(); let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); let results = Machine::new(&hugr).run(TestContext, []); @@ -235,7 +234,7 @@ fn test_tail_loop_containing_conditional() { let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); let results = Machine::new(&hugr).run(TestContext, []); @@ -284,7 +283,7 @@ fn test_conditional() { let [cond_o1, cond_o2] = cond.outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = builder.finish_hugr().unwrap(); let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( 2, @@ -363,7 +362,7 @@ fn xor_and_cfg() -> Hugr { builder.branch(&a, tru, &b).unwrap(); // if true builder.branch(&a, fals, &x).unwrap(); // if false builder.branch(&b, 0, &x).unwrap(); - builder.finish_hugr(&EMPTY_REG).unwrap() + builder.finish_hugr().unwrap() } #[rstest] @@ -410,16 +409,14 @@ fn test_call( let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); let [a, b] = builder.input_wires_arr(); let [a2] = builder - .call(func_defn.handle(), &[], [a], &EMPTY_REG) + .call(func_defn.handle(), &[], [a]) .unwrap() .outputs_arr(); let [b2] = builder - .call(func_defn.handle(), &[], [b], &EMPTY_REG) + .call(func_defn.handle(), &[], [b]) .unwrap() .outputs_arr(); - let hugr = builder - .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) - .unwrap(); + let hugr = builder.finish_hugr_with_outputs([a2, b2]).unwrap(); let results = Machine::new(&hugr).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]); @@ -439,9 +436,7 @@ fn test_region() { .unwrap(); let nested_ins = nested.input_wires(); let nested = nested.finish_with_outputs(nested_ins).unwrap(); - let hugr = builder - .finish_prelude_hugr_with_outputs(nested.outputs()) - .unwrap(); + let hugr = builder.finish_hugr_with_outputs(nested.outputs()).unwrap(); let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); assert_eq!( @@ -494,9 +489,7 @@ fn test_module() { .unwrap(); let [inp] = f2.input_wires_arr(); let cst_true = f2.add_load_value(Value::true_val()); - let f2_call = f2 - .call(leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG) - .unwrap(); + let f2_call = f2.call(leaf_fn.handle(), &[], [inp, cst_true]).unwrap(); let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap(); let mut main = modb @@ -504,11 +497,9 @@ fn test_module() { .unwrap(); let [inp] = main.input_wires_arr(); let cst_false = main.add_load_value(Value::false_val()); - let main_call = main - .call(leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG) - .unwrap(); + let main_call = main.call(leaf_fn.handle(), &[], [inp, cst_false]).unwrap(); main.finish_with_outputs(main_call.outputs()).unwrap(); - let hugr = modb.finish_hugr(&EMPTY_REG).unwrap(); + let hugr = modb.finish_hugr().unwrap(); let [f2_inp, _] = hugr.get_io(f2.node()).unwrap(); let results_just_main = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]); diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index cfbec1682..e5f26248b 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -205,11 +205,10 @@ mod test { use super::*; use hugr_core::builder::{endo_sig, BuildHandle, Dataflow, DataflowHugr}; - use hugr_core::extension::EMPTY_REG; use hugr_core::ops::handle::{DataflowOpID, NodeHandle}; use hugr_core::ops::Value; - use hugr_core::std_extensions::arithmetic::int_ops::{self, IntOpDef}; + use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::types::{Signature, Type}; use hugr_core::{builder::DFGBuilder, hugr::Hugr}; @@ -261,10 +260,7 @@ mod test { .unwrap(); ( builder - .finish_hugr_with_outputs( - [v2.out_wire(0), v3.out_wire(0)], - &int_ops::INT_OPS_REGISTRY, - ) + .finish_hugr_with_outputs([v2.out_wire(0), v3.out_wire(0)]) .unwrap(), nodes, ) @@ -279,8 +275,7 @@ mod test { .iter(&hugr.as_petgraph()) .filter(|n| rank_map.contains_key(n)) .collect_vec(); - hugr.validate_no_extensions(&int_ops::INT_OPS_REGISTRY) - .unwrap(); + hugr.validate_no_extensions().unwrap(); topo_sorted } @@ -326,9 +321,7 @@ mod test { let mut builder = DFGBuilder::new(Signature::new(Type::EMPTY_TYPEROW, Type::UNIT)).unwrap(); let unit = builder.add_load_value(Value::unary_unit_sum()); - builder - .finish_hugr_with_outputs([unit], &EMPTY_REG) - .unwrap() + builder.finish_hugr_with_outputs([unit]).unwrap() }; let root = hugr.root(); force_order(&mut hugr, root, |_, _| 0).unwrap(); diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 3c3d8a40c..cb95b1433 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -96,14 +96,14 @@ mod test { .add_dataflow_op(Noop::new(bool_t()), [b.input_wires().next().unwrap()]) .unwrap() .out_wire(0); - b.finish_prelude_hugr_with_outputs([out]).unwrap() + b.finish_hugr_with_outputs([out]).unwrap() } #[fixture] fn identity_hugr() -> Hugr { let b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let out = b.input_wires().next().unwrap(); - b.finish_prelude_hugr_with_outputs([out]).unwrap() + b.finish_hugr_with_outputs([out]).unwrap() } #[rstest] diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 15752d8cc..9bc2f7f1d 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -159,13 +159,13 @@ mod test { use std::collections::HashSet; use std::sync::Arc; - use hugr_core::extension::prelude::Lift; + use hugr_core::extension::prelude::{Lift, PRELUDE_ID}; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; - use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize, PRELUDE_ID}; - use hugr_core::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}; + use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; + use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; use hugr_core::ops::handle::CfgID; @@ -231,8 +231,6 @@ mod test { let exit_types: TypeRow = vec![usize_t()].into(); let e = extension(); let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::new([PRELUDE.clone(), e]); - reg.validate()?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; @@ -263,10 +261,10 @@ mod test { h.branch(&test_block, 0, &loop_backedge_target)?; h.branch(&test_block, 1, &h.exit_block())?; - let mut h = h.finish_hugr(®)?; + let mut h = h.finish_hugr()?; let r = h.root(); merge_basic_blocks(&mut SiblingMut::::try_new(&mut h, r)?); - h.update_validate(®).unwrap(); + h.validate().unwrap(); assert_eq!(r, h.root()); assert!(matches!(h.get_optype(r), OpType::CFG(_))); let [entry, exit] = h @@ -359,12 +357,10 @@ mod test { h.branch(&bb2, 0, &bb3)?; h.branch(&bb3, 0, &h.exit_block())?; - let reg = ExtensionRegistry::new([e, PRELUDE.clone()]); - reg.validate()?; - let mut h = h.finish_hugr(®)?; + let mut h = h.finish_hugr()?; let root = h.root(); merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); - h.update_validate(®)?; + h.validate()?; // Should only be one BB left let [bb, _exit] = h.children(h.root()).collect::>().try_into().unwrap(); diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index b8091c7af..85ba2242b 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -570,7 +570,6 @@ pub(crate) mod test { use hugr_core::builder::{ endo_sig, BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, }; - use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::extension::{prelude::usize_t, ExtensionSet}; use hugr_core::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; @@ -628,7 +627,7 @@ pub(crate) mod test { let exit = cfg_builder.exit_block(); cfg_builder.branch(&tail, 0, &exit)?; - let mut h = cfg_builder.finish_prelude_hugr()?; + let mut h = cfg_builder.finish_hugr()?; let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); let (entry, exit) = (entry.node(), exit.node()); let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); @@ -654,7 +653,7 @@ pub(crate) mod test { ]) ); transform_cfg_to_nested(&mut IdentityCfgMap::new(rc)); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.validate().unwrap(); assert_eq!(1, depth(&h, entry)); assert_eq!(1, depth(&h, exit)); for n in [split, left, right, merge, head, tail] { @@ -758,7 +757,7 @@ pub(crate) mod test { let root = h.root(); let m = SiblingMut::::try_new(&mut h, root).unwrap(); transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); + h.validate().unwrap(); assert_eq!(1, depth(&h, entry)); assert_eq!(3, depth(&h, head)); for n in [split, left, right, merge] { @@ -902,7 +901,7 @@ pub(crate) mod test { let exit = cfg_builder.exit_block(); cfg_builder.branch(&tail, 0, &exit)?; - let h = cfg_builder.finish_prelude_hugr()?; + let h = cfg_builder.finish_hugr()?; Ok((h, merge, tail)) } @@ -912,7 +911,7 @@ pub(crate) mod test { ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?; let (head, tail) = build_conditional_in_loop(&mut cfg_builder, separate_headers)?; - let h = cfg_builder.finish_prelude_hugr()?; + let h = cfg_builder.finish_hugr()?; Ok((h, head, tail)) } diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 598622089..efb5e7139 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -48,8 +48,6 @@ mod test { types::Signature, }; - use crate::test::TEST_REG; - use super::*; #[test] @@ -62,9 +60,7 @@ mod test { .add_dataflow_op(Noop::new(bool_t()), [in_w]) .unwrap() .outputs_arr(); - builder - .finish_hugr_with_outputs([out_w], &TEST_REG) - .unwrap() + builder.finish_hugr_with_outputs([out_w]).unwrap() }; ensure_no_nonlocal_edges(&hugr).unwrap(); } @@ -91,12 +87,7 @@ mod test { noop_edge, ) }; - ( - builder - .finish_hugr_with_outputs([out_w], &TEST_REG) - .unwrap(), - edge, - ) + (builder.finish_hugr_with_outputs([out_w]).unwrap(), edge) }; assert_eq!( ensure_no_nonlocal_edges(&hugr).unwrap_err(), diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 68fa43601..baf3b86d8 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -3,11 +3,8 @@ use thiserror::Error; -use hugr_core::{ - extension::ExtensionRegistry, - hugr::{hugrmut::HugrMut, ValidationError}, - HugrView, -}; +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; #[derive(Debug, Clone, Copy, Ord, Eq, PartialOrd, PartialEq)] /// A type for running [HugrMut] algorithms with verification. @@ -61,18 +58,19 @@ impl ValidationLevel { pub fn run_validated_pass( &self, hugr: &mut H, - reg: &ExtensionRegistry, pass: impl FnOnce(&mut H, &Self) -> Result, ) -> Result where ValidatePassError: Into, { - self.validation_impl(hugr, reg, |err, pretty_hugr| { - ValidatePassError::InputError { err, pretty_hugr } + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::InputError { + err, + pretty_hugr, })?; let result = pass(hugr, self)?; - self.validation_impl(hugr, reg, |err, pretty_hugr| { - ValidatePassError::OutputError { err, pretty_hugr } + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::OutputError { + err, + pretty_hugr, })?; Ok(result) } @@ -80,7 +78,6 @@ impl ValidationLevel { fn validation_impl( &self, hugr: &impl HugrView, - reg: &ExtensionRegistry, mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, ) -> Result<(), E> where @@ -88,8 +85,8 @@ impl ValidationLevel { { match self { ValidationLevel::None => Ok(()), - ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(reg), - ValidationLevel::WithExtensions => hugr.validate(reg), + ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(), + ValidationLevel::WithExtensions => hugr.validate(), } .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) } diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 9cc64f610..74616e642 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -9,7 +9,6 @@ use hugr::builder::{ use hugr::extension::prelude::{bool_t, qb_t, usize_t}; use hugr::extension::PRELUDE_REGISTRY; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; use hugr::std_extensions::arithmetic::float_types::float64_type; use hugr::types::Signature; use hugr::{type_row, Extension, Hugr, Node}; @@ -18,7 +17,7 @@ use lazy_static::lazy_static; pub fn simple_dfg_hugr() -> Hugr { let dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap(); let [i1] = dfg_builder.input_wires_arr(); - dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap() + dfg_builder.finish_hugr_with_outputs([i1]).unwrap() } pub fn simple_cfg_builder + AsRef>( @@ -50,7 +49,7 @@ pub fn simple_cfg_hugr() -> Hugr { let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap(); simple_cfg_builder(&mut cfg_builder).unwrap(); - cfg_builder.finish_prelude_hugr().unwrap() + cfg_builder.finish_hugr().unwrap() } lazy_static! { @@ -137,8 +136,5 @@ pub fn circuit(layers: usize) -> (Hugr, Vec) { let outs = linear.finish(); f_build.finish_with_outputs(outs).unwrap(); - ( - module_builder.finish_hugr(&FLOAT_OPS_REGISTRY).unwrap(), - layer_ids, - ) + (module_builder.finish_hugr().unwrap(), layer_ids) } diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index cdf4fe42b..88c8c8df0 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -3,6 +3,6 @@ // Exports everything except the `internal` module. pub use hugr_core::hugr::{ hugrmut, rewrite, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, - InvalidIdentifier, NodeMetadata, NodeMetadataMap, OpType, Rewrite, RootTagged, + InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, RootTagged, SimpleReplacement, SimpleReplacementError, ValidationError, DEFAULT_OPTYPE, }; diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index e7bb54c93..e3313887e 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -102,7 +102,7 @@ //! } //! } //! -//! use mini_quantum_extension::{cx_gate, h_gate, measure, REG}; +//! use mini_quantum_extension::{cx_gate, h_gate, measure}; //! //! // ┌───┐ //! // q_0: ┤ H ├──■───── @@ -120,7 +120,7 @@ //! let h1 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; //! let cx = dfg_builder.add_dataflow_op(cx_gate(), h0.outputs().chain(h1.outputs()))?; //! let measure = dfg_builder.add_dataflow_op(measure(), cx.outputs().last())?; -//! dfg_builder.finish_hugr_with_outputs(cx.outputs().take(1).chain(measure.outputs()), ®) +//! dfg_builder.finish_hugr_with_outputs(cx.outputs().take(1).chain(measure.outputs())) //! } //! //! let h: Hugr = make_dfg_hugr().unwrap();