Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Flatten PrimType/PrimValue in to parent enum #685

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod check;
pub mod custom;
mod poly_func;
mod primitive;
mod serialize;
mod signature;
pub mod type_param;
Expand All @@ -26,7 +25,6 @@ use crate::ops::AliasDecl;
use crate::type_row;
use std::fmt::Debug;

pub use self::primitive::PrimType;
use self::type_param::TypeParam;

#[cfg(feature = "pyo3")]
Expand Down Expand Up @@ -153,10 +151,22 @@ impl From<SumType> for Type {
}

#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)]
/// Core types: primitive (leaf), tuple (product) or sum (co-product).
/// Core types
pub enum TypeEnum {
// TODO optimise with Box<CustomType> ?
// or some static version of this?
#[allow(missing_docs)]
Prim(PrimType),
Extension(CustomType),
#[allow(missing_docs)]
#[display(fmt = "Alias({})", "_0.name()")]
Alias(AliasDecl),
#[allow(missing_docs)]
#[display(fmt = "Function({})", "_0")]
Function(Box<PolyFuncType>),
// DeBruijn index, and cache of TypeBound (checked in validation)
#[allow(missing_docs)]
#[display(fmt = "Variable({})", _0)]
Variable(usize, TypeBound),
#[allow(missing_docs)]
#[display(fmt = "Tuple({})", "_0")]
Tuple(TypeRow),
Expand All @@ -168,7 +178,10 @@ impl TypeEnum {
/// The smallest type bound that covers the whole type.
fn least_upper_bound(&self) -> TypeBound {
match self {
TypeEnum::Prim(p) => p.bound(),
TypeEnum::Extension(c) => c.bound(),
TypeEnum::Alias(a) => a.bound,
TypeEnum::Function(_) => TypeBound::Copyable,
TypeEnum::Variable(_, b) => *b,
TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Eq,
TypeEnum::Sum(SumType::General { row }) => {
least_upper_bound(row.iter().map(Type::least_upper_bound))
Expand Down Expand Up @@ -216,7 +229,7 @@ impl Type {

/// Initialize a new function type.
pub fn new_function(fun_ty: impl Into<PolyFuncType>) -> Self {
Self::new(TypeEnum::Prim(PrimType::Function(Box::new(fun_ty.into()))))
Self::new(TypeEnum::Function(Box::new(fun_ty.into())))
}

/// Initialize a new tuple type by providing the elements.
Expand All @@ -235,12 +248,12 @@ impl Type {
// TODO remove? Extensions/TypeDefs should just provide `Type` directly
pub const fn new_extension(opaque: CustomType) -> Self {
let bound = opaque.bound();
Type(TypeEnum::Prim(PrimType::Extension(opaque)), bound)
Type(TypeEnum::Extension(opaque), bound)
}

/// Initialize a new alias.
pub fn new_alias(alias: AliasDecl) -> Self {
Self::new(TypeEnum::Prim(PrimType::Alias(alias)))
Self::new(TypeEnum::Alias(alias))
}

fn new(type_e: TypeEnum) -> Self {
Expand All @@ -267,7 +280,7 @@ impl Type {
/// For use in type schemes only: `bound` must match that with which the
/// variable was declared (i.e. as a [TypeParam::Type]`(bound)`).
pub fn new_var_use(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::Prim(PrimType::Variable(idx, bound)), bound)
Self(TypeEnum::Variable(idx, bound), bound)
}

/// Report the least upper TypeBound, if there is one.
Expand Down Expand Up @@ -307,25 +320,21 @@ impl Type {
.iter()
.try_for_each(|t| t.validate(extension_registry, var_decls)),
TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there
TypeEnum::Prim(PrimType::Alias(_)) => Ok(()),
TypeEnum::Prim(PrimType::Extension(custy)) => {
custy.validate(extension_registry, var_decls)
}
TypeEnum::Prim(PrimType::Function(ft)) => ft.validate(extension_registry, var_decls),
TypeEnum::Prim(PrimType::Variable(idx, bound)) => {
TypeEnum::Alias(_) => Ok(()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better! :)

TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls),
TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls),
TypeEnum::Variable(idx, bound) => {
check_typevar_decl(var_decls, *idx, &TypeParam::Type(*bound))
}
}
}

pub(crate) fn substitute(&self, t: &impl Substitution) -> Self {
match &self.0 {
TypeEnum::Prim(PrimType::Alias(_)) | TypeEnum::Sum(SumType::Unit { .. }) => {
self.clone()
}
TypeEnum::Prim(PrimType::Variable(idx, bound)) => t.apply_typevar(*idx, *bound),
TypeEnum::Prim(PrimType::Extension(cty)) => Type::new_extension(cty.substitute(t)),
TypeEnum::Prim(PrimType::Function(bf)) => Type::new_function(bf.substitute(t)),
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => self.clone(),
TypeEnum::Variable(idx, bound) => t.apply_typevar(*idx, *bound),
TypeEnum::Extension(cty) => Type::new_extension(cty.substitute(t)),
TypeEnum::Function(bf) => Type::new_function(bf.substitute(t)),
TypeEnum::Tuple(elems) => Type::new_tuple(subst_row(elems, t)),
TypeEnum::Sum(SumType::General { row }) => Type::new_sum(subst_row(row, t)),
}
Expand Down
40 changes: 8 additions & 32 deletions src/types/check.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
//! Logic for checking values against types.
use thiserror::Error;

use crate::{
values::{PrimValue, Value},
HugrView,
};
use crate::{values::Value, HugrView};

use super::{primitive::PrimType, CustomType, Type, TypeEnum};
use super::{CustomType, Type, TypeEnum};

/// Struct for custom type check fails.
#[derive(Clone, Debug, PartialEq, Eq, Error)]
Expand Down Expand Up @@ -48,46 +45,25 @@ pub enum ConstTypeError {
CustomCheckFail(#[from] CustomCheckFailure),
}

impl PrimType {
/// Check that a [`PrimValue`] is a valid instance of this [`PrimType`].
impl Type {
/// Check that a [`Value`] is a valid instance of this [`Type`].
///
/// # Errors
///
/// This function will return an error if there is a type check error.
pub fn check_type(&self, val: &PrimValue) -> Result<(), ConstTypeError> {
if let PrimType::Alias(alias) = self {
return Err(ConstTypeError::NoAliases(alias.name().to_string()));
}

match (self, val) {
(PrimType::Extension(e), PrimValue::Extension { c: e_val }) => {
pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driveby nit, this should really be called check_value. The name check_type would be appropriate if this were a method on Value

match (&self.0, val) {
(TypeEnum::Extension(e), Value::Extension { c: e_val }) => {
e_val.0.check_custom_type(e)?;
Ok(())
}
(PrimType::Function(t), PrimValue::Function { hugr: v })
(TypeEnum::Function(t), Value::Function { hugr: v })
if v.get_function_type().is_some_and(|f| &**t == f) =>
{
// exact signature equality, in future this may need to be
// relaxed to be compatibility checks between the signatures.
Ok(())
}
_ => Err(ConstTypeError::ValueCheckFail(
Type::new(TypeEnum::Prim(self.clone())),
Value::Prim { val: val.clone() },
)),
}
}
}

impl Type {
/// Check that a [`Value`] is a valid instance of this [`Type`].
///
/// # Errors
///
/// This function will return an error if there is a type check error.
pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> {
match (&self.0, val) {
(TypeEnum::Prim(p), Value::Prim { val: p_v }) => p.check_type(p_v),
(TypeEnum::Tuple(t), Value::Tuple { vs: t_v }) => {
if t.len() != t_v.len() {
return Err(ConstTypeError::TupleWrongLength);
Expand Down
2 changes: 1 addition & 1 deletion src/types/poly_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{FunctionType, Substitution};
/// A polymorphic function type, e.g. of a [Graph], or perhaps an [OpDef].
/// (Nodes/operations in the Hugr are not polymorphic.)
///
/// [Graph]: crate::values::PrimValue::Function
/// [Graph]: crate::values::Value::Function
/// [OpDef]: crate::extension::OpDef
#[derive(
Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
Expand Down
38 changes: 0 additions & 38 deletions src/types/primitive.rs

This file was deleted.

11 changes: 4 additions & 7 deletions src/types/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use super::custom::CustomType;

use crate::extension::prelude::{array_type, QB_T, USIZE_T};
use crate::ops::AliasDecl;
use crate::types::primitive::PrimType;

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(tag = "t")]
Expand All @@ -31,12 +30,10 @@ impl From<Type> for SerSimpleType {
// TODO short circuiting for array.
let Type(value, _) = value;
match value {
TypeEnum::Prim(t) => match t {
PrimType::Extension(c) => SerSimpleType::Opaque(c),
PrimType::Alias(a) => SerSimpleType::Alias(a),
PrimType::Function(sig) => SerSimpleType::G(sig),
PrimType::Variable(i, b) => SerSimpleType::V { i, b },
},
TypeEnum::Extension(c) => SerSimpleType::Opaque(c),
TypeEnum::Alias(a) => SerSimpleType::Alias(a),
TypeEnum::Function(sig) => SerSimpleType::G(sig),
TypeEnum::Variable(i, b) => SerSimpleType::V { i, b },
TypeEnum::Sum(sum) => SerSimpleType::Sum(sum),
TypeEnum::Tuple(inner) => SerSimpleType::Tuple { inner },
}
Expand Down
56 changes: 15 additions & 41 deletions src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use crate::{Hugr, HugrView};

use crate::types::{CustomCheckFailure, CustomType};

/// A constant value of a primitive (or leaf) type.
/// A value that can be stored as a static constant. Representing core types and
/// extension types.
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "pv")]
pub enum PrimValue {
#[serde(tag = "v")]
pub enum Value {
/// An extension constant value, that can check it is of a given [CustomType].
///
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Expand All @@ -30,32 +31,6 @@ pub enum PrimValue {
#[allow(missing_docs)]
hugr: Box<Hugr>,
},
}

impl PrimValue {
fn name(&self) -> String {
match self {
PrimValue::Extension { c: e } => format!("const:custom:{}", e.0.name()),
PrimValue::Function { hugr: h } => {
let Some(t) = h.get_function_type() else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
}
}
}
}

/// A value that can be stored as a static constant. Representing core types and
/// extension types.
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
pub enum Value {
/// A primitive (non-container) value.
Prim {
#[allow(missing_docs)]
val: PrimValue,
},
/// A tuple
Tuple {
#[allow(missing_docs)]
Expand All @@ -75,7 +50,13 @@ impl Value {
/// Returns the name of this [`Value`].
pub fn name(&self) -> String {
match self {
Value::Prim { val: p } => p.name(),
Value::Extension { c: e } => format!("const:custom:{}", e.0.name()),
Value::Function { hugr: h } => {
let Some(t) = h.get_function_type() else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
}
Value::Tuple { vs: vals } => {
let names: Vec<_> = vals.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.join(", "))
Expand Down Expand Up @@ -123,17 +104,12 @@ impl Value {

/// New custom value (of type that implements [`CustomConst`]).
pub fn custom<C: CustomConst>(c: C) -> Self {
Self::Prim {
val: PrimValue::Extension { c: (Box::new(c),) },
}
Self::Extension { c: (Box::new(c),) }
}

/// For a Const holding a CustomConst, extract the CustomConst by downcasting.
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
if let Value::Prim {
val: PrimValue::Extension { c: (custom,) },
} = self
{
if let Value::Extension { c: (custom,) } = self {
custom.downcast_ref()
} else {
None
Expand Down Expand Up @@ -286,10 +262,8 @@ pub(crate) mod test {

#[rstest]
fn function_value(simple_dfg_hugr: Hugr) {
let v = Value::Prim {
val: PrimValue::Function {
hugr: Box::new(simple_dfg_hugr),
},
let v = Value::Function {
hugr: Box::new(simple_dfg_hugr),
};

let correct_type = Type::new_function(FunctionType::new_linear(type_row![
Expand Down