diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 9dc2fcd18..d1062a3ce 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -180,9 +180,6 @@ impl OpDef { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - // Hugr's are monomorphic, so check the args have no free variables - args.iter().try_for_each(|ta| ta.validate(exts, &[]))?; - let temp: PolyFuncType; // to keep alive let (pf, args) = match &self.signature_func { SignatureFunc::TypeScheme(ts) => (ts, args), @@ -369,11 +366,16 @@ impl Extension { #[cfg(test)] mod test { + use std::num::NonZeroU64; + use smol_str::SmolStr; use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::prelude::USIZE_T; - use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::extension::{ + CustomSignatureFunc, ExtensionRegistry, SignatureError, EMPTY_REG, PRELUDE, + PRELUDE_REGISTRY, + }; use crate::ops::custom::ExternalOp; use crate::ops::LeafOp; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; @@ -413,4 +415,119 @@ mod test { Ok(()) } + + #[test] + fn binary_polyfunc() -> Result<(), Box> { + // Test a custom binary `compute_signature` that returns a PolyFuncType + // where the latter declares more type params itself. In particular, + // we should be able to substitute (external) type variables into the latter, + // but not pass them into the former (custom binary function). + struct SigFun(); + impl CustomSignatureFunc for SigFun { + fn compute_signature( + &self, + _name: &SmolStr, + arg_values: &[TypeArg], + _misc: &std::collections::HashMap, + _exts: &ExtensionRegistry, + ) -> Result { + const TP: TypeParam = TypeParam::Type(TypeBound::Any); + let [TypeArg::BoundedNat { n }] = arg_values else { + return Err(SignatureError::InvalidTypeArgs); + }; + let n = *n as usize; + let tvs: Vec = (0..n) + .map(|_| Type::new_var_use(0, TypeBound::Any)) + .collect(); + Ok(PolyFuncType::new( + vec![TP], + FunctionType::new(tvs.clone(), vec![Type::new_tuple(tvs)]), + )) + } + } + let mut e = Extension::new(EXT_ID); + let def = e.add_op_custom_sig_simple( + "MyOp".into(), + "".to_string(), + vec![TypeParam::max_nat()], + SigFun(), + )?; + + // Base case, no type variables: + let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok(FunctionType::new( + vec![USIZE_T; 3], + vec![Type::new_tuple(vec![USIZE_T; 3])] + )) + ); + assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); + + // Second arg may be a variable (substitutable) + let tyvar = Type::new_var_use(0, TypeBound::Eq); + let tyvars: Vec = vec![tyvar.clone(); 3]; + let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok(FunctionType::new( + tyvars.clone(), + vec![Type::new_tuple(tyvars)] + )) + ); + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeParam::Type(TypeBound::Eq)]) + .unwrap(); + + // quick sanity check that we are validating the args - note changed bound: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeParam::Type(TypeBound::Any)]), + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + actual: TypeBound::Any.into(), + cached: TypeBound::Eq.into() + }) + ); + + // First arg must be concrete, not a variable + let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); + let args = [TypeArg::new_var_use(0, kind.clone()), USIZE_T.into()]; + // We can't prevent this from getting into our compute_signature implementation: + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Err(SignatureError::InvalidTypeArgs) + ); + // But validation rules it out, even when the variable is declared: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), + Err(SignatureError::FreeTypeVar { + idx: 0, + num_decls: 0 + }) + ); + + Ok(()) + } + + #[test] + fn type_scheme_instantiate_var() -> Result<(), Box> { + // Check that we can instantiate a PolyFuncType-scheme with an (external) + // type variable + let mut e = Extension::new(EXT_ID); + let def = e.add_op_type_scheme_simple( + "SimpleOp".into(), + "".into(), + PolyFuncType::new( + vec![TypeParam::Type(TypeBound::Any)], + FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + )?; + let tv = Type::new_var_use(1, TypeBound::Eq); + let args = [TypeArg::Type { ty: tv.clone() }]; + let decls = [TypeParam::Extensions, TypeBound::Eq.into()]; + def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); + assert_eq!( + def.compute_signature(&args, &EMPTY_REG), + Ok(FunctionType::new_endo(vec![tv])) + ); + Ok(()) + } } diff --git a/src/types/type_param.rs b/src/types/type_param.rs index b87e2ad82..16d9e0a5a 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -97,6 +97,12 @@ impl From for TypeParam { } } +impl From for TypeArg { + fn from(ty: Type) -> Self { + Self::Type { ty } + } +} + /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive]