Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: OpEnum trait for common opdef functionality #721

Merged
merged 14 commits into from
Nov 29, 2023
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ serde_json = "1.0.97"
delegate = "0.10.0"
rustversion = "1.0.14"
paste = "1.0"
strum = "0.25.0"
strum_macros = "0.25.3"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
6 changes: 3 additions & 3 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, ValidateJustArgs,
ValidateTypeArgs,
CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
pub mod prelude;
pub mod simple_op;
pub mod validate;

pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
Expand Down
85 changes: 60 additions & 25 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use std::sync::Arc;

use smol_str::SmolStr;

use super::simple_op::OpEnum;
use super::{
Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError,
};

use crate::ops::OpName;
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType};
use crate::Hugr;
Expand Down Expand Up @@ -148,18 +150,18 @@ impl CustomValidator {
}

/// The two ways in which an OpDef may compute the Signature of each operation node.
/// Either as a TypeScheme (polymorphic function type), with optional custom
/// validation for provided type arguments,
/// or a custom binary which computes a polymorphic function type given values
/// for its static type parameters.
#[derive(serde::Deserialize, serde::Serialize)]
pub enum SignatureFunc {
// Note: except for serialization, we could have type schemes just implement the same
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>.
// However instead we treat all CustomFunc's as non-serializable.
/// A TypeScheme (polymorphic function type), with optional custom
/// validation for provided type arguments,
#[serde(rename = "signature")]
TypeScheme(CustomValidator),
#[serde(skip)]
/// A custom binary which computes a polymorphic function type given values
/// for its static type parameters.
CustomFunc(Box<dyn CustomSignatureFunc>),
}
struct NoValidate;
Expand Down Expand Up @@ -211,6 +213,46 @@ impl SignatureFunc {
SignatureFunc::CustomFunc(func) => func.static_params(),
}
}

/// Compute the concrete signature ([FunctionType]).
///
/// # Panics
///
/// Panics if is [SignatureFunc::CustomFunc] and there are not enough type
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
/// arguments provided to match the number of static parameters.
///
/// # Errors
///
/// This function will return an error if the type arguments are invalid or
/// there is some error in type computation.
pub fn compute_signature(
&self,
def: &OpDef,
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
let temp: PolyFuncType;
let (pf, args) = match &self {
SignatureFunc::TypeScheme(custom) => {
custom.validate.validate(args, def, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
temp = func.compute_signature(static_args, def, exts)?;
(&temp, other_args)
}
};

let res = pf.instantiate(args, exts)?;
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL-DEV/hugr/issues/425
// assert!(res.contains(self.extension()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is closed. Does the assert pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, debug_assert!(res.extension_reqs.contains(def.extension())); fails on 40 tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there another issue we can link instead?

Ok(res)
}
}

impl Debug for SignatureFunc {
Expand Down Expand Up @@ -306,27 +348,7 @@ impl OpDef {
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
let temp: PolyFuncType;
let (pf, args) = match &self.signature_func {
SignatureFunc::TypeScheme(custom) => {
custom.validate.validate(args, self, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
temp = func.compute_signature(static_args, self, exts)?;
(&temp, other_args)
}
};

let res = pf.instantiate(args, exts)?;
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL-DEV/hugr/issues/425
// assert!(res.contains(self.extension()));
Ok(res)
self.signature_func.compute_signature(self, args, exts)
}

pub(crate) fn should_serialize_signature(&self) -> bool {
Expand Down Expand Up @@ -427,6 +449,19 @@ impl Extension {
Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
}
}

/// Add an operation implemented as an [OpEnum], which can provide the data
/// required to define an [OpDef].
pub fn add_op_enum(
&mut self,
op: &(impl OpEnum + OpName),
) -> Result<&mut OpDef, ExtensionBuildError> {
let def = self.add_op(op.name(), op.description(), op.def_signature())?;

op.post_opdef(def);

Ok(def)
}
}

#[cfg(test)]
Expand Down
232 changes: 232 additions & 0 deletions src/extension/simple_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
//! A trait that enum for op definitions that gathers up some shared functionality.

use smol_str::SmolStr;
use strum::IntoEnumIterator;

use crate::{
ops::{custom::ExtensionOp, LeafOp, OpName, OpType},
types::{FunctionType, TypeArg},
Extension,
};

use super::{
op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef,
SignatureError,
};
use delegate::delegate;
use thiserror::Error;

/// Error loading [OpEnum]
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
#[derive(Debug, Error, PartialEq)]
#[error("{0}")]
#[allow(missing_docs)]
pub enum OpLoadError {
#[error("Op with name {0} is not a member of this enum.")]
NotEnumMember(String),
#[error("Type args invalid: {0}.")]
InvalidArgs(#[from] SignatureError),
}

impl<T> OpName for T
where
for<'a> &'a T: Into<&'static str>,
{
fn name(&self) -> SmolStr {
let s = self.into();
s.into()
}
}
/// A trait that operation sets defined by simple (C-style) enums can implement
/// to simplify interactions with the extension.
/// Relies on `strum_macros::{EnumIter, EnumString, IntoStaticStr}`
pub trait OpEnum: OpName {
/// Try to load one of the operations of this set from an [OpDef].
fn from_op_def(op_def: &OpDef, args: &[TypeArg]) -> Result<Self, OpLoadError>
where
Self: Sized;

/// Return the signature (polymorphic function type) of the operation.
fn def_signature(&self) -> SignatureFunc;

/// Description of the operation. By default, the same as `self.name()`.
fn description(&self) -> String {
self.name().to_string()
}

/// Any type args which define this operation. Default is no type arguments.
fn type_args(&self) -> Vec<TypeArg> {
vec![]
}

/// Edit the opdef before finalising. By default does nothing.
fn post_opdef(&self, _def: &mut OpDef) {}

/// Try to instantiate a variant from an [OpType]. Default behaviour assumes
/// an [ExtensionOp] and loads from the name.
fn from_optype(op: &OpType) -> Option<Self>
where
Self: Sized,
{
let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?;
Self::from_op_def(ext.def(), ext.args()).ok()
}

/// Given the ID of the extension this operation is defined in, and a
/// registry containing that extension, return a [RegisteredEnum].
fn to_registered(
self,
extension_id: ExtensionId,
registry: &ExtensionRegistry,
) -> RegisteredEnum<'_, Self>
where
Self: Sized,
{
RegisteredEnum {
extension_id,
registry,
op_enum: self,
}
}

/// Iterator over all operations in the set. Non-trivial variants will have
/// default values used for the members.
fn all_variants() -> <Self as IntoEnumIterator>::Iterator
where
Self: IntoEnumIterator,
{
<Self as IntoEnumIterator>::iter()
}

/// load all variants of a [OpEnum] in to an extension as op defs.
fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError>
where
Self: IntoEnumIterator,
{
for op in Self::all_variants() {
extension.add_op_enum(&op)?;
}
Ok(())
}
}

/// Load an [OpEnum] from its name. Works best for C-style enums where each
/// variant corresponds to an [OpDef] and an [OpType], i,e, there are no type parameters.
/// See [strum_macros::EnumString].
pub fn try_from_name<T>(name: &str) -> Result<T, OpLoadError>
where
T: std::str::FromStr + OpEnum,
{
T::from_str(name).map_err(|_| OpLoadError::NotEnumMember(name.to_string()))
}

/// Wrap an [OpEnum] with an extension registry to allow type computation.
/// Generate from [OpEnum::to_registered]
pub struct RegisteredEnum<'r, T> {
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
/// The name of the extension these ops belong to.
extension_id: ExtensionId,
/// A registry of all extensions, used for type computation.
registry: &'r ExtensionRegistry,
/// The inner [OpEnum]
op_enum: T,
}

impl<T> RegisteredEnum<'_, T> {
/// Extract the inner wrapped value
pub fn to_inner(self) -> T {
self.op_enum
}
}

impl<T: OpEnum> RegisteredEnum<'_, T> {
/// Generate an [OpType].
pub fn to_optype(&self) -> Option<OpType> {
let leaf: LeafOp = ExtensionOp::new(
self.registry
.get(&self.extension_id)?
.get_op(&self.name())?
.clone(),
self.type_args(),
self.registry,
)
.ok()?
.into();

Some(leaf.into())
}

/// Compute the [FunctionType] for this operation, instantiating with type arguments.
pub fn function_type(&self) -> Result<FunctionType, SignatureError> {
self.op_enum.def_signature().compute_signature(
self.registry
.get(&self.extension_id)
.expect("should return 'Extension not in registry' error here.")
.get_op(&self.name())
.expect("should return 'Op not in extension' error here."),
&self.type_args(),
self.registry,
)
}

delegate! {
to self.op_enum {
/// Name of the operation - derived from strum serialization.
pub fn name(&self) -> SmolStr;
/// Any type args which define this operation. Default is no type arguments.
pub fn type_args(&self) -> Vec<TypeArg>;
/// Description of the operation.
pub fn description(&self) -> String;
}
}
}

#[cfg(test)]
mod test {
use crate::{type_row, types::FunctionType};

use super::*;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
#[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
enum DummyEnum {
Dumb,
}

impl OpEnum for DummyEnum {
fn def_signature(&self) -> SignatureFunc {
FunctionType::new_endo(type_row![]).into()
}

fn from_op_def(_op_def: &OpDef, _args: &[TypeArg]) -> Result<Self, OpLoadError> {
Ok(Self::Dumb)
}
}

#[test]
fn test_dummy_enum() {
let o = DummyEnum::Dumb;

let ext_name = ExtensionId::new("dummy").unwrap();
let mut e = Extension::new(ext_name.clone());

e.add_op_enum(&o).unwrap();

assert_eq!(
DummyEnum::from_op_def(e.get_op(&o.name()).unwrap(), &[]).unwrap(),
o
);

let registry = ExtensionRegistry::try_new([e.to_owned()]).unwrap();
let registered = o.clone().to_registered(ext_name, &registry);
assert_eq!(
DummyEnum::from_optype(&registered.to_optype().unwrap()).unwrap(),
o
);
assert_eq!(
registered.function_type().unwrap(),
FunctionType::new_endo(type_row![])
);

assert_eq!(registered.description(), "Dumb");

assert_eq!(registered.to_inner(), o);
}
}
Loading