Skip to content

Commit

Permalink
substitute takes &ExtensionRegistry; store in OpDefTypeScheme w/ 'a p…
Browse files Browse the repository at this point in the history
…aram
  • Loading branch information
acl-cqc committed Sep 5, 2023
1 parent 4d5abd5 commit 9132e43
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ impl ExtensionSet {
None => vec![e.clone()].into_iter(),
Some(i) => match args.get(i) {
Some(TypeArg::Extensions(es)) => es.iter().cloned().collect::<Vec<_>>().into_iter(),
_ => panic!("value for type var was not extension set"),
_ => panic!("value for type var was not extension set - type scheme should be validate()d first"),
},
}))
}
Expand Down
25 changes: 19 additions & 6 deletions src/extension/type_scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@ use crate::types::FunctionType;
use super::{CustomSignatureFunc, ExtensionRegistry, SignatureError};

/// A polymorphic type scheme for an op
pub struct OpDefTypeScheme {
pub struct OpDefTypeScheme<'a> {
/// The declared type parameters, i.e., every Op must provide [TypeArg]s for these
pub params: Vec<TypeParam>,
/// Template for the Op type. May contain variables up to length of [OpDefTypeScheme::params]
body: FunctionType,
/// Extensions - the [TypeDefBound]s in here will be needed when we instantiate the [OpDefTypeScheme]
/// into a [FunctionType].
///
/// [TypeDefBound]: super::type_def::TypeDefBound
// Note that if the lifetimes, etc., become too painful to store this reference in here,
// and we'd rather own the necessary data, we really only need the TypeDefBounds not the other parts,
// and the validation traversal in new() discovers the small subset of TypeDefBounds that
// each OpDefTypeScheme actually needs.
exts: &'a ExtensionRegistry,
}

impl OpDefTypeScheme {
impl<'a> OpDefTypeScheme<'a> {
/// Create a new OpDefTypeScheme.
///
/// #Errors
Expand All @@ -24,22 +33,26 @@ impl OpDefTypeScheme {
pub fn new(
params: impl Into<Vec<TypeParam>>,
body: FunctionType,
extension_registry: &ExtensionRegistry,
extension_registry: &'a ExtensionRegistry,
) -> Result<Self, SignatureError> {
let params = params.into();
body.validate(extension_registry, &params)?;
Ok(Self { params, body })
Ok(Self {
params,
body,
exts: extension_registry,
})
}
}

impl CustomSignatureFunc for OpDefTypeScheme {
impl<'a> CustomSignatureFunc for OpDefTypeScheme<'a> {
fn compute_signature(
&self,
_name: &smol_str::SmolStr,
args: &[TypeArg],
_misc: &std::collections::HashMap<String, serde_yaml::Value>,
) -> Result<FunctionType, SignatureError> {
check_type_args(args, &self.params).map_err(SignatureError::TypeArgMismatch)?;
Ok(self.body.substitute(args))
Ok(self.body.substitute(self.exts, args))
}
}
34 changes: 25 additions & 9 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,23 @@ impl Type {
}
}

/// Could make this public, but easier to panic if the TypeArg's don't match the TypeParams
/// (should be checked when the args are first given to the type scheme).
pub(crate) fn substitute(&self, args: &[TypeArg]) -> Self {
/// Substitute the specified [TypeArg]s for type variables in this type.
///
/// # Arguments
///
/// * `args`: values to substitute in; there must be at least enough for the
/// typevars in this type (partial substitution is not supported).
///
/// * `extension_registry`: for looking up [TypeDef]s in order to recompute [TypeBound]s
/// as these may get narrower after substitution
///
/// # Panics
///
/// If a [TypeArg] (that is referenced by a typevar in this type) does not contain a [Type],
/// contains a type with an incorrect [TypeBound], or there are not enough `args`.
/// These conditions can be detected ahead of time by [Type::validate]ing against the [TypeParam]s
/// and [check_type_args]ing the [TypeArg]s against the [TypeParam]s.
pub(crate) fn substitute(&self, exts: &ExtensionRegistry, args: &[TypeArg]) -> Self {
match &self.0 {
TypeEnum::Prim(PrimType::Alias(_)) | TypeEnum::Sum(SumType::Simple { .. }) => {
self.clone()
Expand All @@ -311,17 +325,19 @@ impl Type {
),
None => panic!("No value found for variable"), // No need to support partial substitution for just type schemes
},
TypeEnum::Prim(PrimType::Extension(cty)) => Type::new_extension(cty.substitute(args)),
TypeEnum::Prim(PrimType::Function(bf)) => Type::new_function(bf.substitute(args)),
TypeEnum::Tuple(elems) => Type::new_tuple(subst_row(elems, args)),
TypeEnum::Sum(SumType::General { row }) => Type::new_sum(subst_row(row, args)),
TypeEnum::Prim(PrimType::Extension(cty)) => {
Type::new_extension(cty.substitute(exts, args))
}
TypeEnum::Prim(PrimType::Function(bf)) => Type::new_function(bf.substitute(exts, args)),
TypeEnum::Tuple(elems) => Type::new_tuple(subst_row(elems, exts, args)),
TypeEnum::Sum(SumType::General { row }) => Type::new_sum(subst_row(row, exts, args)),
}
}
}

fn subst_row(row: &TypeRow, args: &[TypeArg]) -> TypeRow {
fn subst_row(row: &TypeRow, exts: &ExtensionRegistry, args: &[TypeArg]) -> TypeRow {
row.iter()
.map(|t| t.substitute(args))
.map(|t| t.substitute(exts, args))
.collect::<Vec<_>>()
.into()
}
Expand Down
29 changes: 20 additions & 9 deletions src/types/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use smol_str::SmolStr;
use std::fmt::{self, Display};

use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError};
use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef};

use super::{
type_param::{TypeArg, TypeParam},
Expand Down Expand Up @@ -69,26 +69,37 @@ impl CustomType {
.iter()
.try_for_each(|a| a.validate(extension_registry, type_vars))?;
// And check they fit into the TypeParams declared by the TypeDef
let def = self.get_type_def(extension_registry)?;
def.check_custom(self)
}

fn get_type_def<'a>(
&self,
extension_registry: &'a ExtensionRegistry,
) -> Result<&'a TypeDef, SignatureError> {
let ex = extension_registry.get(&self.extension);
// Even if OpDef's (+binaries) are not available, the part of the Extension definition
// describing the TypeDefs can easily be passed around (serialized), so should be available.
let ex = ex.ok_or(SignatureError::ExtensionNotFound(self.extension.clone()))?;
let def = ex
.get_type(&self.id)
ex.get_type(&self.id)
.ok_or(SignatureError::ExtensionTypeNotFound {
exn: self.extension.clone(),
typ: self.id.clone(),
})?;
def.check_custom(self)
})
}

pub(super) fn substitute(&self, args: &[TypeArg]) -> Self {
pub(super) fn substitute(&self, exts: &ExtensionRegistry, args: &[TypeArg]) -> Self {
let bound = self.get_type_def(exts).unwrap().bound(args);
assert!(self.bound.contains(bound));
Self {
args: self.args.iter().map(|arg| arg.substitute(args)).collect(),
args: self
.args
.iter()
.map(|arg| arg.substitute(exts, args))
.collect(),
bound,
..self.clone()
}
// TODO the bound could get narrower as a result of substitution.
// But, we need the TypeDefBound (from the TypeDef in the Extension) to recalculate correctly...
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/types/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ impl FunctionType {
self.extension_reqs.validate(type_vars)
}

pub(crate) fn substitute(&self, args: &[TypeArg]) -> Self {
pub(crate) fn substitute(&self, exts: &ExtensionRegistry, args: &[TypeArg]) -> Self {
FunctionType {
input: subst_row(&self.input, args),
output: subst_row(&self.output, args),
input: subst_row(&self.input, exts, args),
output: subst_row(&self.output, exts, args),
extension_reqs: self.extension_reqs.substitute(args),
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,18 @@ impl TypeArg {
}
}

pub(super) fn substitute(&self, args: &[TypeArg]) -> Self {
pub(super) fn substitute(&self, exts: &ExtensionRegistry, args: &[TypeArg]) -> Self {
match self {
TypeArg::Type(t) => TypeArg::Type(t.substitute(args)),
TypeArg::Type(t) => TypeArg::Type(t.substitute(exts, args)),
TypeArg::BoundedNat(_) => self.clone(), // We do not allow variables as bounds on BoundedNat's
TypeArg::Opaque(CustomTypeArg { typ, .. }) => {
// The type must be equal to that declared (in a TypeParam) by the instantiated TypeDef,
// so cannot contain variables declared by the instantiator (providing the TypeArgs)
debug_assert_eq!(&typ.substitute(args), typ);
debug_assert_eq!(&typ.substitute(exts, args), typ);
self.clone()
}
TypeArg::Sequence(elems) => {
TypeArg::Sequence(elems.iter().map(|ta| ta.substitute(args)).collect())
TypeArg::Sequence(elems.iter().map(|ta| ta.substitute(exts, args)).collect())
}
TypeArg::Extensions(es) => TypeArg::Extensions(es.substitute(args)),
// Caller should already have checked arg against bound (cached here):
Expand Down

0 comments on commit 9132e43

Please sign in to comment.