Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: OpDefs and TypeDefs keep a reference to their extension #1719

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use clap::Parser;
use clap_verbosity_flag::Level;
use hugr::package::PackageValidationError;
use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs};
Expand Down Expand Up @@ -64,8 +63,7 @@ impl HugrArgs {
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext)
.map_err(PackageValidationError::Extension)?;
reg.register_updated(ext);
}

package.update_validate(&mut reg)?;
Expand Down
16 changes: 13 additions & 3 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::{ExtensionId, ExtensionSet};
use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
Expand Down Expand Up @@ -298,8 +298,18 @@ mod test {
#[test]
fn with_nonlinear_and_outputs() {
let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
let mut my_ext = Extension::new_test(my_ext_name.clone());
let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB]));
let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
ext.add_op(
"MyOp".into(),
"".to_string(),
Signature::new(vec![QB, NAT], vec![QB]),
extension_ref,
)
.unwrap();
});
let my_custom_op = my_ext
.instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY)
.unwrap();

let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ impl<'a> Context<'a> {

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
_ => return self.make_named_global_ref(opdef.extension(), opdef.name()),
_ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()),
};

let key = (opdef.extension().clone(), opdef.name().clone());
let key = (opdef.extension_id().clone(), opdef.name().clone());
let entry = self.decl_operations.entry(key);

let node = match entry {
Expand All @@ -467,7 +467,7 @@ impl<'a> Context<'a> {
};

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let name = this.make_qualified_name(opdef.extension_id(), opdef.name());
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
Expand Down
153 changes: 121 additions & 32 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ pub use semver::Version;
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use std::mem;
use std::sync::{Arc, Weak};

use thiserror::Error;

Expand Down Expand Up @@ -103,10 +104,7 @@ impl ExtensionRegistry {
///
/// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see
/// [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(
&mut self,
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
Expand All @@ -118,7 +116,6 @@ impl ExtensionRegistry {
ve.insert(extension);
}
}
Ok(())
}

/// Registers a new extension to the registry, keeping most up to date if
Expand All @@ -130,10 +127,7 @@ impl ExtensionRegistry {
///
/// Clones the Arc only when required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(
&mut self,
extension: &Arc<Extension>,
) -> Result<(), ExtensionRegistryError> {
pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
Expand All @@ -144,7 +138,6 @@ impl ExtensionRegistry {
ve.insert(extension.clone());
}
}
Ok(())
}

/// Returns the number of extensions in the registry.
Expand Down Expand Up @@ -335,6 +328,45 @@ impl ExtensionValue {
pub type ExtensionId = IdentList;

/// A extension is a set of capabilities required to execute a graph.
///
/// These are normally defined once and shared across multiple graphs and
/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`].
///
/// # Example
///
/// The following example demonstrates how to define a new extension with a
/// custom operation and a custom type.
///
/// When using `arc`s, the extension can only be modified at creation time. The
/// defined operations and types keep a [`Weak`] reference to their extension. We provide a
/// helper method [`Extension::new_arc`] to aid their definition.
///
/// ```
/// # use hugr_core::types::Signature;
/// # use hugr_core::extension::{Extension, ExtensionId, Version};
/// # use hugr_core::extension::{TypeDefBound};
/// Extension::new_arc(
/// ExtensionId::new_unchecked("my.extension"),
/// Version::new(0, 1, 0),
/// |ext, extension_ref| {
/// // Add a custom type definition
/// ext.add_type(
/// "MyType".into(),
/// vec![], // No type parameters
/// "Some type".into(),
/// TypeDefBound::any(),
/// extension_ref,
/// );
/// // Add a custom operation
/// ext.add_op(
/// "MyOp".into(),
/// "Some operation".into(),
/// Signature::new_endo(vec![]),
/// extension_ref,
/// );
/// },
/// );
/// ```
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
/// Extension version, follows semver.
Expand All @@ -361,6 +393,12 @@ pub struct Extension {

impl Extension {
/// Creates a new extension with the given name.
///
/// In most cases extensions are contained inside an [`Arc`] so that they
/// can be shared across hugr instances and operation definitions.
///
/// See [`Extension::new_arc`] for a more ergonomic way to create boxed
/// extensions.
pub fn new(name: ExtensionId, version: Version) -> Self {
Self {
name,
Expand All @@ -372,14 +410,63 @@ impl Extension {
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn with_reqs(self, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
extension_reqs: self.extension_reqs.union(extension_reqs.into()),
..self
/// Creates a new extension wrapped in an [`Arc`].
///
/// The closure lets us use a weak reference to the arc while the extension
/// is being built. This is necessary for calling [`Extension::add_op`] and
/// [`Extension::add_type`].
pub fn new_arc(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
init(&mut ext, extension_ref);
ext
})
}

/// Creates a new extension wrapped in an [`Arc`], using a fallible
/// initialization function.
///
/// The closure lets us use a weak reference to the arc while the extension
/// is being built. This is necessary for calling [`Extension::add_op`] and
/// [`Extension::add_type`].
pub fn try_new_arc<E>(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
) -> Result<Arc<Self>, E> {
// Annoying hack around not having `Arc::try_new_cyclic` that can return
// a Result.
// https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381
//
// When there is an error, we store it in `error` and return it at the
// end instead of the partially-initialized extension.
let mut error = None;
let ext = Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
match init(&mut ext, extension_ref) {
Ok(_) => ext,
Err(e) => {
error = Some(e);
ext
}
}
});
match error {
Some(e) => Err(e),
None => Ok(ext),
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn add_requirements(&mut self, extension_reqs: impl Into<ExtensionSet>) {
let reqs = mem::take(&mut self.extension_reqs);
self.extension_reqs = reqs.union(extension_reqs.into());
}

/// Allows read-only access to the operations in this Extension
pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(name)
Expand Down Expand Up @@ -634,20 +721,22 @@ pub mod test {

impl Extension {
/// Create a new extension for testing, with a 0 version.
pub(crate) fn new_test(name: ExtensionId) -> Self {
Self::new(name, Version::new(0, 0, 0))
pub(crate) fn new_test_arc(
name: ExtensionId,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Self::new_arc(name, Version::new(0, 0, 0), init)
}

/// Add a simple OpDef to the extension and return an extension op for it.
/// No description, no type parameters.
pub(crate) fn simple_ext_op(
&mut self,
name: &str,
signature: impl Into<SignatureFunc>,
) -> ExtensionOp {
self.add_op(name.into(), "".to_string(), signature).unwrap();
self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY)
.unwrap()
/// Create a new extension for testing, with a 0 version.
pub(crate) fn try_new_test_arc(
name: ExtensionId,
init: impl FnOnce(
&mut Extension,
&Weak<Extension>,
) -> Result<(), Box<dyn std::error::Error>>,
) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
Self::try_new_arc(name, Version::new(0, 0, 0), init)
}
}

Expand Down Expand Up @@ -680,14 +769,14 @@ pub mod test {
);

// register with update works
reg_ref.register_updated_ref(&ext1_1).unwrap();
reg.register_updated(ext1_1.clone()).unwrap();
reg_ref.register_updated_ref(&ext1_1);
reg.register_updated(ext1_1.clone());
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

// register with lower version does not change version
reg_ref.register_updated_ref(&ext1_2).unwrap();
reg.register_updated(ext1_2.clone()).unwrap();
reg_ref.register_updated_ref(&ext1_2);
reg.register_updated(ext1_2.clone());
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

Expand Down
32 changes: 19 additions & 13 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod types;

use std::fs::File;
use std::path::Path;
use std::sync::Arc;

use crate::extension::prelude::PRELUDE_ID;
use crate::ops::OpName;
Expand Down Expand Up @@ -150,19 +151,24 @@ impl ExtensionDeclaration {
&self,
imports: &ExtensionSet,
ctx: DeclarationContext<'_>,
) -> Result<Extension, ExtensionDeclarationError> {
let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0))
.with_reqs(imports.clone());

for t in &self.types {
t.register(&mut ext, ctx)?;
}

for o in &self.operations {
o.register(&mut ext, ctx)?;
}

Ok(ext)
) -> Result<Arc<Extension>, ExtensionDeclarationError> {
Extension::try_new_arc(
self.name.clone(),
// TODO: Get the version as a parameter.
crate::extension::Version::new(0, 0, 0),
|ext, extension_ref| {
for t in &self.types {
t.register(ext, ctx, extension_ref)?;
}

for o in &self.operations {
o.register(ext, ctx, extension_ref)?;
}
ext.add_requirements(imports.clone());

Ok(())
},
)
}
}

Expand Down
12 changes: 11 additions & 1 deletion hugr-core/src/extension/declarative/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration

use std::collections::HashMap;
use std::sync::Weak;

use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
Expand Down Expand Up @@ -55,10 +56,14 @@ pub(super) struct OperationDeclaration {

impl OperationDeclaration {
/// Register this operation in the given extension.
///
/// Requires a [`Weak`] reference to the extension defining the operation.
/// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
pub fn register<'ext>(
&self,
ext: &'ext mut Extension,
ctx: DeclarationContext<'_>,
extension_ref: &Weak<Extension>,
) -> Result<&'ext mut OpDef, ExtensionDeclarationError> {
// We currently only support explicit signatures.
//
Expand Down Expand Up @@ -88,7 +93,12 @@ impl OperationDeclaration {

let signature_func: SignatureFunc = signature.make_signature(ext, ctx, &params)?;

let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?;
let op_def = ext.add_op(
self.name.clone(),
self.description.clone(),
signature_func,
extension_ref,
)?;

for (k, v) in &self.misc {
op_def.add_misc(k, v.clone());
Expand Down
Loading
Loading