From 74ff7ff2c86ec6945887e7950877027997f4cecd Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 29 Nov 2023 17:48:55 +0000 Subject: [PATCH] feat!: `OpEnum` trait for common opdef functionality (#721) BREAKING_CHANGES: const logic op names removd Closes #656 Have demonstrated with logic operations, pending comments on general approach can port other larger extension op sets: TODO: - [x] Port logic.rs - [ ] Port collections.rs - [ ] Port int_ops.rs - [ ] Port other arithmetic ops? --- Cargo.toml | 2 + src/extension.rs | 6 +- src/extension/op_def.rs | 70 +++++++---- src/extension/simple_op.rs | 233 ++++++++++++++++++++++++++++++++++++ src/ops/custom.rs | 7 ++ src/ops/leaf.rs | 16 ++- src/std_extensions/logic.rs | 173 ++++++++++++++++++-------- 7 files changed, 426 insertions(+), 81 deletions(-) create mode 100644 src/extension/simple_op.rs diff --git a/Cargo.toml b/Cargo.toml index 3a2f70bcc..012a80749 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,8 @@ serde_json = "1.0.97" delegate = "0.10.0" rustversion = "1.0.14" paste = "1.0" +strum = "0.25.0" +strum_macros = "0.25.3" [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } diff --git a/src/extension.rs b/src/extension.rs index dfdfc5acc..f18834baa 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -23,14 +23,14 @@ pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError}; mod op_def; pub use op_def::{ - CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, ValidateJustArgs, - ValidateTypeArgs, + CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, SignatureFunc, + ValidateJustArgs, ValidateTypeArgs, }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; pub mod prelude; +pub mod simple_op; pub mod validate; - pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index b8d43dff9..5ce8b483f 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -148,18 +148,18 @@ impl CustomValidator { } /// The two ways in which an OpDef may compute the Signature of each operation node. -/// Either as a TypeScheme (polymorphic function type), with optional custom -/// validation for provided type arguments, -/// or a custom binary which computes a polymorphic function type given values -/// for its static type parameters. #[derive(serde::Deserialize, serde::Serialize)] pub enum SignatureFunc { // Note: except for serialization, we could have type schemes just implement the same // CustomSignatureFunc trait too, and replace this enum with Box. // However instead we treat all CustomFunc's as non-serializable. + /// A TypeScheme (polymorphic function type), with optional custom + /// validation for provided type arguments, #[serde(rename = "signature")] TypeScheme(CustomValidator), #[serde(skip)] + /// A custom binary which computes a polymorphic function type given values + /// for its static type parameters. CustomFunc(Box), } struct NoValidate; @@ -211,6 +211,46 @@ impl SignatureFunc { SignatureFunc::CustomFunc(func) => func.static_params(), } } + + /// Compute the concrete signature ([FunctionType]). + /// + /// # Panics + /// + /// Panics if `self` is a [SignatureFunc::CustomFunc] and there are not enough type + /// arguments provided to match the number of static parameters. + /// + /// # Errors + /// + /// This function will return an error if the type arguments are invalid or + /// there is some error in type computation. + pub fn compute_signature( + &self, + def: &OpDef, + args: &[TypeArg], + exts: &ExtensionRegistry, + ) -> Result { + let temp: PolyFuncType; + let (pf, args) = match &self { + SignatureFunc::TypeScheme(custom) => { + custom.validate.validate(args, def, exts)?; + (&custom.poly_func, args) + } + SignatureFunc::CustomFunc(func) => { + let static_params = func.static_params(); + let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); + + check_type_args(static_args, static_params)?; + temp = func.compute_signature(static_args, def, exts)?; + (&temp, other_args) + } + }; + + let res = pf.instantiate(args, exts)?; + // TODO bring this assert back once resource inference is done? + // https://github.com/CQCL/hugr/issues/388 + // debug_assert!(res.extension_reqs.contains(def.extension())); + Ok(res) + } } impl Debug for SignatureFunc { @@ -306,27 +346,7 @@ impl OpDef { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - let temp: PolyFuncType; - let (pf, args) = match &self.signature_func { - SignatureFunc::TypeScheme(custom) => { - custom.validate.validate(args, self, exts)?; - (&custom.poly_func, args) - } - SignatureFunc::CustomFunc(func) => { - let static_params = func.static_params(); - let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); - - check_type_args(static_args, static_params)?; - temp = func.compute_signature(static_args, self, exts)?; - (&temp, other_args) - } - }; - - let res = pf.instantiate(args, exts)?; - // TODO bring this assert back once resource inference is done? - // https://github.com/CQCL-DEV/hugr/issues/425 - // assert!(res.contains(self.extension())); - Ok(res) + self.signature_func.compute_signature(self, args, exts) } pub(crate) fn should_serialize_signature(&self) -> bool { diff --git a/src/extension/simple_op.rs b/src/extension/simple_op.rs new file mode 100644 index 000000000..bd9f280dd --- /dev/null +++ b/src/extension/simple_op.rs @@ -0,0 +1,233 @@ +//! A trait that enum for op definitions that gathers up some shared functionality. + +use smol_str::SmolStr; +use strum::IntoEnumIterator; + +use crate::{ + ops::{custom::ExtensionOp, OpName, OpType}, + types::TypeArg, + Extension, +}; + +use super::{ + op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef, + SignatureError, +}; +use delegate::delegate; +use thiserror::Error; + +/// Error loading operation. +#[derive(Debug, Error, PartialEq)] +#[error("{0}")] +#[allow(missing_docs)] +pub enum OpLoadError { + #[error("Op with name {0} is not a member of this set.")] + NotMember(String), + #[error("Type args invalid: {0}.")] + InvalidArgs(#[from] SignatureError), +} + +impl OpName for T +where + for<'a> &'a T: Into<&'static str>, +{ + fn name(&self) -> SmolStr { + let s = self.into(); + s.into() + } +} + +/// Traits implemented by types which can add themselves to [`Extension`]s as +/// [`OpDef`]s or load themselves from an [`OpDef`]. +/// Particularly useful with C-style enums that implement [strum::IntoEnumIterator], +/// as then all definitions can be added to an extension at once. +pub trait MakeOpDef: OpName { + /// Try to load one of the operations of this set from an [OpDef]. + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized; + + /// Return the signature (polymorphic function type) of the operation. + fn signature(&self) -> SignatureFunc; + + /// Description of the operation. By default, the same as `self.name()`. + fn description(&self) -> String { + self.name().to_string() + } + + /// Edit the opdef before finalising. By default does nothing. + fn post_opdef(&self, _def: &mut OpDef) {} + + /// Add an operation implemented as an [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { + let def = extension.add_op(self.name(), self.description(), self.signature())?; + + self.post_opdef(def); + + Ok(()) + } + + /// Load all variants of an enum of op definitions in to an extension as op defs. + /// See [strum::IntoEnumIterator]. + fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> + where + Self: IntoEnumIterator, + { + for op in Self::iter() { + op.add_to_extension(extension)?; + } + Ok(()) + } +} + +/// Traits implemented by types which can be loaded from [`ExtensionOp`]s, +/// i.e. concrete instances of [`OpDef`]s, with defined type arguments. +pub trait MakeExtensionOp: OpName { + /// Try to load one of the operations of this set from an [OpDef]. + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized; + /// Try to instantiate a variant from an [OpType]. Default behaviour assumes + /// an [ExtensionOp] and loads from the name. + fn from_optype(op: &OpType) -> Option + where + Self: Sized, + { + let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?; + Self::from_extension_op(ext).ok() + } + + /// Any type args which define this operation. + fn type_args(&self) -> Vec; + + /// Given the ID of the extension this operation is defined in, and a + /// registry containing that extension, return a [RegisteredOp]. + fn to_registered( + self, + extension_id: ExtensionId, + registry: &ExtensionRegistry, + ) -> RegisteredOp<'_, Self> + where + Self: Sized, + { + RegisteredOp { + extension_id, + registry, + op: self, + } + } +} + +/// Blanket implementation for non-polymorphic operations - no type parameters. +impl MakeExtensionOp for T { + #[inline] + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + Self::from_def(ext_op.def()) + } + + #[inline] + fn type_args(&self) -> Vec { + vec![] + } +} + +/// Load an [MakeOpDef] from its name. +/// See [strum_macros::EnumString]. +pub fn try_from_name(name: &str) -> Result +where + T: std::str::FromStr + MakeOpDef, +{ + T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string())) +} + +/// Wrap an [MakeExtensionOp] with an extension registry to allow type computation. +/// Generate from [MakeExtensionOp::to_registered] +#[derive(Clone, Debug)] +pub struct RegisteredOp<'r, T> { + /// The name of the extension these ops belong to. + extension_id: ExtensionId, + /// A registry of all extensions, used for type computation. + registry: &'r ExtensionRegistry, + /// The inner [MakeExtensionOp] + op: T, +} + +impl RegisteredOp<'_, T> { + /// Extract the inner wrapped value + pub fn to_inner(self) -> T { + self.op + } +} + +impl RegisteredOp<'_, T> { + /// Generate an [OpType]. + pub fn to_extension_op(&self) -> Option { + ExtensionOp::new( + self.registry + .get(&self.extension_id)? + .get_op(&self.name())? + .clone(), + self.type_args(), + self.registry, + ) + .ok() + } + + delegate! { + to self.op { + /// Name of the operation - derived from strum serialization. + pub fn name(&self) -> SmolStr; + /// Any type args which define this operation. Default is no type arguments. + pub fn type_args(&self) -> Vec; + } + } +} + +#[cfg(test)] +mod test { + use crate::{type_row, types::FunctionType}; + + use super::*; + use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + #[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] + enum DummyEnum { + Dumb, + } + + impl MakeOpDef for DummyEnum { + fn signature(&self) -> SignatureFunc { + FunctionType::new_endo(type_row![]).into() + } + + fn from_def(_op_def: &OpDef) -> Result { + Ok(Self::Dumb) + } + } + + #[test] + fn test_dummy_enum() { + let o = DummyEnum::Dumb; + + let ext_name = ExtensionId::new("dummy").unwrap(); + let mut e = Extension::new(ext_name.clone()); + + o.add_to_extension(&mut e).unwrap(); + assert_eq!( + DummyEnum::from_def(e.get_op(&o.name()).unwrap()).unwrap(), + o + ); + + let registry = ExtensionRegistry::try_new([e.to_owned()]).unwrap(); + let registered = o.clone().to_registered(ext_name, ®istry); + assert_eq!( + DummyEnum::from_optype(®istered.to_extension_op().unwrap().into()).unwrap(), + o + ); + + assert_eq!(registered.to_inner(), o); + } +} diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 5f2af204f..f5c013d60 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -157,6 +157,13 @@ impl From for LeafOp { } } +impl From for OpType { + fn from(value: ExtensionOp) -> Self { + let leaf: LeafOp = value.into(); + leaf.into() + } +} + impl PartialEq for ExtensionOp { fn eq(&self, other: &Self) -> bool { Arc::::ptr_eq(&self.def, &other.def) && self.args == other.args diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 432790cf8..6a1e5ac62 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -2,7 +2,7 @@ use smol_str::SmolStr; -use super::custom::ExternalOp; +use super::custom::{ExtensionOp, ExternalOp}; use super::dataflow::DataflowOpTrait; use super::{OpName, OpTag}; @@ -62,6 +62,20 @@ pub enum LeafOp { }, } +impl LeafOp { + /// If instance of [ExtensionOp] return a reference to it. + pub fn as_extension_op(&self) -> Option<&ExtensionOp> { + let LeafOp::CustomOp(ext) = self else { + return None; + }; + + match ext.as_ref() { + ExternalOp::Extension(e) => Some(e), + ExternalOp::Opaque(_) => None, + } + } +} + /// Records details of an application of a [PolyFuncType] to some [TypeArg]s /// and the result (a less-, but still potentially-, polymorphic type). #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 710d29b1b..359a74cc5 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -1,10 +1,15 @@ //! Basic logical operations. -use smol_str::SmolStr; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::{ - extension::{prelude::BOOL_T, ExtensionId, SignatureError, SignatureFromArgs}, - ops, type_row, + extension::{ + prelude::BOOL_T, + simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, OpLoadError}, + ExtensionId, OpDef, SignatureError, SignatureFromArgs, SignatureFunc, + }, + ops::{self, custom::ExtensionOp, OpName}, + type_row, types::{ type_param::{TypeArg, TypeParam}, FunctionType, @@ -12,18 +17,87 @@ use crate::{ Extension, }; use lazy_static::lazy_static; - /// Name of extension false value. pub const FALSE_NAME: &str = "FALSE"; /// Name of extension true value. pub const TRUE_NAME: &str = "TRUE"; -/// Name of the "not" operation. -pub const NOT_NAME: &str = "Not"; -/// Name of the "and" operation. -pub const AND_NAME: &str = "And"; -/// Name of the "or" operation. -pub const OR_NAME: &str = "Or"; +/// Logic extension operation definitions. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs)] +pub enum NaryLogic { + And, + Or, +} + +impl MakeOpDef for NaryLogic { + fn signature(&self) -> SignatureFunc { + logic_op_sig().into() + } + + fn description(&self) -> String { + match self { + NaryLogic::And => "logical 'and'", + NaryLogic::Or => "logical 'or'", + } + .to_string() + } + + fn from_def(op_def: &OpDef) -> Result { + try_from_name(op_def.name()) + } +} + +/// Make a [NaryLogic] operation concrete by setting the type argument. +pub struct ConcreteLogicOp(pub NaryLogic, u64); + +impl OpName for ConcreteLogicOp { + fn name(&self) -> smol_str::SmolStr { + self.0.name() + } +} +impl MakeExtensionOp for ConcreteLogicOp { + fn from_extension_op(ext_op: &ExtensionOp) -> Result { + let def: NaryLogic = NaryLogic::from_def(ext_op.def())?; + Ok(match def { + NaryLogic::And | NaryLogic::Or => { + let [TypeArg::BoundedNat { n }] = *ext_op.args() else { + return Err(SignatureError::InvalidTypeArgs.into()); + }; + Self(def, n) + } + }) + } + + fn type_args(&self) -> Vec { + vec![TypeArg::BoundedNat { n: self.1 }] + } +} + +/// Not operation. +#[derive(Debug, Copy, Clone)] +pub struct NotOp; +impl OpName for NotOp { + fn name(&self) -> smol_str::SmolStr { + "Not".into() + } +} +impl MakeOpDef for NotOp { + fn from_def(op_def: &OpDef) -> Result { + if op_def.name() == &NotOp.name() { + Ok(NotOp) + } else { + Err(OpLoadError::NotMember(op_def.name().to_string())) + } + } + + fn signature(&self) -> SignatureFunc { + FunctionType::new_endo(type_row![BOOL_T]).into() + } + fn description(&self) -> String { + "logical 'not'".into() + } +} /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic"); @@ -53,30 +127,8 @@ fn logic_op_sig() -> impl SignatureFromArgs { /// Extension for basic logical operations. fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_ID); - - extension - .add_op( - SmolStr::new_inline(NOT_NAME), - "logical 'not'".into(), - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), - ) - .unwrap(); - - extension - .add_op( - SmolStr::new_inline(AND_NAME), - "logical 'and'".into(), - logic_op_sig(), - ) - .unwrap(); - - extension - .add_op( - SmolStr::new_inline(OR_NAME), - "logical 'or'".into(), - logic_op_sig(), - ) - .unwrap(); + NaryLogic::load_all_ops(&mut extension).unwrap(); + NotOp.add_to_extension(&mut extension).unwrap(); extension .add_value(FALSE_NAME, ops::Const::unit_sum(0, 2)) @@ -94,20 +146,37 @@ lazy_static! { #[cfg(test)] pub(crate) mod test { + use super::{ + extension, ConcreteLogicOp, NaryLogic, NotOp, EXTENSION, EXTENSION_ID, FALSE_NAME, + TRUE_NAME, + }; use crate::{ - extension::{prelude::BOOL_T, EMPTY_REG}, - ops::LeafOp, - types::type_param::TypeArg, + extension::{ + prelude::BOOL_T, + simple_op::{MakeExtensionOp, MakeOpDef}, + ExtensionRegistry, + }, + ops::{custom::ExtensionOp, OpName}, Extension, }; - - use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, NOT_NAME, OR_NAME, TRUE_NAME}; - + use lazy_static::lazy_static; + use strum::IntoEnumIterator; + lazy_static! { + pub(crate) static ref LOGIC_REG: ExtensionRegistry = + ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); + } #[test] fn test_logic_extension() { let r: Extension = extension(); assert_eq!(r.name() as &str, "logic"); assert_eq!(r.operations().count(), 3); + + for op in NaryLogic::iter() { + assert_eq!( + NaryLogic::from_def(r.get_op(&op.name()).unwrap(),).unwrap(), + op + ); + } } #[test] @@ -123,26 +192,26 @@ pub(crate) mod test { } /// Generate a logic extension and "and" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn and_op() -> LeafOp { - EXTENSION - .instantiate_extension_op(AND_NAME, [TypeArg::BoundedNat { n: 2 }], &EMPTY_REG) + pub(crate) fn and_op() -> ExtensionOp { + ConcreteLogicOp(NaryLogic::And, 2) + .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) + .to_extension_op() .unwrap() - .into() } /// Generate a logic extension and "or" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn or_op() -> LeafOp { - EXTENSION - .instantiate_extension_op(OR_NAME, [TypeArg::BoundedNat { n: 2 }], &EMPTY_REG) + pub(crate) fn or_op() -> ExtensionOp { + ConcreteLogicOp(NaryLogic::Or, 2) + .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) + .to_extension_op() .unwrap() - .into() } /// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn not_op() -> LeafOp { - EXTENSION - .instantiate_extension_op(NOT_NAME, [], &EMPTY_REG) + pub(crate) fn not_op() -> ExtensionOp { + NotOp + .to_registered(EXTENSION_ID.to_owned(), &LOGIC_REG) + .to_extension_op() .unwrap() - .into() } }