Skip to content

Commit

Permalink
feat!: OpEnum trait for common opdef functionality (#721)
Browse files Browse the repository at this point in the history
BREAKING_CHANGES: const logic op names removd

Closes #656 

Have demonstrated with logic operations, pending comments on general
approach can port other larger extension op sets:
TODO:

- [x] Port logic.rs
- [ ] Port collections.rs
- [ ] Port int_ops.rs
- [ ] Port other arithmetic ops?
  • Loading branch information
ss2165 committed Nov 29, 2023
1 parent 2b81d6f commit 74ff7ff
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 81 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ serde_json = "1.0.97"
delegate = "0.10.0"
rustversion = "1.0.14"
paste = "1.0"
strum = "0.25.0"
strum_macros = "0.25.3"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand Down
6 changes: 3 additions & 3 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError};

mod op_def;
pub use op_def::{
CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, ValidateJustArgs,
ValidateTypeArgs,
CustomSignatureFunc, CustomValidator, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
pub mod prelude;
pub mod simple_op;
pub mod validate;

pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
Expand Down
70 changes: 45 additions & 25 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,18 @@ impl CustomValidator {
}

/// The two ways in which an OpDef may compute the Signature of each operation node.
/// Either as a TypeScheme (polymorphic function type), with optional custom
/// validation for provided type arguments,
/// or a custom binary which computes a polymorphic function type given values
/// for its static type parameters.
#[derive(serde::Deserialize, serde::Serialize)]
pub enum SignatureFunc {
// Note: except for serialization, we could have type schemes just implement the same
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>.
// However instead we treat all CustomFunc's as non-serializable.
/// A TypeScheme (polymorphic function type), with optional custom
/// validation for provided type arguments,
#[serde(rename = "signature")]
TypeScheme(CustomValidator),
#[serde(skip)]
/// A custom binary which computes a polymorphic function type given values
/// for its static type parameters.
CustomFunc(Box<dyn CustomSignatureFunc>),
}
struct NoValidate;
Expand Down Expand Up @@ -211,6 +211,46 @@ impl SignatureFunc {
SignatureFunc::CustomFunc(func) => func.static_params(),
}
}

/// Compute the concrete signature ([FunctionType]).
///
/// # Panics
///
/// Panics if `self` is a [SignatureFunc::CustomFunc] and there are not enough type
/// arguments provided to match the number of static parameters.
///
/// # Errors
///
/// This function will return an error if the type arguments are invalid or
/// there is some error in type computation.
pub fn compute_signature(
&self,
def: &OpDef,
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
let temp: PolyFuncType;
let (pf, args) = match &self {
SignatureFunc::TypeScheme(custom) => {
custom.validate.validate(args, def, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
temp = func.compute_signature(static_args, def, exts)?;
(&temp, other_args)
}
};

let res = pf.instantiate(args, exts)?;
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL/hugr/issues/388
// debug_assert!(res.extension_reqs.contains(def.extension()));
Ok(res)
}
}

impl Debug for SignatureFunc {
Expand Down Expand Up @@ -306,27 +346,7 @@ impl OpDef {
args: &[TypeArg],
exts: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
let temp: PolyFuncType;
let (pf, args) = match &self.signature_func {
SignatureFunc::TypeScheme(custom) => {
custom.validate.validate(args, self, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
temp = func.compute_signature(static_args, self, exts)?;
(&temp, other_args)
}
};

let res = pf.instantiate(args, exts)?;
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL-DEV/hugr/issues/425
// assert!(res.contains(self.extension()));
Ok(res)
self.signature_func.compute_signature(self, args, exts)
}

pub(crate) fn should_serialize_signature(&self) -> bool {
Expand Down
233 changes: 233 additions & 0 deletions src/extension/simple_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
//! A trait that enum for op definitions that gathers up some shared functionality.

use smol_str::SmolStr;
use strum::IntoEnumIterator;

use crate::{
ops::{custom::ExtensionOp, OpName, OpType},
types::TypeArg,
Extension,
};

use super::{
op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef,
SignatureError,
};
use delegate::delegate;
use thiserror::Error;

/// Error loading operation.
#[derive(Debug, Error, PartialEq)]
#[error("{0}")]
#[allow(missing_docs)]
pub enum OpLoadError {
#[error("Op with name {0} is not a member of this set.")]
NotMember(String),
#[error("Type args invalid: {0}.")]
InvalidArgs(#[from] SignatureError),
}

impl<T> OpName for T
where
for<'a> &'a T: Into<&'static str>,
{
fn name(&self) -> SmolStr {
let s = self.into();
s.into()
}
}

/// Traits implemented by types which can add themselves to [`Extension`]s as
/// [`OpDef`]s or load themselves from an [`OpDef`].
/// Particularly useful with C-style enums that implement [strum::IntoEnumIterator],
/// as then all definitions can be added to an extension at once.
pub trait MakeOpDef: OpName {
/// Try to load one of the operations of this set from an [OpDef].
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized;

/// Return the signature (polymorphic function type) of the operation.
fn signature(&self) -> SignatureFunc;

/// Description of the operation. By default, the same as `self.name()`.
fn description(&self) -> String {
self.name().to_string()
}

/// Edit the opdef before finalising. By default does nothing.
fn post_opdef(&self, _def: &mut OpDef) {}

/// Add an operation implemented as an [MakeOpDef], which can provide the data
/// required to define an [OpDef], to an extension.
fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> {
let def = extension.add_op(self.name(), self.description(), self.signature())?;

self.post_opdef(def);

Ok(())
}

/// Load all variants of an enum of op definitions in to an extension as op defs.
/// See [strum::IntoEnumIterator].
fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError>
where
Self: IntoEnumIterator,
{
for op in Self::iter() {
op.add_to_extension(extension)?;
}
Ok(())
}
}

/// Traits implemented by types which can be loaded from [`ExtensionOp`]s,
/// i.e. concrete instances of [`OpDef`]s, with defined type arguments.
pub trait MakeExtensionOp: OpName {
/// Try to load one of the operations of this set from an [OpDef].
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized;
/// Try to instantiate a variant from an [OpType]. Default behaviour assumes
/// an [ExtensionOp] and loads from the name.
fn from_optype(op: &OpType) -> Option<Self>
where
Self: Sized,
{
let ext: &ExtensionOp = op.as_leaf_op()?.as_extension_op()?;
Self::from_extension_op(ext).ok()
}

/// Any type args which define this operation.
fn type_args(&self) -> Vec<TypeArg>;

/// Given the ID of the extension this operation is defined in, and a
/// registry containing that extension, return a [RegisteredOp].
fn to_registered(
self,
extension_id: ExtensionId,
registry: &ExtensionRegistry,
) -> RegisteredOp<'_, Self>
where
Self: Sized,
{
RegisteredOp {
extension_id,
registry,
op: self,
}
}
}

/// Blanket implementation for non-polymorphic operations - no type parameters.
impl<T: MakeOpDef> MakeExtensionOp for T {
#[inline]
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
Self::from_def(ext_op.def())
}

#[inline]
fn type_args(&self) -> Vec<TypeArg> {
vec![]
}
}

/// Load an [MakeOpDef] from its name.
/// See [strum_macros::EnumString].
pub fn try_from_name<T>(name: &str) -> Result<T, OpLoadError>
where
T: std::str::FromStr + MakeOpDef,
{
T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))
}

/// Wrap an [MakeExtensionOp] with an extension registry to allow type computation.
/// Generate from [MakeExtensionOp::to_registered]
#[derive(Clone, Debug)]
pub struct RegisteredOp<'r, T> {
/// The name of the extension these ops belong to.
extension_id: ExtensionId,
/// A registry of all extensions, used for type computation.
registry: &'r ExtensionRegistry,
/// The inner [MakeExtensionOp]
op: T,
}

impl<T> RegisteredOp<'_, T> {
/// Extract the inner wrapped value
pub fn to_inner(self) -> T {
self.op
}
}

impl<T: MakeExtensionOp> RegisteredOp<'_, T> {
/// Generate an [OpType].
pub fn to_extension_op(&self) -> Option<ExtensionOp> {
ExtensionOp::new(
self.registry
.get(&self.extension_id)?
.get_op(&self.name())?
.clone(),
self.type_args(),
self.registry,
)
.ok()
}

delegate! {
to self.op {
/// Name of the operation - derived from strum serialization.
pub fn name(&self) -> SmolStr;
/// Any type args which define this operation. Default is no type arguments.
pub fn type_args(&self) -> Vec<TypeArg>;
}
}
}

#[cfg(test)]
mod test {
use crate::{type_row, types::FunctionType};

use super::*;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
#[derive(Clone, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
enum DummyEnum {
Dumb,
}

impl MakeOpDef for DummyEnum {
fn signature(&self) -> SignatureFunc {
FunctionType::new_endo(type_row![]).into()
}

fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
Ok(Self::Dumb)
}
}

#[test]
fn test_dummy_enum() {
let o = DummyEnum::Dumb;

let ext_name = ExtensionId::new("dummy").unwrap();
let mut e = Extension::new(ext_name.clone());

o.add_to_extension(&mut e).unwrap();
assert_eq!(
DummyEnum::from_def(e.get_op(&o.name()).unwrap()).unwrap(),
o
);

let registry = ExtensionRegistry::try_new([e.to_owned()]).unwrap();
let registered = o.clone().to_registered(ext_name, &registry);
assert_eq!(
DummyEnum::from_optype(&registered.to_extension_op().unwrap().into()).unwrap(),
o
);

assert_eq!(registered.to_inner(), o);
}
}
7 changes: 7 additions & 0 deletions src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ impl From<ExtensionOp> for LeafOp {
}
}

impl From<ExtensionOp> for OpType {
fn from(value: ExtensionOp) -> Self {
let leaf: LeafOp = value.into();
leaf.into()
}
}

impl PartialEq for ExtensionOp {
fn eq(&self, other: &Self) -> bool {
Arc::<OpDef>::ptr_eq(&self.def, &other.def) && self.args == other.args
Expand Down
Loading

0 comments on commit 74ff7ff

Please sign in to comment.