diff --git a/crates/bindings-macro/src/module.rs b/crates/bindings-macro/src/module.rs index 32de831aae..a81bdb35f3 100644 --- a/crates/bindings-macro/src/module.rs +++ b/crates/bindings-macro/src/module.rs @@ -135,11 +135,7 @@ pub(crate) fn derive_satstype(ty: &SatsType<'_>, gen_type_alias: bool) -> TokenS algebraic_type: <#ty as spacetimedb::SpacetimeType>::make_type(__typespace), }) }); - quote!(spacetimedb::sats::AlgebraicType::Product( - spacetimedb::sats::ProductType { - elements: vec![#(#fields),*], - } - )) + quote!(spacetimedb::sats::AlgebraicType::product(vec![#(#fields),*])) } SatsTypeData::Sum(variants) => { let unit = syn::Type::Tuple(syn::TypeTuple { @@ -154,9 +150,7 @@ pub(crate) fn derive_satstype(ty: &SatsType<'_>, gen_type_alias: bool) -> TokenS algebraic_type: <#ty as spacetimedb::SpacetimeType>::make_type(__typespace), }) }); - quote!(spacetimedb::sats::AlgebraicType::Sum(spacetimedb::sats::SumType { - variants: vec![#(#variants),*], - })) + quote!(spacetimedb::sats::AlgebraicType::sum(vec![#(#variants),*])) // todo!() } // syn::Data::Union(u) => return Err(syn::Error::new(u.union_token.span, "unions not supported")), }; diff --git a/crates/bindings/src/lib.rs b/crates/bindings/src/lib.rs index 13f1dc21c2..23e143f740 100644 --- a/crates/bindings/src/lib.rs +++ b/crates/bindings/src/lib.rs @@ -11,6 +11,7 @@ mod timestamp; use spacetimedb_lib::buffer::{BufReader, BufWriter, Cursor, DecodeError}; pub use spacetimedb_lib::de::{Deserialize, DeserializeOwned}; +use spacetimedb_lib::sats::{impl_deserialize, impl_serialize, impl_st}; pub use spacetimedb_lib::ser::Serialize; use spacetimedb_lib::{bsatn, ColumnIndexAttribute, IndexType, PrimaryKey, ProductType, ProductValue}; use std::cell::RefCell; @@ -691,21 +692,9 @@ impl Clone for ScheduleToken { } impl Copy for ScheduleToken {} -impl Serialize for ScheduleToken { - fn serialize(&self, serializer: S) -> Result { - self.id.serialize(serializer) - } -} -impl<'de, R> Deserialize<'de> for ScheduleToken { - fn deserialize>(deserializer: D) -> Result { - u64::deserialize(deserializer).map(Self::new) - } -} -impl SpacetimeType for ScheduleToken { - fn make_type(_ts: &mut S) -> spacetimedb_lib::AlgebraicType { - spacetimedb_lib::AlgebraicType::U64 - } -} +impl_serialize!([R] ScheduleToken, (self, ser) => self.id.serialize(ser)); +impl_deserialize!([R] ScheduleToken, de => u64::deserialize(de).map(Self::new)); +impl_st!([R] ScheduleToken, _ts => spacetimedb_lib::AlgebraicType::U64); impl ScheduleToken { /// Wrap the ID under which a reducer is scheduled in a [`ScheduleToken`]. diff --git a/crates/bindings/src/rt.rs b/crates/bindings/src/rt.rs index 87a9d4f659..6d0288009c 100644 --- a/crates/bindings/src/rt.rs +++ b/crates/bindings/src/rt.rs @@ -12,8 +12,8 @@ use crate::{sys, ReducerContext, ScheduleToken, SpacetimeType, TableType, Timest use spacetimedb_lib::auth::{StAccess, StTableType}; use spacetimedb_lib::de::{self, Deserialize, SeqProductAccess}; use spacetimedb_lib::sats::typespace::TypespaceBuilder; -use spacetimedb_lib::sats::{AlgebraicType, AlgebraicTypeRef, ProductTypeElement}; -use spacetimedb_lib::ser::{self, Serialize, SerializeSeqProduct}; +use spacetimedb_lib::sats::{impl_deserialize, impl_serialize, AlgebraicType, AlgebraicTypeRef, ProductTypeElement}; +use spacetimedb_lib::ser::{Serialize, SerializeSeqProduct}; use spacetimedb_lib::{bsatn, Identity, MiscModuleExport, ModuleDef, ReducerDef, TableDef, TypeAlias}; use sys::Buffer; @@ -319,20 +319,15 @@ impl_reducer!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, /// Provides deserialization and serialization for any type `A: Args`. struct SerDeArgs(A); -impl<'de, A: Args<'de>> Deserialize<'de> for SerDeArgs { - fn deserialize>(deserializer: D) -> Result { - deserializer - .deserialize_product(ArgsVisitor { _marker: PhantomData }) - .map(Self) - } -} -impl<'de, A: Args<'de>> Serialize for SerDeArgs { - fn serialize(&self, serializer: S) -> Result { - let mut prod = serializer.serialize_seq_product(A::LEN)?; - self.0.serialize_seq_product(&mut prod)?; - prod.end() - } -} +impl_deserialize!( + [A: Args<'de>] SerDeArgs, + de => de.deserialize_product(ArgsVisitor { _marker: PhantomData }).map(Self) +); +impl_serialize!(['de, A: Args<'de>] SerDeArgs, (self, ser) => { + let mut prod = ser.serialize_seq_product(A::LEN)?; + self.0.serialize_seq_product(&mut prod)?; + prod.end() +}); /// Returns a timestamp that is `duration` from now. #[track_caller] diff --git a/crates/bindings/src/timestamp.rs b/crates/bindings/src/timestamp.rs index e98d0ccb86..a238593faa 100644 --- a/crates/bindings/src/timestamp.rs +++ b/crates/bindings/src/timestamp.rs @@ -3,8 +3,7 @@ use std::ops::{Add, Sub}; use std::time::Duration; -use spacetimedb_lib::de::Deserialize; -use spacetimedb_lib::ser::Serialize; +use spacetimedb_lib::sats::{impl_deserialize, impl_serialize, impl_st}; scoped_tls::scoped_thread_local! { static CURRENT_TIMESTAMP: Timestamp @@ -90,20 +89,6 @@ impl Sub for Timestamp { } } -impl crate::SpacetimeType for Timestamp { - fn make_type(_ts: &mut S) -> spacetimedb_lib::AlgebraicType { - spacetimedb_lib::AlgebraicType::U64 - } -} - -impl<'de> Deserialize<'de> for Timestamp { - fn deserialize>(deserializer: D) -> Result { - u64::deserialize(deserializer).map(|micros_since_epoch| Self { micros_since_epoch }) - } -} - -impl Serialize for Timestamp { - fn serialize(&self, serializer: S) -> Result { - self.micros_since_epoch.serialize(serializer) - } -} +impl_st!([] Timestamp, _ts => spacetimedb_lib::AlgebraicType::U64); +impl_deserialize!([] Timestamp, de => u64::deserialize(de).map(|m| Self { micros_since_epoch: m })); +impl_serialize!([] Timestamp, (self, ser) => self.micros_since_epoch.serialize(ser)); diff --git a/crates/cli/src/edit_distance.rs b/crates/cli/src/edit_distance.rs index 166781a82d..4d178c76db 100644 --- a/crates/cli/src/edit_distance.rs +++ b/crates/cli/src/edit_distance.rs @@ -1,6 +1,7 @@ /* -Copyright, The Rust project developers. +Some parts copyright, The Rust project developers. See https://github.com/rust-lang/rust/blob/8882507bc7dbad0cc0548204eb8777e51ac92332/COPYRIGHT +for the parts where MIT / Apache-2.0 applies. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated diff --git a/crates/cli/src/subcommands/generate/csharp.rs b/crates/cli/src/subcommands/generate/csharp.rs index 86acd0a35f..110fe2756d 100644 --- a/crates/cli/src/subcommands/generate/csharp.rs +++ b/crates/cli/src/subcommands/generate/csharp.rs @@ -1,3 +1,5 @@ +use super::util::fmt_fn; + use std::fmt::{self, Write}; use convert_case::{Case, Casing}; @@ -259,16 +261,6 @@ fn csharp_typename(ctx: &GenCtx, typeref: AlgebraicTypeRef) -> &str { ctx.names[typeref.idx()].as_deref().expect("tuples should have names") } -fn fmt_fn(f: impl Fn(&mut fmt::Formatter) -> fmt::Result) -> impl fmt::Display { - struct FDisplay(F); - impl fmt::Result> fmt::Display for FDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (self.0)(f) - } - } - FDisplay(f) -} - fn is_option_type(ty: &SumType) -> bool { if ty.variants.len() != 2 { return false; diff --git a/crates/cli/src/subcommands/generate/mod.rs b/crates/cli/src/subcommands/generate/mod.rs index 286d65dbd5..b0a18a6dfd 100644 --- a/crates/cli/src/subcommands/generate/mod.rs +++ b/crates/cli/src/subcommands/generate/mod.rs @@ -14,6 +14,7 @@ pub mod csharp; pub mod python; pub mod rust; pub mod typescript; +mod util; const INDENT: &str = "\t"; diff --git a/crates/cli/src/subcommands/generate/python.rs b/crates/cli/src/subcommands/generate/python.rs index 32a3c1a827..53c47c2374 100644 --- a/crates/cli/src/subcommands/generate/python.rs +++ b/crates/cli/src/subcommands/generate/python.rs @@ -1,3 +1,5 @@ +use super::util::fmt_fn; + use convert_case::{Case, Casing}; use spacetimedb_lib::{ sats::{AlgebraicType::Builtin, AlgebraicTypeRef, ArrayType, BuiltinType, MapType}, @@ -137,16 +139,6 @@ fn python_filename(ctx: &GenCtx, typeref: AlgebraicTypeRef) -> String { .to_case(Case::Snake) } -fn fmt_fn(f: impl Fn(&mut fmt::Formatter) -> fmt::Result) -> impl fmt::Display { - struct FDisplay(F); - impl fmt::Result> fmt::Display for FDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (self.0)(f) - } - } - FDisplay(f) -} - fn is_option_type(ty: &SumType) -> bool { if ty.variants.len() != 2 { return false; diff --git a/crates/cli/src/subcommands/generate/typescript.rs b/crates/cli/src/subcommands/generate/typescript.rs index ea349c8f0c..5df22d3d4e 100644 --- a/crates/cli/src/subcommands/generate/typescript.rs +++ b/crates/cli/src/subcommands/generate/typescript.rs @@ -1,3 +1,5 @@ +use super::util::fmt_fn; + use std::fmt::{self, Write}; use convert_case::{Case, Casing}; @@ -227,16 +229,6 @@ fn typescript_filename(ctx: &GenCtx, typeref: AlgebraicTypeRef) -> String { .to_case(Case::Snake) } -fn fmt_fn(f: impl Fn(&mut fmt::Formatter) -> fmt::Result) -> impl fmt::Display { - struct FDisplay(F); - impl fmt::Result> fmt::Display for FDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (self.0)(f) - } - } - FDisplay(f) -} - macro_rules! indent_scope { ($x:ident) => { let mut $x = $x.indented(1); diff --git a/crates/cli/src/subcommands/generate/util.rs b/crates/cli/src/subcommands/generate/util.rs new file mode 100644 index 0000000000..7cf819f372 --- /dev/null +++ b/crates/cli/src/subcommands/generate/util.rs @@ -0,0 +1,14 @@ +//! Various utility functions that the generate modules have in common. + +use std::fmt::{Display, Formatter, Result}; + +/// Turns a closure `f: Fn(&mut Formatter) -> Result` into `fmt::Display`. +pub(super) fn fmt_fn(f: impl Fn(&mut Formatter) -> Result) -> impl Display { + struct FDisplay(F); + impl Result> Display for FDisplay { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + (self.0)(f) + } + } + FDisplay(f) +} diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 0377597b5f..63c18487b0 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -18,7 +18,7 @@ use spacetimedb_lib::name; use spacetimedb_lib::name::DomainName; use spacetimedb_lib::name::DomainParsingError; use spacetimedb_lib::name::PublishOp; -use spacetimedb_lib::sats::TypeInSpace; +use spacetimedb_lib::sats::WithTypespace; use crate::auth::{ SpacetimeAuth, SpacetimeAuthHeader, SpacetimeEnergyUsed, SpacetimeExecutionDurationMicros, SpacetimeIdentity, @@ -201,7 +201,7 @@ async fn extract_db_call_info( }) } -fn entity_description_json(description: TypeInSpace, expand: bool) -> Option { +fn entity_description_json(description: WithTypespace, expand: bool) -> Option { let typ = DescribedEntityType::from_entitydef(description.ty()).as_str(); let len = match description.ty() { EntityDef::Table(t) => description.resolve(t.data).ty().as_product()?.elements.len(), diff --git a/crates/core/src/db/datastore/locking_tx_datastore/mod.rs b/crates/core/src/db/datastore/locking_tx_datastore/mod.rs index baaf08c6e4..d6e65bc46f 100644 --- a/crates/core/src/db/datastore/locking_tx_datastore/mod.rs +++ b/crates/core/src/db/datastore/locking_tx_datastore/mod.rs @@ -2094,7 +2094,7 @@ mod tests { StColumnRow { table_id: 1, col_id: 0, col_name: "table_id".to_string(), col_type: AlgebraicType::U32, is_autoinc: false }, StColumnRow { table_id: 1, col_id: 1, col_name: "col_id".to_string(), col_type: AlgebraicType::U32, is_autoinc: false }, - StColumnRow { table_id: 1, col_id: 2, col_name: "col_type".to_string(), col_type: AlgebraicType::make_array_type(AlgebraicType::U8), is_autoinc: false }, + StColumnRow { table_id: 1, col_id: 2, col_name: "col_type".to_string(), col_type: AlgebraicType::array(AlgebraicType::U8), is_autoinc: false }, StColumnRow { table_id: 1, col_id: 3, col_name: "col_name".to_string(), col_type: AlgebraicType::String, is_autoinc: false }, StColumnRow { table_id: 1, col_id: 4, col_name: "is_autoinc".to_string(), col_type: AlgebraicType::Bool, is_autoinc: false }, diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index 7bd2461296..1735924ef2 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -7,7 +7,7 @@ use spacetimedb_lib::de::serde::SeedWrapper; use spacetimedb_lib::de::DeserializeSeed; use spacetimedb_lib::{bsatn, Hash, Identity}; use spacetimedb_lib::{ProductValue, ReducerDef}; -use spacetimedb_sats::TypeInSpace; +use spacetimedb_sats::WithTypespace; mod host_controller; pub(crate) mod module_host; @@ -35,13 +35,13 @@ pub enum ReducerArgs { } impl ReducerArgs { - fn into_tuple(self, schema: TypeInSpace<'_, ReducerDef>) -> Result { + fn into_tuple(self, schema: WithTypespace<'_, ReducerDef>) -> Result { self._into_tuple(schema).map_err(|err| InvalidReducerArguments { err, reducer: schema.ty().name.clone(), }) } - fn _into_tuple(self, schema: TypeInSpace<'_, ReducerDef>) -> anyhow::Result { + fn _into_tuple(self, schema: WithTypespace<'_, ReducerDef>) -> anyhow::Result { Ok(match self { ReducerArgs::Json(json) => ArgsTuple { tuple: from_json_seed(&json, SeedWrapper(ReducerDef::deserialize(schema)))?, diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index a85f7219a5..9e30b2ead1 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -10,7 +10,7 @@ use crate::protobuf::client_api::{table_row_operation, SubscriptionUpdate, Table use crate::subscription::module_subscription_actor::ModuleSubscriptionManager; use indexmap::IndexMap; use spacetimedb_lib::{ReducerDef, TableDef}; -use spacetimedb_sats::{ProductValue, TypeInSpace, Typespace}; +use spacetimedb_sats::{ProductValue, Typespace, WithTypespace}; use std::collections::HashMap; use std::convert::Infallible; use std::sync::Arc; @@ -560,18 +560,18 @@ impl Catalog { &self.0.typespace } - pub fn get(&self, name: &str) -> Option> { + pub fn get(&self, name: &str) -> Option> { self.0.catalog.get(name).map(|ty| self.0.typespace.with_type(ty)) } - pub fn get_reducer(&self, name: &str) -> Option> { + pub fn get_reducer(&self, name: &str) -> Option> { let schema = self.get(name)?; Some(schema.with(schema.ty().as_reducer()?)) } - pub fn get_table(&self, name: &str) -> Option> { + pub fn get_table(&self, name: &str) -> Option> { let schema = self.get(name)?; Some(schema.with(schema.ty().as_table()?)) } - pub fn iter(&self) -> impl Iterator)> + '_ { + pub fn iter(&self) -> impl Iterator)> + '_ { self.0 .catalog .iter() diff --git a/crates/core/src/host/timestamp.rs b/crates/core/src/host/timestamp.rs index 67957b3854..ed69fd1508 100644 --- a/crates/core/src/host/timestamp.rs +++ b/crates/core/src/host/timestamp.rs @@ -1,5 +1,7 @@ use std::time::{Duration, SystemTime}; +use spacetimedb_sats::{impl_deserialize, impl_serialize}; + #[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize)] #[serde(transparent)] #[repr(transparent)] @@ -24,13 +26,5 @@ impl Timestamp { } } -impl<'de> spacetimedb_sats::de::Deserialize<'de> for Timestamp { - fn deserialize>(deserializer: D) -> Result { - u64::deserialize(deserializer).map(Self) - } -} -impl spacetimedb_sats::ser::Serialize for Timestamp { - fn serialize(&self, serializer: S) -> Result { - self.0.serialize(serializer) - } -} +impl_deserialize!([] Timestamp, de => u64::deserialize(de).map(Self)); +impl_serialize!([] Timestamp, (self, ser) => self.0.serialize(ser)); diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index 6b203aa0ea..7a96a80358 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -276,24 +276,23 @@ fn extract_field(table: &From, of: &SqlExpr) -> Result, value: &str, is_long: bool) -> Result { - let ty = match field { + match field { None => { - if value.contains('.') { + let ty = if value.contains('.') { if is_long { - &AlgebraicType::F64 + AlgebraicType::F64 } else { - &AlgebraicType::F32 + AlgebraicType::F32 } } else if is_long { - &AlgebraicType::I64 + AlgebraicType::I64 } else { - &AlgebraicType::I32 - } + AlgebraicType::I32 + }; + parse(value, &ty) } - Some(f) => &f.algebraic_type, - }; - - parse(value, ty) + Some(f) => parse(value, &f.algebraic_type), + } } /// Compiles a [SqlExpr] expression into a [ColumnOp] @@ -743,8 +742,8 @@ fn column_def_type(named: &String, is_null: bool, data_type: &DataType) -> Resul DataType::Real => AlgebraicType::F32, DataType::Double => AlgebraicType::F64, DataType::Boolean => AlgebraicType::Bool, - DataType::Array(Some(ty)) => AlgebraicType::make_array_type(column_def_type(named, false, ty)?), - DataType::Enum(values) => AlgebraicType::make_simple_enum(values.iter().map(|x| x.as_str())), + DataType::Array(Some(ty)) => AlgebraicType::array(column_def_type(named, false, ty)?), + DataType::Enum(values) => AlgebraicType::simple_enum(values.iter().map(|x| x.as_str())), x => { return Err(PlanError::Unsupported { feature: format!("Column {} of type {}", named, x), @@ -752,11 +751,7 @@ fn column_def_type(named: &String, is_null: bool, data_type: &DataType) -> Resul } }; - Ok(if is_null { - AlgebraicType::make_option_type(ty) - } else { - ty - }) + Ok(if is_null { AlgebraicType::option(ty) } else { ty }) } /// Extract the column attributes into [ColumnIndexAttribute] diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index bc6061f0d4..af5ac2e312 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -605,7 +605,7 @@ pub(crate) mod tests { if is_null { assert_eq!( col.col_type, - AlgebraicType::make_option_type(AlgebraicType::I64), + AlgebraicType::option(AlgebraicType::I64), "Null type {}.{}", table_name, col.col_name diff --git a/crates/lib/src/address.rs b/crates/lib/src/address.rs index a9462e0b6b..e5fabfea94 100644 --- a/crates/lib/src/address.rs +++ b/crates/lib/src/address.rs @@ -2,8 +2,9 @@ use std::net::Ipv6Addr; use anyhow::Context as _; use hex::FromHex as _; +use sats::{impl_deserialize, impl_serialize, impl_st}; -use crate::sats::{self, de, ser}; +use crate::sats; /// This is the address for a SpacetimeDB database. It is a unique identifier /// for a particular database and once set for a database, does not change. @@ -58,17 +59,8 @@ impl Address { } } -impl ser::Serialize for Address { - fn serialize(&self, serializer: S) -> Result { - self.0.to_be_bytes().serialize(serializer) - } -} - -impl<'de> de::Deserialize<'de> for Address { - fn deserialize>(deserializer: D) -> Result { - <[u8; 16]>::deserialize(deserializer).map(|v| Self(u128::from_be_bytes(v))) - } -} +impl_serialize!([] Address, (self, ser) =>self.0.to_be_bytes().serialize(ser)); +impl_deserialize!([] Address, de => <[u8; 16]>::deserialize(de).map(|v| Self(u128::from_be_bytes(v)))); #[cfg(feature = "serde")] impl<'de> serde::Deserialize<'de> for Address { @@ -81,8 +73,4 @@ impl<'de> serde::Deserialize<'de> for Address { } } -impl sats::SpacetimeType for Address { - fn make_type(_typespace: &mut S) -> sats::AlgebraicType { - crate::AlgebraicType::U128 - } -} +impl_st!([] Address, _ts => sats::AlgebraicType::U128); diff --git a/crates/lib/src/auth.rs b/crates/lib/src/auth.rs index 963c25551a..ba927336e9 100644 --- a/crates/lib/src/auth.rs +++ b/crates/lib/src/auth.rs @@ -1,6 +1,6 @@ -use crate::de::{Deserializer, Error}; -use crate::ser::Serializer; -use crate::{de, ser}; +use spacetimedb_sats::{impl_deserialize, impl_serialize}; + +use crate::de::Error; /// Describe the visibility of the table #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -43,22 +43,15 @@ impl<'a> TryFrom<&'a str> for StAccess { } } -impl ser::Serialize for StAccess { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(self.as_str()) - } -} - -impl<'de> de::Deserialize<'de> for StAccess { - fn deserialize>(deserializer: D) -> Result { - let value = deserializer.deserialize_str_slice()?; - StAccess::try_from(value).map_err(|x| { - Error::custom(format!( - "DecodeError for StAccess: `{x}`. Expected `public` | 'private'" - )) - }) - } -} +impl_serialize!([] StAccess, (self, ser) => ser.serialize_str(self.as_str())); +impl_deserialize!([] StAccess, de => { + let value = de.deserialize_str_slice()?; + StAccess::try_from(value).map_err(|x| { + Error::custom(format!( + "DecodeError for StAccess: `{x}`. Expected `public` | 'private'" + )) + }) +}); /// Describe is the table is a `system table` or not. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -92,19 +85,12 @@ impl<'a> TryFrom<&'a str> for StTableType { } } -impl ser::Serialize for StTableType { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(self.as_str()) - } -} - -impl<'de> de::Deserialize<'de> for StTableType { - fn deserialize>(deserializer: D) -> Result { - let value = deserializer.deserialize_str_slice()?; - StTableType::try_from(value).map_err(|x| { - Error::custom(format!( - "DecodeError for StTableType: `{x}`. Expected 'system' | 'user'" - )) - }) - } -} +impl_serialize!([] StTableType, (self, ser) => ser.serialize_str(self.as_str())); +impl_deserialize!([] StTableType, de => { + let value = de.deserialize_str_slice()?; + StTableType::try_from(value).map_err(|x| { + Error::custom(format!( + "DecodeError for StTableType: `{x}`. Expected 'system' | 'user'" + )) + }) +}); diff --git a/crates/lib/src/error.rs b/crates/lib/src/error.rs index 179d1b0926..0485a9eff3 100644 --- a/crates/lib/src/error.rs +++ b/crates/lib/src/error.rs @@ -1,11 +1,34 @@ use crate::relation::{FieldName, Header}; use crate::{buffer, AlgebraicType}; -use spacetimedb_sats::algebraic_type::TypeError; use spacetimedb_sats::product_value::InvalidFieldError; +use spacetimedb_sats::AlgebraicValue; use std::fmt; use std::string::FromUtf8Error; use thiserror::Error; +#[derive(Error, Debug)] +pub enum TypeError { + #[error("Arrays must be homogeneous. It expects to be `{{expect.to_satns()}}` but `{{value.to_satns()}}` is of type `{{found.to_satns()}}`")] + Array { + expect: AlgebraicType, + found: AlgebraicType, + value: AlgebraicValue, + }, + #[error("Arrays must define a type for the elements")] + ArrayEmpty, + #[error("Maps must be homogeneous. It expects to be `{{key_expect.to_satns()}}:{{value_expect.to_satns()}}` but `{{key.to_satns()}}::{{value.to_satns()}}` is of type `{{key_found.to_satns()}}:{{value_found.to_satns()}}`")] + Map { + key_expect: AlgebraicType, + value_expect: AlgebraicType, + key_found: AlgebraicType, + value_found: AlgebraicType, + key: AlgebraicValue, + value: AlgebraicValue, + }, + #[error("Maps must define a type for both key & value")] + MapEmpty, +} + #[derive(Error, Debug, Clone)] pub enum DecodeError { #[error("Decode UTF8: {0}")] diff --git a/crates/lib/src/hash.rs b/crates/lib/src/hash.rs index f1a9f9ec4f..a47ce85273 100644 --- a/crates/lib/src/hash.rs +++ b/crates/lib/src/hash.rs @@ -2,8 +2,7 @@ use crate::{de, ser}; use core::fmt; use sha3::{Digest, Keccak256}; -use spacetimedb_sats::typespace::SpacetimeType; -use spacetimedb_sats::AlgebraicType; +use spacetimedb_sats::{impl_deserialize, impl_serialize, impl_st, AlgebraicType}; pub const HASH_SIZE: usize = 32; @@ -12,24 +11,9 @@ pub struct Hash { pub data: [u8; HASH_SIZE], } -impl SpacetimeType for Hash { - fn make_type(_ts: &mut S) -> AlgebraicType { - AlgebraicType::bytes() - } -} - -impl ser::Serialize for Hash { - fn serialize(&self, serializer: S) -> Result { - self.data.serialize(serializer) - } -} -impl<'de> de::Deserialize<'de> for Hash { - fn deserialize>(deserializer: D) -> Result { - Ok(Self { - data: <_>::deserialize(deserializer)?, - }) - } -} +impl_st!([] Hash, _ts => AlgebraicType::bytes()); +impl_serialize!([] Hash, (self, ser) => self.data.serialize(ser)); +impl_deserialize!([] Hash, de => Ok(Self { data: <_>::deserialize(de)? })); impl Hash { const ABBREVIATION_LEN: usize = 16; diff --git a/crates/lib/src/identity.rs b/crates/lib/src/identity.rs index 428c5fe0fd..33b1542192 100644 --- a/crates/lib/src/identity.rs +++ b/crates/lib/src/identity.rs @@ -1,5 +1,7 @@ use std::fmt; +use sats::{impl_deserialize, impl_serialize, impl_st}; + use crate::sats::{self, de, ser}; #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -37,24 +39,9 @@ impl Identity { } } -impl sats::SpacetimeType for Identity { - fn make_type(_ts: &mut S) -> crate::AlgebraicType { - crate::AlgebraicType::bytes() - } -} - -impl ser::Serialize for Identity { - fn serialize(&self, serializer: S) -> Result { - self.data.serialize(serializer) - } -} -impl<'de> de::Deserialize<'de> for Identity { - fn deserialize>(deserializer: D) -> Result { - Ok(Self { - data: <_>::deserialize(deserializer)?, - }) - } -} +impl_st!([] Identity, _ts => sats::AlgebraicType::bytes()); +impl_serialize!([] Identity, (self, ser) => self.data.serialize(ser)); +impl_deserialize!([] Identity, de => Ok(Self { data: <_>::deserialize(de)? })); impl Identity { const ABBREVIATION_LEN: usize = 16; diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index a2fcbab359..0aa6309ca0 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -1,5 +1,6 @@ use auth::StAccess; use auth::StTableType; +use sats::impl_serialize; pub use spacetimedb_sats::buffer; pub mod address; pub mod data_key; @@ -111,18 +112,18 @@ impl ReducerDef { bsatn::to_writer(writer, self).unwrap() } - pub fn serialize_args<'a>(ty: sats::TypeInSpace<'a, Self>, value: &'a ProductValue) -> impl ser::Serialize + 'a { + pub fn serialize_args<'a>(ty: sats::WithTypespace<'a, Self>, value: &'a ProductValue) -> impl ser::Serialize + 'a { ReducerArgsWithSchema { value, ty } } pub fn deserialize( - ty: sats::TypeInSpace<'_, Self>, + ty: sats::WithTypespace<'_, Self>, ) -> impl for<'de> de::DeserializeSeed<'de, Output = ProductValue> + '_ { ReducerDeserialize(ty) } } -struct ReducerDeserialize<'a>(sats::TypeInSpace<'a, ReducerDef>); +struct ReducerDeserialize<'a>(sats::WithTypespace<'a, ReducerDef>); impl<'de> de::DeserializeSeed<'de> for ReducerDeserialize<'_> { type Output = ProductValue; @@ -156,20 +157,17 @@ impl<'de> de::ProductVisitor<'de> for ReducerDeserialize<'_> { struct ReducerArgsWithSchema<'a> { value: &'a ProductValue, - ty: sats::TypeInSpace<'a, ReducerDef>, + ty: sats::WithTypespace<'a, ReducerDef>, } - -impl ser::Serialize for ReducerArgsWithSchema<'_> { - fn serialize(&self, serializer: S) -> Result { - use itertools::Itertools; - use ser::SerializeSeqProduct; - let mut seq = serializer.serialize_seq_product(self.value.elements.len())?; - for (value, elem) in self.value.elements.iter().zip_eq(&self.ty.ty().args) { - seq.serialize_element(&self.ty.with(&elem.algebraic_type).with_value(value))?; - } - seq.end() +impl_serialize!([] ReducerArgsWithSchema<'_>, (self, ser) => { + use itertools::Itertools; + use ser::SerializeSeqProduct; + let mut seq = ser.serialize_seq_product(self.value.elements.len())?; + for (value, elem) in self.value.elements.iter().zip_eq(&self.ty.ty().args) { + seq.serialize_element(&self.ty.with(&elem.algebraic_type).with_value(value))?; } -} + seq.end() +}); //WARNING: Change this structure(or any of their members) is an ABI change. #[derive(Debug, Clone, Default, de::Deserialize, ser::Serialize)] @@ -254,30 +252,3 @@ impl TryFrom for ColumnIndexAttribute { } } } - -// use std::fmt; -// -// #[cfg(feature = "serde")] -// use serde::de::Expected as SerdeExpected; -// #[cfg(not(feature = "serde"))] -// use Sized as SerdeExpected; -// fn fmt_fn(f: impl Fn(&mut fmt::Formatter) -> fmt::Result) -> impl fmt::Display + fmt::Debug + SerdeExpected { -// struct FDisplay(F); -// impl fmt::Result> fmt::Display for FDisplay { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// (self.0)(f) -// } -// } -// impl fmt::Result> fmt::Debug for FDisplay { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// (self.0)(f) -// } -// } -// #[cfg(feature = "serde")] -// impl fmt::Result> serde::de::Expected for FDisplay { -// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { -// (self.0)(f) -// } -// } -// FDisplay(f) -// } diff --git a/crates/lib/src/relation.rs b/crates/lib/src/relation.rs index 264f766db8..8323fab10f 100644 --- a/crates/lib/src/relation.rs +++ b/crates/lib/src/relation.rs @@ -8,7 +8,7 @@ use crate::table::ColumnDef; use spacetimedb_sats::algebraic_value::AlgebraicValue; use spacetimedb_sats::product_value::ProductValue; use spacetimedb_sats::satn::Satn; -use spacetimedb_sats::{algebraic_type, AlgebraicType, ProductType, ProductTypeElement, TypeInSpace, Typespace}; +use spacetimedb_sats::{algebraic_type, AlgebraicType, ProductType, ProductTypeElement, Typespace, WithTypespace}; impl ColumnDef { pub fn name(&self) -> FieldOnly { @@ -147,7 +147,7 @@ impl fmt::Display for FieldExpr { FieldExpr::Value(x) => { let ty = x.type_of(); let ts = Typespace::new(vec![]); - write!(f, "{}", TypeInSpace::new(&ts, &ty).with_value(x).to_satn()) + write!(f, "{}", WithTypespace::new(&ts, &ty).with_value(x).to_satn()) } } } @@ -330,7 +330,7 @@ impl fmt::Display for Header { f, "{}: {}", col.field, - algebraic_type::satn::Formatter::new(&col.algebraic_type) + algebraic_type::fmt::fmt_algebraic_type(&col.algebraic_type) )?; if pos + 1 < self.fields.len() { diff --git a/crates/lib/src/table.rs b/crates/lib/src/table.rs index 9e605ebc47..887c87c6d6 100644 --- a/crates/lib/src/table.rs +++ b/crates/lib/src/table.rs @@ -28,7 +28,7 @@ impl ProductTypeMeta { pub fn with_capacity(capacity: usize) -> Self { Self { attr: Vec::with_capacity(capacity), - columns: ProductType::with_capacity(capacity), + columns: ProductType::new(Vec::with_capacity(capacity)), } } diff --git a/crates/lib/tests/serde.rs b/crates/lib/tests/serde.rs index 6c322a09e2..0a152efb33 100644 --- a/crates/lib/tests/serde.rs +++ b/crates/lib/tests/serde.rs @@ -1,13 +1,13 @@ use spacetimedb_lib::de::serde::SerdeDeserializer; use spacetimedb_lib::de::DeserializeSeed; use spacetimedb_lib::{AlgebraicType, ProductType, ProductTypeElement, ProductValue, SumType}; -use spacetimedb_sats::{satn::Satn, ArrayType, BuiltinType::*, SumTypeVariant, TypeInSpace, Typespace}; +use spacetimedb_sats::{satn::Satn, SumTypeVariant, Typespace, WithTypespace}; macro_rules! de_json_snapshot { ($schema:expr, $json:expr) => { let (schema, json) = (&$schema, &$json); let value = de_json(schema, json).unwrap(); - let value = TypeInSpace::new(&EMPTY_TYPESPACE, schema) + let value = WithTypespace::new(&EMPTY_TYPESPACE, schema) .with_value(&value) .to_satn_pretty(); let debug_expr = format!("de_json({})", json.trim()); @@ -18,9 +18,9 @@ macro_rules! de_json_snapshot { #[test] fn test_json_mappings() { let schema = tuple([ - ("foo", AlgebraicType::Builtin(U32)), + ("foo", AlgebraicType::U32), ("bar", AlgebraicType::bytes()), - ("baz", vec(AlgebraicType::Builtin(String))), + ("baz", AlgebraicType::array(AlgebraicType::String)), ( "quux", AlgebraicType::Sum(enumm([ @@ -28,10 +28,7 @@ fn test_json_mappings() { ("Unit", AlgebraicType::UNIT_TYPE), ])), ), - ( - "and_peggy", - AlgebraicType::make_option_type(AlgebraicType::Builtin(F64)), - ), + ("and_peggy", AlgebraicType::option(AlgebraicType::F64)), ]); let data = r#" { @@ -59,10 +56,7 @@ fn tuple<'a>(elems: impl IntoIterator) -> Produ ProductType { elements: elems .into_iter() - .map(|(name, algebraic_type)| ProductTypeElement { - name: Some(name.into()), - algebraic_type, - }) + .map(|(name, ty)| ProductTypeElement::new_named(ty, name)) .collect(), } } @@ -70,22 +64,15 @@ fn enumm<'a>(elems: impl IntoIterator) -> SumTy SumType { variants: elems .into_iter() - .map(|(name, algebraic_type)| SumTypeVariant { - name: Some(name.into()), - algebraic_type, - }) + .map(|(name, ty)| SumTypeVariant::new_named(ty, name)) .collect(), } } -fn vec(ty: AlgebraicType) -> AlgebraicType { - AlgebraicType::Builtin(Array(ArrayType { elem_ty: Box::new(ty) })) -} - static EMPTY_TYPESPACE: Typespace = Typespace::new(Vec::new()); -fn in_space(x: &T) -> TypeInSpace<'_, T> { - TypeInSpace::new(&EMPTY_TYPESPACE, x) +fn in_space(x: &T) -> WithTypespace<'_, T> { + WithTypespace::new(&EMPTY_TYPESPACE, x) } fn de_json(schema: &ProductType, data: &str) -> serde_json::Result { diff --git a/crates/sats/src/algebraic_type.rs b/crates/sats/src/algebraic_type.rs index 03aa7305e0..5d43ffafcd 100644 --- a/crates/sats/src/algebraic_type.rs +++ b/crates/sats/src/algebraic_type.rs @@ -1,12 +1,14 @@ +pub mod fmt; pub mod map_notation; -pub mod satn; use crate::algebraic_value::de::{ValueDeserializeError, ValueDeserializer}; use crate::algebraic_value::ser::ValueSerializer; +use crate::meta_type::MetaType; use crate::{de::Deserialize, ser::Serialize, MapType}; -use crate::{AlgebraicTypeRef, AlgebraicValue, ArrayType, BuiltinType, ProductType, SumType, SumTypeVariant}; +use crate::{ + AlgebraicTypeRef, AlgebraicValue, ArrayType, BuiltinType, ProductType, ProductTypeElement, SumType, SumTypeVariant, +}; use enum_as_inner::EnumAsInner; -use thiserror::Error; /// The SpacetimeDB Algebraic Type System (SATS) is a structural type system in /// which a nominal type system can be constructed. @@ -54,244 +56,294 @@ use thiserror::Error; #[derive(EnumAsInner, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub enum AlgebraicType { + /// A structural sum type. + /// + /// Unlike most languages, sums in SATs are *[structural]* and not nominal. + /// When checking whether two nominal types are the same, + /// their names and/or declaration sites (e.g., module / namespace) are considered. + /// Meanwhile, a structural type system would only check the structure of the type itself, + /// e.g., the names of its variants and their inner data types in the case of a sum. + /// + /// This is also known as a discriminated union (implementation) or disjoint union. + /// Another name is [coproduct (category theory)](https://ncatlab.org/nlab/show/coproduct). + /// + /// These structures are known as sum types because the number of possible values a sum + /// ```ignore + /// { N_0(T_0), N_1(T_1), ..., N_n(T_n) } + /// ``` + /// is: + /// ```ignore + /// Σ (i ∈ 0..n). values(T_i) + /// ``` + /// so for example, `values({ A(U64), B(Bool) }) = values(U64) + values(Bool)`. + /// + /// See also: https://ncatlab.org/nlab/show/sum+type. + /// + /// [structural]: https://en.wikipedia.org/wiki/Structural_type_system Sum(SumType), + /// A structural product type. + /// + /// This is also known as `struct` and `tuple` in many languages, + /// but note that unlike most languages, sums in SATs are *[structural]* and not nominal. + /// When checking whether two nominal types are the same, + /// their names and/or declaration sites (e.g., module / namespace) are considered. + /// Meanwhile, a structural type system would only check the structure of the type itself, + /// e.g., the names of its fields and their types in the case of a record. + /// The name "product" comes from category theory. + /// + /// See also: https://ncatlab.org/nlab/show/product+type. + /// + /// These structures are known as product types because the number of possible values in product + /// ```ignore + /// { N_0: T_0, N_1: T_1, ..., N_n: T_n } + /// ``` + /// is: + /// ```ignore + /// Π (i ∈ 0..n). values(T_i) + /// ``` + /// so for example, `values({ A: U64, B: Bool }) = values(U64) * values(Bool)`. + /// + /// [structural]: https://en.wikipedia.org/wiki/Structural_type_system Product(ProductType), + /// A bulltin type, e.g., `bool`. Builtin(BuiltinType), + /// A type where the definition is given by the typing context (`Typespace`). + /// In other words, this is defined by a pointer to another `AlgebraicType`. + /// + /// This should not be conflated with reference and pointer types in languages like Rust, + /// In other words, this is not `&T` or `*const T`. Ref(AlgebraicTypeRef), } +#[allow(non_upper_case_globals)] impl AlgebraicType { - #[allow(non_upper_case_globals)] - pub const Bool: Self = AlgebraicType::Builtin(BuiltinType::Bool); - #[allow(non_upper_case_globals)] - pub const I8: Self = AlgebraicType::Builtin(BuiltinType::I8); - #[allow(non_upper_case_globals)] - pub const U8: Self = AlgebraicType::Builtin(BuiltinType::U8); - #[allow(non_upper_case_globals)] - pub const I16: Self = AlgebraicType::Builtin(BuiltinType::I16); - #[allow(non_upper_case_globals)] - pub const U16: Self = AlgebraicType::Builtin(BuiltinType::U16); - #[allow(non_upper_case_globals)] - pub const I32: Self = AlgebraicType::Builtin(BuiltinType::I32); - #[allow(non_upper_case_globals)] - pub const U32: Self = AlgebraicType::Builtin(BuiltinType::U32); - #[allow(non_upper_case_globals)] - pub const I64: Self = AlgebraicType::Builtin(BuiltinType::I64); - #[allow(non_upper_case_globals)] - pub const U64: Self = AlgebraicType::Builtin(BuiltinType::U64); - #[allow(non_upper_case_globals)] - pub const I128: Self = AlgebraicType::Builtin(BuiltinType::I128); - #[allow(non_upper_case_globals)] - pub const U128: Self = AlgebraicType::Builtin(BuiltinType::U128); - #[allow(non_upper_case_globals)] - pub const F32: Self = AlgebraicType::Builtin(BuiltinType::F32); - #[allow(non_upper_case_globals)] - pub const F64: Self = AlgebraicType::Builtin(BuiltinType::F64); - #[allow(non_upper_case_globals)] - pub const String: Self = AlgebraicType::Builtin(BuiltinType::String); - - #[allow(non_upper_case_globals)] + /// The built-in Bool type. + pub const Bool: Self = Self::Builtin(BuiltinType::Bool); + + /// The built-in signed 8-bit integer type. + pub const I8: Self = Self::Builtin(BuiltinType::I8); + + /// The built-in unsigned 8-bit integer type. + pub const U8: Self = Self::Builtin(BuiltinType::U8); + + /// The built-in signed 16-bit integer type. + pub const I16: Self = Self::Builtin(BuiltinType::I16); + + /// The built-in unsigned 16-bit integer type. + pub const U16: Self = Self::Builtin(BuiltinType::U16); + + /// The built-in signed 32-bit integer type. + pub const I32: Self = Self::Builtin(BuiltinType::I32); + + /// The built-in unsigned 32-bit integer type. + pub const U32: Self = Self::Builtin(BuiltinType::U32); + + /// The built-in signed 64-bit integer type. + pub const I64: Self = Self::Builtin(BuiltinType::I64); + + /// The built-in unsigned 64-bit integer type. + pub const U64: Self = Self::Builtin(BuiltinType::U64); + + /// The built-in signed 128-bit integer type. + pub const I128: Self = Self::Builtin(BuiltinType::I128); + + /// The built-in unsigned 128-bit integer type. + pub const U128: Self = Self::Builtin(BuiltinType::U128); + + /// The built-in 32-bit floating point type. + pub const F32: Self = Self::Builtin(BuiltinType::F32); + + /// The built-in 64-bit floating point type. + pub const F64: Self = Self::Builtin(BuiltinType::F64); + + /// The built-in string type. + pub const String: Self = Self::Builtin(BuiltinType::String); + + /// The canonical 0-element unit type. + pub const UNIT_TYPE: Self = Self::product(Vec::new()); + + /// The canonical 0-variant "never" / "absurd" / "void" type. + pub const NEVER_TYPE: Self = Self::sum(Vec::new()); + + /// A type representing an array of `U8`s. pub fn bytes() -> Self { - Self::make_array_type(Self::U8) + Self::array(Self::U8) } } -impl AlgebraicType { - /// This is a static function that constructs the type of AlgebraicType and - /// returns it as an AlgebraicType. This could alternatively be implemented +impl MetaType for AlgebraicType { + /// This is a static function that constructs the type of `AlgebraicType` + /// and returns it as an `AlgebraicType`. + /// + /// This could alternatively be implemented /// as a regular AlgebraicValue or as a static variable. - pub fn make_meta_type() -> AlgebraicType { - AlgebraicType::Sum(SumType::new(vec![ - SumTypeVariant::new_named(SumType::make_meta_type(), "sum"), - SumTypeVariant::new_named(ProductType::make_meta_type(), "product"), - SumTypeVariant::new_named(BuiltinType::make_meta_type(), "builtin"), - SumTypeVariant::new_named(AlgebraicTypeRef::make_meta_type(), "ref"), - ])) + fn meta_type() -> Self { + AlgebraicType::sum(vec![ + SumTypeVariant::new_named(SumType::meta_type(), "sum"), + SumTypeVariant::new_named(ProductType::meta_type(), "product"), + SumTypeVariant::new_named(BuiltinType::meta_type(), "builtin"), + SumTypeVariant::new_named(AlgebraicTypeRef::meta_type(), "ref"), + ]) } +} - pub fn make_never_type() -> AlgebraicType { - AlgebraicType::Sum(SumType { variants: vec![] }) +impl AlgebraicType { + /// Returns a sum type with the given `variants`. + pub const fn sum(variants: Vec) -> Self { + AlgebraicType::Sum(SumType { variants }) } - pub const UNIT_TYPE: AlgebraicType = AlgebraicType::Product(ProductType { elements: Vec::new() }); + /// Returns a product type with the given `factors`. + pub const fn product(factors: Vec) -> Self { + AlgebraicType::Product(ProductType::new(factors)) + } - pub fn make_option_type(some_type: AlgebraicType) -> AlgebraicType { - AlgebraicType::Sum(SumType { - variants: vec![ - SumTypeVariant::new_named(some_type, "some"), - SumTypeVariant::new_named(AlgebraicType::UNIT_TYPE, "none"), - ], - }) + /// Returns a structural option type where `some_type` is the type for the `some` variant. + pub fn option(some_type: Self) -> Self { + Self::sum(vec![ + SumTypeVariant::new_named(some_type, "some"), + SumTypeVariant::new_named(AlgebraicType::UNIT_TYPE, "none"), + ]) } - pub fn make_array_type(ty: AlgebraicType) -> AlgebraicType { + /// Returns an unsized array type where the element type is `ty`. + pub fn array(ty: Self) -> Self { AlgebraicType::Builtin(BuiltinType::Array(ArrayType { elem_ty: Box::new(ty) })) } - pub fn make_map_type(key: AlgebraicType, value: AlgebraicType) -> AlgebraicType { + /// Returns a map type from the type `key` to the type `value`. + pub fn map(key: Self, value: Self) -> Self { let value = MapType::new(key, value); AlgebraicType::Builtin(BuiltinType::Map(value)) } - pub fn make_simple_enum<'a>(arms: impl Iterator) -> AlgebraicType { - AlgebraicType::Sum(SumType { - variants: arms + /// Returns a sum type of unit variants with names taken from `var_names`. + pub fn simple_enum<'a>(var_names: impl Iterator) -> Self { + Self::sum( + var_names .into_iter() .map(|x| SumTypeVariant::new_named(AlgebraicType::UNIT_TYPE, x)) .collect(), - }) + ) } pub fn as_value(&self) -> AlgebraicValue { self.serialize(ValueSerializer).unwrap_or_else(|x| match x {}) } - pub fn from_value(value: &AlgebraicValue) -> Result { + pub fn from_value(value: &AlgebraicValue) -> Result { Self::deserialize(ValueDeserializer::from_ref(value)) } } -#[derive(Error, Debug)] -pub enum TypeError { - #[error("Arrays must be homogeneous. It expects to be `{{expect.to_satns()}}` but `{{value.to_satns()}}` is of type `{{found.to_satns()}}`")] - Array { - expect: AlgebraicType, - found: AlgebraicType, - value: AlgebraicValue, - }, - #[error("Arrays must define a type for the elements")] - ArrayEmpty, - #[error("Maps must be homogeneous. It expects to be `{{key_expect.to_satns()}}:{{value_expect.to_satns()}}` but `{{key.to_satns()}}::{{value.to_satns()}}` is of type `{{key_found.to_satns()}}:{{value_found.to_satns()}}`")] - Map { - key_expect: AlgebraicType, - value_expect: AlgebraicType, - key_found: AlgebraicType, - value_found: AlgebraicType, - key: AlgebraicValue, - value: AlgebraicValue, - }, - #[error("Maps must define a type for both key & value")] - MapEmpty, -} - #[cfg(test)] mod tests { use super::AlgebraicType; - use crate::algebraic_type::map_notation; + use crate::meta_type::MetaType; use crate::satn::Satn; use crate::{ - algebraic_type::satn::Formatter, algebraic_type_ref::AlgebraicTypeRef, builtin_type::BuiltinType, - product_type::ProductType, product_type_element::ProductTypeElement, sum_type::SumType, typespace::Typespace, + algebraic_type::fmt::fmt_algebraic_type, algebraic_type::map_notation::fmt_algebraic_type as fmt_map, + algebraic_type_ref::AlgebraicTypeRef, product_type_element::ProductTypeElement, typespace::Typespace, }; - use crate::{TypeInSpace, ValueWithType}; + use crate::{ValueWithType, WithTypespace}; #[test] fn never() { - let never = AlgebraicType::Sum(SumType { variants: vec![] }); - assert_eq!("(|)", Formatter::new(&never).to_string()); + assert_eq!("(|)", fmt_algebraic_type(&AlgebraicType::NEVER_TYPE).to_string()); } #[test] fn never_map() { - let never = AlgebraicType::Sum(SumType { variants: vec![] }); - assert_eq!("{ ty_: Sum }", map_notation::Formatter::new(&never).to_string()); + assert_eq!("{ ty_: Sum }", fmt_map(&AlgebraicType::NEVER_TYPE).to_string()); } #[test] fn unit() { - let unit = AlgebraicType::Product(ProductType { elements: vec![] }); - assert_eq!("()", Formatter::new(&unit).to_string()); + assert_eq!("()", fmt_algebraic_type(&AlgebraicType::UNIT_TYPE).to_string()); } #[test] fn unit_map() { - let unit = AlgebraicType::Product(ProductType { elements: vec![] }); - assert_eq!("{ ty_: Product }", map_notation::Formatter::new(&unit).to_string()); + assert_eq!("{ ty_: Product }", fmt_map(&AlgebraicType::UNIT_TYPE).to_string()); } #[test] fn primitive() { - let u8 = AlgebraicType::Builtin(BuiltinType::U8); - assert_eq!("U8", Formatter::new(&u8).to_string()); + assert_eq!("U8", fmt_algebraic_type(&AlgebraicType::U8).to_string()); } #[test] fn primitive_map() { - let u8 = AlgebraicType::Builtin(BuiltinType::U8); - assert_eq!("{ ty_: Builtin, 0: U8 }", map_notation::Formatter::new(&u8).to_string()); + assert_eq!("{ ty_: Builtin, 0: U8 }", fmt_map(&AlgebraicType::U8).to_string()); } #[test] fn option() { - let never = AlgebraicType::Sum(SumType { variants: vec![] }); - let option = AlgebraicType::make_option_type(never); - assert_eq!("(some: (|) | none: ())", Formatter::new(&option).to_string()); + let option = AlgebraicType::option(AlgebraicType::NEVER_TYPE); + assert_eq!("(some: (|) | none: ())", fmt_algebraic_type(&option).to_string()); } #[test] fn option_map() { - let never = AlgebraicType::Sum(SumType { variants: vec![] }); - let option = AlgebraicType::make_option_type(never); + let option = AlgebraicType::option(AlgebraicType::NEVER_TYPE); assert_eq!( "{ ty_: Sum, some: { ty_: Sum }, none: { ty_: Product } }", - map_notation::Formatter::new(&option).to_string() + fmt_map(&option).to_string() ); } #[test] fn algebraic_type() { - let algebraic_type = AlgebraicType::make_meta_type(); - assert_eq!("(sum: (variants: Array<(name: (some: String | none: ()), algebraic_type: &0)>) | product: (elements: Array<(name: (some: String | none: ()), algebraic_type: &0)>) | builtin: (bool: () | i8: () | u8: () | i16: () | u16: () | i32: () | u32: () | i64: () | u64: () | i128: () | u128: () | f32: () | f64: () | string: () | array: &0 | map: (key_ty: &0, ty: &0)) | ref: U32)", Formatter::new(&algebraic_type).to_string()); + let algebraic_type = AlgebraicType::meta_type(); + assert_eq!( + "(sum: (variants: Array<(name: (some: String | none: ()), algebraic_type: &0)>) | product: (elements: Array<(name: (some: String | none: ()), algebraic_type: &0)>) | builtin: (bool: () | i8: () | u8: () | i16: () | u16: () | i32: () | u32: () | i64: () | u64: () | i128: () | u128: () | f32: () | f64: () | string: () | array: &0 | map: (key_ty: &0, ty: &0)) | ref: U32)", + fmt_algebraic_type(&algebraic_type).to_string() + ); } #[test] fn algebraic_type_map() { - let algebraic_type = AlgebraicType::make_meta_type(); - assert_eq!("{ ty_: Sum, sum: { ty_: Product, variants: { ty_: Builtin, 0: Array, 1: { ty_: Product, name: { ty_: Sum, some: { ty_: Builtin, 0: String }, none: { ty_: Product } }, algebraic_type: { ty_: Ref, 0: 0 } } } }, product: { ty_: Product, elements: { ty_: Builtin, 0: Array, 1: { ty_: Product, name: { ty_: Sum, some: { ty_: Builtin, 0: String }, none: { ty_: Product } }, algebraic_type: { ty_: Ref, 0: 0 } } } }, builtin: { ty_: Sum, bool: { ty_: Product }, i8: { ty_: Product }, u8: { ty_: Product }, i16: { ty_: Product }, u16: { ty_: Product }, i32: { ty_: Product }, u32: { ty_: Product }, i64: { ty_: Product }, u64: { ty_: Product }, i128: { ty_: Product }, u128: { ty_: Product }, f32: { ty_: Product }, f64: { ty_: Product }, string: { ty_: Product }, array: { ty_: Ref, 0: 0 }, map: { ty_: Product, key_ty: { ty_: Ref, 0: 0 }, ty: { ty_: Ref, 0: 0 } } }, ref: { ty_: Builtin, 0: U32 } }", map_notation::Formatter::new(&algebraic_type).to_string()); + let algebraic_type = AlgebraicType::meta_type(); + assert_eq!( + "{ ty_: Sum, sum: { ty_: Product, variants: { ty_: Builtin, 0: Array, 1: { ty_: Product, name: { ty_: Sum, some: { ty_: Builtin, 0: String }, none: { ty_: Product } }, algebraic_type: { ty_: Ref, 0: 0 } } } }, product: { ty_: Product, elements: { ty_: Builtin, 0: Array, 1: { ty_: Product, name: { ty_: Sum, some: { ty_: Builtin, 0: String }, none: { ty_: Product } }, algebraic_type: { ty_: Ref, 0: 0 } } } }, builtin: { ty_: Sum, bool: { ty_: Product }, i8: { ty_: Product }, u8: { ty_: Product }, i16: { ty_: Product }, u16: { ty_: Product }, i32: { ty_: Product }, u32: { ty_: Product }, i64: { ty_: Product }, u64: { ty_: Product }, i128: { ty_: Product }, u128: { ty_: Product }, f32: { ty_: Product }, f64: { ty_: Product }, string: { ty_: Product }, array: { ty_: Ref, 0: 0 }, map: { ty_: Product, key_ty: { ty_: Ref, 0: 0 }, ty: { ty_: Ref, 0: 0 } } }, ref: { ty_: Builtin, 0: U32 } }", + fmt_map(&algebraic_type).to_string() + ); } #[test] fn nested_products_and_sums() { - let never = AlgebraicType::Sum(SumType { variants: vec![] }); - let builtin = AlgebraicType::Builtin(BuiltinType::U8); - let product = AlgebraicType::Product(ProductType::new(vec![ProductTypeElement { + let builtin = AlgebraicType::U8; + let product = AlgebraicType::product(vec![ProductTypeElement { name: Some("thing".into()), - algebraic_type: AlgebraicType::Builtin(BuiltinType::U8), - }])); - let next = AlgebraicType::Sum(SumType::new_unnamed(vec![builtin.clone(), builtin.clone(), product])); - let next = AlgebraicType::Product(ProductType::new(vec![ + algebraic_type: AlgebraicType::U8, + }]); + let next = AlgebraicType::sum(vec![builtin.clone().into(), builtin.clone().into(), product.into()]); + let next = AlgebraicType::product(vec![ ProductTypeElement { algebraic_type: builtin.clone(), name: Some("test".into()), }, + next.into(), + builtin.into(), ProductTypeElement { - algebraic_type: next, - name: None, //Some("foo".into()), - }, - ProductTypeElement { - algebraic_type: builtin, - name: None, - }, - ProductTypeElement { - algebraic_type: never, + algebraic_type: AlgebraicType::NEVER_TYPE, name: Some("never".into()), }, - ])); + ]); assert_eq!( "(test: U8, 1: (U8 | U8 | (thing: U8)), 2: U8, never: (|))", - Formatter::new(&next).to_string() + fmt_algebraic_type(&next).to_string() ); } fn in_space<'a, T: crate::Value>(ts: &'a Typespace, ty: &'a T::Type, val: &'a T) -> ValueWithType<'a, T> { - TypeInSpace::new(ts, ty).with_value(val) + WithTypespace::new(ts, ty).with_value(val) } #[test] fn option_as_value() { - let never = AlgebraicType::Sum(SumType::new(Vec::new())); - let option = AlgebraicType::make_option_type(never); - let algebraic_type = AlgebraicType::make_meta_type(); + let option = AlgebraicType::option(AlgebraicType::NEVER_TYPE); + let algebraic_type = AlgebraicType::meta_type(); let typespace = Typespace::new(vec![algebraic_type]); let at_ref = AlgebraicType::Ref(AlgebraicTypeRef(0)); assert_eq!( @@ -302,8 +354,8 @@ mod tests { #[test] fn builtin_as_value() { - let array = AlgebraicType::Builtin(BuiltinType::U8); - let algebraic_type = AlgebraicType::make_meta_type(); + let array = AlgebraicType::U8; + let algebraic_type = AlgebraicType::meta_type(); let typespace = Typespace::new(vec![algebraic_type]); let at_ref = AlgebraicType::Ref(AlgebraicTypeRef(0)); assert_eq!( @@ -314,7 +366,7 @@ mod tests { #[test] fn algebraic_type_as_value() { - let algebraic_type = AlgebraicType::make_meta_type(); + let algebraic_type = AlgebraicType::meta_type(); let typespace = Typespace::new(vec![algebraic_type.clone()]); let at_ref = AlgebraicType::Ref(AlgebraicTypeRef(0)); assert_eq!( @@ -325,25 +377,24 @@ mod tests { #[test] fn option_from_value() { - let never = AlgebraicType::Sum(SumType::new(Vec::new())); - let option = AlgebraicType::make_option_type(never); + let option = AlgebraicType::option(AlgebraicType::NEVER_TYPE); AlgebraicType::from_value(&option.as_value()).expect("No errors."); } #[test] fn builtin_from_value() { - let u8 = AlgebraicType::Builtin(BuiltinType::U8); + let u8 = AlgebraicType::U8; AlgebraicType::from_value(&u8.as_value()).expect("No errors."); } #[test] fn algebraic_type_from_value() { - let algebraic_type = AlgebraicType::make_meta_type(); + let algebraic_type = AlgebraicType::meta_type(); AlgebraicType::from_value(&algebraic_type.as_value()).expect("No errors."); } fn _legacy_encoding_comparison() { - let algebraic_type = AlgebraicType::make_meta_type(); + let algebraic_type = AlgebraicType::meta_type(); let mut buf = Vec::new(); algebraic_type.as_value().encode(&mut buf); diff --git a/crates/sats/src/algebraic_type/fmt.rs b/crates/sats/src/algebraic_type/fmt.rs new file mode 100644 index 0000000000..9f52316937 --- /dev/null +++ b/crates/sats/src/algebraic_type/fmt.rs @@ -0,0 +1,86 @@ +use super::{AlgebraicType, BuiltinType, ProductType, SumType}; +use crate::de::fmt_fn; +use std::fmt::Display; + +/// Wraps the algebraic `ty` into a `Display`able. +/// +/// NOTE: You might ask: Why do we have a formatter and a notation for +/// `AlgebraicType`s if we don't have an encoding for `AlgebraicType`s? +/// +/// This is because we just want an easier to read text format for algebraic +/// types. This could just as easily take in an algebraic value, which +/// represents an algebraic type and format it that way. It's just more +/// convenient to format it from the Rust type. +pub fn fmt_algebraic_type(ty: &AlgebraicType) -> impl '_ + Display { + fmt_fn(move |f| match ty { + AlgebraicType::Sum(ty) => write!(f, "{}", fmt_sum_type(ty)), + AlgebraicType::Product(ty) => write!(f, "{}", fmt_product_type(ty)), + AlgebraicType::Builtin(p) => write!(f, "{}", fmt_builtin_type(p)), + AlgebraicType::Ref(r) => write!(f, "{}", r), + }) +} + +/// Wraps the builtin `ty` into a `Display`able. +fn fmt_product_type(ty: &ProductType) -> impl '_ + Display { + fmt_fn(move |f| { + write!(f, "(")?; + for (i, e) in ty.elements.iter().enumerate() { + if let Some(name) = &e.name { + write!(f, "{}", name)?; + } else { + write!(f, "{}", i)?; + } + write!(f, ": ")?; + write!(f, "{}", fmt_algebraic_type(&e.algebraic_type))?; + if i < ty.elements.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, ")") + }) +} + +/// Wraps the builtin `ty` into a `Display`able. +fn fmt_sum_type(ty: &SumType) -> impl '_ + Display { + fmt_fn(move |f| { + if ty.variants.is_empty() { + return write!(f, "(|)"); + } + write!(f, "(")?; + for (i, e) in ty.variants.iter().enumerate() { + if let Some(name) = &e.name { + write!(f, "{}", name)?; + write!(f, ": ")?; + } + write!(f, "{}", fmt_algebraic_type(&e.algebraic_type))?; + if i < ty.variants.len() - 1 { + write!(f, " | ")?; + } + } + write!(f, ")") + }) +} + +/// Wraps the builtin `ty` into a `Display`able. +fn fmt_builtin_type(ty: &BuiltinType) -> impl '_ + Display { + use fmt_algebraic_type as fmt; + + fmt_fn(move |f| match ty { + BuiltinType::Bool => write!(f, "Bool"), + BuiltinType::I8 => write!(f, "I8"), + BuiltinType::U8 => write!(f, "U8"), + BuiltinType::I16 => write!(f, "I16"), + BuiltinType::U16 => write!(f, "U16"), + BuiltinType::I32 => write!(f, "I32"), + BuiltinType::U32 => write!(f, "U32"), + BuiltinType::I64 => write!(f, "I64"), + BuiltinType::U64 => write!(f, "U64"), + BuiltinType::I128 => write!(f, "I128"), + BuiltinType::U128 => write!(f, "U128"), + BuiltinType::F32 => write!(f, "F32"), + BuiltinType::F64 => write!(f, "F64"), + BuiltinType::String => write!(f, "String"), + BuiltinType::Array(a) => write!(f, "Array<{}>", fmt(&a.elem_ty)), + BuiltinType::Map(m) => write!(f, "Map<{}, {}>", fmt(&m.key_ty), fmt(&m.ty)), + }) +} diff --git a/crates/sats/src/algebraic_type/map_notation.rs b/crates/sats/src/algebraic_type/map_notation.rs index f6c2732732..433f483373 100644 --- a/crates/sats/src/algebraic_type/map_notation.rs +++ b/crates/sats/src/algebraic_type/map_notation.rs @@ -1,76 +1,58 @@ use super::AlgebraicType; use crate::builtin_type::BuiltinType; +use crate::de::fmt_fn; use crate::{ArrayType, MapType}; -use std::fmt::Display; +use std::fmt::{self, Formatter}; -pub struct Formatter<'a> { - ty: &'a AlgebraicType, -} +/// Wraps an algebraic `ty` in a `Display` impl using a map notation. +pub fn fmt_algebraic_type(ty: &AlgebraicType) -> impl '_ + fmt::Display { + use fmt_algebraic_type as fmt; -impl<'a> Formatter<'a> { - pub fn new(ty: &'a AlgebraicType) -> Self { - Self { ty } - } -} + // Format name/index + type. + let fmt_name_ty = |f: &mut Formatter<'_>, i, name, ty| match name { + Some(name) => write!(f, "{}: {}", name, fmt(ty)), + None => write!(f, "{}: {}", i, fmt(ty)), + }; -impl<'a> Display for Formatter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.ty { - AlgebraicType::Sum(ty) => { - write!(f, "{{ ty_: Sum",)?; - for (i, e_ty) in ty.variants.iter().enumerate() { - write!(f, ", ")?; - if let Some(name) = &e_ty.name { - write!(f, "{}: {}", name, Formatter::new(&e_ty.algebraic_type))?; - } else { - write!(f, "{}: {}", i, Formatter::new(&e_ty.algebraic_type))?; - } - } - write!(f, " }}",) - } - AlgebraicType::Product(ty) => { - write!(f, "{{ ty_: Product",)?; - for (i, e_ty) in ty.elements.iter().enumerate() { - write!(f, ", ")?; - if let Some(name) = &e_ty.name { - write!(f, "{}: {}", name, Formatter::new(&e_ty.algebraic_type))?; - } else { - write!(f, "{}: {}", i, Formatter::new(&e_ty.algebraic_type))?; - } - } - write!(f, " }}",) + fmt_fn(move |f| match ty { + AlgebraicType::Sum(ty) => { + write!(f, "{{ ty_: Sum")?; + for (i, e_ty) in ty.variants.iter().enumerate() { + write!(f, ", ")?; + fmt_name_ty(f, i, e_ty.name.as_ref(), &e_ty.algebraic_type)?; } - AlgebraicType::Builtin(ty) => { - write!(f, "{{ ty_: Builtin")?; - match &ty { - BuiltinType::Bool => write!(f, ", 0: Bool")?, - BuiltinType::I8 => write!(f, ", 0: I8")?, - BuiltinType::U8 => write!(f, ", 0: U8")?, - BuiltinType::I16 => write!(f, ", 0: I16")?, - BuiltinType::U16 => write!(f, ", 0: U16")?, - BuiltinType::I32 => write!(f, ", 0: I32")?, - BuiltinType::U32 => write!(f, ", 0: U32")?, - BuiltinType::I64 => write!(f, ", 0: I64")?, - BuiltinType::U64 => write!(f, ", 0: U64")?, - BuiltinType::I128 => write!(f, ", 0: I128")?, - BuiltinType::U128 => write!(f, ", 0: U128")?, - BuiltinType::F32 => write!(f, ", 0: F32")?, - BuiltinType::F64 => write!(f, ", 0: F64")?, - BuiltinType::String => write!(f, ", 0: String")?, - BuiltinType::Array(ArrayType { elem_ty }) => { - write!(f, ", 0: Array, 1: {}", Formatter::new(elem_ty))? - } - BuiltinType::Map(MapType { key_ty, ty }) => { - write!(f, "0: Map, 1: {}, 2: {}", Formatter::new(key_ty), Formatter::new(ty))? - } - } - write!(f, " }}",) + write!(f, " }}") + } + AlgebraicType::Product(ty) => { + write!(f, "{{ ty_: Product")?; + for (i, e_ty) in ty.elements.iter().enumerate() { + write!(f, ", ")?; + fmt_name_ty(f, i, e_ty.name.as_ref(), &e_ty.algebraic_type)?; } - AlgebraicType::Ref(r) => { - write!(f, "{{ ty_: Ref, 0: ")?; - write!(f, "{}", r.0)?; - write!(f, " }}",) + write!(f, " }}") + } + AlgebraicType::Builtin(ty) => { + write!(f, "{{ ty_: Builtin")?; + match &ty { + BuiltinType::Bool => write!(f, ", 0: Bool")?, + BuiltinType::I8 => write!(f, ", 0: I8")?, + BuiltinType::U8 => write!(f, ", 0: U8")?, + BuiltinType::I16 => write!(f, ", 0: I16")?, + BuiltinType::U16 => write!(f, ", 0: U16")?, + BuiltinType::I32 => write!(f, ", 0: I32")?, + BuiltinType::U32 => write!(f, ", 0: U32")?, + BuiltinType::I64 => write!(f, ", 0: I64")?, + BuiltinType::U64 => write!(f, ", 0: U64")?, + BuiltinType::I128 => write!(f, ", 0: I128")?, + BuiltinType::U128 => write!(f, ", 0: U128")?, + BuiltinType::F32 => write!(f, ", 0: F32")?, + BuiltinType::F64 => write!(f, ", 0: F64")?, + BuiltinType::String => write!(f, ", 0: String")?, + BuiltinType::Array(ArrayType { elem_ty }) => write!(f, ", 0: Array, 1: {}", fmt(elem_ty))?, + BuiltinType::Map(MapType { key_ty, ty }) => write!(f, "0: Map, 1: {}, 2: {}", fmt(key_ty), fmt(ty))?, } + write!(f, " }}") } - } + AlgebraicType::Ref(r) => write!(f, "{{ ty_: Ref, 0: {} }}", r.0), + }) } diff --git a/crates/sats/src/algebraic_type/satn.rs b/crates/sats/src/algebraic_type/satn.rs deleted file mode 100644 index 9173cb24cd..0000000000 --- a/crates/sats/src/algebraic_type/satn.rs +++ /dev/null @@ -1,39 +0,0 @@ -use super::AlgebraicType; -use crate::{builtin_type, product_type, sum_type}; -use std::fmt::Display; - -/// NOTE: You might ask: Why do we have a formatter and a notation for -/// `AlgebraicType`s if we don't have an encoding for `AlgebraicType`s? -/// -/// This is because we just want an easier to read text format for algebraic -/// types. This could just as easily take in an algebraic value, which -/// represents an algebraic type and format it that way. It's just more -/// convenient to format it from the Rust type. -pub struct Formatter<'a> { - ty: &'a AlgebraicType, -} - -impl<'a> Formatter<'a> { - pub fn new(ty: &'a AlgebraicType) -> Self { - Self { ty } - } -} - -impl<'a> Display for Formatter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.ty { - AlgebraicType::Sum(ty) => { - write!(f, "{}", sum_type::satn::Formatter::new(ty)) - } - AlgebraicType::Product(ty) => { - write!(f, "{}", product_type::satn::Formatter::new(ty)) - } - AlgebraicType::Builtin(p) => { - write!(f, "{}", builtin_type::satn::Formatter::new(p)) - } - AlgebraicType::Ref(r) => { - write!(f, "{}", r) - } - } - } -} diff --git a/crates/sats/src/algebraic_type_legacy_encoding.rs b/crates/sats/src/algebraic_type_legacy_encoding.rs deleted file mode 100644 index e785edbb95..0000000000 --- a/crates/sats/src/algebraic_type_legacy_encoding.rs +++ /dev/null @@ -1,270 +0,0 @@ -use crate::MapType; -use crate::{ - algebraic_type::AlgebraicType, builtin_type::BuiltinType, product_type::ProductType, - product_type_element::ProductTypeElement, sum_type::SumType, sum_type_variant::SumTypeVariant, -}; - -const TAG_SUM: u8 = 0x0; -const TAG_PRODUCT: u8 = 0x1; -const TAG_BOOL: u8 = 0x02; -const TAG_I8: u8 = 0x03; -const TAG_U8: u8 = 0x04; -const TAG_I16: u8 = 0x05; -const TAG_U16: u8 = 0x06; -const TAG_I32: u8 = 0x07; -const TAG_U32: u8 = 0x08; -const TAG_I64: u8 = 0x09; -const TAG_U64: u8 = 0x0a; -const TAG_I128: u8 = 0x0b; -const TAG_U128: u8 = 0x0c; -const TAG_F32: u8 = 0x0d; -const TAG_F64: u8 = 0x0e; -const TAG_STRING: u8 = 0x0f; -const TAG_ARRAY: u8 = 0x10; -const TAG_MAP: u8 = 0x11; -const TAG_REF: u8 = 0x12; - -impl AlgebraicType { - pub fn decode(bytes: impl AsRef<[u8]>) -> Result<(Self, usize), String> { - let bytes = bytes.as_ref(); - if bytes.len() == 0 { - return Err("Byte array length is invalid.".to_string()); - } - match bytes[0] { - TAG_PRODUCT => { - let (ty, bytes_read) = ProductType::decode(&bytes[1..])?; - Ok((AlgebraicType::Product(ty), bytes_read + 1)) - } - TAG_SUM => { - let (ty, bytes_read) = SumType::decode(&bytes[1..])?; - Ok((AlgebraicType::Sum(ty), bytes_read + 1)) - } - _ => { - let (ty, bytes_read) = BuiltinType::decode(&bytes[0..])?; - Ok((AlgebraicType::Builtin(ty), bytes_read)) - } - } - } - - pub fn encode(&self, bytes: &mut Vec) { - match self { - AlgebraicType::Product(ty) => { - bytes.push(TAG_PRODUCT); - ty.encode(bytes); - } - AlgebraicType::Sum(ty) => { - bytes.push(TAG_SUM); - ty.encode(bytes); - } - AlgebraicType::Builtin(ty) => { - ty.encode(bytes); - } - AlgebraicType::Ref(ty) => { - bytes.push(TAG_REF); - bytes.extend(ty.0.to_le_bytes()); - } - } - } -} - -impl BuiltinType { - pub fn decode(bytes: impl AsRef<[u8]>) -> Result<(Self, usize), String> { - let bytes = bytes.as_ref(); - if bytes.len() == 0 { - return Err("Byte array length is invalid.".to_string()); - } - match bytes[0] { - TAG_BOOL => Ok((Self::Bool, 1)), - TAG_I8 => Ok((Self::I8, 1)), - TAG_U8 => Ok((Self::U8, 1)), - TAG_I16 => Ok((Self::I16, 1)), - TAG_U16 => Ok((Self::U16, 1)), - TAG_I32 => Ok((Self::I32, 1)), - TAG_U32 => Ok((Self::U32, 1)), - TAG_I64 => Ok((Self::I64, 1)), - TAG_U64 => Ok((Self::U64, 1)), - TAG_I128 => Ok((Self::I128, 1)), - TAG_U128 => Ok((Self::U128, 1)), - TAG_F32 => Ok((Self::F32, 1)), - TAG_F64 => Ok((Self::F64, 1)), - TAG_STRING => Ok((Self::String, 1)), - TAG_ARRAY => { - let (ty, num_read) = AlgebraicType::decode(bytes)?; - Ok((Self::Array { ty: Box::new(ty) }, num_read)) - } - TAG_MAP => { - let mut num_read = 0; - let (key_ty, nr) = AlgebraicType::decode(bytes)?; - num_read += nr; - let (ty, nr) = AlgebraicType::decode(bytes)?; - num_read += nr; - Ok(( - Self::Map(MapType { - key_ty: Box::new(key_ty), - ty: Box::new(ty), - }), - num_read, - )) - } - b => panic!("Unknown {}", b), - } - } - - pub fn encode(&self, bytes: &mut Vec) { - match self { - BuiltinType::Bool => bytes.push(TAG_BOOL), - BuiltinType::I8 => bytes.push(TAG_I8), - BuiltinType::U8 => bytes.push(TAG_U8), - BuiltinType::I16 => bytes.push(TAG_I16), - BuiltinType::U16 => bytes.push(TAG_U16), - BuiltinType::I32 => bytes.push(TAG_I32), - BuiltinType::U32 => bytes.push(TAG_U32), - BuiltinType::I64 => bytes.push(TAG_I64), - BuiltinType::U64 => bytes.push(TAG_U64), - BuiltinType::I128 => bytes.push(TAG_I128), - BuiltinType::U128 => bytes.push(TAG_U128), - BuiltinType::F32 => bytes.push(TAG_F32), - BuiltinType::F64 => bytes.push(TAG_F64), - BuiltinType::String => bytes.push(TAG_STRING), - BuiltinType::Array { ty } => { - bytes.push(TAG_ARRAY); - ty.encode(bytes); - } - BuiltinType::Map(MapType { key_ty, ty }) => { - bytes.push(TAG_MAP); - key_ty.encode(bytes); - ty.encode(bytes); - } - } - } -} - -impl ProductType { - pub fn decode(bytes: impl AsRef<[u8]>) -> Result<(Self, usize), String> { - let mut num_read = 0; - let bytes = bytes.as_ref(); - if bytes.len() == 0 { - return Err("Byte array has invalid length.".to_string()); - } - - let len = bytes[num_read]; - num_read += 1; - - let mut elements = Vec::new(); - for _ in 0..len { - let (element, nr) = ProductTypeElement::decode(&bytes[num_read..])?; - elements.push(element); - num_read += nr; - } - Ok((ProductType { elements }, num_read)) - } - - pub fn encode(&self, bytes: &mut Vec) { - bytes.push(self.elements.len() as u8); - for item in &self.elements { - item.encode(bytes); - } - } -} - -impl SumType { - pub fn decode(bytes: impl AsRef<[u8]>) -> Result<(Self, usize), String> { - let mut num_read = 0; - let bytes = bytes.as_ref(); - if bytes.len() <= 0 { - return Err("Bytes array length is invalid.".to_string()); - } - - let len = bytes[num_read]; - num_read += 1; - - let mut items = Vec::new(); - for _ in 0..len { - let (item, nr) = SumTypeVariant::decode(&bytes[num_read..])?; - items.push(item); - num_read += nr; - } - Ok((SumType { variants: items }, num_read)) - } - - pub fn encode(&self, bytes: &mut Vec) { - bytes.push(self.variants.len() as u8); - for item in &self.variants { - item.encode(bytes); - } - } -} - -impl ProductTypeElement { - pub fn decode(bytes: impl AsRef<[u8]>) -> Result<(Self, usize), String> { - let mut num_read = 0; - let bytes = bytes.as_ref(); - if bytes.len() <= 0 { - return Err("Byte array has invalid length.".to_string()); - } - - let name_len = bytes[num_read]; - num_read += 1; - - let name = if name_len == 0 { - None - } else { - let name_bytes = &bytes[num_read..num_read + name_len as usize]; - num_read += name_len as usize; - Some(String::from_utf8(name_bytes.to_vec()).expect("Yeah this should really return a result.")) - }; - - let (algebraic_type, nr) = AlgebraicType::decode(&bytes[num_read..])?; - num_read += nr; - - Ok((ProductTypeElement { algebraic_type, name }, num_read)) - } - - pub fn encode(&self, bytes: &mut Vec) { - if let Some(name) = &self.name { - bytes.push(name.len() as u8); - bytes.extend(name.as_bytes()) - } else { - bytes.push(0); - } - - self.algebraic_type.encode(bytes); - } -} - -impl SumTypeVariant { - pub fn decode(bytes: impl AsRef<[u8]>) -> Result<(Self, usize), String> { - let mut num_read = 0; - let bytes = bytes.as_ref(); - if bytes.len() <= 0 { - return Err("Byte array has invalid length.".to_string()); - } - - let name_len = bytes[num_read]; - num_read += 1; - - let name = if name_len == 0 { - None - } else { - let name_bytes = &bytes[num_read..num_read + name_len as usize]; - num_read += name_len as usize; - Some(String::from_utf8(name_bytes.to_vec()).expect("Yeah this should really return a result.")) - }; - - let (algebraic_type, nr) = AlgebraicType::decode(&bytes[num_read..])?; - num_read += nr; - - Ok((SumTypeVariant { algebraic_type, name }, num_read)) - } - - pub fn encode(&self, bytes: &mut Vec) { - if let Some(name) = &self.name { - bytes.push(name.len() as u8); - bytes.extend(name.as_bytes()) - } else { - bytes.push(0); - } - - self.algebraic_type.encode(bytes); - } -} diff --git a/crates/sats/src/algebraic_type_ref.rs b/crates/sats/src/algebraic_type_ref.rs index f14a79ac51..bb7ac53f51 100644 --- a/crates/sats/src/algebraic_type_ref.rs +++ b/crates/sats/src/algebraic_type_ref.rs @@ -1,35 +1,35 @@ -use crate::{algebraic_type::AlgebraicType, builtin_type::BuiltinType}; +use crate::{algebraic_type::AlgebraicType, impl_deserialize, impl_serialize, meta_type::MetaType}; use std::fmt::Display; +/// A reference to an [`AlgebraicType`] within a `Typespace`. +/// +/// Using this in a different `Typespace` than its maker +/// will most likely result in a panic. #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] -pub struct AlgebraicTypeRef(pub u32); +pub struct AlgebraicTypeRef( + /// The index into the specific `Typespace`'s list of types. + pub u32, +); impl AlgebraicTypeRef { - pub fn idx(self) -> usize { + /// Returns the index into the specific `Typespace`'s list of types. + pub const fn idx(self) -> usize { self.0 as usize } } -impl crate::ser::Serialize for AlgebraicTypeRef { - fn serialize(&self, serializer: S) -> Result { - self.0.serialize(serializer) - } -} - -impl<'de> crate::de::Deserialize<'de> for AlgebraicTypeRef { - fn deserialize>(deserializer: D) -> Result { - u32::deserialize(deserializer).map(Self) - } -} +impl_serialize!([] AlgebraicTypeRef, (self, ser) => self.0.serialize(ser)); +impl_deserialize!([] AlgebraicTypeRef, de => u32::deserialize(de).map(Self)); impl Display for AlgebraicTypeRef { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // For example: `&42`. write!(f, "&{}", self.0) } } -impl AlgebraicTypeRef { - pub fn make_meta_type() -> AlgebraicType { - AlgebraicType::Builtin(BuiltinType::U32) +impl MetaType for AlgebraicTypeRef { + fn meta_type() -> AlgebraicType { + AlgebraicType::U32 } } diff --git a/crates/sats/src/algebraic_value.rs b/crates/sats/src/algebraic_value.rs index 7996673f02..30b7a70df2 100644 --- a/crates/sats/src/algebraic_value.rs +++ b/crates/sats/src/algebraic_value.rs @@ -3,297 +3,427 @@ pub mod ser; use std::collections::BTreeMap; use crate::builtin_value::{F32, F64}; -use crate::{ - AlgebraicType, ArrayValue, BuiltinType, BuiltinValue, ProductType, ProductTypeElement, ProductValue, SumValue, -}; +use crate::{AlgebraicType, ArrayValue, BuiltinType, BuiltinValue, ProductValue, SumValue}; use enum_as_inner::EnumAsInner; +/// A value in SATS. +/// +/// These values are fully evaluated. #[derive(EnumAsInner, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum AlgebraicValue { + /// A structural sum value. + /// + /// Given a sum type `{ N_0(T_0), N_1(T_1), ..., N_n(T_n) }` + /// where `N_i` denotes a variant name + /// and where `T_i` denotes the type the variant stores, + /// a sum value makes a specific choice as to the variant. + /// So for example, we might chose `N_1(T_1)` + /// and represent this choice with `(1, v)` where `v` is a value of type `T_1`. Sum(SumValue), + /// A structural product value. + /// + /// Given a product type `{ N_0: T_0, N_1: T_1, ..., N_n: T_n }` + /// where `N_i` denotes a field / element name + /// and where `T_i` denotes the type the field stores, + /// a product value stores a value `v_i` of type `T_i` for each field `N_i`. Product(ProductValue), + /// A builtin value that has a builtin type. Builtin(BuiltinValue), } +#[allow(non_snake_case)] impl AlgebraicValue { - pub const UNIT: Self = AlgebraicValue::Product(ProductValue { elements: Vec::new() }); + /// The canonical unit value defined as the nullary product value `()`. + /// + /// The type of `UNIT` is `()`. + pub const UNIT: Self = Self::product(Vec::new()); + /// Interpret the value as a `bool` or `None` if it isn't a `bool` value. #[inline] pub fn as_bool(&self) -> Option<&bool> { self.as_builtin()?.as_bool() } + + /// Interpret the value as an `i8` or `None` if it isn't a `i8` value. #[inline] pub fn as_i8(&self) -> Option<&i8> { self.as_builtin()?.as_i8() } + + /// Interpret the value as a `u8` or `None` if it isn't a `u8` value. #[inline] pub fn as_u8(&self) -> Option<&u8> { self.as_builtin()?.as_u8() } + + /// Interpret the value as an `i16` or `None` if it isn't an `i16` value. #[inline] pub fn as_i16(&self) -> Option<&i16> { self.as_builtin()?.as_i16() } + + /// Interpret the value as a `u16` or `None` if it isn't a `u16` value. #[inline] pub fn as_u16(&self) -> Option<&u16> { self.as_builtin()?.as_u16() } + + /// Interpret the value as an `i32` or `None` if it isn't an `i32` value. #[inline] pub fn as_i32(&self) -> Option<&i32> { self.as_builtin()?.as_i32() } + + /// Interpret the value as a `u32` or `None` if it isn't a `u32` value. #[inline] pub fn as_u32(&self) -> Option<&u32> { self.as_builtin()?.as_u32() } + + /// Interpret the value as an `i64` or `None` if it isn't an `i64` value. #[inline] pub fn as_i64(&self) -> Option<&i64> { self.as_builtin()?.as_i64() } + + /// Interpret the value as a `u64` or `None` if it isn't a `u64` value. #[inline] pub fn as_u64(&self) -> Option<&u64> { self.as_builtin()?.as_u64() } + + /// Interpret the value as an `i128` or `None` if it isn't an `i128` value. #[inline] pub fn as_i128(&self) -> Option<&i128> { self.as_builtin()?.as_i128() } + + /// Interpret the value as a `u128` or `None` if it isn't a `u128` value. #[inline] pub fn as_u128(&self) -> Option<&u128> { self.as_builtin()?.as_u128() } + + /// Interpret the value as a `f32` or `None` if it isn't a `f32` value. #[inline] pub fn as_f32(&self) -> Option<&F32> { self.as_builtin()?.as_f32() } + + /// Interpret the value as a `f64` or `None` if it isn't a `f64` value. #[inline] pub fn as_f64(&self) -> Option<&F64> { self.as_builtin()?.as_f64() } + + /// Interpret the value as a `String` or `None` if it isn't a `String` value. #[inline] pub fn as_string(&self) -> Option<&String> { self.as_builtin()?.as_string() } + + /// Interpret the value as a `Vec` or `None` if it isn't a `Vec` value. #[inline] pub fn as_bytes(&self) -> Option<&Vec> { self.as_builtin()?.as_bytes() } + + /// Interpret the value as an `ArrayValue` or `None` if it isn't an `ArrayValue` value. #[inline] pub fn as_array(&self) -> Option<&ArrayValue> { self.as_builtin()?.as_array() } + + /// Interpret the value as a map or `None` if it isn't a map value. #[inline] - pub fn as_map(&self) -> Option<&BTreeMap> { + pub fn as_map(&self) -> Option<&BTreeMap> { self.as_builtin()?.as_map() } + /// Convert the value into a `bool` or `Err(self)` if it isn't a `bool` value. #[inline] pub fn into_bool(self) -> Result { self.into_builtin()?.into_bool().map_err(Self::Builtin) } + + /// Convert the value into an `i8` or `Err(self)` if it isn't an `i8` value. #[inline] pub fn into_i8(self) -> Result { self.into_builtin()?.into_i8().map_err(Self::Builtin) } + + /// Convert the value into a `u8` or `Err(self)` if it isn't a `u8` value. #[inline] pub fn into_u8(self) -> Result { self.into_builtin()?.into_u8().map_err(Self::Builtin) } + + /// Convert the value into an `i16` or `Err(self)` if it isn't an `i16` value. #[inline] pub fn into_i16(self) -> Result { self.into_builtin()?.into_i16().map_err(Self::Builtin) } + + /// Convert the value into a `u16` or `Err(self)` if it isn't a `u16` value. #[inline] pub fn into_u16(self) -> Result { self.into_builtin()?.into_u16().map_err(Self::Builtin) } + + /// Convert the value into an `i32` or `Err(self)` if it isn't an `i32` value. #[inline] pub fn into_i32(self) -> Result { self.into_builtin()?.into_i32().map_err(Self::Builtin) } + + /// Convert the value into a `u32` or `Err(self)` if it isn't a `u32` value. #[inline] pub fn into_u32(self) -> Result { self.into_builtin()?.into_u32().map_err(Self::Builtin) } + + /// Convert the value into an `i64` or `Err(self)` if it isn't an `i64` value. #[inline] pub fn into_i64(self) -> Result { self.into_builtin()?.into_i64().map_err(Self::Builtin) } + + /// Convert the value into a `u64` or `Err(self)` if it isn't a `u64` value. #[inline] pub fn into_u64(self) -> Result { self.into_builtin()?.into_u64().map_err(Self::Builtin) } + + /// Convert the value into an `i128` or `Err(self)` if it isn't an `i128` value. #[inline] pub fn into_i128(self) -> Result { self.into_builtin()?.into_i128().map_err(Self::Builtin) } + + /// Convert the value into a `u128` or `Err(self)` if it isn't a `u128` value. #[inline] pub fn into_u128(self) -> Result { self.into_builtin()?.into_u128().map_err(Self::Builtin) } + + /// Convert the value into a `f32` or `Err(self)` if it isn't a `f32` value. #[inline] pub fn into_f32(self) -> Result { self.into_builtin()?.into_f32().map_err(Self::Builtin) } + + /// Convert the value into a `f64` or `Err(self)` if it isn't a `f64` value. #[inline] pub fn into_f64(self) -> Result { self.into_builtin()?.into_f64().map_err(Self::Builtin) } + + /// Convert the value into a `String` or `Err(self)` if it isn't a `String` value. #[inline] pub fn into_string(self) -> Result { self.into_builtin()?.into_string().map_err(Self::Builtin) } + + /// Convert the value into a `Vec` or `Err(self)` if it isn't a `Vec` value. #[inline] pub fn into_bytes(self) -> Result, Self> { self.into_builtin()?.into_bytes().map_err(Self::Builtin) } + + /// Convert the value into an [`ArrayValue`] or `Err(self)` if it isn't an [`ArrayValue`] value. #[inline] pub fn into_array(self) -> Result { self.into_builtin()?.into_array().map_err(Self::Builtin) } + + /// Convert the value into a map or `Err(self)` if it isn't a map value. #[inline] - pub fn into_map(self) -> Result, Self> { + pub fn into_map(self) -> Result, Self> { self.into_builtin()?.into_map().map_err(Self::Builtin) } - #[allow(non_snake_case)] + /// Returns an [`AlgebraicValue`] representing `v: bool`. #[inline] - pub fn Bool(v: bool) -> Self { + pub const fn Bool(v: bool) -> Self { Self::Builtin(BuiltinValue::Bool(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: i8`. #[inline] - pub fn I8(v: i8) -> Self { + pub const fn I8(v: i8) -> Self { Self::Builtin(BuiltinValue::I8(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: u8`. #[inline] - pub fn U8(v: u8) -> Self { + pub const fn U8(v: u8) -> Self { Self::Builtin(BuiltinValue::U8(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: i16`. #[inline] - pub fn I16(v: i16) -> Self { + pub const fn I16(v: i16) -> Self { Self::Builtin(BuiltinValue::I16(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: u16`. #[inline] - pub fn U16(v: u16) -> Self { + pub const fn U16(v: u16) -> Self { Self::Builtin(BuiltinValue::U16(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: i32`. #[inline] - pub fn I32(v: i32) -> Self { + pub const fn I32(v: i32) -> Self { Self::Builtin(BuiltinValue::I32(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: u32`. #[inline] - pub fn U32(v: u32) -> Self { + pub const fn U32(v: u32) -> Self { Self::Builtin(BuiltinValue::U32(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: i64`. #[inline] - pub fn I64(v: i64) -> Self { + pub const fn I64(v: i64) -> Self { Self::Builtin(BuiltinValue::I64(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: u64`. #[inline] - pub fn U64(v: u64) -> Self { + pub const fn U64(v: u64) -> Self { Self::Builtin(BuiltinValue::U64(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: i128`. #[inline] - pub fn I128(v: i128) -> Self { + pub const fn I128(v: i128) -> Self { Self::Builtin(BuiltinValue::I128(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: u128`. #[inline] - pub fn U128(v: u128) -> Self { + pub const fn U128(v: u128) -> Self { Self::Builtin(BuiltinValue::U128(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: f32`. #[inline] - pub fn F32(v: F32) -> Self { + pub const fn F32(v: F32) -> Self { Self::Builtin(BuiltinValue::F32(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: f64`. #[inline] - pub fn F64(v: F64) -> Self { + pub const fn F64(v: F64) -> Self { Self::Builtin(BuiltinValue::F64(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: String`. #[inline] - pub fn String(v: String) -> Self { + pub const fn String(v: String) -> Self { Self::Builtin(BuiltinValue::String(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] representing `v: Vec`. #[inline] - pub fn Bytes(v: Vec) -> Self { + pub const fn Bytes(v: Vec) -> Self { Self::Builtin(BuiltinValue::Bytes(v)) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] for a `val` which can be converted into an [`ArrayValue`]. #[inline] - pub fn ArrayOf>(val: T) -> Self { + pub fn ArrayOf(val: impl Into) -> Self { Self::Builtin(BuiltinValue::Array { val: val.into() }) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] for `some: v`. + /// + /// The `some` variant is assigned the tag `0`. #[inline] - pub fn OptionSome(v: AlgebraicValue) -> Self { - Self::Sum(SumValue { - tag: 0, - value: Box::new(v), - }) + pub fn OptionSome(v: Self) -> Self { + Self::sum(0, v) } - #[allow(non_snake_case)] + + /// Returns an [`AlgebraicValue`] for `none`. + /// + /// The `none` variant is assigned the tag `1`. #[inline] pub fn OptionNone() -> Self { - Self::Sum(SumValue { - tag: 1, - value: Box::new(AlgebraicValue::Product(ProductValue { elements: Vec::new() })), - }) + Self::sum(1, Self::UNIT) + } + + /// Returns an [`AlgebraicValue`] representing a sum value with `tag` and `value`. + pub fn sum(tag: u8, value: Self) -> Self { + let value = Box::new(value); + Self::Sum(SumValue { tag, value }) + } + + /// Returns an [`AlgebraicValue`] representing a product value with the given `elements`. + pub const fn product(elements: Vec) -> Self { + Self::Product(ProductValue { elements }) } + + /// Returns an [`AlgebraicValue`] representing a map value defined by the given `map`. + pub const fn map(map: BTreeMap) -> Self { + Self::Builtin(BuiltinValue::Map { val: map }) + } + + /// Returns the [`AlgebraicType`] of the sum value `x`. pub(crate) fn type_of_sum(x: &SumValue) -> AlgebraicType { - AlgebraicType::Product(ProductType::new(vec![ProductTypeElement::new(x.value.type_of(), None)])) + // TODO(centril): This is unsound! + // + // The type of a sum value must be a sum type and *not* a product type. + // Suppose `x.tag` is for the variant `VarName(VarType)`. + // Then `VarType` is *not* the same type as `{ VarName(VarType) | r }` + // where `r` represents a polymorphic variants compontent. + // + // To assign this a correct type we either have to store the type with the value + // or alternatively, we must have polymorphic variants (see row polymorphism) + // *and* derive the correct variant name. + AlgebraicType::product(vec![x.value.type_of().into()]) } - pub(crate) fn type_of_product(x: &ProductValue) -> AlgebraicType { - let ty = x.elements.iter().map(|x| ProductTypeElement::new(x.type_of(), None)); - AlgebraicType::Product(ProductType::new(ty.collect())) + /// Returns the [`AlgebraicType`] of the product value `x`. + pub(crate) fn type_of_product(x: &ProductValue) -> AlgebraicType { + AlgebraicType::product(x.elements.iter().map(|x| x.type_of().into()).collect()) } - pub(crate) fn type_of_map(val: &BTreeMap) -> AlgebraicType { - let ty = if let Some((k, v)) = val.first_key_value() { - ProductType::new(vec![ - ProductTypeElement::new(k.type_of(), None), - ProductTypeElement::new(v.type_of(), None), - ]) + + /// Returns the [`AlgebraicType`] of the map with key type `k` and value type `v`. + pub(crate) fn type_of_map(val: &BTreeMap) -> AlgebraicType { + AlgebraicType::product(if let Some((k, v)) = val.first_key_value() { + vec![k.type_of().into(), v.type_of().into()] } else { - let ty = ProductTypeElement::new(AlgebraicType::make_never_type(), None); - ProductType::new(vec![ty.clone(), ty]) - }; - AlgebraicType::Product(ty) + // TODO(centril): What is the motivation for this? + // I think this requires a soundness argument. + // I could see that it is OK with the argument that this is an empty map + // under the requirement that we cannot insert elements into the map. + vec![AlgebraicType::NEVER_TYPE.into(); 2] + }) } - /// Infer the [AlgebraicType] of [Self]. + /// Infer the [`AlgebraicType`] of an [`AlgebraicValue`]. pub fn type_of(&self) -> AlgebraicType { - //todo: What are the types of empty arrays/maps/sums... + // TODO: What are the types of empty arrays/maps/sums? match self { AlgebraicValue::Sum(x) => Self::type_of_sum(x), AlgebraicValue::Product(x) => Self::type_of_product(x), AlgebraicValue::Builtin(x) => match x { - BuiltinValue::Bool(_) => BuiltinType::Bool.into(), - BuiltinValue::I8(_) => BuiltinType::I8.into(), - BuiltinValue::U8(_) => BuiltinType::U8.into(), - BuiltinValue::I16(_) => BuiltinType::I16.into(), - BuiltinValue::U16(_) => BuiltinType::U16.into(), - BuiltinValue::I32(_) => BuiltinType::I32.into(), - BuiltinValue::U32(_) => BuiltinType::U32.into(), - BuiltinValue::I64(_) => BuiltinType::I64.into(), - BuiltinValue::U64(_) => BuiltinType::U64.into(), - BuiltinValue::I128(_) => BuiltinType::I128.into(), - BuiltinValue::U128(_) => BuiltinType::U128.into(), - BuiltinValue::F32(_) => BuiltinType::F32.into(), - BuiltinValue::F64(_) => BuiltinType::F64.into(), - BuiltinValue::String(_) => BuiltinType::String.into(), + BuiltinValue::Bool(_) => AlgebraicType::Bool, + BuiltinValue::I8(_) => AlgebraicType::I8, + BuiltinValue::U8(_) => AlgebraicType::U8, + BuiltinValue::I16(_) => AlgebraicType::I16, + BuiltinValue::U16(_) => AlgebraicType::U16, + BuiltinValue::I32(_) => AlgebraicType::I32, + BuiltinValue::U32(_) => AlgebraicType::U32, + BuiltinValue::I64(_) => AlgebraicType::I64, + BuiltinValue::U64(_) => AlgebraicType::U64, + BuiltinValue::I128(_) => AlgebraicType::I128, + BuiltinValue::U128(_) => AlgebraicType::U128, + BuiltinValue::F32(_) => AlgebraicType::F32, + BuiltinValue::F64(_) => AlgebraicType::F64, + BuiltinValue::String(_) => AlgebraicType::String, BuiltinValue::Array { val } => AlgebraicType::Builtin(BuiltinType::Array(val.type_of())), BuiltinValue::Map { val } => Self::type_of_map(val), }, @@ -316,32 +446,26 @@ mod tests { use crate::satn::Satn; use crate::{ - AlgebraicType, AlgebraicValue, ArrayType, BuiltinType, BuiltinValue, MapType, ProductType, ProductTypeElement, - ProductValue, TypeInSpace, Typespace, ValueWithType, + AlgebraicType, AlgebraicValue, ArrayValue, ProductTypeElement, Typespace, ValueWithType, WithTypespace, }; fn in_space<'a, T: crate::Value>(ts: &'a Typespace, ty: &'a T::Type, val: &'a T) -> ValueWithType<'a, T> { - TypeInSpace::new(ts, ty).with_value(val) + WithTypespace::new(ts, ty).with_value(val) } #[test] fn unit() { - let val = AlgebraicValue::Product(ProductValue { elements: vec![] }); - let unit = AlgebraicType::Product(ProductType::new(vec![])); + let val = AlgebraicValue::UNIT; + let unit = AlgebraicType::UNIT_TYPE; let typespace = Typespace::new(vec![]); assert_eq!(in_space(&typespace, &unit, &val).to_satn(), "()"); } #[test] fn product_value() { - let product_type = AlgebraicType::Product(ProductType::new(vec![ProductTypeElement::new( - AlgebraicType::Builtin(BuiltinType::I32), - Some("foo".into()), - )])); + let product_type = AlgebraicType::product(vec![ProductTypeElement::new_named(AlgebraicType::I32, "foo")]); let typespace = Typespace::new(vec![]); - let product_value = AlgebraicValue::Product(ProductValue { - elements: vec![AlgebraicValue::Builtin(BuiltinValue::I32(42))], - }); + let product_value = AlgebraicValue::product(vec![AlgebraicValue::I32(42)]); assert_eq!( "(foo = 42)", in_space(&typespace, &product_type, &product_value).to_satn(), @@ -350,8 +474,7 @@ mod tests { #[test] fn option_some() { - let never = AlgebraicType::make_never_type(); - let option = AlgebraicType::make_option_type(never); + let option = AlgebraicType::option(AlgebraicType::NEVER_TYPE); let sum_value = AlgebraicValue::OptionNone(); let typespace = Typespace::new(vec![]); assert_eq!("(none = ())", in_space(&typespace, &option, &sum_value).to_satn(),); @@ -359,57 +482,42 @@ mod tests { #[test] fn primitive() { - let u8 = AlgebraicType::Builtin(BuiltinType::U8); - let value = AlgebraicValue::Builtin(BuiltinValue::U8(255)); + let u8 = AlgebraicType::U8; + let value = AlgebraicValue::U8(255); let typespace = Typespace::new(vec![]); assert_eq!(in_space(&typespace, &u8, &value).to_satn(), "255"); } #[test] fn array() { - let array = AlgebraicType::Builtin(BuiltinType::Array(ArrayType { - elem_ty: Box::new(AlgebraicType::Builtin(BuiltinType::U8)), - })); - let value = AlgebraicValue::Builtin(BuiltinValue::Array { - val: Default::default(), - }); + let array = AlgebraicType::array(AlgebraicType::U8); + let value = AlgebraicValue::ArrayOf(ArrayValue::Sum(Vec::new())); let typespace = Typespace::new(vec![]); assert_eq!(in_space(&typespace, &array, &value).to_satn(), "[]"); } #[test] fn array_of_values() { - let array = AlgebraicType::Builtin(BuiltinType::Array(ArrayType { - elem_ty: Box::new(AlgebraicType::Builtin(BuiltinType::U8)), - })); - let value = AlgebraicValue::Builtin(BuiltinValue::Array { val: vec![3u8].into() }); + let array = AlgebraicType::array(AlgebraicType::U8); + let value = AlgebraicValue::ArrayOf(vec![3u8]); let typespace = Typespace::new(vec![]); assert_eq!(in_space(&typespace, &array, &value).to_satn(), "[3]"); } #[test] fn map() { - let map = AlgebraicType::Builtin(BuiltinType::Map(MapType { - key_ty: Box::new(AlgebraicType::Builtin(BuiltinType::U8)), - ty: Box::new(AlgebraicType::Builtin(BuiltinType::U8)), - })); - let value = AlgebraicValue::Builtin(BuiltinValue::Map { val: BTreeMap::new() }); + let map = AlgebraicType::map(AlgebraicType::U8, AlgebraicType::U8); + let value = AlgebraicValue::map(BTreeMap::new()); let typespace = Typespace::new(vec![]); assert_eq!(in_space(&typespace, &map, &value).to_satn(), "[:]"); } #[test] fn map_of_values() { - let map = AlgebraicType::Builtin(BuiltinType::Map(MapType { - key_ty: Box::new(AlgebraicType::Builtin(BuiltinType::U8)), - ty: Box::new(AlgebraicType::Builtin(BuiltinType::U8)), - })); - let mut value = BTreeMap::::new(); - value.insert( - AlgebraicValue::Builtin(BuiltinValue::U8(2)), - AlgebraicValue::Builtin(BuiltinValue::U8(3)), - ); - let value = AlgebraicValue::Builtin(BuiltinValue::Map { val: value }); + let map = AlgebraicType::map(AlgebraicType::U8, AlgebraicType::U8); + let mut val = BTreeMap::::new(); + val.insert(AlgebraicValue::U8(2), AlgebraicValue::U8(3)); + let value = AlgebraicValue::map(val); let typespace = Typespace::new(vec![]); assert_eq!(in_space(&typespace, &map, &value).to_satn(), "[2: 3]"); } diff --git a/crates/sats/src/algebraic_value/de.rs b/crates/sats/src/algebraic_value/de.rs index 52736ddbea..c99a1b40f1 100644 --- a/crates/sats/src/algebraic_value/de.rs +++ b/crates/sats/src/algebraic_value/de.rs @@ -1,113 +1,129 @@ use crate::builtin_value::{ArrayValueIntoIter, ArrayValueIterCloned}; use crate::{de, AlgebraicValue, SumValue}; +/// An implementation of [`Deserializer`](de::Deserializer) +/// where the input of deserialization is an `AlgebraicValue`. #[repr(transparent)] pub struct ValueDeserializer { + /// The value to deserialize to some `T`. val: AlgebraicValue, } impl ValueDeserializer { + /// Returns a `ValueDeserializer` with `val` as the input for deserialization. pub fn new(val: AlgebraicValue) -> Self { Self { val } } + + /// Converts `&AlgebraicValue` to `&ValueDeserialize`. pub fn from_ref(val: &AlgebraicValue) -> &Self { + // SAFETY: The conversion is OK due to `repr(transparent)`. unsafe { &*(val as *const AlgebraicValue as *const ValueDeserializer) } } } + impl From for ValueDeserializer { fn from(val: AlgebraicValue) -> Self { Self { val } } } +/// Errors that can occur when deserializing the `AlgebraicValue`. #[derive(Debug)] pub enum ValueDeserializeError { + /// The input type does not match the target type. MismatchedType, + /// An unstructured error message. Custom(String), } + impl de::Error for ValueDeserializeError { fn custom(msg: impl std::fmt::Display) -> Self { Self::Custom(msg.to_string()) } } +/// Turns any error into `ValueDeserializeError::MismatchedType`. +fn map_err(res: Result) -> Result { + res.map_err(|_| ValueDeserializeError::MismatchedType) +} + +/// Turns any option into `ValueDeserializeError::MismatchedType`. +fn ok_or(res: Option) -> Result { + res.ok_or(ValueDeserializeError::MismatchedType) +} + impl<'de> de::Deserializer<'de> for ValueDeserializer { type Error = ValueDeserializeError; fn deserialize_product>(self, visitor: V) -> Result { - let prod = self - .val - .into_product() - .map_err(|_| ValueDeserializeError::MismatchedType)?; - let vals = prod.elements.into_iter(); + let vals = map_err(self.val.into_product())?.elements.into_iter(); visitor.visit_seq_product(ProductAccess { vals }) } fn deserialize_sum>(self, visitor: V) -> Result { - let sum = self.val.into_sum().map_err(|_| ValueDeserializeError::MismatchedType)?; + let sum = map_err(self.val.into_sum())?; visitor.visit_sum(SumAccess { sum }) } fn deserialize_bool(self) -> Result { - self.val.into_bool().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_bool()) } + fn deserialize_u8(self) -> Result { - self.val.into_u8().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_u8()) } + fn deserialize_u16(self) -> Result { - self.val.into_u16().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_u16()) } + fn deserialize_u32(self) -> Result { - self.val.into_u32().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_u32()) } + fn deserialize_u64(self) -> Result { - self.val.into_u64().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_u64()) } + fn deserialize_u128(self) -> Result { - self.val.into_u128().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_u128()) } + fn deserialize_i8(self) -> Result { - self.val.into_i8().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_i8()) } + fn deserialize_i16(self) -> Result { - self.val.into_i16().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_i16()) } + fn deserialize_i32(self) -> Result { - self.val.into_i32().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_i32()) } + fn deserialize_i64(self) -> Result { - self.val.into_i64().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_i64()) } + fn deserialize_i128(self) -> Result { - self.val.into_i128().map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_i128()) } + fn deserialize_f32(self) -> Result { - self.val - .into_f32() - .map(f32::from) - .map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_f32().map(f32::from)) } + fn deserialize_f64(self) -> Result { - self.val - .into_f64() - .map(f64::from) - .map_err(|_| ValueDeserializeError::MismatchedType) + map_err(self.val.into_f64().map(f64::from)) } fn deserialize_str>(self, visitor: V) -> Result { - let s = self - .val - .into_string() - .map_err(|_| ValueDeserializeError::MismatchedType)?; - visitor.visit_owned(s) + visitor.visit_owned(map_err(self.val.into_string())?) } fn deserialize_bytes>(self, visitor: V) -> Result { - let b = self - .val - .into_bytes() - .map_err(|_| ValueDeserializeError::MismatchedType)?; - visitor.visit_owned(b) + visitor.visit_owned(map_err(self.val.into_bytes())?) } fn deserialize_array_seed, T: de::DeserializeSeed<'de> + Clone>( @@ -115,11 +131,7 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer { visitor: V, seed: T, ) -> Result { - let iter = self - .val - .into_array() - .map_err(|_| ValueDeserializeError::MismatchedType)? - .into_iter(); + let iter = map_err(self.val.into_array())?.into_iter(); visitor.visit(ArrayAccess { iter, seed }) } @@ -133,16 +145,14 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer { kseed: K, vseed: V, ) -> Result { - let iter = self - .val - .into_map() - .map_err(|_| ValueDeserializeError::MismatchedType)? - .into_iter(); + let iter = map_err(self.val.into_map())?.into_iter(); visitor.visit(MapAccess { iter, kseed, vseed }) } } +/// Defines deserialization for [`ValueDeserializer`] where product elements are in the input. struct ProductAccess { + /// The element values of the product as an iterator of owned values. vals: std::vec::IntoIter, } @@ -157,17 +167,22 @@ impl<'de> de::SeqProductAccess<'de> for ProductAccess { } } +/// Defines deserialization for [`ValueDeserializer`] where a sum value is in the input. #[repr(transparent)] struct SumAccess { + /// The input sum value to deserialize. sum: SumValue, } + impl SumAccess { + /// Converts `&SumValue` to `&SumAccess`. fn from_ref(sum: &SumValue) -> &Self { + // SAFETY: `repr(transparent)` allows this. unsafe { &*(sum as *const SumValue as *const SumAccess) } } } -impl<'de> de::SumAccess<'de> for SumAccess { +impl de::SumAccess<'_> for SumAccess { type Error = ValueDeserializeError; type Variant = ValueDeserializer; @@ -187,8 +202,12 @@ impl<'de> de::VariantAccess<'de> for ValueDeserializer { } } +/// Defines deserialization for [`ValueDeserializer`] where an array value is in the input. struct ArrayAccess { + /// The elements of the array as an iterator of owned elements. iter: ArrayValueIntoIter, + /// A seed value provided by the caller of + /// [`deserialize_array_seed`](de::Deserializer::deserialize_array_seed). seed: T, } @@ -204,9 +223,15 @@ impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for ArrayAcc } } +/// Defines deserialization for [`ValueDeserializer`] where a map value is in the input. struct MapAccess { + /// The elements of the map as an iterator of owned key/value entries. iter: std::collections::btree_map::IntoIter, + /// A key seed value provided by the caller of + /// [`deserialize_map_seed`](de::Deserializer::deserialize_map_seed). kseed: K, + /// A value seed value provided by the caller of + /// [`deserialize_map_seed`](de::Deserializer::deserialize_map_seed). vseed: V, } @@ -234,72 +259,61 @@ impl<'de> de::Deserializer<'de> for &'de ValueDeserializer { type Error = ValueDeserializeError; fn deserialize_product>(self, visitor: V) -> Result { - let prod = self.val.as_product().ok_or(ValueDeserializeError::MismatchedType)?; - let vals = prod.elements.iter(); + let vals = ok_or(self.val.as_product())?.elements.iter(); visitor.visit_seq_product(RefProductAccess { vals }) } fn deserialize_sum>(self, visitor: V) -> Result { - let sum = self.val.as_sum().ok_or(ValueDeserializeError::MismatchedType)?; + let sum = ok_or(self.val.as_sum())?; visitor.visit_sum(SumAccess::from_ref(sum)) } fn deserialize_bool(self) -> Result { - self.val.as_bool().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_bool().copied()) } fn deserialize_u8(self) -> Result { - self.val.as_u8().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_u8().copied()) } fn deserialize_u16(self) -> Result { - self.val.as_u16().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_u16().copied()) } fn deserialize_u32(self) -> Result { - self.val.as_u32().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_u32().copied()) } fn deserialize_u64(self) -> Result { - self.val.as_u64().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_u64().copied()) } fn deserialize_u128(self) -> Result { - self.val.as_u128().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_u128().copied()) } fn deserialize_i8(self) -> Result { - self.val.as_i8().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_i8().copied()) } fn deserialize_i16(self) -> Result { - self.val.as_i16().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_i16().copied()) } fn deserialize_i32(self) -> Result { - self.val.as_i32().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_i32().copied()) } fn deserialize_i64(self) -> Result { - self.val.as_i64().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_i64().copied()) } fn deserialize_i128(self) -> Result { - self.val.as_i128().copied().ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_i128().copied()) } fn deserialize_f32(self) -> Result { - self.val - .as_f32() - .copied() - .map(f32::from) - .ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_f32().copied().map(f32::from)) } fn deserialize_f64(self) -> Result { - self.val - .as_f64() - .copied() - .map(f64::from) - .ok_or(ValueDeserializeError::MismatchedType) + ok_or(self.val.as_f64().copied().map(f64::from)) } fn deserialize_str>(self, visitor: V) -> Result { - let s = self.val.as_string().ok_or(ValueDeserializeError::MismatchedType)?; - visitor.visit_borrowed(s) + visitor.visit_borrowed(ok_or(self.val.as_string())?) } fn deserialize_bytes>(self, visitor: V) -> Result { - let b = self.val.as_bytes().ok_or(ValueDeserializeError::MismatchedType)?; - visitor.visit_borrowed(b) + visitor.visit_borrowed(ok_or(self.val.as_bytes())?) } fn deserialize_array_seed, T: de::DeserializeSeed<'de> + Clone>( @@ -307,11 +321,7 @@ impl<'de> de::Deserializer<'de> for &'de ValueDeserializer { visitor: V, seed: T, ) -> Result { - let iter = self - .val - .as_array() - .ok_or(ValueDeserializeError::MismatchedType)? - .iter_cloned(); + let iter = ok_or(self.val.as_array())?.iter_cloned(); visitor.visit(RefArrayAccess { iter, seed }) } @@ -325,12 +335,14 @@ impl<'de> de::Deserializer<'de> for &'de ValueDeserializer { kseed: K, vseed: V, ) -> Result { - let iter = self.val.as_map().ok_or(ValueDeserializeError::MismatchedType)?.iter(); + let iter = ok_or(self.val.as_map())?.iter(); visitor.visit(RefMapAccess { iter, kseed, vseed }) } } +/// Defines deserialization for [`&'de ValueDeserializer`] where product elements are in the input. struct RefProductAccess<'a> { + /// The element values of the product as an iterator of borrowed values. vals: std::slice::Iter<'a, AlgebraicValue>, } @@ -364,9 +376,13 @@ impl<'de> de::VariantAccess<'de> for &'de ValueDeserializer { } } +/// Defines deserialization for [`&'de ValueDeserializer`] where an array value is in the input. struct RefArrayAccess<'a, T> { // TODO: idk this kinda sucks + /// The elements of the array as an iterator of cloned elements. iter: ArrayValueIterCloned<'a>, + /// A seed value provided by the caller of + /// [`deserialize_array_seed`](de::Deserializer::deserialize_array_seed). seed: T, } @@ -382,9 +398,15 @@ impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for RefArray } } +/// Defines deserialization for [`&'de ValueDeserializer`] where an map value is in the input. struct RefMapAccess<'a, K, V> { + /// The elements of the map as an iterator of borrowed key/value entries. iter: std::collections::btree_map::Iter<'a, AlgebraicValue, AlgebraicValue>, + /// A key seed value provided by the caller of + /// [`deserialize_map_seed`](de::Deserializer::deserialize_map_seed). kseed: K, + /// A value seed value provided by the caller of + /// [`deserialize_map_seed`](de::Deserializer::deserialize_map_seed). vseed: V, } diff --git a/crates/sats/src/algebraic_value/ser.rs b/crates/sats/src/algebraic_value/ser.rs index 5faa24c7d5..645d4eaf01 100644 --- a/crates/sats/src/algebraic_value/ser.rs +++ b/crates/sats/src/algebraic_value/ser.rs @@ -2,10 +2,20 @@ use std::convert::Infallible; use super::AlgebraicValue; use crate::ser::{self, ForwardNamedToSeqProduct}; -use crate::{ArrayValue, BuiltinValue, ProductValue, SumValue}; +use crate::ArrayValue; +/// An implementation of [`Serializer`](ser::Serializer) +/// where the output of serialization is an `AlgebraicValue`. pub struct ValueSerializer; +macro_rules! method { + ($name:ident -> $t:ty) => { + fn $name(self, v: $t) -> Result { + Ok(v.into()) + } + }; +} + impl ser::Serializer for ValueSerializer { type Ok = AlgebraicValue; type Error = Infallible; @@ -15,45 +25,20 @@ impl ser::Serializer for ValueSerializer { type SerializeSeqProduct = SerializeProductValue; type SerializeNamedProduct = ForwardNamedToSeqProduct; - fn serialize_bool(self, v: bool) -> Result { - Ok(AlgebraicValue::Bool(v)) - } - fn serialize_u8(self, v: u8) -> Result { - Ok(AlgebraicValue::U8(v)) - } - fn serialize_u16(self, v: u16) -> Result { - Ok(AlgebraicValue::U16(v)) - } - fn serialize_u32(self, v: u32) -> Result { - Ok(AlgebraicValue::U32(v)) - } - fn serialize_u64(self, v: u64) -> Result { - Ok(AlgebraicValue::U64(v)) - } - fn serialize_u128(self, v: u128) -> Result { - Ok(AlgebraicValue::U128(v)) - } - fn serialize_i8(self, v: i8) -> Result { - Ok(AlgebraicValue::I8(v)) - } - fn serialize_i16(self, v: i16) -> Result { - Ok(AlgebraicValue::I16(v)) - } - fn serialize_i32(self, v: i32) -> Result { - Ok(AlgebraicValue::I32(v)) - } - fn serialize_i64(self, v: i64) -> Result { - Ok(AlgebraicValue::I64(v)) - } - fn serialize_i128(self, v: i128) -> Result { - Ok(AlgebraicValue::I128(v)) - } - fn serialize_f32(self, v: f32) -> Result { - Ok(AlgebraicValue::F32(v.into())) - } - fn serialize_f64(self, v: f64) -> Result { - Ok(AlgebraicValue::F64(v.into())) - } + method!(serialize_bool -> bool); + method!(serialize_u8 -> u8); + method!(serialize_u16 -> u16); + method!(serialize_u32 -> u32); + method!(serialize_u64 -> u64); + method!(serialize_u128 -> u128); + method!(serialize_i8 -> i8); + method!(serialize_i16 -> i16); + method!(serialize_i32 -> i32); + method!(serialize_i64 -> i64); + method!(serialize_i128 -> i128); + method!(serialize_f32 -> f32); + method!(serialize_f64 -> f64); + fn serialize_str(self, v: &str) -> Result { Ok(AlgebraicValue::String(v.to_owned())) } @@ -61,19 +46,22 @@ impl ser::Serializer for ValueSerializer { Ok(AlgebraicValue::Bytes(v.to_owned())) } - fn serialize_array(self, _len: usize) -> Result { - Ok(SerializeArrayValue { v: Default::default() }) + fn serialize_array(self, len: usize) -> Result { + Ok(SerializeArrayValue { + len: Some(len), + array: Default::default(), + }) } fn serialize_map(self, len: usize) -> Result { Ok(SerializeMapValue { - v: Vec::with_capacity(len), + entries: Vec::with_capacity(len), }) } fn serialize_seq_product(self, len: usize) -> Result { Ok(SerializeProductValue { - v: Vec::with_capacity(len), + elements: Vec::with_capacity(len), }) } @@ -87,13 +75,17 @@ impl ser::Serializer for ValueSerializer { _name: Option<&str>, value: &T, ) -> Result { - let value = Box::new(value.serialize(self)?); - Ok(AlgebraicValue::Sum(SumValue { tag, value })) + value.serialize(self).map(|v| AlgebraicValue::sum(tag, v)) } } +/// Continuation for serializing an array. pub struct SerializeArrayValue { - v: ArrayValue, + /// For efficiency, the first time `serialize_element` is done, + /// this is used to allocate with capacity. + len: Option, + /// The array being built. + array: ArrayValue, } impl ser::SerializeArray for SerializeArrayValue { @@ -101,19 +93,21 @@ impl ser::SerializeArray for SerializeArrayValue { type Error = Infallible; fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { - // TODO: this can be more efficient - self.v - .push(elem.serialize(ValueSerializer)?) + self.array + .push(elem.serialize(ValueSerializer)?, self.len.take()) .expect("heterogeneous array"); Ok(()) } + fn end(self) -> Result { - Ok(AlgebraicValue::Builtin(BuiltinValue::Array { val: self.v })) + Ok(AlgebraicValue::ArrayOf(self.array)) } } +/// Continuation for serializing a map value. pub struct SerializeMapValue { - v: Vec<(AlgebraicValue, AlgebraicValue)>, + /// The entry pairs to collect and convert into a map. + entries: Vec<(AlgebraicValue, AlgebraicValue)>, } impl ser::SerializeMap for SerializeMapValue { @@ -125,20 +119,20 @@ impl ser::SerializeMap for SerializeMapValue { key: &K, value: &V, ) -> Result<(), Self::Error> { - self.v + self.entries .push((key.serialize(ValueSerializer)?, value.serialize(ValueSerializer)?)); Ok(()) } fn end(self) -> Result { - Ok(AlgebraicValue::Builtin(BuiltinValue::Map { - val: self.v.into_iter().collect(), - })) + Ok(AlgebraicValue::map(self.entries.into_iter().collect())) } } +/// Continuation for serializing a map value. pub struct SerializeProductValue { - v: Vec, + /// The elements serialized so far. + elements: Vec, } impl ser::SerializeSeqProduct for SerializeProductValue { @@ -146,10 +140,10 @@ impl ser::SerializeSeqProduct for SerializeProductValue { type Error = Infallible; fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { - self.v.push(elem.serialize(ValueSerializer)?); + self.elements.push(elem.serialize(ValueSerializer)?); Ok(()) } fn end(self) -> Result { - Ok(AlgebraicValue::Product(ProductValue { elements: self.v })) + Ok(AlgebraicValue::product(self.elements)) } } diff --git a/crates/sats/src/bsatn.rs b/crates/sats/src/bsatn.rs index 8262656824..7cfa9fcbb9 100644 --- a/crates/sats/src/bsatn.rs +++ b/crates/sats/src/bsatn.rs @@ -11,26 +11,28 @@ pub use ser::Serializer; pub use crate::buffer::DecodeError; +/// Serialize `value` into the buffered writer `w` in the BSATN format. pub fn to_writer(w: &mut W, value: &T) -> Result<(), ser::BsatnError> { value.serialize(Serializer::new(w)) } +/// Serialize `value` into a `Vec` in the BSATN format. pub fn to_vec(value: &T) -> Result, ser::BsatnError> { let mut v = Vec::new(); to_writer(&mut v, value)?; Ok(v) } -pub fn from_reader<'de, R: BufReader<'de>, T: Deserialize<'de>>(r: &mut R) -> Result { - T::deserialize(Deserializer::new(r)) +/// Deserialize a `T` from the BSATM format in the buffered `reader`. +pub fn from_reader<'de, T: Deserialize<'de>>(reader: &mut impl BufReader<'de>) -> Result { + T::deserialize(Deserializer::new(reader)) } -pub fn from_slice<'de, T: Deserialize<'de>>(b: &'de [u8]) -> Result { - from_reader(&mut &b[..]) +/// Deserialize a `T` from the BSATM format in `bytes`. +pub fn from_slice<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result { + from_reader(&mut &*bytes) } -static EMPTY_TYPESPACE: Typespace = Typespace::new(Vec::new()); - macro_rules! codec_funcs { ($ty:ty) => { impl $ty { @@ -49,7 +51,8 @@ macro_rules! codec_funcs { algebraic_type: &::Type, bytes: &mut impl BufReader<'a>, ) -> Result { - crate::TypeInSpace::new(&EMPTY_TYPESPACE, algebraic_type).deserialize(Deserializer::new(bytes)) + crate::WithTypespace::new(&Typespace::new(Vec::new()), algebraic_type) + .deserialize(Deserializer::new(bytes)) } pub fn encode(&self, bytes: &mut impl BufWriter) { diff --git a/crates/sats/src/bsatn/de.rs b/crates/sats/src/bsatn/de.rs index 2724870454..e8c2a2de88 100644 --- a/crates/sats/src/bsatn/de.rs +++ b/crates/sats/src/bsatn/de.rs @@ -2,15 +2,19 @@ use crate::buffer::{BufReader, DecodeError}; use crate::de::{self, SeqProductAccess, SumAccess, VariantAccess}; +/// Deserializer from the BSATN data format. pub struct Deserializer<'a, R> { + // The input to deserialize. reader: &'a mut R, } impl<'a, 'de, R: BufReader<'de>> Deserializer<'a, R> { + /// Returns a deserializer using the given `reader`. pub fn new(reader: &'a mut R) -> Self { Self { reader } } + /// Reborrows the deserializer. #[inline] fn reborrow(&mut self) -> Deserializer<'_, R> { Deserializer { reader: self.reader } @@ -27,10 +31,17 @@ impl de::Error for DecodeError { } } +/// Read a length as a `u32` then converted to `usize`. fn get_len<'de>(reader: &mut impl BufReader<'de>) -> Result { Ok(reader.get_u32()? as usize) } +/// Read a byte slice from the `reader`. +fn read_bytes<'a, 'de: 'a>(reader: &'a mut impl BufReader<'de>) -> Result<&'de [u8], DecodeError> { + let len = get_len(reader)?; + reader.get_slice(len) +} + impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> { type Error = DecodeError; @@ -83,15 +94,13 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> { } fn deserialize_str>(self, visitor: V) -> Result { - let len = get_len(self.reader)?; - let slice = self.reader.get_slice(len)?; + let slice = read_bytes(self.reader)?; let slice = core::str::from_utf8(slice)?; visitor.visit_borrowed(slice) } fn deserialize_bytes>(self, visitor: V) -> Result { - let len = get_len(self.reader)?; - let slice = self.reader.get_slice(len)?; + let slice = read_bytes(self.reader)?; visitor.visit_borrowed(slice) } @@ -146,6 +155,7 @@ impl<'de, 'a, R: BufReader<'de>> VariantAccess<'de> for Deserializer<'a, R> { } } +/// Deserializer for array elements. pub struct ArrayAccess<'a, R, T> { de: Deserializer<'a, R>, seeds: itertools::RepeatN, @@ -167,6 +177,7 @@ impl<'de, 'a, R: BufReader<'de>, T: de::DeserializeSeed<'de> + Clone> de::ArrayA } } +/// Deserializer for map elements. pub struct MapAccess<'a, R, K, V> { de: Deserializer<'a, R>, seeds: itertools::RepeatN<(K, V)>, diff --git a/crates/sats/src/bsatn/ser.rs b/crates/sats/src/bsatn/ser.rs index 0d77c466d4..ebdd9a1bcf 100644 --- a/crates/sats/src/bsatn/ser.rs +++ b/crates/sats/src/bsatn/ser.rs @@ -4,31 +4,38 @@ use crate::buffer::BufWriter; use crate::ser::{self, Error, ForwardNamedToSeqProduct, Serialize, SerializeArray, SerializeMap, SerializeSeqProduct}; +/// Defines the BSATN serialization data format. pub struct Serializer<'a, W> { writer: &'a mut W, } impl<'a, W> Serializer<'a, W> { + /// Returns a serializer using the given `writer`. pub fn new(writer: &'a mut W) -> Self { Self { writer } } + /// Reborrows the serializer. #[inline] fn reborrow(&mut self) -> Serializer<'_, W> { Serializer { writer: self.writer } } } +/// An error during BSATN serialization. #[derive(Debug)] pub struct BsatnError { + /// The error message for the BSATN error. custom: String, } + impl fmt::Display for BsatnError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(&self.custom) } } impl std::error::Error for BsatnError {} + impl Error for BsatnError { fn custom(msg: T) -> Self { let custom = msg.to_string(); @@ -36,6 +43,9 @@ impl Error for BsatnError { } } +/// Writes `len` converted to a `u32` to `writer`. +/// +/// Errors if `len` would not fit in a `u32`. fn put_len(writer: &mut impl BufWriter, len: usize) -> Result<(), BsatnError> { let len = len.try_into().map_err(|_| BsatnError::custom("len too long"))?; writer.put_u32(len); @@ -106,22 +116,23 @@ impl ser::Serializer for Serializer<'_, W> { self.serialize_bytes(v.as_bytes()) } fn serialize_bytes(self, v: &[u8]) -> Result { - put_len(self.writer, v.len())?; + put_len(self.writer, v.len())?; // N.B. `v.len() > u32::MAX` isn't allowed. self.writer.put_slice(v); Ok(()) } fn serialize_array(self, len: usize) -> Result { - put_len(self.writer, len)?; + put_len(self.writer, len)?; // N.B. `len > u32::MAX` isn't allowed. Ok(self) } fn serialize_map(self, len: usize) -> Result { - put_len(self.writer, len)?; + put_len(self.writer, len)?; // N.B. `len > u32::MAX` isn't allowed. Ok(self) } fn serialize_seq_product(self, _len: usize) -> Result { Ok(self) } fn serialize_named_product(self, len: usize) -> Result { + // Serialize named like unnamed. self.serialize_seq_product(len).map(ForwardNamedToSeqProduct::new) } fn serialize_variant( diff --git a/crates/sats/src/buffer.rs b/crates/sats/src/buffer.rs index 1c8ef22c02..345cdab138 100644 --- a/crates/sats/src/buffer.rs +++ b/crates/sats/src/buffer.rs @@ -1,16 +1,21 @@ +//! Minimal utility for reading & writing the types we need to internal storage, +//! without relying on types in third party libraries like `bytes::Bytes`, etc. +//! Meant to be kept slim and trim for use across both native and WASM. + use std::cell::Cell; use std::fmt; use std::str::Utf8Error; -/// Minimal utility for reading & writing the types we need to internal storage, without relying -/// on third party libraries like bytes::Bytes, etc. -/// Meant to be kept slim and trim for use across both native and wasm. - +/// An error that occurred when decoding. #[derive(Debug, Clone)] pub enum DecodeError { + /// Not enough data was provided in the input. BufferLength, + /// The tag does not exist for the sum. InvalidTag, + /// Expected data to be UTF-8 but it wasn't. InvalidUtf8, + /// Custom error not in the other variants of `DecodeError`. Other(String), } @@ -18,7 +23,7 @@ impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { DecodeError::BufferLength => f.write_str("data too short"), - DecodeError::InvalidTag => f.write_str("invalid tag for enum"), + DecodeError::InvalidTag => f.write_str("invalid tag for sum"), DecodeError::InvalidUtf8 => f.write_str("invalid utf8"), DecodeError::Other(err) => f.write_str(err), } @@ -37,86 +42,161 @@ impl From for DecodeError { } } +/// A buffered writer of some kind. pub trait BufWriter { + /// Writes the `slice` to the buffer. + /// + /// This is the only method implementations are required to define. + /// All other methods are provided. fn put_slice(&mut self, slice: &[u8]); + + /// Writes a `u8` to the buffer in little-endian (LE) encoding. fn put_u8(&mut self, val: u8) { self.put_slice(&val.to_le_bytes()) } + + /// Writes a `u16` to the buffer in little-endian (LE) encoding. fn put_u16(&mut self, val: u16) { self.put_slice(&val.to_le_bytes()) } + + /// Writes a `u32` to the buffer in little-endian (LE) encoding. fn put_u32(&mut self, val: u32) { self.put_slice(&val.to_le_bytes()) } + + /// Writes a `u64` to the buffer in little-endian (LE) encoding. fn put_u64(&mut self, val: u64) { self.put_slice(&val.to_le_bytes()) } + + /// Writes a `u128` to the buffer in little-endian (LE) encoding. fn put_u128(&mut self, val: u128) { self.put_slice(&val.to_le_bytes()) } + + /// Writes an `i8` to the buffer in little-endian (LE) encoding. fn put_i8(&mut self, val: i8) { self.put_slice(&val.to_le_bytes()) } + + /// Writes an `i16` to the buffer in little-endian (LE) encoding. fn put_i16(&mut self, val: i16) { self.put_slice(&val.to_le_bytes()) } + + /// Writes an `i32` to the buffer in little-endian (LE) encoding. fn put_i32(&mut self, val: i32) { self.put_slice(&val.to_le_bytes()) } + + /// Writes an `i64` to the buffer in little-endian (LE) encoding. fn put_i64(&mut self, val: i64) { self.put_slice(&val.to_le_bytes()) } + + /// Writes an `i128` to the buffer in little-endian (LE) encoding. fn put_i128(&mut self, val: i128) { self.put_slice(&val.to_le_bytes()) } } +/// A buffered reader of some kind. +/// +/// The lifetime `'de` allows the output of deserialization to borrow from the input. pub trait BufReader<'de> { + /// Reads and returns a byte slice of `.len() = size` advancing the cursor. fn get_slice(&mut self, size: usize) -> Result<&'de [u8], DecodeError>; + + /// Returns the number of bytes left to read in the input. fn remaining(&self) -> usize; + /// Reads a `u8` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_u8(&mut self) -> Result { self.get_array().map(u8::from_le_bytes) } + + /// Reads a `u16` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_u16(&mut self) -> Result { self.get_array().map(u16::from_le_bytes) } + + /// Reads a `u32` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_u32(&mut self) -> Result { self.get_array().map(u32::from_le_bytes) } + + /// Reads a `u64` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_u64(&mut self) -> Result { self.get_array().map(u64::from_le_bytes) } + + /// Reads a `u128` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_u128(&mut self) -> Result { self.get_array().map(u128::from_le_bytes) } + + /// Reads an `i8` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_i8(&mut self) -> Result { self.get_array().map(i8::from_le_bytes) } + + /// Reads an `i16` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_i16(&mut self) -> Result { self.get_array().map(i16::from_le_bytes) } + + /// Reads an `i32` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_i32(&mut self) -> Result { self.get_array().map(i32::from_le_bytes) } + + /// Reads an `i64` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_i64(&mut self) -> Result { self.get_array().map(i64::from_le_bytes) } + + /// Reads an `i128` in little endian (LE) encoding from the input. + /// + /// This method is provided for convenience + /// and is derived from [`get_slice`](BufReader::get_slice)'s definition. fn get_i128(&mut self) -> Result { self.get_array().map(i128::from_le_bytes) } + /// Reads an array of type `[u8; C]` from the input. fn get_array(&mut self) -> Result<[u8; C], DecodeError> { let mut buf: [u8; C] = [0; C]; - self.get_into_array(&mut buf, C)?; + buf.copy_from_slice(self.get_slice(C)?); Ok(buf) } - - fn get_into_array(&mut self, buf: &mut [u8; C], amount: usize) -> Result<(), DecodeError> { - let bytes = self.get_slice(amount)?; - buf.copy_from_slice(bytes); - Ok(()) - } } impl BufWriter for Vec { @@ -151,21 +231,33 @@ impl<'de> BufReader<'de> for &'de [u8] { } } -pub struct Cursor { - pub buf: B, +/// A cursor based [`BufReader<'de>`] implementation. +pub struct Cursor { + /// The underlying input read from. + pub buf: I, + /// The position within the reader. pub pos: Cell, } -impl Cursor { - pub fn new(buf: B) -> Self { + +impl Cursor { + /// Returns a new cursor on the provided `buf` input. + /// + /// The cursor starts at the beginning of `buf`. + pub fn new(buf: I) -> Self { Self { buf, pos: 0.into() } } } -impl<'de, B: AsRef<[u8]>> BufReader<'de> for &'de Cursor { + +impl<'de, I: AsRef<[u8]>> BufReader<'de> for &'de Cursor { fn get_slice(&mut self, size: usize) -> Result<&'de [u8], DecodeError> { + // "Read" the slice `buf[pos..size]`. let ret = self.buf.as_ref()[self.pos.get()..] .get(..size) .ok_or(DecodeError::BufferLength)?; + + // Advance the cursor by `size` bytes. self.pos.set(self.pos.get() + size); + Ok(ret) } diff --git a/crates/sats/src/builtin_type.rs b/crates/sats/src/builtin_type.rs index 8be106ab12..cf42f7b17f 100644 --- a/crates/sats/src/builtin_type.rs +++ b/crates/sats/src/builtin_type.rs @@ -1,57 +1,82 @@ -pub mod satn; - use crate::algebraic_value::de::{ValueDeserializeError, ValueDeserializer}; use crate::algebraic_value::ser::ValueSerializer; +use crate::meta_type::MetaType; use crate::{de::Deserialize, ser::Serialize}; use crate::{ - AlgebraicType, AlgebraicTypeRef, AlgebraicValue, ProductType, ProductTypeElement, SumType, SumTypeVariant, + impl_deserialize, impl_serialize, AlgebraicType, AlgebraicTypeRef, AlgebraicValue, ProductType, ProductTypeElement, + SumTypeVariant, }; use enum_as_inner::EnumAsInner; +/// Represents the built-in types in SATS. +/// +/// Some of these types are nominal in our otherwise structural type system. #[derive(EnumAsInner, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub enum BuiltinType { + /// The bool type. Values `BuiltinValue::Bool(b)` will have this type. Bool, + /// The `I8` type. Values `BuiltinValue::I8(v)` will have this type. I8, + /// The `U8` type. Values `BuiltinValue::U8(v)` will have this type. U8, + /// The `I16` type. Values `BuiltinValue::I16(v)` will have this type. I16, + /// The `U16` type. Values `BuiltinValue::U16(v)` will have this type. U16, + /// The `I32` type. Values `BuiltinValue::I32(v)` will have this type. I32, + /// The `U32` type. Values `BuiltinValue::U32(v)` will have this type. U32, + /// The `I64` type. Values `BuiltinValue::I64(v)` will have this type. I64, + /// The `U64` type. Values `BuiltinValue::U64(v)` will have this type. U64, + /// The `I128` type. Values `BuiltinValue::I128(v)` will have this type. I128, + /// The `U128` type. Values `BuiltinValue::U128(v)` will have this type. U128, + /// The `F32` type. Values `BuiltinValue::F32(v)` will have this type. F32, + /// The `F64` type. Values `BuiltinValue::F64(v)` will have this type. F64, - String, // Keep this because it is easy to just use Rust's String (utf-8) + /// The UTF-8 encoded `String` type. + /// Values `BuiltinValue::String(s)` will have this type. + String, // Keep this because it is easy to just use Rust's `String` (UTF-8). + /// The type of array values where elements are of a base type `elem_ty`. + /// Values `BuiltinValue::Array(array)` will have this type. Array(ArrayType), + /// The type of map values consisting of a key type `key_ty` and value `ty`. + /// Values `BuiltinValue::Map(map)` will have this type. Map(MapType), } +/// An array type is a homegeneous product type of dynamic length. +/// +/// That is, it is a product type +/// where every element / factor / field is of the same type +/// and where the length is statically unknown. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct ArrayType { + /// The base type every element of the array has. pub elem_ty: Box, } -impl Serialize for ArrayType { - fn serialize(&self, serializer: S) -> Result { - self.elem_ty.serialize(serializer) - } -} -impl<'de> Deserialize<'de> for ArrayType { - fn deserialize>(deserializer: D) -> Result { - Deserialize::deserialize(deserializer).map(|elem_ty| Self { elem_ty }) - } -} +impl_serialize!([] ArrayType, (self, ser) => self.elem_ty.serialize(ser)); +impl_deserialize!([] ArrayType, de => Deserialize::deserialize(de).map(|elem_ty| Self { elem_ty })); + +/// A map type from keys of type `key_ty` to values of type `ty`. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub struct MapType { + /// The key type of the map. pub key_ty: Box, + /// The value type of the map. pub ty: Box, } impl MapType { + /// Returns a map type with keys of type `key` and values of type `value`. pub fn new(key: AlgebraicType, value: AlgebraicType) -> Self { Self { key_ty: Box::new(key), @@ -60,37 +85,40 @@ impl MapType { } } -impl BuiltinType { - pub fn make_meta_type() -> AlgebraicType { - // TODO: sats(rename_all = "lowercase"), otherwise json won't work - AlgebraicType::Sum(SumType::new(vec![ - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "bool"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "i8"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "u8"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "i16"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "u16"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "i32"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "u32"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "i64"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "u64"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "i128"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "u128"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "f32"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "f64"), - SumTypeVariant::new_named(AlgebraicType::Product(ProductType { elements: Vec::new() }), "string"), - SumTypeVariant::new_named(AlgebraicType::Ref(AlgebraicTypeRef(0)), "array"), +impl MetaType for BuiltinType { + fn meta_type() -> AlgebraicType { + let product = |elements| AlgebraicType::Product(ProductType { elements }); + let unit = || product(Vec::new()); + let zero_ref = || AlgebraicType::Ref(AlgebraicTypeRef(0)); + // TODO: sats(rename_all = "lowercase"), otherwise json won't work. + AlgebraicType::sum(vec![ + SumTypeVariant::new_named(unit(), "bool"), + SumTypeVariant::new_named(unit(), "i8"), + SumTypeVariant::new_named(unit(), "u8"), + SumTypeVariant::new_named(unit(), "i16"), + SumTypeVariant::new_named(unit(), "u16"), + SumTypeVariant::new_named(unit(), "i32"), + SumTypeVariant::new_named(unit(), "u32"), + SumTypeVariant::new_named(unit(), "i64"), + SumTypeVariant::new_named(unit(), "u64"), + SumTypeVariant::new_named(unit(), "i128"), + SumTypeVariant::new_named(unit(), "u128"), + SumTypeVariant::new_named(unit(), "f32"), + SumTypeVariant::new_named(unit(), "f64"), + SumTypeVariant::new_named(unit(), "string"), + SumTypeVariant::new_named(zero_ref(), "array"), SumTypeVariant::new_named( - AlgebraicType::Product(ProductType { - elements: vec![ - ProductTypeElement::new_named(AlgebraicType::Ref(AlgebraicTypeRef(0)), "key_ty"), - ProductTypeElement::new_named(AlgebraicType::Ref(AlgebraicTypeRef(0)), "ty"), - ], - }), + product(vec![ + ProductTypeElement::new_named(zero_ref(), "key_ty"), + ProductTypeElement::new_named(zero_ref(), "ty"), + ]), "map", ), - ])) + ]) } +} +impl BuiltinType { pub fn as_value(&self) -> AlgebraicValue { self.serialize(ValueSerializer).unwrap_or_else(|x| match x {}) } diff --git a/crates/sats/src/builtin_type/satn.rs b/crates/sats/src/builtin_type/satn.rs deleted file mode 100644 index 3790c980f8..0000000000 --- a/crates/sats/src/builtin_type/satn.rs +++ /dev/null @@ -1,43 +0,0 @@ -use super::BuiltinType; -use crate::{algebraic_type, ArrayType, MapType}; -use std::fmt::Display; - -pub struct Formatter<'a> { - ty: &'a BuiltinType, -} - -impl<'a> Formatter<'a> { - pub fn new(ty: &'a BuiltinType) -> Self { - Self { ty } - } -} - -impl<'a> Display for Formatter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.ty { - BuiltinType::Bool => write!(f, "Bool"), - BuiltinType::I8 => write!(f, "I8"), - BuiltinType::U8 => write!(f, "U8"), - BuiltinType::I16 => write!(f, "I16"), - BuiltinType::U16 => write!(f, "U16"), - BuiltinType::I32 => write!(f, "I32"), - BuiltinType::U32 => write!(f, "U32"), - BuiltinType::I64 => write!(f, "I64"), - BuiltinType::U64 => write!(f, "U64"), - BuiltinType::I128 => write!(f, "I128"), - BuiltinType::U128 => write!(f, "U128"), - BuiltinType::F32 => write!(f, "F32"), - BuiltinType::F64 => write!(f, "F64"), - BuiltinType::String => write!(f, "String"), - BuiltinType::Array(ArrayType { elem_ty }) => { - write!(f, "Array<{}>", algebraic_type::satn::Formatter::new(elem_ty)) - } - BuiltinType::Map(MapType { key_ty, ty }) => write!( - f, - "Map<{}, {}>", - algebraic_type::satn::Formatter::new(key_ty), - algebraic_type::satn::Formatter::new(ty) - ), - } - } -} diff --git a/crates/sats/src/builtin_value.rs b/crates/sats/src/builtin_value.rs index cc78a9b927..88cb637127 100644 --- a/crates/sats/src/builtin_value.rs +++ b/crates/sats/src/builtin_value.rs @@ -5,32 +5,59 @@ use enum_as_inner::EnumAsInner; use std::collections::BTreeMap; use std::fmt; -/// Totally ordered [f32] +/// Totally ordered [`f32`]. pub type F32 = decorum::Total; -/// Totally ordered [f64] +/// Totally ordered [`f64`]. pub type F64 = decorum::Total; +/// A built-in value of a [`BuiltinType`]. #[derive(EnumAsInner, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum BuiltinValue { + /// A [`bool`] value. Bool(bool), + /// An [`i8`] value. I8(i8), + /// A [`u8`] value. U8(u8), + /// An [`i16`] value. I16(i16), + /// A [`u16`] value. U16(u16), + /// An [`i32`] value. I32(i32), + /// A [`u32`] value. U32(u32), + /// An [`i64`] value. I64(i64), + /// A [`u64`] value. U64(u64), + /// An [`i128`] value. I128(i128), + /// A [`u128`] value. U128(u128), + /// A totally ordered [`F32`] value. F32(F32), + /// A totally ordered [`F64`] value. F64(F64), + /// A UTF-8 string value. + /// + /// Uses Rust's standard representation of strings. String(String), + /// A homogeneous array of `AlgebraicValue`s. + /// + /// The contained values are stored packed in a representation appropriate for their type. + /// See [`ArrayValue`] for details on the representation. Array { val: ArrayValue }, + /// An ordered map value of `key: AlgebraicValue`s mapped to `value: AlgebraicValue`s. + /// Each `key` must be of the same [`AlgebraicType`] as all the others + /// and the same applies to each `value`. + /// + /// Maps are implemented internally as `BTreeMap`. Map { val: MapValue }, } +/// A map value `AlgebraicValue` → `AlgebraicValue`. pub type MapValue = BTreeMap; impl crate::Value for MapValue { @@ -38,11 +65,13 @@ impl crate::Value for MapValue { } impl BuiltinValue { + /// Returns the byte string `v` as a [`BuiltinValue`]. #[allow(non_snake_case)] - pub fn Bytes(v: Vec) -> Self { - Self::Array { val: v.into() } + pub const fn Bytes(v: Vec) -> Self { + Self::Array { val: ArrayValue::U8(v) } } + /// Returns `self` as a borrowed byte string, if applicable. pub fn as_bytes(&self) -> Option<&Vec> { match self { BuiltinValue::Array { val: ArrayValue::U8(v) } => Some(v), @@ -50,6 +79,7 @@ impl BuiltinValue { } } + /// Converts `self` into a byte string, if applicable. pub fn into_bytes(self) -> Result, Self> { match self { BuiltinValue::Array { val: ArrayValue::U8(v) } => Ok(v), @@ -62,44 +92,62 @@ impl crate::Value for BuiltinValue { type Type = BuiltinType; } +/// An array value in "monomorphized form". +/// +/// Arrays are represented in this way monomorphized fashion for efficiency +/// rather than unnecessary indirections and tags of `Vec`. +/// We can do this as we know statically that the type of each element is the same +/// as arrays are homogenous dynamically sized product types. #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub enum ArrayValue { + /// An array of [`SumValue`](crate::SumValue)s. Sum(Vec), + /// An array of [`ProductValue`](crate::ProductValue)s. Product(Vec), + /// An array of [`bool`]s. Bool(Vec), + /// An array of [`i8`]s. I8(Vec), + /// An array of [`u8`]s. U8(Vec), + /// An array of [`i16`]s. I16(Vec), + /// An array of [`u16`]s. U16(Vec), + /// An array of [`i32`]s. I32(Vec), + /// An array of [`u32`]s. U32(Vec), + /// An array of [`i64`]s. I64(Vec), + /// An array of [`u64`]s. U64(Vec), + /// An array of [`i128`]s. I128(Vec), + /// An array of [`u128`]s. U128(Vec), + /// An array of totally ordered [`F32`]s. F32(Vec), + /// An array of totally ordered [`F64`]s. F64(Vec), + /// An array of UTF-8 strings. String(Vec), + /// An array of arrays. Array(Vec), + /// An array of maps. Map(Vec), } impl crate::Value for ArrayValue { - // element type type Type = ArrayType; } impl ArrayValue { + /// Determines (infers / synthesises) the type of the value. pub(crate) fn type_of(&self) -> ArrayType { - let elem_ty = match self { - ArrayValue::Sum(v) => v - .first() - .map(AlgebraicValue::type_of_sum) - .unwrap_or_else(AlgebraicType::make_never_type), - ArrayValue::Product(v) => v - .first() - .map(AlgebraicValue::type_of_product) - .unwrap_or_else(AlgebraicType::make_never_type), + let elem_ty = Box::new(match self { + ArrayValue::Sum(v) => Self::first_type_of(v, AlgebraicValue::type_of_sum), + ArrayValue::Product(v) => Self::first_type_of(v, AlgebraicValue::type_of_product), ArrayValue::Bool(_) => AlgebraicType::Bool, ArrayValue::I8(_) => AlgebraicType::I8, ArrayValue::U8(_) => AlgebraicType::U8, @@ -114,20 +162,19 @@ impl ArrayValue { ArrayValue::F32(_) => AlgebraicType::F32, ArrayValue::F64(_) => AlgebraicType::F64, ArrayValue::String(_) => AlgebraicType::String, - ArrayValue::Array(v) => v - .first() - .map(|a| AlgebraicType::Builtin(BuiltinType::Array(a.type_of()))) - .unwrap_or_else(AlgebraicType::make_never_type), - ArrayValue::Map(v) => v - .first() - .map(AlgebraicValue::type_of_map) - .unwrap_or_else(AlgebraicType::make_never_type), - }; - ArrayType { - elem_ty: Box::new(elem_ty), - } + ArrayValue::Array(v) => Self::first_type_of(v, |a| AlgebraicType::Builtin(BuiltinType::Array(a.type_of()))), + ArrayValue::Map(v) => Self::first_type_of(v, AlgebraicValue::type_of_map), + }); + ArrayType { elem_ty } + } + + /// Helper for `type_of` above. + /// Infers the `AlgebraicType` from the first element by running `then` on it. + fn first_type_of(arr: &[T], then: impl FnOnce(&T) -> AlgebraicType) -> AlgebraicType { + arr.first().map(then).unwrap_or_else(|| AlgebraicType::NEVER_TYPE) } + /// Returns the length of the array. pub fn len(&self) -> usize { match self { ArrayValue::Sum(v) => v.len(), @@ -151,35 +198,50 @@ impl ArrayValue { } } + /// Returns whether the array is empty. #[must_use] pub fn is_empty(&self) -> bool { self.len() == 0 } - fn from_one(val: AlgebraicValue) -> Self { + /// Returns a singleton array with `val` as its only element. + /// + /// Optionally allocates the backing `Vec<_>`s with `capacity`. + fn from_one_with_capacity(val: AlgebraicValue, capacity: Option) -> Self { + fn vec(e: T, c: Option) -> Vec { + let mut vec = c.map_or(Vec::new(), Vec::with_capacity); + vec.push(e); + vec + } + match val { - AlgebraicValue::Sum(x) => vec![x].into(), - AlgebraicValue::Product(x) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::Bool(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::I8(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::U8(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::I16(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::U16(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::I32(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::U32(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::I64(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::U64(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::I128(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::U128(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::F32(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::F64(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::String(x)) => vec![x].into(), - AlgebraicValue::Builtin(BuiltinValue::Array { val }) => vec![val].into(), - AlgebraicValue::Builtin(BuiltinValue::Map { val }) => vec![val].into(), + AlgebraicValue::Sum(x) => vec(x, capacity).into(), + AlgebraicValue::Product(x) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::Bool(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::I8(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::U8(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::I16(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::U16(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::I32(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::U32(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::I64(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::U64(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::I128(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::U128(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::F32(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::F64(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::String(x)) => vec(x, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::Array { val }) => vec(val, capacity).into(), + AlgebraicValue::Builtin(BuiltinValue::Map { val }) => vec(val, capacity).into(), } } - pub fn push(&mut self, val: AlgebraicValue) -> Result<(), AlgebraicValue> { + /// Pushes the value `val` onto the array `self` + /// or returns back `Err(val)` if there was a type mismatch + /// between the base type of the array and `val`. + /// + /// Optionally allocates the backing `Vec<_>`s with `capacity`. + pub fn push(&mut self, val: AlgebraicValue, capacity: Option) -> Result<(), AlgebraicValue> { match (self, val) { (ArrayValue::Sum(v), AlgebraicValue::Sum(val)) => v.push(val), (ArrayValue::Product(v), AlgebraicValue::Product(val)) => v.push(val), @@ -199,37 +261,39 @@ impl ArrayValue { (ArrayValue::String(v), AlgebraicValue::Builtin(BuiltinValue::String(val))) => v.push(val), (ArrayValue::Array(v), AlgebraicValue::Builtin(BuiltinValue::Array { val })) => v.push(val), (ArrayValue::Map(v), AlgebraicValue::Builtin(BuiltinValue::Map { val })) => v.push(val), - (me, val) if me.is_empty() => *me = Self::from_one(val), + (me, val) if me.is_empty() => *me = Self::from_one_with_capacity(val, capacity), (_, val) => return Err(val), } Ok(()) } + /// Returns a cloning iterator on the elements of `self` as `AlgebraicValue`s. pub fn iter_cloned(&self) -> ArrayValueIterCloned { match self { - ArrayValue::Sum(v) => ArrayValueIterCloned::Sum(v.iter().cloned()), - ArrayValue::Product(v) => ArrayValueIterCloned::Product(v.iter().cloned()), - ArrayValue::Bool(v) => ArrayValueIterCloned::Bool(v.iter().cloned()), - ArrayValue::I8(v) => ArrayValueIterCloned::I8(v.iter().cloned()), - ArrayValue::U8(v) => ArrayValueIterCloned::U8(v.iter().cloned()), - ArrayValue::I16(v) => ArrayValueIterCloned::I16(v.iter().cloned()), - ArrayValue::U16(v) => ArrayValueIterCloned::U16(v.iter().cloned()), - ArrayValue::I32(v) => ArrayValueIterCloned::I32(v.iter().cloned()), - ArrayValue::U32(v) => ArrayValueIterCloned::U32(v.iter().cloned()), - ArrayValue::I64(v) => ArrayValueIterCloned::I64(v.iter().cloned()), - ArrayValue::U64(v) => ArrayValueIterCloned::U64(v.iter().cloned()), - ArrayValue::I128(v) => ArrayValueIterCloned::I128(v.iter().cloned()), - ArrayValue::U128(v) => ArrayValueIterCloned::U128(v.iter().cloned()), - ArrayValue::F32(v) => ArrayValueIterCloned::F32(v.iter().cloned()), - ArrayValue::F64(v) => ArrayValueIterCloned::F64(v.iter().cloned()), - ArrayValue::String(v) => ArrayValueIterCloned::String(v.iter().cloned()), - ArrayValue::Array(v) => ArrayValueIterCloned::Array(v.iter().cloned()), - ArrayValue::Map(v) => ArrayValueIterCloned::Map(v.iter().cloned()), + ArrayValue::Sum(v) => ArrayValueIterCloned::Sum(v.iter()), + ArrayValue::Product(v) => ArrayValueIterCloned::Product(v.iter()), + ArrayValue::Bool(v) => ArrayValueIterCloned::Bool(v.iter()), + ArrayValue::I8(v) => ArrayValueIterCloned::I8(v.iter()), + ArrayValue::U8(v) => ArrayValueIterCloned::U8(v.iter()), + ArrayValue::I16(v) => ArrayValueIterCloned::I16(v.iter()), + ArrayValue::U16(v) => ArrayValueIterCloned::U16(v.iter()), + ArrayValue::I32(v) => ArrayValueIterCloned::I32(v.iter()), + ArrayValue::U32(v) => ArrayValueIterCloned::U32(v.iter()), + ArrayValue::I64(v) => ArrayValueIterCloned::I64(v.iter()), + ArrayValue::U64(v) => ArrayValueIterCloned::U64(v.iter()), + ArrayValue::I128(v) => ArrayValueIterCloned::I128(v.iter()), + ArrayValue::U128(v) => ArrayValueIterCloned::U128(v.iter()), + ArrayValue::F32(v) => ArrayValueIterCloned::F32(v.iter()), + ArrayValue::F64(v) => ArrayValueIterCloned::F64(v.iter()), + ArrayValue::String(v) => ArrayValueIterCloned::String(v.iter()), + ArrayValue::Array(v) => ArrayValueIterCloned::Array(v.iter()), + ArrayValue::Map(v) => ArrayValueIterCloned::Map(v.iter()), } } } impl Default for ArrayValue { + /// The default `ArrayValue` is an empty array of sum values. fn default() -> Self { Self::from(Vec::::default()) } @@ -265,6 +329,7 @@ impl_from_array!(ArrayValue, Array); impl_from_array!(MapValue, Map); impl ArrayValue { + /// Returns `self` as `&dyn Debug`. fn as_dyn_debug(&self) -> &dyn fmt::Debug { match self { Self::Sum(v) => v, @@ -323,24 +388,43 @@ impl IntoIterator for ArrayValue { } } +/// A by-value iterator on the elements of an `ArrayValue` as `AlgebraicValue`s. pub enum ArrayValueIntoIter { + /// An iterator on a sum value array. Sum(std::vec::IntoIter), + /// An iterator on a product value array. Product(std::vec::IntoIter), + /// An iterator on a [`bool`] array. Bool(std::vec::IntoIter), + /// An iterator on an [`i8`] array. I8(std::vec::IntoIter), + /// An iterator on a [`u8`] array. U8(std::vec::IntoIter), + /// An iterator on an [`i16`] array. I16(std::vec::IntoIter), + /// An iterator on a [`u16`] array. U16(std::vec::IntoIter), + /// An iterator on an [`i32`] array. I32(std::vec::IntoIter), + /// An iterator on a [`u32`] array. U32(std::vec::IntoIter), + /// An iterator on an [`i64`] array. I64(std::vec::IntoIter), + /// An iterator on a [`u64`] array. U64(std::vec::IntoIter), + /// An iterator on an [`i128`] array. I128(std::vec::IntoIter), + /// An iterator on a [`u128`] array. U128(std::vec::IntoIter), + /// An iterator on a [`F32`] array. F32(std::vec::IntoIter), + /// An iterator on a [`F64`] array. F64(std::vec::IntoIter), + /// An iterator on an array of UTF-8 strings. String(std::vec::IntoIter), + /// An iterator on an array of arrays. Array(std::vec::IntoIter), + /// An iterator on an array of maps. Map(std::vec::IntoIter), } @@ -365,33 +449,31 @@ impl Iterator for ArrayValueIntoIter { ArrayValueIntoIter::F32(it) => it.next().map(|f| f32::from(f).into()), ArrayValueIntoIter::F64(it) => it.next().map(|f| f64::from(f).into()), ArrayValueIntoIter::String(it) => it.next().map(Into::into), - ArrayValueIntoIter::Array(it) => it - .next() - .map(|val| AlgebraicValue::Builtin(BuiltinValue::Array { val })), - ArrayValueIntoIter::Map(it) => it.next().map(|val| AlgebraicValue::Builtin(BuiltinValue::Map { val })), + ArrayValueIntoIter::Array(it) => it.next().map(AlgebraicValue::ArrayOf), + ArrayValueIntoIter::Map(it) => it.next().map(AlgebraicValue::map), } } } pub enum ArrayValueIterCloned<'a> { - Sum(std::iter::Cloned>), - Product(std::iter::Cloned>), - Bool(std::iter::Cloned>), - I8(std::iter::Cloned>), - U8(std::iter::Cloned>), - I16(std::iter::Cloned>), - U16(std::iter::Cloned>), - I32(std::iter::Cloned>), - U32(std::iter::Cloned>), - I64(std::iter::Cloned>), - U64(std::iter::Cloned>), - I128(std::iter::Cloned>), - U128(std::iter::Cloned>), - F32(std::iter::Cloned>), - F64(std::iter::Cloned>), - String(std::iter::Cloned>), - Array(std::iter::Cloned>), - Map(std::iter::Cloned>), + Sum(std::slice::Iter<'a, crate::SumValue>), + Product(std::slice::Iter<'a, crate::ProductValue>), + Bool(std::slice::Iter<'a, bool>), + I8(std::slice::Iter<'a, i8>), + U8(std::slice::Iter<'a, u8>), + I16(std::slice::Iter<'a, i16>), + U16(std::slice::Iter<'a, u16>), + I32(std::slice::Iter<'a, i32>), + U32(std::slice::Iter<'a, u32>), + I64(std::slice::Iter<'a, i64>), + U64(std::slice::Iter<'a, u64>), + I128(std::slice::Iter<'a, i128>), + U128(std::slice::Iter<'a, u128>), + F32(std::slice::Iter<'a, F32>), + F64(std::slice::Iter<'a, F64>), + String(std::slice::Iter<'a, String>), + Array(std::slice::Iter<'a, ArrayValue>), + Map(std::slice::Iter<'a, MapValue>), } impl Iterator for ArrayValueIterCloned<'_> { @@ -399,26 +481,24 @@ impl Iterator for ArrayValueIterCloned<'_> { fn next(&mut self) -> Option { match self { - ArrayValueIterCloned::Sum(it) => it.next().map(AlgebraicValue::Sum), - ArrayValueIterCloned::Product(it) => it.next().map(Into::into), - ArrayValueIterCloned::Bool(it) => it.next().map(Into::into), - ArrayValueIterCloned::I8(it) => it.next().map(Into::into), - ArrayValueIterCloned::U8(it) => it.next().map(Into::into), - ArrayValueIterCloned::I16(it) => it.next().map(Into::into), - ArrayValueIterCloned::U16(it) => it.next().map(Into::into), - ArrayValueIterCloned::I32(it) => it.next().map(Into::into), - ArrayValueIterCloned::U32(it) => it.next().map(Into::into), - ArrayValueIterCloned::I64(it) => it.next().map(Into::into), - ArrayValueIterCloned::U64(it) => it.next().map(Into::into), - ArrayValueIterCloned::I128(it) => it.next().map(Into::into), - ArrayValueIterCloned::U128(it) => it.next().map(Into::into), - ArrayValueIterCloned::F32(it) => it.next().map(|f| f32::from(f).into()), - ArrayValueIterCloned::F64(it) => it.next().map(|f| f64::from(f).into()), - ArrayValueIterCloned::String(it) => it.next().map(Into::into), - ArrayValueIterCloned::Array(it) => it - .next() - .map(|val| AlgebraicValue::Builtin(BuiltinValue::Array { val })), - ArrayValueIterCloned::Map(it) => it.next().map(|val| AlgebraicValue::Builtin(BuiltinValue::Map { val })), + ArrayValueIterCloned::Sum(it) => it.next().cloned().map(AlgebraicValue::Sum), + ArrayValueIterCloned::Product(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::Bool(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::I8(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::U8(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::I16(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::U16(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::I32(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::U32(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::I64(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::U64(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::I128(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::U128(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::F32(it) => it.next().map(|f| f32::from(*f).into()), + ArrayValueIterCloned::F64(it) => it.next().map(|f| f64::from(*f).into()), + ArrayValueIterCloned::String(it) => it.next().cloned().map(Into::into), + ArrayValueIterCloned::Array(it) => it.next().cloned().map(AlgebraicValue::ArrayOf), + ArrayValueIterCloned::Map(it) => it.next().cloned().map(AlgebraicValue::map), } } } diff --git a/crates/sats/src/convert.rs b/crates/sats/src/convert.rs index 14c9f8059d..eee3565fd6 100644 --- a/crates/sats/src/convert.rs +++ b/crates/sats/src/convert.rs @@ -34,15 +34,13 @@ impl From for ProductValue { impl From<&AlgebraicValue> for ProductValue { fn from(x: &AlgebraicValue) -> Self { - Self { - elements: vec![x.clone()], - } + x.clone().into() } } impl From for ProductType { fn from(x: AlgebraicType) -> Self { - Self::new(vec![ProductTypeElement::new(x, None)]) + Self::new(vec![x.into()]) } } diff --git a/crates/sats/src/de.rs b/crates/sats/src/de.rs index 6d72d78a94..8a0fdf0f76 100644 --- a/crates/sats/src/de.rs +++ b/crates/sats/src/de.rs @@ -1,3 +1,6 @@ +// Some parts copyright Serde developers under the MIT / Apache-2.0 licenses at your option. +// See `serde` version `v1.0.169` for the parts where MIT / Apache-2.0 applies. + mod impls; #[cfg(feature = "serde")] pub mod serde; @@ -10,36 +13,115 @@ use std::collections::BTreeMap; use std::fmt; use std::marker::PhantomData; -use crate::{fmt_fn, FDisplay}; - +/// A **data format** that can deserialize any data structure supported by SATS. +/// +/// The `Deserializer` trait in SATS performs the same function as [`serde::Deserializer`] in [`serde`]. +/// See the documentation of [`serde::Deserializer`] for more information of the data model. +/// +/// Implementations of `Deserialize` map themselves into this data model +/// by passing to the `Deserializer` a visitor that can receive the necessary types. +/// The kind of visitor depends on the `deserialize_*` method. +/// Unlike in Serde, there isn't a single monolithic `Visitor` trait, +/// but rather, this functionality is split up into more targeted traits such as `SumVisitor<'de>`. +/// +/// The lifetime `'de` allows us to deserialize lifetime-generic types in a zero-copy fashion. +/// +/// [`serde::Deserializer`]: ::serde::Deserializer +/// [`serde`]: https://crates.io/crates/serde pub trait Deserializer<'de>: Sized { + /// The error type that can be returned if some error occurs during deserialization. type Error: Error; + /// Deserializes a product value from the input. fn deserialize_product>(self, visitor: V) -> Result; + /// Deserializes a sum value from the input. + /// + /// The entire process of deserializing a sum, starting from `deserialize(args...)`, is roughly: + /// + /// - [`deserialize`][Deserialize::deserialize] calls this method, + /// [`deserialize_sum(sum_visitor)`](Deserializer::deserialize_sum), + /// providing us with a [`sum_visitor`](SumVisitor). + /// + /// - This method calls [`sum_visitor.visit_sum(sum_access)`](SumVisitor::visit_sum), + /// where [`sum_access`](SumAccess) deals with extracting the tag and the variant data, + /// with the latter provided as [`VariantAccess`]). + /// The `SumVisitor` will then assemble these into the representation of a sum value + /// that the [`Deserialize`] implementation wants. + /// + /// - [`visit_sum`](SumVisitor::visit_sum) then calls + /// [`sum_access.variant(variant_visitor)`](SumAccess::variant), + /// and uses the provided `variant_visitor` to translate extracted variant names / tags + /// into something that is meaningful for `visit_sum`, e.g., an index. + /// + /// The call to `variant` will also return [`variant_access`](VariantAccess) + /// that can deserialize the contents of the variant. + /// + /// - Finally, after `variant` returns, + /// `visit_sum` deserializes the variant data using + /// [`variant_access.deserialize_seed(seed)`](VariantAccess::deserialize_seed) + /// or [`variant_access.deserialize()`](VariantAccess::deserialize). + /// This part may require some conditional logic depending on the identified variant. + /// + /// + /// The data format will also return an object ([`VariantAccess`]) + /// that can deserialize the contents of the variant. fn deserialize_sum>(self, visitor: V) -> Result; + /// Deserializes a `bool` value from the input. fn deserialize_bool(self) -> Result; + + /// Deserializes a `u8` value from the input. fn deserialize_u8(self) -> Result; + + /// Deserializes a `u16` value from the input. fn deserialize_u16(self) -> Result; + + /// Deserializes a `u32` value from the input. fn deserialize_u32(self) -> Result; + + /// Deserializes a `u64` value from the input. fn deserialize_u64(self) -> Result; + + /// Deserializes a `u128` value from the input. fn deserialize_u128(self) -> Result; + + /// Deserializes an `i8 value from the input. fn deserialize_i8(self) -> Result; + + /// Deserializes an `i16 value from the input. fn deserialize_i16(self) -> Result; + + /// Deserializes an `i32 value from the input. fn deserialize_i32(self) -> Result; + + /// Deserializes an `i64 value from the input. fn deserialize_i64(self) -> Result; + + /// Deserializes an `i128 value from the input. fn deserialize_i128(self) -> Result; + + /// Deserializes an `f32 value from the input. fn deserialize_f32(self) -> Result; + + /// Deserializes an `f64 value from the input. fn deserialize_f64(self) -> Result; + /// Deserializes a string-like object the input. fn deserialize_str>(self, visitor: V) -> Result; + + /// Deserializes an `&str` string value. fn deserialize_str_slice(self) -> Result<&'de str, Self::Error> { self.deserialize_str(BorrowedSliceVisitor) } + /// Deserializes a byte slice-like value. fn deserialize_bytes>(self, visitor: V) -> Result; + /// Deserializes an array value. + /// + /// This is typically the same as [`deserialize_array_seed`](Deserializer::deserialize_array_seed) + /// with an uninteresting `seed` value. fn deserialize_array, T: Deserialize<'de>>( self, visitor: V, @@ -47,12 +129,19 @@ pub trait Deserializer<'de>: Sized { self.deserialize_array_seed(visitor, PhantomData) } + /// Deserializes an array value. + /// + /// The deserialization is provided with a `seed` value. fn deserialize_array_seed, T: DeserializeSeed<'de> + Clone>( self, visitor: V, seed: T, ) -> Result; + /// Deserializes a map value. + /// + /// This is typically the same as [`deserialize_map_seed`](Deserializer::deserialize_map_seed) + /// with an uninteresting `seed` value. fn deserialize_map, K: Deserialize<'de>, V: Deserialize<'de>>( self, visitor: Vi, @@ -60,6 +149,9 @@ pub trait Deserializer<'de>: Sized { self.deserialize_map_seed(visitor, PhantomData, PhantomData) } + /// Deserializes a map value. + /// + /// The deserialization is provided with `kseed` and `vseed` for keys and values respectively. fn deserialize_map_seed< Vi: MapVisitor<'de, K::Output, V::Output>, K: DeserializeSeed<'de> + Clone, @@ -72,9 +164,20 @@ pub trait Deserializer<'de>: Sized { ) -> Result; } +/// The `Error` trait allows [`Deserialize`] implementations to create descriptive error messages +/// belonging to the [`Deserializer`] against which they are currently running. +/// +/// Every [`Deserializer`] declares an [`Error`] type +/// that encompasses both general-purpose deserialization errors +/// as well as errors specific to the particular deserialization format. +/// +/// Most deserializers should only need to provide the [`Error::custom`] method +/// and inherit the default behavior for the other methods. pub trait Error: Sized { + /// Raised when there is general error when deserializing a type. fn custom(msg: impl fmt::Display) -> Self; + /// The product length was not as promised. fn invalid_product_length<'de, T: ProductVisitor<'de>>(len: usize, expected: &T) -> Self { Self::custom(format_args!( "invalid length {}, expected {}", @@ -83,26 +186,30 @@ pub trait Error: Sized { )) } - fn missing_field<'de, T: ProductVisitor<'de>>(field: usize, field_name: Option<&str>, prod: &T) -> Self { - Self::custom(error_on_field("missing ", field, field_name, prod)) + /// There was a missing field at `index`. + fn missing_field<'de, T: ProductVisitor<'de>>(index: usize, field_name: Option<&str>, prod: &T) -> Self { + Self::custom(error_on_field("missing ", index, field_name, prod)) } - fn duplicate_field<'de, T: ProductVisitor<'de>>(field: usize, field_name: Option<&str>, prod: &T) -> Self { - Self::custom(error_on_field("duplicate ", field, field_name, prod)) + /// A field with `index` was specified more than once. + fn duplicate_field<'de, T: ProductVisitor<'de>>(index: usize, field_name: Option<&str>, prod: &T) -> Self { + Self::custom(error_on_field("duplicate ", index, field_name, prod)) } + /// A field with name `field_name` does not exist. fn unknown_field_name<'de, T: FieldNameVisitor<'de>>(field_name: &str, expected: &T) -> Self { let el_ty = match expected.kind() { ProductKind::Normal => "field", ProductKind::ReducerArgs => "reducer argument", }; if let Some(one_of) = one_of_names(|n| expected.field_names(n)) { - Self::custom(format_args!("unknown {el_ty} `{field_name}`, expected {one_of}",)) + Self::custom(format_args!("unknown {el_ty} `{field_name}`, expected {one_of}")) } else { Self::custom(format_args!("unknown {el_ty} `{field_name}`, there are no {el_ty}s")) } } + /// The `tag` does not specify a variant of the sum type. fn unknown_variant_tag<'de, T: SumVisitor<'de>>(tag: u8, expected: &T) -> Self { Self::custom(format_args!( "unknown tag {tag:#x} for sum type {}", @@ -110,6 +217,7 @@ pub trait Error: Sized { )) } + /// The `name` is not that of a variant of the sum type. fn unknown_variant_name(name: &str, expected: &T) -> Self { if let Some(one_of) = one_of_names(|n| expected.variant_names(n)) { Self::custom(format_args!("unknown variant `{name}`, expected {one_of}",)) @@ -119,11 +227,26 @@ pub trait Error: Sized { } } -fn error_on_field<'a, 'de, T: ProductVisitor<'de>>( +/// Turns a closure `impl Fn(&mut Formatter) -> Result` into a `Display`able object. +pub struct FDisplay(F); + +impl fmt::Result> fmt::Display for FDisplay { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self.0)(f) + } +} + +/// Turns a closure `F: Fn(&mut Formatter) -> Result` into a `Display`able object. +pub fn fmt_fn fmt::Result>(f: F) -> FDisplay { + FDisplay(f) +} + +/// Returns an error message for a `problem` with field at `index` and an optional `name`. +fn error_on_field<'a, 'de>( problem: &'static str, - field: usize, - field_name: Option<&'a str>, - prod: &T, + index: usize, + name: Option<&'a str>, + prod: &impl ProductVisitor<'de>, ) -> impl fmt::Display + 'a { let field_kind = match prod.product_kind() { ProductKind::Normal => "field", @@ -133,17 +256,18 @@ fn error_on_field<'a, 'de, T: ProductVisitor<'de>>( // e.g. "missing field `foo`" f.write_str(problem)?; f.write_str(field_kind)?; - if let Some(name) = field_name { + if let Some(name) = name { write!(f, " `{}`", name) } else { - write!(f, " (index {})", field) + write!(f, " (index {})", index) } }) } -fn fmt_invalid_len<'de, T: ProductVisitor<'de>>( - expected: &T, -) -> FDisplay fmt::Result + '_> { +/// Returns an error message for invalid product type lengths. +fn fmt_invalid_len<'de>( + expected: &impl ProductVisitor<'de>, +) -> FDisplay fmt::Result> { fmt_fn(|f| { let ty = match expected.product_kind() { ProductKind::Normal => "product type", @@ -156,168 +280,332 @@ fn fmt_invalid_len<'de, T: ProductVisitor<'de>>( }) } +/// A visitor walking through a [`Deserializer`] for products. pub trait ProductVisitor<'de> { + /// The resulting product. type Output; + /// Returns the name of the product, if any. fn product_name(&self) -> Option<&str>; + + /// Returns the length of the product. fn product_len(&self) -> usize; + + /// Returns the kind of the product. fn product_kind(&self) -> ProductKind { ProductKind::Normal } + /// The input contains an unnamed product. fn visit_seq_product>(self, prod: A) -> Result; + + /// The input contains a named product. fn visit_named_product>(self, prod: A) -> Result; } +/// What kind of product is this? #[derive(Clone, Copy)] pub enum ProductKind { + // A normal product. Normal, + /// A product in the context of reducer arguments. ReducerArgs, } +/// Provides a [`ProductVisitor`] with access to each element of the unnamed product in the input. +/// +/// This is a trait that a [`Deserializer`] passes to a [`ProductVisitor`] implementation. pub trait SeqProductAccess<'de> { + /// The error type that can be returned if some error occurs during deserialization. type Error: Error; + /// Deserializes an `T` from the input. + /// + /// Returns `Ok(Some(value))` for the next element in the product, + /// or `Ok(None)` if there are no more remaining items. + /// + /// This method exists as a convenience for [`Deserialize`] implementations. + /// [`SeqProductAccess`] implementations should not override the default behavior. fn next_element>(&mut self) -> Result, Self::Error> { self.next_element_seed(PhantomData) } + /// Statefully deserializes `T::Output` from the input provided a `seed` value. + /// + /// Returns `Ok(Some(value))` for the next element in the unnamed product, + /// or `Ok(None)` if there are no more remaining items. + /// + /// [`Deserialize`] implementations should typically use + /// [`next_element`](SeqProductAccess::next_element) instead. fn next_element_seed>(&mut self, seed: T) -> Result, Self::Error>; } +/// Provides a [`ProductVisitor`] with access to each element of the named product in the input. +/// +/// This is a trait that a [`Deserializer`] passes to a [`ProductVisitor`] implementation. pub trait NamedProductAccess<'de> { + /// The error type that can be returned if some error occurs during deserialization. type Error: Error; + /// Deserializes field name of type `V::Output` from the input using a visitor + /// provided by the deserializer. fn get_field_ident>(&mut self, visitor: V) -> Result, Self::Error>; + /// Deserializes field value of type `T` from the input. + /// + /// This method exists as a convenience for [`Deserialize`] implementations. + /// [`NamedProductAccess`] implementations should not override the default behavior. fn get_field_value>(&mut self) -> Result { self.get_field_value_seed(PhantomData) } + /// Statefully deserializes the field value `T::Output` from the input provided a `seed` value. + /// + /// [`Deserialize`] implementations should typically use + /// [`next_element`](NamedProductAccess::get_field_value) instead. fn get_field_value_seed>(&mut self, seed: T) -> Result; } +/// Visitor used to deserialize the name of a field. pub trait FieldNameVisitor<'de> { + /// The resulting field name. type Output; + /// The sort of product deserialized. fn kind(&self) -> ProductKind { ProductKind::Normal } + + /// Provides the visitor the chance to add valid names into `names`. fn field_names(&self, names: &mut dyn ValidNames); fn visit(self, name: &str) -> Result; } +/// A trait for types storing a set of valid names. pub trait ValidNames { + /// Adds the name `s` to the set. fn push(&mut self, s: &str); + + /// Runs the function `names` provided with `self` as the store + /// and then returns back `self`. + /// This method exists for convenience. + fn run(mut self, names: &impl Fn(&mut dyn ValidNames)) -> Self + where + Self: Sized, + { + names(&mut self); + self + } } + impl dyn ValidNames + '_ { - pub fn extend(&mut self, i: I) + /// Adds the names in `iter` to the set. + pub fn extend(&mut self, iter: I) where I::Item: AsRef, { - for name in i { + for name in iter { self.push(name.as_ref()) } } } +/// A visitor walking through a [`Deserializer`] for sums. +/// +/// This side is provided by a [`Deserialize`] implementation +/// when calling [`Deserializer::deserialize_sum`]. pub trait SumVisitor<'de> { + /// The resulting sum. type Output; + /// Returns the name of the sum, if any. fn sum_name(&self) -> Option<&str>; + + /// Returns whether an option is expected. + /// + /// The provided implementation does not. fn is_option(&self) -> bool { false } + /// Drives the deserialization of a sum value. + /// + /// This method will ask the data format ([`A: SumAccess`][SumAccess]) + /// which variant of the sum to select in terms of a variant name / tag. + /// `A` will use a [`VariantVisitor`], that `SumVisitor` has provided, + /// to translate into something that is meaningful for `visit_sum`, e.g., an index. + /// + /// The data format will also return an object ([`VariantAccess`]) + /// that can deserialize the contents of the variant. fn visit_sum>(self, data: A) -> Result; } +/// Provides a [`SumVisitor`] access to the data of a sum in the input. +/// +/// An `A: SumAccess` object is created by the [`D: Deserializer`] +/// which passes `A` to a [`V: SumVisitor`] that `D` in turn was passed. +/// `A` is then used by `V` to split tag and value input apart. pub trait SumAccess<'de> { + /// The error type that can be returned if some error occurs during deserialization. type Error: Error; + + /// The visitor used to deserialize the content of the sum variant. type Variant: VariantAccess<'de, Error = Self::Error>; + /// Called to identify which variant to deserialize. + /// Returns a tuple with the result of identification (`V::Output`) + /// and the input to variant data deserialization. + /// + /// The `visitor` is provided by the [`Deserializer`]. + /// This method is typically called from [`SumVisitor::visit_sum`] + /// which will provide the [`V: VariantVisitor`](VariantVisitor). fn variant(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error>; } +/// A visitor passed from [`SumVisitor`] to [`SumAccess::variant`] +/// which the latter uses to decide what variant to deserialize. pub trait VariantVisitor { + /// The result of identifying a variant, e.g., some index type. type Output; + /// Provides the visitor the chance to add valid names into `names`. fn variant_names(&self, names: &mut dyn ValidNames); + /// Identify the variant based on `tag`. fn visit_tag(self, tag: u8) -> Result; + + /// Identify the variant based on `name`. fn visit_name(self, name: &str) -> Result; } +/// A visitor passed from [`SumAccess`] to [`SumVisitor::visit_sum`] +/// which the latter uses to deserialize the data of a selected variant. pub trait VariantAccess<'de>: Sized { type Error: Error; + /// Called when deserializing the contents of a sum variant. + /// + /// This method exists as a convenience for [`Deserialize`] implementations. fn deserialize>(self) -> Result { self.deserialize_seed(PhantomData) } + /// Called when deserializing the contents of a sum variant, and provided with a `seed` value. fn deserialize_seed>(self, seed: T) -> Result; } +/// A `SliceVisitor` is provided a slice `T` of some elements by a [`Deserializer`] +/// and is tasked with translating this slice to the `Output` type. pub trait SliceVisitor<'de, T: ToOwned + ?Sized>: Sized { + /// The output produced by this visitor. type Output; + /// The input contains a slice. + /// + /// The lifetime of the slice is ephemeral + /// and it may be destroyed after this method returns. fn visit(self, slice: &T) -> Result; + /// The input contains a slice and ownership of the slice is being given to the [`SliceVisitor`]. fn visit_owned(self, buf: T::Owned) -> Result { self.visit(buf.borrow()) } + /// The input contains a slice that lives at least as long (`'de`) as the [`Deserializer`]. fn visit_borrowed(self, borrowed_slice: &'de T) -> Result { self.visit(borrowed_slice) } } +/// A visitor walking through a [`Deserializer`] for arrays. pub trait ArrayVisitor<'de, T> { + /// The output produced by this visitor. type Output; + /// The input contains an array. fn visit>(self, vec: A) -> Result; } +/// Provides an [`ArrayVisitor`] with access to each element of the array in the input. +/// +/// This is a trait that a [`Deserializer`] passes to an [`ArrayVisitor`] implementation. pub trait ArrayAccess<'de> { + /// The element / base type of the array. type Element; + + /// The error type that can be returned if some error occurs during deserialization. type Error: Error; + /// This returns `Ok(Some(value))` for the next element in the array, + /// or `Ok(None)` if there are no more remaining elements. fn next_element(&mut self) -> Result, Self::Error>; + /// Returns the number of elements remaining in the array, if known. fn size_hint(&self) -> Option { None } } +/// A visitor walking through a [`Deserializer`] for maps. pub trait MapVisitor<'de, K, V> { + /// The output produced by this visitor. type Output; + /// The input contains a key-value map. fn visit>(self, map: A) -> Result; } +/// Provides a [`MapVisitor`] with access to each element of the array in the input. +/// +/// This is a trait that a [`Deserializer`] passes to a [`MapVisitor`] implementation. pub trait MapAccess<'de> { + /// The key type of the map. type Key; + + /// The value type of the map. type Value; + + /// The error type that can be returned if some error occurs during deserialization. type Error: Error; + /// This returns `Ok(Some((key, value)))` for the next (key-value) pair in the map, + /// or `Ok(None)` if there are no more remaining items. #[allow(clippy::type_complexity)] fn next_entry(&mut self) -> Result, Self::Error>; + /// Returns the number of elements remaining in the map, if known. fn size_hint(&self) -> Option { None } } +/// `DeserializeSeed` is the stateful form of the [`Deserialize`] trait. pub trait DeserializeSeed<'de> { + /// The type produced by using this seed. type Output; + + /// Equivalent to the more common [`Deserialize::deserialize`] associated function, + /// except with some initial piece of data (the seed `self`) passed in. fn deserialize>(self, deserializer: D) -> Result; } use crate::de::impls::BorrowedSliceVisitor; pub use spacetimedb_bindings_macro::Deserialize; +/// A **datastructure** that can be deserialized from any data format supported by SATS. +/// +/// In most cases, implementations of `Deserialize` may be `#[derive(Deserialize)]`d. +/// +/// The `Deserialize` trait in SATS performs the same function as [`serde::Deserialize`] in [`serde`]. +/// See the documentation of [`serde::Deserialize`] for more information of the data model. +/// +/// The lifetime `'de` allows us to deserialize lifetime-generic types in a zero-copy fashion. +/// +/// [`serde::Deserialize`]: ::serde::Deserialize +/// [`serde`]: https://crates.io/crates/serde pub trait Deserialize<'de>: Sized { + /// Deserialize this value from the given `deserializer`. fn deserialize>(deserializer: D) -> Result; /// used in the Deserialize for Vec impl to allow specializing deserializing Vec as bytes @@ -334,8 +622,10 @@ pub trait Deserialize<'de>: Sized { } } +/// A data structure that can be deserialized in SATS +/// without borrowing any data from the deserializer. pub trait DeserializeOwned: for<'de> Deserialize<'de> {} -impl DeserializeOwned for T where T: for<'de> Deserialize<'de> {} +impl Deserialize<'de>> DeserializeOwned for T {} impl<'de, T: Deserialize<'de>> DeserializeSeed<'de> for PhantomData { type Output = T; @@ -345,6 +635,7 @@ impl<'de, T: Deserialize<'de>> DeserializeSeed<'de> for PhantomData { } } +/// An implementation of [`ArrayVisitor<'de, T>`] where the output is a `Vec`. pub struct BasicVecVisitor; impl<'de, T> ArrayVisitor<'de, T> for BasicVecVisitor { @@ -359,6 +650,7 @@ impl<'de, T> ArrayVisitor<'de, T> for BasicVecVisitor { } } +/// An implementation of [`MapVisitor<'de, K, V>`] where the output is a `BTreeMap`. pub struct BasicMapVisitor; impl<'de, K: Ord, V> MapVisitor<'de, K, V> for BasicMapVisitor { @@ -373,6 +665,7 @@ impl<'de, K: Ord, V> MapVisitor<'de, K, V> for BasicMapVisitor { } } +/// An implementation of [`ArrayVisitor<'de, T>`] where the output is a `[T; N]`. struct BasicArrayVisitor; impl<'de, T, const N: usize> ArrayVisitor<'de, T> for BasicArrayVisitor { @@ -388,51 +681,67 @@ impl<'de, T, const N: usize> ArrayVisitor<'de, T> for BasicArrayVisitor { } } +/// Provided a function `names` that is allowed to store a name into a valid set, +/// returns a human readable list of all the names, +/// or `None` in the case of an empty list of names. fn one_of_names(names: impl Fn(&mut dyn ValidNames)) -> Option { - let mut n = NNames(0); - names(&mut n); - let NNames(n) = n; - (n != 0).then(|| { - fmt_fn(move |f| { - let mut f = OneOfNames::new(n != 2, f); - names(&mut f); - f.f.map(drop) - }) - }) -} + /// An implementation of `ValidNames` that just counts how many valid names are pushed into it. + struct CountNames(usize); -struct NNames(usize); -impl ValidNames for NNames { - fn push(&mut self, _: &str) { - self.0 += 1 + impl ValidNames for CountNames { + fn push(&mut self, _: &str) { + self.0 += 1 + } } -} -struct OneOfNames<'a, 'b> { - at_start: bool, - many: bool, - f: Result<&'a mut fmt::Formatter<'b>, fmt::Error>, -} -impl<'a, 'b> OneOfNames<'a, 'b> { - fn new(many: bool, f: &'a mut fmt::Formatter<'b>) -> Self { - Self { - at_start: true, - many, - f: Ok(f), + /// An implementation of `ValidNames` that provides a human friendly enumeration of names. + struct OneOfNames<'a, 'b> { + /// A `.push(_)` counter. + index: usize, + /// How many names there were. + count: usize, + /// Result of formatting thus far. + f: Result<&'a mut fmt::Formatter<'b>, fmt::Error>, + } + + impl<'a, 'b> OneOfNames<'a, 'b> { + fn new(count: usize, f: &'a mut fmt::Formatter<'b>) -> Self { + Self { + index: 0, + count, + f: Ok(f), + } } } -} -impl ValidNames for OneOfNames<'_, '_> { - fn push(&mut self, name: &str) { - let (start, sep) = if self.many { ("", " or ") } else { ("one of", ", ") }; - if let Ok(f) = &mut self.f { - let mut go = || -> fmt::Result { - f.write_str(if std::mem::take(&mut self.at_start) { start } else { sep })?; - write!(f, "`{name}`") - }; - if let Err(e) = go() { + + impl ValidNames for OneOfNames<'_, '_> { + fn push(&mut self, name: &str) { + // This will give us, after all `.push()`es have been made, the following: + // + // count = 1 -> "`foo`" + // = 2 -> "`foo` or `bar`" + // > 2 -> "one of `foo`, `bar`, or `baz`" + + let Ok(f) = &mut self.f else { return; }; + + self.index += 1; + + if let Err(e) = match (self.count, self.index) { + (1, _) => write!(f, "`{name}`"), + (2, 1) => write!(f, "`{name}`"), + (2, 2) => write!(f, "`or `{name}`"), + (_, 1) => write!(f, "one of `{name}`"), + (c, i) if i < c => write!(f, ", `{name}`"), + (_, _) => write!(f, ", `, or {name}`"), + } { self.f = Err(e); } } } + + // Count how many names have been pushed. + let count = CountNames(0).run(&names).0; + + // There was at least one name; render those names. + (count != 0).then(|| fmt_fn(move |fmt| OneOfNames::new(count, fmt).run(&names).f.map(drop))) } diff --git a/crates/sats/src/de/impls.rs b/crates/sats/src/de/impls.rs index 835d66e076..bb19160299 100644 --- a/crates/sats/src/de/impls.rs +++ b/crates/sats/src/de/impls.rs @@ -8,7 +8,7 @@ use std::marker::PhantomData; use crate::builtin_value::{F32, F64}; use crate::{ AlgebraicType, AlgebraicValue, ArrayType, ArrayValue, BuiltinType, BuiltinValue, MapType, MapValue, ProductType, - ProductTypeElement, ProductValue, SumType, SumValue, TypeInSpace, + ProductTypeElement, ProductValue, SumType, SumValue, WithTypespace, }; use super::{ @@ -16,101 +16,98 @@ use super::{ ProductVisitor, SeqProductAccess, SliceVisitor, SumAccess, SumVisitor, VariantAccess, VariantVisitor, }; +/// Implements [`Deserialize`] for a type in a simplified manner. +/// +/// An example: +/// ```ignore +/// impl_deserialize!( +/// // Type parameters Optional where Impl type +/// // v v v +/// // ---------------- --------------- ---------- +/// [T: Deserialize<'de>] where [T: Copy] std::rc::Rc, +/// // The `deserialize` implementation where `de` is the `Deserializer<'de>` +/// // and the expression right of `=>` is the body of `deserialize`. +/// de => T::deserialize(de).map(std::rc::Rc::new) +/// ); +/// ``` +#[macro_export] +macro_rules! impl_deserialize { + ([$($generics:tt)*] $(where [$($wc:tt)*])? $typ:ty, $de:ident => $body:expr) => { + impl<'de, $($generics)*> $crate::de::Deserialize<'de> for $typ { + fn deserialize>($de: D) -> Result { $body } + } + }; +} + +/// Implements [`Deserialize`] for a primitive type. +/// +/// The `$method` is a parameterless method on `deserializer` to call. macro_rules! impl_prim { ($(($prim:ty, $method:ident))*) => { - $(impl<'de> Deserialize<'de> for $prim { - fn deserialize>(de: D) -> Result { - de.$method() - } - })* + $(impl_deserialize!([] $prim, de => de.$method());)* }; } -impl<'de> Deserialize<'de> for () { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_product(UnitVisitor) - } +impl_prim! { + (bool, deserialize_bool) /*(u8, deserialize_u8)*/ (u16, deserialize_u16) + (u32, deserialize_u32) (u64, deserialize_u64) (u128, deserialize_u128) (i8, deserialize_i8) + (i16, deserialize_i16) (i32, deserialize_i32) (i64, deserialize_i64) (i128, deserialize_i128) + (f32, deserialize_f32) (f64, deserialize_f64) } + +impl_deserialize!([] (), de => de.deserialize_product(UnitVisitor)); + +/// The `UnitVisitor` looks for a unit product. +/// That is, it consumes nothing from the input. struct UnitVisitor; impl<'de> ProductVisitor<'de> for UnitVisitor { type Output = (); + fn product_name(&self) -> Option<&str> { None } + fn product_len(&self) -> usize { 0 } + fn visit_seq_product>(self, _prod: A) -> Result { Ok(()) } + fn visit_named_product>(self, _prod: A) -> Result { Ok(()) } } -impl_prim! { - (bool, deserialize_bool) /*(u8, deserialize_u8)*/ (u16, deserialize_u16) - (u32, deserialize_u32) (u64, deserialize_u64) (u128, deserialize_u128) (i8, deserialize_i8) - (i16, deserialize_i16) (i32, deserialize_i32) (i64, deserialize_i64) (i128, deserialize_i128) - (f32, deserialize_f32) (f64, deserialize_f64) -} - impl<'de> Deserialize<'de> for u8 { fn deserialize>(deserializer: D) -> Result { deserializer.deserialize_u8() } - // specialize Vec deserialization + + // Specialize `Vec` deserialization. + // This is more likely to compile down to a `memcpy`. fn __deserialize_vec>(deserializer: D) -> Result, D::Error> { deserializer.deserialize_bytes(OwnedSliceVisitor) } + fn __deserialize_array, const N: usize>(deserializer: D) -> Result<[Self; N], D::Error> { deserializer.deserialize_bytes(ByteArrayVisitor) } } -impl<'de> Deserialize<'de> for F32 { - fn deserialize>(deserializer: D) -> Result { - f32::deserialize(deserializer).map(Into::into) - } -} -impl<'de> Deserialize<'de> for F64 { - fn deserialize>(deserializer: D) -> Result { - f64::deserialize(deserializer).map(Into::into) - } -} - -impl<'de> Deserialize<'de> for String { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_str(OwnedSliceVisitor) - } -} - -impl<'de, T: Deserialize<'de>> Deserialize<'de> for Vec { - fn deserialize>(deserializer: D) -> Result { - T::__deserialize_vec(deserializer) - } -} - -impl<'de, T: Deserialize<'de>, const N: usize> Deserialize<'de> for [T; N] { - fn deserialize>(deserializer: D) -> Result { - T::__deserialize_array(deserializer) - } -} - -impl<'de> Deserialize<'de> for Box { - fn deserialize>(deserializer: D) -> Result { - String::deserialize(deserializer).map(|s| s.into_boxed_str()) - } -} - -impl<'de, T: Deserialize<'de>> Deserialize<'de> for Box<[T]> { - fn deserialize>(deserializer: D) -> Result { - Vec::deserialize(deserializer).map(|s| s.into_boxed_slice()) - } -} +impl_deserialize!([] F32, de => f32::deserialize(de).map(Into::into)); +impl_deserialize!([] F64, de => f64::deserialize(de).map(Into::into)); +impl_deserialize!([] String, de => de.deserialize_str(OwnedSliceVisitor)); +impl_deserialize!([T: Deserialize<'de>] Vec, de => T::__deserialize_vec(de)); +impl_deserialize!([T: Deserialize<'de>, const N: usize] [T; N], de => T::__deserialize_array(de)); +impl_deserialize!([] Box, de => String::deserialize(de).map(|s| s.into_boxed_str())); +impl_deserialize!([T: Deserialize<'de>] Box<[T]>, de => Vec::deserialize(de).map(|s| s.into_boxed_slice())); +/// The visitor converts the slice to its owned version. struct OwnedSliceVisitor; -impl<'de, T: ToOwned + ?Sized> SliceVisitor<'de, T> for OwnedSliceVisitor { + +impl SliceVisitor<'_, T> for OwnedSliceVisitor { type Output = T::Owned; fn visit(self, slice: &T) -> Result { @@ -122,8 +119,12 @@ impl<'de, T: ToOwned + ?Sized> SliceVisitor<'de, T> for OwnedSliceVisitor { } } +/// The visitor will convert the byte slice to `[u8; N]`. +/// +/// When `slice.len() != N` an error will be raised. struct ByteArrayVisitor; -impl<'de, const N: usize> SliceVisitor<'de, [u8]> for ByteArrayVisitor { + +impl SliceVisitor<'_, [u8]> for ByteArrayVisitor { type Output = [u8; N]; fn visit(self, slice: &[u8]) -> Result { @@ -137,19 +138,12 @@ impl<'de, const N: usize> SliceVisitor<'de, [u8]> for ByteArrayVisitor { } } -impl<'de> Deserialize<'de> for &'de str { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_str(BorrowedSliceVisitor) - } -} - -impl<'de> Deserialize<'de> for &'de [u8] { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_bytes(BorrowedSliceVisitor) - } -} +impl_deserialize!([] &'de str, de => de.deserialize_str(BorrowedSliceVisitor)); +impl_deserialize!([] &'de [u8], de => de.deserialize_bytes(BorrowedSliceVisitor)); +/// The visitor returns the slice as-is and borrowed. pub(crate) struct BorrowedSliceVisitor; + impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for BorrowedSliceVisitor { type Output = &'de T; @@ -162,19 +156,12 @@ impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for BorrowedSliceVisit } } -impl<'de> Deserialize<'de> for Cow<'de, str> { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_str(CowSliceVisitor) - } -} - -impl<'de> Deserialize<'de> for Cow<'de, [u8]> { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_bytes(CowSliceVisitor) - } -} +impl_deserialize!([] Cow<'de, str>, de => de.deserialize_str(CowSliceVisitor)); +impl_deserialize!([] Cow<'de, [u8]>, de => de.deserialize_bytes(CowSliceVisitor)); +/// The visitor works with either owned or borrowed versions to produce `Cow<'de, T>`. struct CowSliceVisitor; + impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for CowSliceVisitor { type Output = Cow<'de, T>; @@ -191,35 +178,33 @@ impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for CowSliceVisitor { } } -impl<'de, K: Deserialize<'de> + Ord, V: Deserialize<'de>> Deserialize<'de> for BTreeMap { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_map(BasicMapVisitor) - } -} +impl_deserialize!( + [K: Deserialize<'de> + Ord, V: Deserialize<'de>] BTreeMap, + de => de.deserialize_map(BasicMapVisitor) +); -impl<'de, T: Deserialize<'de>> Deserialize<'de> for Box { - fn deserialize>(deserializer: D) -> Result { - T::deserialize(deserializer).map(Box::new) - } -} - -impl<'de, T: Deserialize<'de>> Deserialize<'de> for Option { - fn deserialize>(deserializer: D) -> Result { - deserializer.deserialize_sum(OptionVisitor(PhantomData)) - } -} +impl_deserialize!([T: Deserialize<'de>] Box, de => T::deserialize(de).map(Box::new)); +impl_deserialize!([T: Deserialize<'de>] Option, de => de.deserialize_sum(OptionVisitor(PhantomData))); +/// The visitor deserializes an `Option`. struct OptionVisitor(PhantomData); + impl<'de, T: Deserialize<'de>> SumVisitor<'de> for OptionVisitor { type Output = Option; + fn sum_name(&self) -> Option<&str> { Some("option") } + fn is_option(&self) -> bool { true } + fn visit_sum>(self, data: A) -> Result { + // Determine the variant. let (some, data) = data.variant(self)?; + + // Deserialize contents for it. Ok(if some { Some(data.deserialize()?) } else { @@ -228,6 +213,7 @@ impl<'de, T: Deserialize<'de>> SumVisitor<'de> for OptionVisitor { }) } } + impl<'de, T: Deserialize<'de>> VariantVisitor for OptionVisitor { type Output = bool; @@ -252,7 +238,7 @@ impl<'de, T: Deserialize<'de>> VariantVisitor for OptionVisitor { } } -impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, AlgebraicType> { +impl<'de> DeserializeSeed<'de> for WithTypespace<'_, AlgebraicType> { type Output = AlgebraicValue; fn deserialize>(self, deserializer: D) -> Result { @@ -265,7 +251,7 @@ impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, AlgebraicType> { } } -impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, BuiltinType> { +impl<'de> DeserializeSeed<'de> for WithTypespace<'_, BuiltinType> { type Output = BuiltinValue; fn deserialize>(self, deserializer: D) -> Result { @@ -294,7 +280,7 @@ impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, BuiltinType> { } } -impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, SumType> { +impl<'de> DeserializeSeed<'de> for WithTypespace<'_, SumType> { type Output = SumValue; fn deserialize>(self, deserializer: D) -> Result { @@ -302,49 +288,57 @@ impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, SumType> { } } -impl<'de> SumVisitor<'de> for TypeInSpace<'_, SumType> { +impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> { type Output = SumValue; fn sum_name(&self) -> Option<&str> { None } + fn is_option(&self) -> bool { self.ty().looks_like_option().is_some() } fn visit_sum>(self, data: A) -> Result { let (tag, data) = data.variant(self)?; + // Find the variant type by `tag`. let variant_ty = self.map(|ty| &ty.variants[tag as usize].algebraic_type); + let value = Box::new(data.deserialize_seed(variant_ty)?); Ok(SumValue { tag, value }) } } -impl VariantVisitor for TypeInSpace<'_, SumType> { + +impl VariantVisitor for WithTypespace<'_, SumType> { type Output = u8; fn variant_names(&self, names: &mut dyn super::ValidNames) { - names.extend(self.ty().variants.iter().filter_map(|v| v.name.as_deref())) + // Provide the names known from the `SumType`. + names.extend(self.ty().variants.iter().filter_map(|v| v.name())) } fn visit_tag(self, tag: u8) -> Result { + // Verify that tag identifies a valid variant in `SumType`. self.ty() .variants .get(tag as usize) .ok_or_else(|| E::unknown_variant_tag(tag, &self))?; + Ok(tag) } fn visit_name(self, name: &str) -> Result { + // Translate the variant `name` to its tag. self.ty() .variants .iter() - .position(|var| var.name.as_deref() == Some(name)) + .position(|var| var.has_name(name)) .map(|pos| pos as u8) .ok_or_else(|| E::unknown_variant_name(name, &self)) } } -impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, ProductType> { +impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ProductType> { type Output = ProductValue; fn deserialize>(self, deserializer: D) -> Result { @@ -352,7 +346,7 @@ impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, ProductType> { } } -impl<'de> ProductVisitor<'de> for TypeInSpace<'_, ProductType> { +impl<'de> ProductVisitor<'de> for WithTypespace<'_, ProductType> { type Output = ProductValue; fn product_name(&self) -> Option<&str> { @@ -371,77 +365,62 @@ impl<'de> ProductVisitor<'de> for TypeInSpace<'_, ProductType> { } } -impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, ArrayType> { +impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ArrayType> { type Output = ArrayValue; fn deserialize>(self, deserializer: D) -> Result { + /// Deserialize a vector and `map` it to the appropriate `ArrayValue` variant. + fn de_array<'de, D: Deserializer<'de>, T: Deserialize<'de>>( + de: D, + map: impl FnOnce(Vec) -> ArrayValue, + ) -> Result { + de.deserialize_array(BasicVecVisitor).map(map) + } + let mut ty = &*self.ty().elem_ty; + + // Loop, resolving `Ref`s, until we reach a non-`Ref` type. loop { break match ty { + AlgebraicType::Ref(r) => { + // The only arm that will loop. + ty = self.resolve(*r).ty(); + continue; + } AlgebraicType::Sum(ty) => deserializer .deserialize_array_seed(BasicVecVisitor, self.with(ty)) .map(ArrayValue::Sum), AlgebraicType::Product(ty) => deserializer .deserialize_array_seed(BasicVecVisitor, self.with(ty)) .map(ArrayValue::Product), - AlgebraicType::Builtin(BuiltinType::Bool) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::Bool) - } - AlgebraicType::Builtin(BuiltinType::I8) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::I8) - } + AlgebraicType::Builtin(BuiltinType::Bool) => de_array(deserializer, ArrayValue::Bool), + AlgebraicType::Builtin(BuiltinType::I8) => de_array(deserializer, ArrayValue::I8), AlgebraicType::Builtin(BuiltinType::U8) => { deserializer.deserialize_bytes(OwnedSliceVisitor).map(ArrayValue::U8) } - AlgebraicType::Builtin(BuiltinType::I16) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::I16) - } - AlgebraicType::Builtin(BuiltinType::U16) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::U16) - } - AlgebraicType::Builtin(BuiltinType::I32) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::I32) - } - AlgebraicType::Builtin(BuiltinType::U32) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::U32) - } - AlgebraicType::Builtin(BuiltinType::I64) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::I64) - } - AlgebraicType::Builtin(BuiltinType::U64) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::U64) - } - AlgebraicType::Builtin(BuiltinType::I128) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::I128) - } - AlgebraicType::Builtin(BuiltinType::U128) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::U128) - } - AlgebraicType::Builtin(BuiltinType::F32) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::F32) - } - AlgebraicType::Builtin(BuiltinType::F64) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::F64) - } - AlgebraicType::Builtin(BuiltinType::String) => { - deserializer.deserialize_array(BasicVecVisitor).map(ArrayValue::String) - } + AlgebraicType::Builtin(BuiltinType::I16) => de_array(deserializer, ArrayValue::I16), + AlgebraicType::Builtin(BuiltinType::U16) => de_array(deserializer, ArrayValue::U16), + AlgebraicType::Builtin(BuiltinType::I32) => de_array(deserializer, ArrayValue::I32), + AlgebraicType::Builtin(BuiltinType::U32) => de_array(deserializer, ArrayValue::U32), + AlgebraicType::Builtin(BuiltinType::I64) => de_array(deserializer, ArrayValue::I64), + AlgebraicType::Builtin(BuiltinType::U64) => de_array(deserializer, ArrayValue::U64), + AlgebraicType::Builtin(BuiltinType::I128) => de_array(deserializer, ArrayValue::I128), + AlgebraicType::Builtin(BuiltinType::U128) => de_array(deserializer, ArrayValue::U128), + AlgebraicType::Builtin(BuiltinType::F32) => de_array(deserializer, ArrayValue::F32), + AlgebraicType::Builtin(BuiltinType::F64) => de_array(deserializer, ArrayValue::F64), + AlgebraicType::Builtin(BuiltinType::String) => de_array(deserializer, ArrayValue::String), AlgebraicType::Builtin(BuiltinType::Array(ty)) => deserializer .deserialize_array_seed(BasicVecVisitor, self.with(ty)) .map(ArrayValue::Array), AlgebraicType::Builtin(BuiltinType::Map(ty)) => deserializer .deserialize_array_seed(BasicVecVisitor, self.with(ty)) .map(ArrayValue::Map), - AlgebraicType::Ref(r) => { - ty = self.resolve(*r).ty(); - continue; - } }; } } } -impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, MapType> { +impl<'de> DeserializeSeed<'de> for WithTypespace<'_, MapType> { type Output = MapValue; fn deserialize>(self, deserializer: D) -> Result { @@ -480,8 +459,9 @@ impl<'de> DeserializeSeed<'de> for TypeInSpace<'_, MapType> { // } // } +/// Deserialize, provided the fields' types, a product value with unnamed fields. pub fn visit_seq_product<'de, A: SeqProductAccess<'de>>( - elems: TypeInSpace<[ProductTypeElement]>, + elems: WithTypespace<[ProductTypeElement]>, visitor: &impl ProductVisitor<'de>, mut tup: A, ) -> Result { @@ -493,55 +473,76 @@ pub fn visit_seq_product<'de, A: SeqProductAccess<'de>>( Ok(ProductValue { elements }) } +/// Deserialize, provided the fields' types, a product value with named fields. pub fn visit_named_product<'de, A: super::NamedProductAccess<'de>>( - elems_tys: TypeInSpace<[ProductTypeElement]>, + elems_tys: WithTypespace<[ProductTypeElement]>, visitor: &impl ProductVisitor<'de>, mut tup: A, ) -> Result { let elems = elems_tys.ty(); let mut elements = vec![None; elems.len()]; - let mut n = 0; let kind = visitor.product_kind(); - // under a certain threshold, just do linear searches - while n < elems.len() { - let tag = tup.get_field_ident(TupleNameVisitor { elems, kind })?.ok_or_else(|| { + + // Deserialize a product value corresponding to each product type field. + // This is worst case quadratic in complexity + // as fields can be specified out of order (value side) compared to `elems` (type side). + for _ in 0..elems.len() { + // Deserialize a field name, match against the element types, . + let index = tup.get_field_ident(TupleNameVisitor { elems, kind })?.ok_or_else(|| { + // Couldn't deserialize a field name. + // Find the first field name we haven't filled an element for. let missing = elements.iter().position(|field| field.is_none()).unwrap(); - let field_name = elems[missing].name.as_deref(); + let field_name = elems[missing].name(); Error::missing_field(missing, field_name, visitor) })?; - let element = &elems[tag]; - let slot = &mut elements[tag]; + + let element = &elems[index]; + + // By index we can select which element to deserialize a value for. + let slot = &mut elements[index]; if slot.is_some() { - return Err(Error::duplicate_field(tag, element.name.as_deref(), visitor)); + return Err(Error::duplicate_field(index, element.name(), visitor)); } + + // Deserialize the value for this field's type. *slot = Some(tup.get_field_value_seed(elems_tys.with(&element.algebraic_type))?); - n += 1; } + + // Get rid of the `Option<_>` layer. let elements = elements .into_iter() + // We reached here, so we know nothing was missing, i.e., `None`. .map(|x| x.unwrap_or_else(|| unreachable!("visit_named_product"))) .collect(); + Ok(ProductValue { elements }) } +/// A visitor for extracting indices of field names in the elements of a [`ProductType`]. struct TupleNameVisitor<'a> { + /// The elements of a product type, in order. elems: &'a [ProductTypeElement], + /// The kind of product this is. kind: ProductKind, } -impl<'de> FieldNameVisitor<'de> for TupleNameVisitor<'_> { + +impl FieldNameVisitor<'_> for TupleNameVisitor<'_> { + // The index of the field name. type Output = usize; fn field_names(&self, names: &mut dyn super::ValidNames) { - names.extend(self.elems.iter().filter_map(|f| f.name.as_deref())) + names.extend(self.elems.iter().filter_map(|f| f.name())) } + fn kind(&self) -> ProductKind { self.kind } fn visit(self, name: &str) -> Result { + // Finds the index of a field with `name`. self.elems .iter() - .position(|f| f.name.as_deref() == Some(name)) + .position(|f| f.has_name(name)) .ok_or_else(|| Error::unknown_field_name(name, &self)) } } diff --git a/crates/sats/src/de/serde.rs b/crates/sats/src/de/serde.rs index e82bd270f9..c258d29c88 100644 --- a/crates/sats/src/de/serde.rs +++ b/crates/sats/src/de/serde.rs @@ -4,16 +4,21 @@ use std::marker::PhantomData; use super::Deserializer; use ::serde::de as serde; +/// Converts any [`serde::Deserializer`] to a SATS [`Deserializer`] +/// so that Serde's data formats can be reused. pub struct SerdeDeserializer { + /// A deserialization data format in Serde. de: D, } impl SerdeDeserializer { + /// Wraps a Serde deserializer. pub fn new(de: D) -> Self { Self { de } } } +/// An error that occured when deserializing SATS to a Serde data format. #[repr(transparent)] pub struct SerdeError(pub E); #[inline] @@ -32,6 +37,11 @@ impl super::Error for SerdeError { } } +/// Deserialize a `T` provided a serde deserializer `D`. +fn deserialize<'de, D: serde::Deserializer<'de>, T: serde::Deserialize<'de>>(de: D) -> Result> { + serde::Deserialize::deserialize(de).map_err(SerdeError) +} + impl<'de, D: serde::Deserializer<'de>> Deserializer<'de> for SerdeDeserializer { type Error = SerdeError; @@ -52,43 +62,43 @@ impl<'de, D: serde::Deserializer<'de>> Deserializer<'de> for SerdeDeserializer Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_u8(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_u16(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_u32(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_u64(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_u128(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_i8(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_i16(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_i32(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_i64(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_i128(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_f32(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_f64(self) -> Result { - serde::Deserialize::deserialize(self.de).map_err(SerdeError) + deserialize(self.de) } fn deserialize_str>(self, visitor: V) -> Result { @@ -129,10 +139,14 @@ impl<'de, D: serde::Deserializer<'de>> Deserializer<'de> for SerdeDeserializer` in SATS to the one in Serde. #[repr(transparent)] pub struct SeedWrapper(pub T); + impl SeedWrapper { + /// Convert `&T` to `&SeedWrapper`. pub fn from_ref(t: &T) -> &Self { + // SAFETY: `repr(transparent)` allows this. unsafe { &*(t as *const T as *const SeedWrapper) } } } @@ -148,7 +162,9 @@ impl<'de, T: super::DeserializeSeed<'de>> serde::DeserializeSeed<'de> for SeedWr } } +/// Converts a `ProductVisitor` to a `serde::Visitor`. struct TupleVisitor { + /// The `ProductVisitor` to convert. visitor: V, } @@ -163,34 +179,23 @@ impl<'de, V: super::ProductVisitor<'de>> serde::Visitor<'de> for TupleVisitor } } - fn visit_map(self, map: A) -> Result - where - A: serde::MapAccess<'de>, - { + fn visit_map>(self, map: A) -> Result { self.visitor .visit_named_product(NamedTupleAccess { map }) .map_err(unwrap_error) } - fn visit_seq(self, seq: A) -> Result - where - A: serde::SeqAccess<'de>, - { + fn visit_seq>(self, seq: A) -> Result { self.visitor .visit_seq_product(SeqTupleAccess { seq }) .map_err(unwrap_error) } } -struct NullProduct(PhantomData); -impl<'de, E: super::Error> super::SeqProductAccess<'de> for NullProduct { - type Error = E; - fn next_element_seed>(&mut self, _: T) -> Result, Self::Error> { - Ok(None) - } -} - +/// Turns Serde's style of deserializing map entries +/// into deserializing field names and their values. struct NamedTupleAccess { + /// An implementation of `serde::MapAccess<'de>` to convert. map: A, } @@ -209,17 +214,16 @@ impl<'de, A: serde::MapAccess<'de>> super::NamedProductAccess<'de> for NamedTupl } } +/// Converts a SATS field name visitor for use in [`NamedTupleAccess`]. struct FieldNameVisitor { + /// The underlying field name visitor. visitor: V, } impl<'de, V: super::FieldNameVisitor<'de>> serde::DeserializeSeed<'de> for FieldNameVisitor { type Value = V::Output; - fn deserialize(self, deserializer: D) -> Result - where - D: ::serde::Deserializer<'de>, - { + fn deserialize>(self, deserializer: D) -> Result { deserializer.deserialize_str(self) } } @@ -235,15 +239,15 @@ impl<'de, V: super::FieldNameVisitor<'de>> serde::Visitor<'de> for FieldNameVisi } } - fn visit_str(self, v: &str) -> Result - where - E: serde::Error, - { + fn visit_str(self, v: &str) -> Result { self.visitor.visit(v).map_err(unwrap_error) } } +/// Turns `serde::SeqAccess` deserializing the elements of a sequence +/// into `SeqProductAccess`. struct SeqTupleAccess { + /// The `serde::SeqAccess` to convert. seq: A, } @@ -256,7 +260,9 @@ impl<'de, A: serde::SeqAccess<'de>> super::SeqProductAccess<'de> for SeqTupleAcc } } +/// Converts a `SumVisitor` into a `serde::Visitor` for deserializing option. struct OptionVisitor { + /// The visitor to convert. visitor: V, } @@ -267,22 +273,19 @@ impl<'de, V: super::SumVisitor<'de>> serde::Visitor<'de> for OptionVisitor { f.write_str("option") } - fn visit_map(self, map: A) -> Result - where - A: serde::MapAccess<'de>, - { + fn visit_map>(self, map: A) -> Result { self.visitor.visit_sum(SomeAccess(map)).map_err(unwrap_error) } - fn visit_unit(self) -> Result - where - E: serde::Error, - { + fn visit_unit(self) -> Result { self.visitor.visit_sum(NoneAccess(PhantomData)).map_err(unwrap_error) } } +/// Deserializes `some` variant of an optional value. +/// Converts Serde's map deserialization to SATS. struct SomeAccess(A); + impl<'de, A: serde::MapAccess<'de>> super::SumAccess<'de> for SomeAccess { type Error = SerdeError; type Variant = Self; @@ -306,6 +309,8 @@ impl<'de, A: serde::MapAccess<'de>> super::VariantAccess<'de> for SomeAccess Ok(ret) } } + +/// Deserializes nothing, producing `!` effectively. struct NothingVisitor; impl<'de> serde::DeserializeSeed<'de> for NothingVisitor { type Value = std::convert::Infallible; @@ -313,15 +318,16 @@ impl<'de> serde::DeserializeSeed<'de> for NothingVisitor { deserializer.deserialize_identifier(self) } } -impl<'de> serde::Visitor<'de> for NothingVisitor { +impl serde::Visitor<'_> for NothingVisitor { type Value = std::convert::Infallible; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("nothing") } } +/// Deserializes `none` variant of an optional value. struct NoneAccess(PhantomData); -impl<'de, E: super::Error> super::SumAccess<'de> for NoneAccess { +impl super::SumAccess<'_> for NoneAccess { type Error = E; type Variant = Self; @@ -341,7 +347,9 @@ impl<'de, E: super::Error> super::VariantAccess<'de> for NoneAccess { } } +/// Converts a SATS `SumVisitor` to `serde::Visitor`. struct EnumVisitor { + /// The `SumVisitor`. visitor: V, } @@ -352,15 +360,14 @@ impl<'de, V: super::SumVisitor<'de>> serde::Visitor<'de> for EnumVisitor { f.write_str("enum") } - fn visit_enum(self, access: A) -> Result - where - A: serde::EnumAccess<'de>, - { + fn visit_enum>(self, access: A) -> Result { self.visitor.visit_sum(EnumAccess { access }).map_err(unwrap_error) } } +/// Converts Serde's `EnumAccess` to SATS `SumAccess`. struct EnumAccess { + /// The Serde `EnumAccess`. access: A, } @@ -376,46 +383,38 @@ impl<'de, A: serde::EnumAccess<'de>> super::SumAccess<'de> for EnumAccess { } } +/// Converts SATS way of identifying a variant to Serde's way. struct VariantVisitor { + /// The SATS `VariantVisitor` to convert. visitor: V, } + impl<'de, V: super::VariantVisitor> serde::DeserializeSeed<'de> for VariantVisitor { type Value = V::Output; - fn deserialize(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { + fn deserialize>(self, deserializer: D) -> Result { deserializer.deserialize_identifier(self) } } -impl<'de, V: super::VariantVisitor> serde::Visitor<'de> for VariantVisitor { + +impl serde::Visitor<'_> for VariantVisitor { type Value = V::Output; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("variant identifier (string or int)") } - fn visit_u8(self, v: u8) -> Result - where - E: serde::Error, - { + fn visit_u8(self, v: u8) -> Result { self.visitor.visit_tag(v).map_err(unwrap_error) } - fn visit_u64(self, v: u64) -> Result - where - E: serde::Error, - { + fn visit_u64(self, v: u64) -> Result { let v: u8 = v .try_into() .map_err(|_| E::invalid_value(serde::Unexpected::Unsigned(v), &"a u8 tag"))?; self.visit_u8(v) } - fn visit_str(self, v: &str) -> Result - where - E: serde::Error, - { + fn visit_str(self, v: &str) -> Result { if let Ok(tag) = v.parse::() { self.visit_u8(tag) } else { @@ -424,7 +423,9 @@ impl<'de, V: super::VariantVisitor> serde::Visitor<'de> for VariantVisitor { } } +/// Deserializes the data of a variant using Serde's `serde::VariantAccess` translating this to SATS. struct VariantAccess { + // Implements `serde::VariantAccess`. access: A, } @@ -436,7 +437,10 @@ impl<'de, A: serde::VariantAccess<'de>> super::VariantAccess<'de> for VariantAcc } } +/// Translates a `SliceVisitor<'de, str>` to `serde::Visitor<'de>` +/// for implementing `deserialize_str`. struct StrVisitor { + /// The `SliceVisitor<'de, str>`. visitor: V, } @@ -447,29 +451,23 @@ impl<'de, V: super::SliceVisitor<'de, str>> serde::Visitor<'de> for StrVisitor(self, v: &str) -> Result - where - E: serde::Error, - { + fn visit_str(self, v: &str) -> Result { self.visitor.visit(v).map_err(unwrap_error) } - fn visit_borrowed_str(self, v: &'de str) -> Result - where - E: serde::Error, - { + fn visit_borrowed_str(self, v: &'de str) -> Result { self.visitor.visit_borrowed(v).map_err(unwrap_error) } - fn visit_string(self, v: String) -> Result - where - E: serde::Error, - { + fn visit_string(self, v: String) -> Result { self.visitor.visit_owned(v).map_err(unwrap_error) } } +/// Translates a `SliceVisitor<'de, str>` to `serde::Visitor<'de>` +/// for implementing `deserialize_bytes`. struct BytesVisitor { + /// The `SliceVisitor<'de, [u8]>`. visitor: V, } @@ -480,39 +478,24 @@ impl<'de, V: super::SliceVisitor<'de, [u8]>> serde::Visitor<'de> for BytesVisito f.write_str("a byte array") } - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::Error, - { + fn visit_bytes(self, v: &[u8]) -> Result { self.visitor.visit(v).map_err(unwrap_error) } - fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result - where - E: serde::Error, - { + fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result { self.visitor.visit_borrowed(v).map_err(unwrap_error) } - fn visit_byte_buf(self, v: Vec) -> Result - where - E: serde::Error, - { + fn visit_byte_buf(self, v: Vec) -> Result { self.visitor.visit_owned(v).map_err(unwrap_error) } - fn visit_str(self, v: &str) -> Result - where - E: serde::Error, - { + fn visit_str(self, v: &str) -> Result { let data = hex_string(v, &self)?; self.visitor.visit_owned(data).map_err(unwrap_error) } - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::SeqAccess<'de>, - { + fn visit_seq>(self, mut seq: A) -> Result { let mut v = Vec::with_capacity(std::cmp::min(seq.size_hint().unwrap_or(0), 4096)); while let Some(val) = seq.next_element()? { v.push(val); @@ -521,6 +504,7 @@ impl<'de, V: super::SliceVisitor<'de, [u8]>> serde::Visitor<'de> for BytesVisito } } +/// Hex decodes the string `v`. fn hex_string, E: serde::Error>( v: &str, exp: &dyn serde::Expected, @@ -556,8 +540,12 @@ fn hex_string, E: serde::Error>( // } // } +/// Translates `ArrayVisitor<'de, T::Output>` (the trait) to `serde::Visitor<'de>` +/// for implementing `deserialize_array`. struct ArrayVisitor { + /// The SATS visitor to translate to a Serde visitor. visitor: V, + /// The seed value to provide to `DeserializeSeed`. seed: T, } @@ -570,18 +558,19 @@ impl<'de, T: super::DeserializeSeed<'de> + Clone, V: super::ArrayVisitor<'de, T: f.write_str("a vec") } - fn visit_seq(self, seq: A) -> Result - where - A: serde::SeqAccess<'de>, - { + fn visit_seq>(self, seq: A) -> Result { self.visitor .visit(ArrayAccess { seq, seed: self.seed }) .map_err(unwrap_error) } } +/// Translates `serde::SeqAcess<'de>` (the trait) to `ArrayAccess<'de>` +/// for implementing deserialization of array elements. struct ArrayAccess { + /// The `serde::SeqAcess<'de>` implementation. seq: A, + /// The seed to pass onto `DeserializeSeed`. seed: T, } @@ -602,9 +591,16 @@ impl<'de, A: serde::SeqAccess<'de>, T: super::DeserializeSeed<'de> + Clone> supe } } +/// Translates SATS's `MapVisior<'de>` (the trait) to `serde::Visitor<'de>` +/// for implementing deserialization of maps. struct MapVisitor { + /// The SATS visitor to translate to a Serde visitor. visitor: Vi, + /// The seed value to provide to `DeserializeSeed` for deserializing keys. + /// As this is reused for every entry element, it will be `.cloned()`. kseed: K, + /// The seed value to provide to `DeserializeSeed` for deserializing values. + /// As this is reused for every entry element, it will be `.cloned()`. vseed: V, } @@ -621,10 +617,7 @@ impl< f.write_str("a vec") } - fn visit_map(self, map: A) -> Result - where - A: serde::MapAccess<'de>, - { + fn visit_map>(self, map: A) -> Result { self.visitor .visit(MapAccess { map, @@ -636,8 +629,13 @@ impl< } struct MapAccess { + /// An implementation of `serde::MapAccess<'de>`. map: A, + /// The seed value to provide to `DeserializeSeed` for deserializing keys. + /// As this is reused for every entry element, it will be `.cloned()`. kseed: K, + /// The seed value to provide to `DeserializeSeed` for deserializing values. + /// As this is reused for every entry element, it will be `.cloned()`. vseed: V, } @@ -665,21 +663,20 @@ impl fmt::Result> serde::Expected for super::FDisp } } +/// Deserializes `T` as a SATS object from `deserializer: D` +/// where `D` is a serde data format. pub fn deserialize_from<'de, T: super::Deserialize<'de>, D: serde::Deserializer<'de>>( deserializer: D, ) -> Result { T::deserialize(SerdeDeserializer::new(deserializer)).map_err(unwrap_error) } +/// Turns a type deserializable in SATS into one deserializiable in Serde. +/// +/// That is, `T: sats::Deserialize<'de> => DeserializeWrapper: serde::Deserialize`. pub struct DeserializeWrapper(pub T); -impl<'de, T> serde::Deserialize<'de> for DeserializeWrapper -where - T: super::Deserialize<'de>, -{ - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { +impl<'de, T: super::Deserialize<'de>> serde::Deserialize<'de> for DeserializeWrapper { + fn deserialize>(deserializer: D) -> Result { deserialize_from(deserializer).map(Self) } } @@ -687,10 +684,7 @@ where macro_rules! delegate_serde { ($($t:ty),*) => { $(impl<'de> serde::Deserialize<'de> for $t { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { + fn deserialize>(deserializer: D) -> Result { deserialize_from(deserializer) } })* diff --git a/crates/sats/src/lib.rs b/crates/sats/src/lib.rs index 7d79f53c1b..a9025b3063 100644 --- a/crates/sats/src/lib.rs +++ b/crates/sats/src/lib.rs @@ -1,23 +1,23 @@ pub mod algebraic_type; +mod algebraic_type_ref; +pub mod algebraic_value; +pub mod bsatn; +pub mod buffer; pub mod builtin_type; pub mod builtin_value; pub mod convert; +pub mod de; +pub mod meta_type; pub mod product_type; pub mod product_type_element; pub mod product_value; +mod resolve_refs; +pub mod satn; +pub mod ser; pub mod sum_type; pub mod sum_type_variant; pub mod sum_value; pub mod typespace; -// mod algebraic_type_legacy_encoding; -mod algebraic_type_ref; -pub mod algebraic_value; -pub mod bsatn; -pub mod buffer; -pub mod de; -mod resolve_refs; -pub mod satn; -pub mod ser; pub use algebraic_type::AlgebraicType; pub use algebraic_type_ref::AlgebraicTypeRef; @@ -32,16 +32,24 @@ pub use sum_type_variant::SumTypeVariant; pub use sum_value::SumValue; pub use typespace::{SpacetimeType, Typespace}; +/// The `Value` trait provides an abstract notion of a value. +/// +/// All we know about values abstractly is that they have a `Type`. pub trait Value { + /// The type of this value. type Type; } impl Value for Vec { + // TODO(centril/phoebe): This looks weird; shouldn't it be ArrayType? type Type = T::Type; } +/// A borrowed value combined with its type and typing context (`Typespace`). pub struct ValueWithType<'a, T: Value> { - ty: TypeInSpace<'a, T::Type>, + /// The type combined with the context of this `val`ue. + ty: WithTypespace<'a, T::Type>, + /// The borrowed value. val: &'a T, } @@ -53,18 +61,27 @@ impl Clone for ValueWithType<'_, T> { } impl<'a, T: Value> ValueWithType<'a, T> { - pub fn new(ty: TypeInSpace<'a, T::Type>, val: &'a T) -> Self { + /// Wraps the borrowed value `val` with its type combined with context. + pub fn new(ty: WithTypespace<'a, T::Type>, val: &'a T) -> Self { Self { ty, val } } + + /// Returns the borrowed value. pub fn value(&self) -> &'a T { self.val } + + /// Returns the type of the value. pub fn ty(&self) -> &'a T::Type { self.ty.ty } + + /// Returns the typing context (`Typespace`). pub fn typespace(&self) -> &'a Typespace { self.ty.typespace } + + /// Reuses the typespace we already have and returns `val` and `ty` wrapped with it. pub fn with<'b, U: Value>(&self, ty: &'b U::Type, val: &'b U) -> ValueWithType<'b, U> where 'a: 'b, @@ -82,42 +99,50 @@ impl<'a, T: Value> ValueWithType<'a, Vec> { } } +/// Adds a `Typespace` context atop of a borrowed type. #[derive(Debug)] -pub struct TypeInSpace<'a, T: ?Sized> { +pub struct WithTypespace<'a, T: ?Sized> { + /// The typespace context that has been added to `ty`. typespace: &'a Typespace, + /// What we've added the context to. ty: &'a T, } -impl Copy for TypeInSpace<'_, T> {} -impl Clone for TypeInSpace<'_, T> { +impl Copy for WithTypespace<'_, T> {} +impl Clone for WithTypespace<'_, T> { fn clone(&self) -> Self { *self } } -impl<'a, T: ?Sized> TypeInSpace<'a, T> { - pub fn new(typespace: &'a Typespace, ty: &'a T) -> Self { +impl<'a, T: ?Sized> WithTypespace<'a, T> { + /// Wraps `ty` in a context combined with the `typespace`. + pub const fn new(typespace: &'a Typespace, ty: &'a T) -> Self { Self { typespace, ty } } - pub fn ty(&self) -> &'a T { + /// Returns the object that the context was created with. + pub const fn ty(&self) -> &'a T { self.ty } - pub fn typespace(&self) -> &'a Typespace { + /// Returns the typespace context. + pub const fn typespace(&self) -> &'a Typespace { self.typespace } - pub fn with<'b, U>(&self, ty: &'b U) -> TypeInSpace<'b, U> + /// Reuses the typespace we already have and returns `ty: U` wrapped with it. + pub fn with<'b, U>(&self, ty: &'b U) -> WithTypespace<'b, U> where 'a: 'b, { - TypeInSpace { + WithTypespace { typespace: self.typespace, ty, } } + /// Wraps `val` with the type and typespace context in `self`. pub fn with_value<'b, V: Value>(&self, val: &'b V) -> ValueWithType<'b, V> where 'a: 'b, @@ -125,32 +150,24 @@ impl<'a, T: ?Sized> TypeInSpace<'a, T> { ValueWithType::new(*self, val) } - pub fn resolve(&self, r: AlgebraicTypeRef) -> TypeInSpace<'a, AlgebraicType> { - TypeInSpace { + /// Returns the `AlgebraicType` that `r` resolves to in the context of our `Typespace`. + /// + /// Panics if `r` is not known by our `Typespace`. + pub fn resolve(&self, r: AlgebraicTypeRef) -> WithTypespace<'a, AlgebraicType> { + WithTypespace { typespace: self.typespace, ty: &self.typespace[r], } } - pub fn map(&self, f: impl FnOnce(&'a T) -> &'a U) -> TypeInSpace<'a, U> { - TypeInSpace { + /// Maps the object we've wrapped from `&T -> &U` in our context. + /// + /// This can be used to e.g., project fields and through a structure. + /// This provides an implementation of functor mapping for `WithTypespace`. + pub fn map(&self, f: impl FnOnce(&'a T) -> &'a U) -> WithTypespace<'a, U> { + WithTypespace { typespace: self.typespace, ty: f(self.ty), } } } - -struct FDisplay(F); -impl std::fmt::Result> std::fmt::Display for FDisplay { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - (self.0)(f) - } -} -impl std::fmt::Result> std::fmt::Debug for FDisplay { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - (self.0)(f) - } -} -fn fmt_fn std::fmt::Result>(f: F) -> FDisplay { - FDisplay(f) -} diff --git a/crates/sats/src/meta_type.rs b/crates/sats/src/meta_type.rs new file mode 100644 index 0000000000..fe8369302c --- /dev/null +++ b/crates/sats/src/meta_type.rs @@ -0,0 +1,21 @@ +//! Provides the `MetaType` trait. + +use crate::AlgebraicType; + +/// Rust types which represent components of the SATS type system +/// and can themselves be represented as algebraic objects will implement [`MetaType`]. +/// +/// A type's meta-type is an [`AlgebraicType`] +/// which can store the data associated with a definition of that type. +/// +/// For example, the `MetaType` of [`ProductType`](crate::ProductType) is +/// ```ignore +/// AlgebraicType::product(vec![ProductTypeElement::new_named( +/// AlgebraicType::array(ProductTypeElement::meta_type()), +/// "elements", +/// )]) +/// ``` +pub trait MetaType { + /// Returns the type structure of this type as an `AlgebraicType`. + fn meta_type() -> AlgebraicType; +} diff --git a/crates/sats/src/product_type.rs b/crates/sats/src/product_type.rs index 918b30d141..68826a9fed 100644 --- a/crates/sats/src/product_type.rs +++ b/crates/sats/src/product_type.rs @@ -1,32 +1,47 @@ -pub mod satn; - use crate::algebraic_value::de::{ValueDeserializeError, ValueDeserializer}; use crate::algebraic_value::ser::ValueSerializer; +use crate::meta_type::MetaType; use crate::{de::Deserialize, ser::Serialize}; -use crate::{AlgebraicType, AlgebraicTypeRef, AlgebraicValue, ArrayType, BuiltinType, ProductTypeElement}; +use crate::{AlgebraicType, AlgebraicValue, ProductTypeElement}; +/// A structural product type of the factors given by `elements`. +/// +/// This is also known as `struct` and `tuple` in many languages, +/// but note that unlike most languages, products in SATs are *[structural]* and not nominal. +/// When checking whether two nominal types are the same, +/// their names and/or declaration sites (e.g., module / namespace) are considered. +/// Meanwhile, a structural type system would only check the structure of the type itself, +/// e.g., the names of its fields and their types in the case of a record. +/// The name "product" comes from category theory. +/// +/// See also: https://ncatlab.org/nlab/show/product+type. +/// +/// These structures are known as product types because the number of possible values in product +/// ```ignore +/// { N_0: T_0, N_1: T_1, ..., N_n: T_n } +/// ``` +/// is: +/// ```ignore +/// Π (i ∈ 0..n). values(T_i) +/// ``` +/// so for example, `values({ A: U64, B: Bool }) = values(U64) * values(Bool)`. +/// +/// [structural]: https://en.wikipedia.org/wiki/Structural_type_system #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub struct ProductType { + /// The factors of the product type. + /// + /// These factors can either be named or unnamed. + /// When all the factors are unnamed, we can regard this as a plain tuple type. pub elements: Vec, } impl ProductType { - pub fn new(elements: Vec) -> Self { + /// Returns a product type with the given `elements` as its factors. + pub const fn new(elements: Vec) -> Self { Self { elements } } - - pub fn with_capacity(capacity: usize) -> Self { - Self { - elements: Vec::with_capacity(capacity), - } - } -} - -impl From for ProductTypeElement { - fn from(value: AlgebraicType) -> Self { - ProductTypeElement::new(value, None) - } } impl> FromIterator for ProductType { @@ -50,29 +65,16 @@ impl<'a, I: Into> FromIterator<(Option<&'a str>, I)> for ProductT } } -impl ProductType { - pub fn make_meta_type() -> AlgebraicType { - let string = AlgebraicType::Builtin(BuiltinType::String); - let option = AlgebraicType::make_option_type(string); - let element_type = AlgebraicType::Product(ProductType::new(vec![ - ProductTypeElement { - algebraic_type: option, - name: Some("name".into()), - }, - ProductTypeElement { - algebraic_type: AlgebraicType::Ref(AlgebraicTypeRef(0)), - name: Some("algebraic_type".into()), - }, - ])); - let array = AlgebraicType::Builtin(BuiltinType::Array(ArrayType { - elem_ty: Box::new(element_type), - })); - AlgebraicType::Product(ProductType::new(vec![ProductTypeElement { - algebraic_type: array, - name: Some("elements".into()), - }])) +impl MetaType for ProductType { + fn meta_type() -> AlgebraicType { + AlgebraicType::product(vec![ProductTypeElement::new_named( + AlgebraicType::array(ProductTypeElement::meta_type()), + "elements", + )]) } +} +impl ProductType { pub fn as_value(&self) -> AlgebraicValue { self.serialize(ValueSerializer).unwrap_or_else(|x| match x {}) } diff --git a/crates/sats/src/product_type/satn.rs b/crates/sats/src/product_type/satn.rs deleted file mode 100644 index d91603f7b2..0000000000 --- a/crates/sats/src/product_type/satn.rs +++ /dev/null @@ -1,32 +0,0 @@ -use super::ProductType; -use crate::algebraic_type; -use std::fmt::Display; - -pub struct Formatter<'a> { - ty: &'a ProductType, -} - -impl<'a> Formatter<'a> { - pub fn new(ty: &'a ProductType) -> Self { - Self { ty } - } -} - -impl<'a> Display for Formatter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "(")?; - for (i, e) in self.ty.elements.iter().enumerate() { - if let Some(name) = &e.name { - write!(f, "{}", name)?; - } else { - write!(f, "{}", i)?; - } - write!(f, ": ")?; - write!(f, "{}", algebraic_type::satn::Formatter::new(&e.algebraic_type))?; - if i < self.ty.elements.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, ")") - } -} diff --git a/crates/sats/src/product_type_element.rs b/crates/sats/src/product_type_element.rs index c09f18e7f9..3e2aa3c882 100644 --- a/crates/sats/src/product_type_element.rs +++ b/crates/sats/src/product_type_element.rs @@ -1,24 +1,61 @@ -use crate::AlgebraicType; +use crate::meta_type::MetaType; use crate::{de::Deserialize, ser::Serialize}; +use crate::{AlgebraicType, AlgebraicTypeRef}; +/// A factor / element of a product type. +/// +/// An element consist of an optional name and a type. +/// /// NOTE: Each element has an implicit element tag based on its order. /// Uniquely identifies an element similarly to protobuf tags. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub struct ProductTypeElement { + /// The name of the field / element. + /// + /// As our type system is structural, + /// a type like `{ foo: U8 }`, where `foo: U8` is the `ProductTypeElement`, + /// is inequal to `{ bar: U8 }`, although their `algebraic_type`s (`U8`) match. pub name: Option, + /// The type of the element. + /// + /// Only values of this type can be stored in the element. pub algebraic_type: AlgebraicType, } impl ProductTypeElement { - pub fn new(algebraic_type: AlgebraicType, name: Option) -> Self { + /// Returns an element with the given `name` and `algebraic_type`. + pub const fn new(algebraic_type: AlgebraicType, name: Option) -> Self { Self { algebraic_type, name } } + /// Returns a named element with `name` and `algebraic_type`. pub fn new_named(algebraic_type: AlgebraicType, name: impl Into) -> Self { - Self { - algebraic_type, - name: Some(name.into()), - } + Self::new(algebraic_type, Some(name.into())) + } + + /// Returns the name of the field. + pub fn name(&self) -> Option<&str> { + self.name.as_deref() + } + + /// Returns whether the field has the given name. + pub fn has_name(&self, name: &str) -> bool { + self.name() == Some(name) + } +} + +impl MetaType for ProductTypeElement { + fn meta_type() -> AlgebraicType { + AlgebraicType::product(vec![ + Self::new_named(AlgebraicType::option(AlgebraicType::String), "name"), + Self::new_named(AlgebraicType::Ref(AlgebraicTypeRef(0)), "algebraic_type"), + ]) + } +} + +impl From for ProductTypeElement { + fn from(value: AlgebraicType) -> Self { + ProductTypeElement::new(value, None) } } diff --git a/crates/sats/src/product_value.rs b/crates/sats/src/product_value.rs index 34c77dc196..89d415ec30 100644 --- a/crates/sats/src/product_value.rs +++ b/crates/sats/src/product_value.rs @@ -1,11 +1,19 @@ use crate::algebraic_value::AlgebraicValue; use crate::product_type::ProductType; +/// A product value is made of a a list of +/// "elements" / "fields" / "factors" of other `AlgebraicValue`s. +/// +/// The type of a product value is a [product type](`ProductType`). #[derive(Debug, Clone, Ord, PartialOrd, PartialEq, Eq, Hash)] pub struct ProductValue { + /// The values that make up this product value. pub elements: Vec, } +/// Constructs a product value from a list of fields with syntax `product![v1, v2, ...]`. +/// +/// Repeat notation from `vec![x; n]` is not supported. #[macro_export] macro_rules! product { [$($elems:expr),*$(,)?] => { @@ -16,6 +24,7 @@ macro_rules! product { } impl ProductValue { + /// Returns a product value constructed from the given values in `elements`. pub fn new(elements: &[AlgebraicValue]) -> Self { Self { elements: elements.into(), @@ -34,57 +43,71 @@ impl crate::Value for ProductValue { type Type = ProductType; } -#[derive(thiserror::Error, Debug, Clone)] -#[error("Field {0}({1:?}) not found or has an invalid type")] -pub struct InvalidFieldError(pub usize, pub Option<&'static str>); +/// An error that occurs when a field, of a product value, is accessed that doesn't exist. +#[derive(thiserror::Error, Debug, Copy, Clone)] +#[error("Field {index}({name:?}) not found or has an invalid type")] +pub struct InvalidFieldError { + /// The claimed index of the field within the product value. + pub index: usize, + /// The name of the field, if any. + pub name: Option<&'static str>, +} impl ProductValue { - pub fn get_field(&self, index: usize, named: Option<&'static str>) -> Result<&AlgebraicValue, InvalidFieldError> { - self.elements.get(index).ok_or(InvalidFieldError(index, named)) + /// Borrow the value at field of `self` indentified by `index`. + /// + /// The `name` is non-functional and is only used for error-messages. + pub fn get_field(&self, index: usize, name: Option<&'static str>) -> Result<&AlgebraicValue, InvalidFieldError> { + self.elements.get(index).ok_or(InvalidFieldError { index, name }) } - pub fn extract_field<'a, T, F>( + /// Extracts the `value` at field of `self` identified by `index` + /// and then runs it through the function `f` which possibly returns a `T` derived from `value`. + pub fn extract_field<'a, T>( &'a self, index: usize, - named: Option<&'static str>, - f: F, - ) -> Result - where - F: Fn(&'a AlgebraicValue) -> Option + 'a, - { - let v = self.elements.get(index).ok_or(InvalidFieldError(index, named))?; - let r = f(v).ok_or(InvalidFieldError(index, named))?; - Ok(r) + name: Option<&'static str>, + f: impl 'a + Fn(&'a AlgebraicValue) -> Option, + ) -> Result { + f(self.get_field(index, name)?).ok_or(InvalidFieldError { index, name }) } + /// Interprets the value at field of `self` indentified by `index` as a `bool`. pub fn field_as_bool(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_bool().copied()) } + /// Interprets the value at field of `self` indentified by `index` as a `u8`. pub fn field_as_u8(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_u8().copied()) } + /// Interprets the value at field of `self` indentified by `index` as a `u32`. pub fn field_as_u32(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_u32().copied()) } + /// Interprets the value at field of `self` indentified by `index` as a `i64`. pub fn field_as_i64(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_i64().copied()) } + /// Interprets the value at field of `self` indentified by `index` as a `i128`. pub fn field_as_i128(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_i128().copied()) } + /// Interprets the value at field of `self` indentified by `index` as a `u128`. pub fn field_as_u128(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_u128().copied()) } + /// Interprets the value at field of `self` indentified by `index` as a string slice. pub fn field_as_str(&self, index: usize, named: Option<&'static str>) -> Result<&str, InvalidFieldError> { self.extract_field(index, named, |f| f.as_string().map(|x| x.as_str())) } + /// Interprets the value at field of `self` indentified by `index` as a byte slice. pub fn field_as_bytes(&self, index: usize, named: Option<&'static str>) -> Result<&[u8], InvalidFieldError> { self.extract_field(index, named, |f| f.as_bytes().map(|x| x.as_slice())) } diff --git a/crates/sats/src/resolve_refs.rs b/crates/sats/src/resolve_refs.rs index 500324ea21..af941f5d8b 100644 --- a/crates/sats/src/resolve_refs.rs +++ b/crates/sats/src/resolve_refs.rs @@ -1,21 +1,69 @@ use crate::{ AlgebraicType, AlgebraicTypeRef, ArrayType, BuiltinType, MapType, ProductType, ProductTypeElement, SumType, - SumTypeVariant, TypeInSpace, + SumTypeVariant, WithTypespace, }; +/// Resolver for [`AlgebraicTypeRef`]s within a structure. #[derive(Default)] pub struct ResolveRefState { + /// The stack used to handle cycle detection for [recursive types] (`μα. T`). + /// + /// [recursive types]: https://en.wikipedia.org/wiki/Recursive_data_type#Theory stack: Vec, } +/// A trait for types that know how to resolve their [`AlgebraicTypeRef`]s +/// provided a typing context and the resolver `state`. pub trait ResolveRefs { + /// Output type after type references have been resolved. type Output; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option; + + /// Returns, if possible, an output with all [`AlgebraicTypeRef`]s + /// within `this` (typing context carried) resolved + /// using the provided resolver `state`. + /// + /// `None` is only returned if there were cycles in the precense of recursive μ-types. + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option; +} + +// ----------------------------------------------------------------------------- +// The interesting logic: +// ----------------------------------------------------------------------------- + +impl ResolveRefs for AlgebraicTypeRef { + type Output = AlgebraicType; + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { + // Suppose we have `&0 = { Nil, Cons({ elem: U8, tail: &0 }) }`. + // This is our standard cons-list type. + // In this setup, when getting to `tail`, + // we would recurse back to expanding `tail` again, and so or... + // So we will never halt. This check breaks that cycle. + if state.stack.contains(this.ty()) { + return None; + } + + // Push ourselves to the stack. + state.stack.push(*this.ty()); + + // Extract the `at: AlgebraicType` pointed to by `this` and then resolve `at`. + let ret = this + .typespace() + .get(*this.ty()) + .and_then(|at| this.with(at)._resolve_refs(state)); + + // Remove ourselves. + state.stack.pop(); + ret + } } +// ----------------------------------------------------------------------------- +// All the below is just plumbing: +// ----------------------------------------------------------------------------- + impl ResolveRefs for AlgebraicType { type Output = Self; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { match this.ty() { AlgebraicType::Sum(sum) => this.with(sum)._resolve_refs(state).map(Self::Sum), AlgebraicType::Product(prod) => this.with(prod)._resolve_refs(state).map(Self::Product), @@ -24,36 +72,41 @@ impl ResolveRefs for AlgebraicType { } } } + impl ResolveRefs for BuiltinType { type Output = Self; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { match this.ty() { BuiltinType::Array(ty) => this.with(ty)._resolve_refs(state).map(Self::Array), BuiltinType::Map(m) => this.with(m)._resolve_refs(state).map(Self::Map), + // These types are plain and cannot have refs in them. x => Some(x.clone()), } } } + impl ResolveRefs for ArrayType { type Output = ArrayType; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { Some(ArrayType { elem_ty: Box::new(this.map(|m| &*m.elem_ty)._resolve_refs(state)?), }) } } + impl ResolveRefs for MapType { type Output = MapType; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { Some(MapType { key_ty: Box::new(this.map(|m| &*m.key_ty)._resolve_refs(state)?), ty: Box::new(this.map(|m| &*m.ty)._resolve_refs(state)?), }) } } + impl ResolveRefs for ProductType { type Output = Self; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { let elements = this .ty() .elements @@ -63,18 +116,20 @@ impl ResolveRefs for ProductType { Some(ProductType { elements }) } } + impl ResolveRefs for ProductTypeElement { type Output = Self; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { Some(ProductTypeElement { algebraic_type: this.map(|e| &e.algebraic_type)._resolve_refs(state)?, name: this.ty().name.clone(), }) } } + impl ResolveRefs for SumType { type Output = Self; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { let variants = this .ty() .variants @@ -84,32 +139,18 @@ impl ResolveRefs for SumType { Some(SumType { variants }) } } + impl ResolveRefs for SumTypeVariant { type Output = Self; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { + fn resolve_refs(this: WithTypespace<'_, Self>, state: &mut ResolveRefState) -> Option { Some(SumTypeVariant { algebraic_type: this.map(|v| &v.algebraic_type)._resolve_refs(state)?, name: this.ty().name.clone(), }) } } -impl ResolveRefs for AlgebraicTypeRef { - type Output = AlgebraicType; - fn resolve_refs(this: TypeInSpace<'_, Self>, state: &mut ResolveRefState) -> Option { - if state.stack.contains(this.ty()) { - return None; - } - state.stack.push(*this.ty()); - let ret = this - .typespace() - .get(*this.ty()) - .and_then(|ty| this.with(ty)._resolve_refs(state)); - state.stack.pop(); - ret - } -} -impl TypeInSpace<'_, T> { +impl WithTypespace<'_, T> { pub fn resolve_refs(self) -> Option { T::resolve_refs(self, &mut ResolveRefState::default()) } diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index e502e5a6ff..f8cf256462 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -2,18 +2,26 @@ use std::fmt::{self, Write as _}; use crate::ser; +/// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods. pub trait Satn: ser::Serialize { + /// Formats the value using the SATN data format into the formatter `f`. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { Writer::with(f, |f| self.serialize(SatnFormatter { f }))?; Ok(()) } + + /// Formats the value using the postgres SATN data format into the formatter `f`. fn fmt_psql(&self, f: &mut fmt::Formatter) -> fmt::Result { Writer::with(f, |f| self.serialize(PsqlFormatter(SatnFormatter { f })))?; Ok(()) } + + /// Formats the value using the SATN data format into the returned `String`. fn to_satn(&self) -> String { Wrapper::from_ref(self).to_string() } + + /// Pretty prints the value using the SATN data format into the returned `String`. fn to_satn_pretty(&self) -> String { format!("{:#}", Wrapper::from_ref(self)) } @@ -21,12 +29,18 @@ pub trait Satn: ser::Serialize { impl Satn for T {} +/// A wrapper around a `T: Satn` +/// providing `Display` and `Debug` implementations +/// that uses the SATN formatting for `T`. #[repr(transparent)] pub struct Wrapper(pub T); impl Wrapper { + /// Converts `&T` to `&Wrapper`. pub fn from_ref(t: &T) -> &Self { - unsafe { &*(t as *const T as *const Wrapper) } + // SAFETY: `repr(transparent)` turns the ABI of `T` + // into the same as `Self` so we can also cast `&T` to `&Self`. + unsafe { &*(t as *const T as *const Self) } } } @@ -35,18 +49,25 @@ impl fmt::Display for Wrapper { self.0.fmt(f) } } + impl fmt::Debug for Wrapper { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } +/// A wrapper around a `T: Satn` +/// providing `Display` and `Debug` implementations +/// that uses postgres SATN formatting for `T`. #[repr(transparent)] pub struct PsqlWrapper(pub T); impl PsqlWrapper { + /// Converts `&T` to `&PsqlWrapper`. pub fn from_ref(t: &T) -> &Self { - unsafe { &*(t as *const T as *const PsqlWrapper) } + // SAFETY: `repr(transparent)` turns the ABI of `T` + // into the same as `Self` so we can also cast `&T` to `&Self`. + unsafe { &*(t as *const T as *const Self) } } } @@ -55,22 +76,32 @@ impl fmt::Display for PsqlWrapper { self.0.fmt_psql(f) } } + impl fmt::Debug for PsqlWrapper { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt_psql(f) } } +/// Wraps a writer for formatting lists separated by `SEP` into it. struct EntryWrapper<'a, 'b, const SEP: char> { + /// The writer we're formatting into. fmt: Writer<'a, 'b>, + /// Whether there were any fields. + /// Initially `false` and then `true` after calling [`.entry(..)`](EntryWrapper::entry). has_fields: bool, } impl<'a, 'b, const SEP: char> EntryWrapper<'a, 'b, SEP> { + /// Constructs the entry wrapper using the writer `fmt`. fn new(fmt: Writer<'a, 'b>) -> Self { Self { fmt, has_fields: false } } - fn entry fmt::Result>(&mut self, entry: F) -> fmt::Result { + + /// Formats another entry in the larger structure. + /// + /// The formatting for the element / entry itself is provided by the function `entry`. + fn entry(&mut self, entry: impl FnOnce(Writer) -> fmt::Result) -> fmt::Result { let res = (|| match &mut self.fmt { Writer::Pretty(f) => { if !self.has_fields { @@ -96,14 +127,19 @@ impl<'a, 'b, const SEP: char> EntryWrapper<'a, 'b, SEP> { } } +/// An implementation of [`fmt::Write`] supporting indented and non-idented formatting. enum Writer<'a, 'b> { + /// Uses the standard library's formatter i.e. plain formatting. Normal(&'a mut fmt::Formatter<'b>), + /// Uses indented formatting. Pretty(IndentedWriter<'a, 'b>), } impl<'a, 'b> Writer<'a, 'b> { + /// Provided with a formatter `f`, runs `func` provided with a `Writer`. fn with(f: &mut fmt::Formatter<'_>, func: impl FnOnce(Writer<'_, '_>) -> R) -> R { let mut state; + // We use `alternate`, i.e., the `#` flag to let the user trigger pretty printing. let f = if f.alternate() { state = IndentState { indent: 0, @@ -115,6 +151,8 @@ impl<'a, 'b> Writer<'a, 'b> { }; func(f) } + + /// Returns a sub-writer without moving `self`. fn as_mut(&mut self) -> Writer<'_, 'b> { match self { Writer::Normal(f) => Writer::Normal(f), @@ -123,17 +161,22 @@ impl<'a, 'b> Writer<'a, 'b> { } } +/// A formatter that adds decoration atop of the standard library's formatter. struct IndentedWriter<'a, 'b> { f: &'a mut fmt::Formatter<'b>, state: &'a mut IndentState, } +/// The indentation state. struct IndentState { + /// Number of tab indentations to make. indent: u32, + /// Whether we were last on a newline. on_newline: bool, } impl<'a, 'b> IndentedWriter<'a, 'b> { + /// Returns a sub-writer without moving `self`. fn as_mut(&mut self) -> IndentedWriter<'_, 'b> { IndentedWriter { f: self.f, @@ -146,6 +189,7 @@ impl<'a, 'b> fmt::Write for IndentedWriter<'a, 'b> { fn write_str(&mut self, s: &str) -> fmt::Result { for s in s.split_inclusive('\n') { if self.state.on_newline { + // Indent 4 characters times the indentation level. for _ in 0..self.state.indent { self.f.write_str(" ")?; } @@ -167,16 +211,21 @@ impl<'a, 'b> fmt::Write for Writer<'a, 'b> { } } +/// Provides the SATN data format implementing [`Serializer`](ser::Serializer). struct SatnFormatter<'a, 'b> { + /// The sink / writer / output / formatter. f: Writer<'a, 'b>, } +/// An error occured during serialization to the SATS data format. struct SatnError(fmt::Error); + impl From for fmt::Error { fn from(err: SatnError) -> Self { err.0 } } + impl From for SatnError { fn from(err: fmt::Error) -> Self { SatnError(err) @@ -190,6 +239,7 @@ impl ser::Error for SatnError { } impl<'a, 'b> SatnFormatter<'a, 'b> { + /// Writes `args` formatted to `self`. #[inline(always)] fn write_fmt(&mut self, args: fmt::Arguments) -> Result<(), SatnError> { self.f.write_fmt(args)?; @@ -254,14 +304,14 @@ impl<'a, 'b> ser::Serializer for SatnFormatter<'a, 'b> { } fn serialize_array(mut self, _len: usize) -> Result { - write!(self, "[")?; + write!(self, "[")?; // Closed via `.end()`. Ok(ArrayFormatter { f: EntryWrapper::new(self.f), }) } fn serialize_map(mut self, len: usize) -> Result { - write!(self, "[")?; + write!(self, "[")?; // Closed via `.end()`. if len == 0 { write!(self, ":")?; } @@ -271,14 +321,15 @@ impl<'a, 'b> ser::Serializer for SatnFormatter<'a, 'b> { } fn serialize_seq_product(self, len: usize) -> Result { + // Delegate to named products handling of element formatting. self.serialize_named_product(len).map(|inner| SeqFormatter { inner }) } fn serialize_named_product(mut self, _len: usize) -> Result { - write!(self, "(")?; + write!(self, "(")?; // Closed via `.end()`. Ok(NamedFormatter { f: EntryWrapper::new(self.f), - i: 0, + idx: 0, }) } @@ -301,7 +352,9 @@ impl<'a, 'b> ser::Serializer for SatnFormatter<'a, 'b> { } } +/// Defines the SATN formatting for arrays. struct ArrayFormatter<'a, 'b> { + /// The formatter for each element separating elements by a `,`. f: EntryWrapper<'a, 'b, ','>, } @@ -320,7 +373,9 @@ impl<'a, 'b> ser::SerializeArray for ArrayFormatter<'a, 'b> { } } +/// Provides the data format for maps for SATN. struct MapFormatter<'a, 'b> { + /// The formatter for each element separating elements by a `,`. f: EntryWrapper<'a, 'b, ','>, } @@ -348,7 +403,9 @@ impl<'a, 'b> ser::SerializeMap for MapFormatter<'a, 'b> { } } +/// Provides the data format for unnamed products for SATN. struct SeqFormatter<'a, 'b> { + /// Delegates to the named format. inner: NamedFormatter<'a, 'b>, } @@ -365,9 +422,12 @@ impl<'a, 'b> ser::SerializeSeqProduct for SeqFormatter<'a, 'b> { } } +/// Provides the data format for named products for SATN. struct NamedFormatter<'a, 'b> { + /// The formatter for each element separating elements by a `,`. f: EntryWrapper<'a, 'b, ','>, - i: usize, + /// The index of the element. + idx: usize, } impl<'a, 'b> ser::SerializeNamedProduct for NamedFormatter<'a, 'b> { @@ -380,16 +440,17 @@ impl<'a, 'b> ser::SerializeNamedProduct for NamedFormatter<'a, 'b> { elem: &T, ) -> Result<(), Self::Error> { let res = self.f.entry(|mut f| { + // Format the name or use the index if unnamed. if let Some(name) = name { write!(f, "{}", name)?; } else { - write!(f, "{}", self.i)?; + write!(f, "{}", self.idx)?; } write!(f, " = ")?; elem.serialize(SatnFormatter { f })?; Ok(()) }); - self.i += 1; + self.idx += 1; res?; Ok(()) } @@ -400,6 +461,8 @@ impl<'a, 'b> ser::SerializeNamedProduct for NamedFormatter<'a, 'b> { } } +/// An implementation of [`Serializer`](ser::Serializer) +/// that borrows from [`SatnFormatter`] except in `serialize_str`. struct PsqlFormatter<'a, 'b>(SatnFormatter<'a, 'b>); impl<'a, 'b> ser::Serializer for PsqlFormatter<'a, 'b> { diff --git a/crates/sats/src/ser.rs b/crates/sats/src/ser.rs index f394a24cfd..13fe55772a 100644 --- a/crates/sats/src/ser.rs +++ b/crates/sats/src/ser.rs @@ -1,38 +1,123 @@ +// Some parts copyright Serde developers under the MIT / Apache-2.0 licenses at your option. +// See `serde` version `v1.0.169` for the parts where MIT / Apache-2.0 applies. + mod impls; #[cfg(feature = "serde")] pub mod serde; use std::fmt; +/// A **data format** that can deserialize any data structure supported by SATs. +/// +/// The `Serializer` trait in SATS performs the same function as [`serde::Serializer`] in [`serde`]. +/// See the documentation of [`serde::Serializer`] for more information of the data model. +/// +/// [`serde::Serializer`]: ::serde::Serializer +/// [`serde`]: https://crates.io/crates/serde pub trait Serializer: Sized { + /// The output type produced by this `Serializer` during successful serialization. + /// + /// Most serializers that produce text or binary output should set `Ok = ()` + /// and serialize into an [`io::Write`] or buffer contained within the `Serializer` instance. + /// Serializers that build in-memory data structures may be simplified by using `Ok` to propagate + /// the data structure around. + /// + /// [`io::Write`]: https://doc.rust-lang.org/std/io/trait.Write.html type Ok; + + /// The error type when some error occurs during serialization. type Error: Error; + + /// Type returned from [`serialize_array`](Serializer::serialize_array) + /// for serializing the contents of the array. type SerializeArray: SerializeArray; + + /// Type returned from [`serialize_map`](Serializer::serialize_map) + /// for serializing the contents of the map. type SerializeMap: SerializeMap; + + /// Type returned from [`serialize_seq_product`](Serializer::serialize_seq_product) + /// for serializing the contents of the *unnamed* product. type SerializeSeqProduct: SerializeSeqProduct; + + /// Type returned from [`serialize_named_product`](Serializer::serialize_named_product) + /// for serializing the contents of the *named* product. type SerializeNamedProduct: SerializeNamedProduct; + /// Serialize a `bool` value. fn serialize_bool(self, v: bool) -> Result; + + /// Serialize a `u8` value. fn serialize_u8(self, v: u8) -> Result; + + /// Serialize a `u16` value. fn serialize_u16(self, v: u16) -> Result; + + /// Serialize a `u32` value. fn serialize_u32(self, v: u32) -> Result; + + /// Serialize a `u64` value. fn serialize_u64(self, v: u64) -> Result; + + /// Serialize a `u128` value. fn serialize_u128(self, v: u128) -> Result; + + /// Serialize an `i8` value. fn serialize_i8(self, v: i8) -> Result; + + /// Serialize an `i16` value. fn serialize_i16(self, v: i16) -> Result; + + /// Serialize an `i32` value. fn serialize_i32(self, v: i32) -> Result; + + /// Serialize an `i64` value. fn serialize_i64(self, v: i64) -> Result; + + /// Serialize an `i128` value. fn serialize_i128(self, v: i128) -> Result; + + /// Serialize an `f32` value. fn serialize_f32(self, v: f32) -> Result; + + /// Serialize an `f64` value. fn serialize_f64(self, v: f64) -> Result; + + /// Serialize a `&str` string slice. fn serialize_str(self, v: &str) -> Result; + + /// Serialize a `&[u8]` byte slice. fn serialize_bytes(self, v: &[u8]) -> Result; + /// Begin to serialize a variably sized array. + /// This call must be followed by zero or more calls to [`SerializeArray::serialize_element`], + /// then a call to [`SerializeArray::end`]. + /// + /// The argument is the number of elements in the sequence. fn serialize_array(self, len: usize) -> Result; + + /// Begin to serialize a variably sized map. + /// This call must be followed by zero or more calls to [`SerializeMap::serialize_element`], + /// then a call to [`SerializeMap::end`]. + /// + /// The argument is the number of elements in the map. fn serialize_map(self, len: usize) -> Result; + /// Begin to serialize a product with unnamed fields. + /// This call must be followed by zero or more calls to [`SerializeSeqProduct::serialize_element`], + /// then a call to [`SerializeSeqProduct::end`]. + /// + /// The argument is the number of fields in the product. fn serialize_seq_product(self, len: usize) -> Result; + + /// Begin to serialize a product with named fields. + /// This call must be followed by zero or more calls to [`SerializeNamedProduct::serialize_element`], + /// then a call to [`SerializeNamedProduct::end`]. + /// + /// The argument is the number of fields in the product. fn serialize_named_product(self, len: usize) -> Result; + + /// Serialize a sum value provided the chosen `tag`, `name`, and `value`. fn serialize_variant( self, tag: u8, @@ -42,10 +127,22 @@ pub trait Serializer: Sized { } pub use spacetimedb_bindings_macro::Serialize; + +/// A **data structure** that can be serialized into any data format supported by SATS. +/// +/// In most cases, implementations of `Serialize` may be `#[derive(Serialize)]`d. +/// +/// The `Serialize` trait in SATS performs the same function as [`serde::Serialize`] in [`serde`]. +/// See the documentation of [`serde::Serialize`] for more information of the data model. +/// +/// [`serde::Serialize`]: ::serde::Serialize +/// [`serde`]: https://crates.io/crates/serde pub trait Serialize { + /// Serialize `self` in the data format of `S` using the provided `serializer`. fn serialize(&self, serializer: S) -> Result; - /// used in the Serialize for Vec impl to allow specializing serializing Vec as bytes + /// Used in the `Serialize for Vec` implementation + /// to allow a specialized serialization of `Vec` as bytes. #[doc(hidden)] #[inline(always)] fn __serialize_array(this: &[Self], serializer: S) -> Result @@ -60,7 +157,9 @@ pub trait Serialize { } } +/// The base trait serialization error types must implement. pub trait Error { + /// Returns an error derived from `msg: impl Display`. fn custom(msg: T) -> Self; } @@ -69,55 +168,109 @@ impl Error for String { msg.to_string() } } + impl Error for std::convert::Infallible { fn custom(msg: T) -> Self { panic!("error generated for Infallible serializer: {msg}") } } +/// Returned from [`Serializer::serialize_array`]. +/// +/// This provides a continuation of sorts +/// where you can call [`serialize_element`](SerializeArray::serialize_element) however many times +/// and then finally the [`end`](SerializeArray::end) is reached. pub trait SerializeArray { + /// Must match the `Ok` type of any `Serializer` that uses this type. type Ok; + + /// Must match the `Error` type of any `Serializer` that uses this type. type Error: Error; - fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error>; + /// Serialize an array `element`. + fn serialize_element(&mut self, element: &T) -> Result<(), Self::Error>; + + /// Consumes and finalizes the array serializer returning the `Self::Ok` data. fn end(self) -> Result; } +/// Returned from [`Serializer::serialize_map`]. +/// +/// This provides a continuation of sorts +/// where you can call [`serialize_entry`](SerializeMap::serialize_entry) however many times +/// and then finally the [`end`](SerializeMap::end) is reached. pub trait SerializeMap { + /// Must match the `Ok` type of any `Serializer` that uses this type. type Ok; + + /// Must match the `Error` type of any `Serializer` that uses this type. type Error: Error; + /// Serialize a map entry given by its `key` and `value`. fn serialize_entry( &mut self, key: &K, value: &V, ) -> Result<(), Self::Error>; + + /// Consumes and finalizes the map serializer returning the `Self::Ok` data. fn end(self) -> Result; } +/// Returned from [`Serializer::serialize_seq_product`]. +/// +/// This provides a continuation of sorts +/// where you can call [`serialize_element`](SerializeSeqProduct::serialize_element) however many times +/// and then finally the [`end`](SerializeSeqProduct::end) is reached. pub trait SerializeSeqProduct { + /// Must match the `Ok` type of any `Serializer` that uses this type. type Ok; + + /// Must match the `Error` type of any `Serializer` that uses this type. type Error: Error; - fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error>; + /// Serialize an unnamed product `element`. + fn serialize_element(&mut self, element: &T) -> Result<(), Self::Error>; + + /// Consumes and finalizes the product serializer returning the `Self::Ok` data. fn end(self) -> Result; } +/// Returned from [`Serializer::serialize_named_product`]. +/// +/// This provides a continuation of sorts +/// where you can call [`serialize_element`](SerializeNamedProduct::serialize_element) however many times +/// and then finally the [`end`](SerializeNamedProduct::end) is reached. pub trait SerializeNamedProduct { + /// Must match the `Ok` type of any `Serializer` that uses this type. type Ok; + + /// Must match the `Error` type of any `Serializer` that uses this type. type Error: Error; + /// Serialize a named product `element` with `name`. fn serialize_element(&mut self, name: Option<&str>, elem: &T) -> Result<(), Self::Error>; + + /// Consumes and finalizes the product serializer returning the `Self::Ok` data. fn end(self) -> Result; } +/// Forwards the implementation of a named product value +/// to the implementation of the unnamed kind, +/// thereby ignoring any field names. pub struct ForwardNamedToSeqProduct { + /// The unnamed product serializer. tup: S, } + impl ForwardNamedToSeqProduct { + /// Returns a forwarder based on the provided unnamed product serializer. pub fn new(tup: S) -> Self { Self { tup } } + + /// Forwards the serialization of a named product of `len` fields + /// to an unnamed serialization format. pub fn forward(ser: Ser, len: usize) -> Result where Ser: Serializer, @@ -125,6 +278,7 @@ impl ForwardNamedToSeqProduct { ser.serialize_seq_product(len).map(Self::new) } } + impl SerializeNamedProduct for ForwardNamedToSeqProduct { type Ok = S::Ok; type Error = S::Error; @@ -132,6 +286,7 @@ impl SerializeNamedProduct for ForwardNamedToSeqProduct< fn serialize_element(&mut self, _name: Option<&str>, elem: &T) -> Result<(), Self::Error> { self.tup.serialize_element(elem) } + fn end(self) -> Result { self.tup.end() } diff --git a/crates/sats/src/ser/impls.rs b/crates/sats/src/ser/impls.rs index 49216c8702..fa724735e0 100644 --- a/crates/sats/src/ser/impls.rs +++ b/crates/sats/src/ser/impls.rs @@ -7,22 +7,45 @@ use crate::{ use super::{Serialize, SerializeArray, SerializeMap, SerializeNamedProduct, SerializeSeqProduct, Serializer}; -macro_rules! impl_prim { - ($(($prim:ty, $method:ident))*) => { - $(impl Serialize for $prim { - fn serialize(&self, ser: S) -> Result { - ser.$method((*self).into()) +/// Implements [`Serialize`] for a type in a simplified manner. +/// +/// An example: +/// ```ignore +/// struct Foo<'a, T: Copy>(&'a T, u8); +/// impl_serialize!( +/// // Type parameters Optional where Impl type +/// // v v v +/// // ---------------- --------------- ---------- +/// ['a, T: Serialize] where [T: Copy] Foo<'a, T>, +/// // The `serialize` implementation where `self` is serialized into `ser` +/// // and the expression right of `=>` is the body of `serialize`. +/// (self, ser) => { +/// let mut prod = ser.serialize_seq_product(2)?; +/// prod.serialize_element(&self.0)?; +/// prod.serialize_element(&self.1)?; +/// prod.end() +/// } +/// ); +/// ``` +#[macro_export] +macro_rules! impl_serialize { + ([$($generics:tt)*] $(where [$($wc:tt)*])? $typ:ty, ($self:ident, $ser:ident) => $body:expr) => { + impl<$($generics)*> $crate::ser::Serialize for $typ $(where $($wc)*)? { + fn serialize($self: &Self, $ser: S) -> Result { + $body } - })* + } }; } -impl Serialize for () { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_seq_product(0)?.end() - } +macro_rules! impl_prim { + ($(($prim:ty, $method:ident))*) => { + $(impl_serialize!([] $prim, (self, ser) => ser.$method((*self).into()));)* + }; } +impl_serialize!([] (), (self, ser) => ser.serialize_seq_product(0)?.end()); + impl_prim! { (bool, serialize_bool) /*(u8, serialize_u8)*/ (u16, serialize_u16) (u32, serialize_u32) (u64, serialize_u64) (u128, serialize_u128) (i8, serialize_i8) @@ -43,261 +66,166 @@ impl Serialize for u8 { } } -impl Serialize for crate::builtin_value::F32 { - fn serialize(&self, serializer: S) -> Result { - f32::from(*self).serialize(serializer) - } -} -impl Serialize for crate::builtin_value::F64 { - fn serialize(&self, serializer: S) -> Result { - f64::from(*self).serialize(serializer) - } -} - -impl Serialize for Vec { - fn serialize(&self, serializer: S) -> Result { - (**self).serialize(serializer) - } -} -impl Serialize for [T] { - fn serialize(&self, serializer: S) -> Result { - T::__serialize_array(self, serializer) - } -} - -impl Serialize for [T; N] { - fn serialize(&self, serializer: S) -> Result { - T::__serialize_array(self, serializer) - } -} - -impl Serialize for Box { - fn serialize(&self, serializer: S) -> Result { - (**self).serialize(serializer) - } -} -impl Serialize for &T { - fn serialize(&self, serializer: S) -> Result { - (**self).serialize(serializer) - } -} - -impl Serialize for String { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(self) - } -} - -impl Serialize for Option { - fn serialize(&self, serializer: S) -> Result { - match self { - Some(v) => serializer.serialize_variant(0, Some("some"), v), - None => serializer.serialize_variant(1, Some("none"), &()), - } - } -} - -impl Serialize for BTreeMap { - fn serialize(&self, serializer: S) -> Result { - let mut map = serializer.serialize_map(self.len())?; - for (k, v) in self { - map.serialize_entry(k, v)?; - } - map.end() - } -} - -impl Serialize for AlgebraicValue { - fn serialize(&self, serializer: S) -> Result { - match self { - AlgebraicValue::Sum(sum) => sum.serialize(serializer), - AlgebraicValue::Product(prod) => prod.serialize(serializer), - AlgebraicValue::Builtin(b) => b.serialize(serializer), - } - } -} - -impl Serialize for BuiltinValue { - fn serialize(&self, serializer: S) -> Result { - match self { - BuiltinValue::Bool(v) => serializer.serialize_bool(*v), - BuiltinValue::I8(v) => serializer.serialize_i8(*v), - BuiltinValue::U8(v) => serializer.serialize_u8(*v), - BuiltinValue::I16(v) => serializer.serialize_i16(*v), - BuiltinValue::U16(v) => serializer.serialize_u16(*v), - BuiltinValue::I32(v) => serializer.serialize_i32(*v), - BuiltinValue::U32(v) => serializer.serialize_u32(*v), - BuiltinValue::I64(v) => serializer.serialize_i64(*v), - BuiltinValue::U64(v) => serializer.serialize_u64(*v), - BuiltinValue::I128(v) => serializer.serialize_i128(*v), - BuiltinValue::U128(v) => serializer.serialize_u128(*v), - BuiltinValue::F32(v) => serializer.serialize_f32((*v).into()), - BuiltinValue::F64(v) => serializer.serialize_f64((*v).into()), - BuiltinValue::String(v) => serializer.serialize_str(v), - // BuiltinValue::Bytes(v) => serializer.serialize_bytes(v), - BuiltinValue::Array { val } => val.serialize(serializer), - BuiltinValue::Map { val } => val.serialize(serializer), - } - } -} - -impl Serialize for ProductValue { - fn serialize(&self, serializer: S) -> Result { - let mut tup = serializer.serialize_seq_product(self.elements.len())?; - for elem in &*self.elements { - tup.serialize_element(elem)?; - } - tup.end() - } -} - -impl Serialize for SumValue { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_variant(self.tag, None, &*self.value) - } -} - -impl Serialize for ArrayValue { - fn serialize(&self, serializer: S) -> Result { - match self { - ArrayValue::Sum(v) => v.serialize(serializer), - ArrayValue::Product(v) => v.serialize(serializer), - ArrayValue::Bool(v) => v.serialize(serializer), - ArrayValue::I8(v) => v.serialize(serializer), - ArrayValue::U8(v) => v.serialize(serializer), - ArrayValue::I16(v) => v.serialize(serializer), - ArrayValue::U16(v) => v.serialize(serializer), - ArrayValue::I32(v) => v.serialize(serializer), - ArrayValue::U32(v) => v.serialize(serializer), - ArrayValue::I64(v) => v.serialize(serializer), - ArrayValue::U64(v) => v.serialize(serializer), - ArrayValue::I128(v) => v.serialize(serializer), - ArrayValue::U128(v) => v.serialize(serializer), - ArrayValue::F32(v) => v.serialize(serializer), - ArrayValue::F64(v) => v.serialize(serializer), - ArrayValue::String(v) => v.serialize(serializer), - ArrayValue::Array(v) => v.serialize(serializer), - ArrayValue::Map(v) => v.serialize(serializer), - } - } -} - -impl Serialize for ValueWithType<'_, AlgebraicValue> { - fn serialize(&self, serializer: S) -> Result { - let mut ty = self.ty(); - loop { - break match (self.value(), ty) { - (AlgebraicValue::Sum(val), AlgebraicType::Sum(ty)) => self.with(ty, val).serialize(serializer), - (AlgebraicValue::Product(val), AlgebraicType::Product(ty)) => self.with(ty, val).serialize(serializer), - (AlgebraicValue::Builtin(val), AlgebraicType::Builtin(ty)) => self.with(ty, val).serialize(serializer), - (_, &AlgebraicType::Ref(r)) => { - ty = &self.typespace()[r]; - continue; - } - _ => panic!("mismatched value and schema"), - }; - } - } -} - -impl Serialize for ValueWithType<'_, BuiltinValue> { - fn serialize(&self, serializer: S) -> Result { - match (self.value(), self.ty()) { - (BuiltinValue::Bool(v), BuiltinType::Bool) => serializer.serialize_bool(*v), - (BuiltinValue::I8(v), BuiltinType::I8) => serializer.serialize_i8(*v), - (BuiltinValue::U8(v), BuiltinType::U8) => serializer.serialize_u8(*v), - (BuiltinValue::I16(v), BuiltinType::I16) => serializer.serialize_i16(*v), - (BuiltinValue::U16(v), BuiltinType::U16) => serializer.serialize_u16(*v), - (BuiltinValue::I32(v), BuiltinType::I32) => serializer.serialize_i32(*v), - (BuiltinValue::U32(v), BuiltinType::U32) => serializer.serialize_u32(*v), - (BuiltinValue::I64(v), BuiltinType::I64) => serializer.serialize_i64(*v), - (BuiltinValue::U64(v), BuiltinType::U64) => serializer.serialize_u64(*v), - (BuiltinValue::I128(v), BuiltinType::I128) => serializer.serialize_i128(*v), - (BuiltinValue::U128(v), BuiltinType::U128) => serializer.serialize_u128(*v), - (BuiltinValue::F32(v), BuiltinType::F32) => serializer.serialize_f32((*v).into()), - (BuiltinValue::F64(v), BuiltinType::F64) => serializer.serialize_f64((*v).into()), - (BuiltinValue::String(s), BuiltinType::String) => serializer.serialize_str(s), - (BuiltinValue::Array { val }, BuiltinType::Array(ty)) => self.with(ty, val).serialize(serializer), - (BuiltinValue::Map { val }, BuiltinType::Map(ty)) => self.with(ty, val).serialize(serializer), - (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), - } - } -} - -impl Serialize for ValueWithType<'_, Vec> -where - for<'a> ValueWithType<'a, T>: Serialize, -{ - fn serialize(&self, serializer: S) -> Result { - let mut vec = serializer.serialize_array(self.value().len())?; +impl_serialize!([] crate::builtin_value::F32, (self, ser) => f32::from(*self).serialize(ser)); +impl_serialize!([] crate::builtin_value::F64, (self, ser) => f64::from(*self).serialize(ser)); +impl_serialize!([T: Serialize] Vec, (self, ser) => (**self).serialize(ser)); +impl_serialize!([T: Serialize] [T], (self, ser) => T::__serialize_array(self, ser)); +impl_serialize!([T: Serialize, const N: usize] [T; N], (self, ser) => T::__serialize_array(self, ser)); +impl_serialize!([T: Serialize + ?Sized] Box, (self, ser) => (**self).serialize(ser)); +impl_serialize!([T: Serialize + ?Sized] &T, (self, ser) => (**self).serialize(ser)); +impl_serialize!([] String, (self, ser) => ser.serialize_str(self)); +impl_serialize!([T: Serialize] Option, (self, ser) => match self { + Some(v) => ser.serialize_variant(0, Some("some"), v), + None => ser.serialize_variant(1, Some("none"), &()), +}); +impl_serialize!([K: Serialize, V: Serialize] BTreeMap, (self, ser) => { + let mut map = ser.serialize_map(self.len())?; + for (k, v) in self { + map.serialize_entry(k, v)?; + } + map.end() +}); +impl_serialize!([] AlgebraicValue, (self, ser) => match self { + Self::Sum(sum) => sum.serialize(ser), + Self::Product(prod) => prod.serialize(ser), + Self::Builtin(b) => b.serialize(ser), +}); +impl_serialize!([] BuiltinValue, (self, ser) => match self { + Self::Bool(v) => ser.serialize_bool(*v), + Self::I8(v) => ser.serialize_i8(*v), + Self::U8(v) => ser.serialize_u8(*v), + Self::I16(v) => ser.serialize_i16(*v), + Self::U16(v) => ser.serialize_u16(*v), + Self::I32(v) => ser.serialize_i32(*v), + Self::U32(v) => ser.serialize_u32(*v), + Self::I64(v) => ser.serialize_i64(*v), + Self::U64(v) => ser.serialize_u64(*v), + Self::I128(v) => ser.serialize_i128(*v), + Self::U128(v) => ser.serialize_u128(*v), + Self::F32(v) => ser.serialize_f32((*v).into()), + Self::F64(v) => ser.serialize_f64((*v).into()), + Self::String(v) => ser.serialize_str(v), + // Self::Bytes(v) => ser.serialize_bytes(v), + Self::Array { val } => val.serialize(ser), + Self::Map { val } => val.serialize(ser), +}); +impl_serialize!([] ProductValue, (self, ser) => { + let mut tup = ser.serialize_seq_product(self.elements.len())?; + for elem in &*self.elements { + tup.serialize_element(elem)?; + } + tup.end() +}); +impl_serialize!([] SumValue, (self, ser) => ser.serialize_variant(self.tag, None, &*self.value)); +impl_serialize!([] ArrayValue, (self, ser) => match self { + Self::Sum(v) => v.serialize(ser), + Self::Product(v) => v.serialize(ser), + Self::Bool(v) => v.serialize(ser), + Self::I8(v) => v.serialize(ser), + Self::U8(v) => v.serialize(ser), + Self::I16(v) => v.serialize(ser), + Self::U16(v) => v.serialize(ser), + Self::I32(v) => v.serialize(ser), + Self::U32(v) => v.serialize(ser), + Self::I64(v) => v.serialize(ser), + Self::U64(v) => v.serialize(ser), + Self::I128(v) => v.serialize(ser), + Self::U128(v) => v.serialize(ser), + Self::F32(v) => v.serialize(ser), + Self::F64(v) => v.serialize(ser), + Self::String(v) => v.serialize(ser), + Self::Array(v) => v.serialize(ser), + Self::Map(v) => v.serialize(ser), +}); +impl_serialize!([] ValueWithType<'_, AlgebraicValue>, (self, ser) => { + let mut ty = self.ty(); + loop { // We're doing this because of `Ref`s. + break match (self.value(), ty) { + (AlgebraicValue::Sum(val), AlgebraicType::Sum(ty)) => self.with(ty, val).serialize(ser), + (AlgebraicValue::Product(val), AlgebraicType::Product(ty)) => self.with(ty, val).serialize(ser), + (AlgebraicValue::Builtin(val), AlgebraicType::Builtin(ty)) => self.with(ty, val).serialize(ser), + (_, &AlgebraicType::Ref(r)) => { + ty = &self.typespace()[r]; + continue; + } + _ => panic!("mismatched value and schema"), + }; + } +}); +impl_serialize!([] ValueWithType<'_, BuiltinValue>, (self, ser) => match (self.value(), self.ty()) { + (BuiltinValue::Bool(v), BuiltinType::Bool) => ser.serialize_bool(*v), + (BuiltinValue::I8(v), BuiltinType::I8) => ser.serialize_i8(*v), + (BuiltinValue::U8(v), BuiltinType::U8) => ser.serialize_u8(*v), + (BuiltinValue::I16(v), BuiltinType::I16) => ser.serialize_i16(*v), + (BuiltinValue::U16(v), BuiltinType::U16) => ser.serialize_u16(*v), + (BuiltinValue::I32(v), BuiltinType::I32) => ser.serialize_i32(*v), + (BuiltinValue::U32(v), BuiltinType::U32) => ser.serialize_u32(*v), + (BuiltinValue::I64(v), BuiltinType::I64) => ser.serialize_i64(*v), + (BuiltinValue::U64(v), BuiltinType::U64) => ser.serialize_u64(*v), + (BuiltinValue::I128(v), BuiltinType::I128) => ser.serialize_i128(*v), + (BuiltinValue::U128(v), BuiltinType::U128) => ser.serialize_u128(*v), + (BuiltinValue::F32(v), BuiltinType::F32) => ser.serialize_f32((*v).into()), + (BuiltinValue::F64(v), BuiltinType::F64) => ser.serialize_f64((*v).into()), + (BuiltinValue::String(s), BuiltinType::String) => ser.serialize_str(s), + (BuiltinValue::Array { val }, BuiltinType::Array(ty)) => self.with(ty, val).serialize(ser), + (BuiltinValue::Map { val }, BuiltinType::Map(ty)) => self.with(ty, val).serialize(ser), + (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), +}); +impl_serialize!( + [T: crate::Value] where [for<'a> ValueWithType<'a, T>: Serialize] + ValueWithType<'_, Vec>, + (self, ser) => { + let mut vec = ser.serialize_array(self.value().len())?; for val in self.iter() { vec.serialize_element(&val)?; } vec.end() } -} - -impl Serialize for ValueWithType<'_, SumValue> { - fn serialize(&self, serializer: S) -> Result { - let &SumValue { tag, ref value } = self.value(); - let var_ty = &self.ty().variants[tag as usize]; - serializer.serialize_variant( - tag, - var_ty.name.as_deref(), - &self.with(&var_ty.algebraic_type, &**value), - ) - } -} - -impl Serialize for ValueWithType<'_, ProductValue> { - fn serialize(&self, serializer: S) -> Result { - let val = &self.value().elements; - assert_eq!(val.len(), self.ty().elements.len()); - let mut prod = serializer.serialize_named_product(val.len())?; - for (val, el_ty) in val.iter().zip(&self.ty().elements) { - prod.serialize_element(el_ty.name.as_deref(), &self.with(&el_ty.algebraic_type, val))? - } - prod.end() - } -} - -impl Serialize for ValueWithType<'_, ArrayValue> { - fn serialize(&self, serializer: S) -> Result { - match (self.value(), &*self.ty().elem_ty) { - (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => self.with(ty, v).serialize(serializer), - (ArrayValue::Product(v), AlgebraicType::Product(ty)) => self.with(ty, v).serialize(serializer), - (ArrayValue::Bool(v), &AlgebraicType::Bool) => v.serialize(serializer), - (ArrayValue::I8(v), &AlgebraicType::I8) => v.serialize(serializer), - (ArrayValue::U8(v), &AlgebraicType::U8) => v.serialize(serializer), - (ArrayValue::I16(v), &AlgebraicType::I16) => v.serialize(serializer), - (ArrayValue::U16(v), &AlgebraicType::U16) => v.serialize(serializer), - (ArrayValue::I32(v), &AlgebraicType::I32) => v.serialize(serializer), - (ArrayValue::U32(v), &AlgebraicType::U32) => v.serialize(serializer), - (ArrayValue::I64(v), &AlgebraicType::I64) => v.serialize(serializer), - (ArrayValue::U64(v), &AlgebraicType::U64) => v.serialize(serializer), - (ArrayValue::I128(v), &AlgebraicType::I128) => v.serialize(serializer), - (ArrayValue::U128(v), &AlgebraicType::U128) => v.serialize(serializer), - (ArrayValue::F32(v), &AlgebraicType::F32) => v.serialize(serializer), - (ArrayValue::F64(v), &AlgebraicType::F64) => v.serialize(serializer), - (ArrayValue::String(v), &AlgebraicType::String) => v.serialize(serializer), - (ArrayValue::Array(v), AlgebraicType::Builtin(BuiltinType::Array(ty))) => { - self.with(ty, v).serialize(serializer) - } - (ArrayValue::Map(v), AlgebraicType::Builtin(BuiltinType::Map(m))) => self.with(m, v).serialize(serializer), - (val, _) if val.is_empty() => serializer.serialize_array(0)?.end(), - (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), - } - } -} - -impl Serialize for ValueWithType<'_, MapValue> { - fn serialize(&self, serializer: S) -> Result { - let val = self.value(); - let MapType { key_ty, ty } = self.ty(); - let mut map = serializer.serialize_map(val.len())?; - for (key, val) in val { - map.serialize_entry(&self.with(&**key_ty, key), &self.with(&**ty, val))?; - } - map.end() - } -} +); +impl_serialize!([] ValueWithType<'_, SumValue>, (self, ser) => { + let &SumValue { tag, ref value } = self.value(); + let var_ty = &self.ty().variants[tag as usize]; // Extract the variant type by tag. + ser.serialize_variant(tag, var_ty.name(), &self.with(&var_ty.algebraic_type, &**value)) +}); +impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => { + let val = &self.value().elements; + assert_eq!(val.len(), self.ty().elements.len()); + let mut prod = ser.serialize_named_product(val.len())?; + for (val, el_ty) in val.iter().zip(&self.ty().elements) { + prod.serialize_element(el_ty.name(), &self.with(&el_ty.algebraic_type, val))? + } + prod.end() +}); +impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => match (self.value(), &*self.ty().elem_ty) { + (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => self.with(ty, v).serialize(ser), + (ArrayValue::Product(v), AlgebraicType::Product(ty)) => self.with(ty, v).serialize(ser), + (ArrayValue::Bool(v), &AlgebraicType::Builtin(BuiltinType::Bool)) => v.serialize(ser), + (ArrayValue::I8(v), &AlgebraicType::Builtin(BuiltinType::I8)) => v.serialize(ser), + (ArrayValue::U8(v), &AlgebraicType::Builtin(BuiltinType::U8)) => v.serialize(ser), + (ArrayValue::I16(v), &AlgebraicType::Builtin(BuiltinType::I16)) => v.serialize(ser), + (ArrayValue::U16(v), &AlgebraicType::Builtin(BuiltinType::U16)) => v.serialize(ser), + (ArrayValue::I32(v), &AlgebraicType::Builtin(BuiltinType::I32)) => v.serialize(ser), + (ArrayValue::U32(v), &AlgebraicType::Builtin(BuiltinType::U32)) => v.serialize(ser), + (ArrayValue::I64(v), &AlgebraicType::Builtin(BuiltinType::I64)) => v.serialize(ser), + (ArrayValue::U64(v), &AlgebraicType::Builtin(BuiltinType::U64)) => v.serialize(ser), + (ArrayValue::I128(v), &AlgebraicType::Builtin(BuiltinType::I128)) => v.serialize(ser), + (ArrayValue::U128(v), &AlgebraicType::Builtin(BuiltinType::U128)) => v.serialize(ser), + (ArrayValue::F32(v), &AlgebraicType::Builtin(BuiltinType::F32)) => v.serialize(ser), + (ArrayValue::F64(v), &AlgebraicType::Builtin(BuiltinType::F64)) => v.serialize(ser), + (ArrayValue::String(v), &AlgebraicType::Builtin(BuiltinType::String)) => v.serialize(ser), + (ArrayValue::Array(v), AlgebraicType::Builtin(BuiltinType::Array(ty))) => { + self.with(ty, v).serialize(ser) + } + (ArrayValue::Map(v), AlgebraicType::Builtin(BuiltinType::Map(m))) => self.with(m, v).serialize(ser), + (val, _) if val.is_empty() => ser.serialize_array(0)?.end(), + (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), +}); +impl_serialize!([] ValueWithType<'_, MapValue>, (self, ser) => { + let val = self.value(); + let MapType { key_ty, ty } = self.ty(); + let mut map = ser.serialize_map(val.len())?; + for (key, val) in val { + map.serialize_entry(&self.with(&**key_ty, key), &self.with(&**ty, val))?; + } + map.end() +}); diff --git a/crates/sats/src/ser/serde.rs b/crates/sats/src/ser/serde.rs index 40f1eab01e..831da5ef8a 100644 --- a/crates/sats/src/ser/serde.rs +++ b/crates/sats/src/ser/serde.rs @@ -4,21 +4,22 @@ use ::serde::ser as serde; use crate::ser::{self, Serializer}; +/// Converts any [`serde::Serializer`] to a SATS [`Serializer`] +/// so that Serde's data formats can be reused. pub struct SerdeSerializer { + /// A serialization data format in Serde. ser: S, } -impl SerdeSerializer { +impl SerdeSerializer { + /// Returns a wrapped serializer. pub fn new(ser: S) -> Self { Self { ser } } } +/// An error that occured when serializing SATS to a Serde data format. pub struct SerdeError(pub E); -fn unwrap_error(err: SerdeError) -> E { - let SerdeError(err) = err; - err -} impl ser::Error for SerdeError { fn custom(msg: T) -> Self { @@ -121,7 +122,9 @@ impl Serializer for SerdeSerializer { } } +/// Serializes array elements by forwarding to `S: serde::SerializeSeq`. pub struct SerializeArray { + /// An implementation of `serde::SerializeSeq`. seq: S, } @@ -140,7 +143,9 @@ impl ser::SerializeArray for SerializeArray { } } +/// Serializes map entries by forwarding to `S: serde::SerializeMap`. pub struct SerializeMap { + /// An implementation of `serde::SerializeMap`. map: S, } @@ -163,7 +168,9 @@ impl ser::SerializeMap for SerializeMap { } } +/// Serializes unnamed product elements by forwarding to `S: serde::SerializeTuple`. pub struct SerializeSeqProduct { + /// An implementation of `serde::SerializeTuple`. tup: S, } @@ -182,7 +189,9 @@ impl ser::SerializeSeqProduct for SerializeSeqProduct< } } +/// Serializes named product elements by forwarding to `S: serde::SerializeMap`. pub struct SerializeNamedProduct { + /// An implementation of `serde::SerializeMap`. map: S, } @@ -206,31 +215,41 @@ impl ser::SerializeNamedProduct for SerializeNamedProduc } } +/// Serializes `T` as a SATS object into `serializer: S` +/// where `S` is a serde data format. pub fn serialize_to( value: &T, serializer: S, ) -> Result { - value.serialize(SerdeSerializer::new(serializer)).map_err(unwrap_error) + value + .serialize(SerdeSerializer::new(serializer)) + .map_err(|SerdeError(e)| e) } +/// Turns a type serializable in SATS into one serializable in serde. +/// +/// That is, `T: sats::Serialize => SerializeWrapper: serde::Serialize`. #[repr(transparent)] pub struct SerializeWrapper(T); + impl SerializeWrapper { + /// Wraps an object serializable in SATS so that it's serializable in Serde. pub fn new(t: T) -> Self where T: Sized, { Self(t) } + + /// Converts `&T` to `&SerializeWrapper`. pub fn from_ref(t: &T) -> &Self { + // SAFETY: OK because of `repr(transparent)`. unsafe { &*(t as *const T as *const SerializeWrapper) } } } + impl serde::Serialize for SerializeWrapper { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { + fn serialize(&self, serializer: S) -> Result { serialize_to(&self.0, serializer) } } diff --git a/crates/sats/src/sum_type.rs b/crates/sats/src/sum_type.rs index 20d529a37a..47bd6faa7b 100644 --- a/crates/sats/src/sum_type.rs +++ b/crates/sats/src/sum_type.rs @@ -1,37 +1,65 @@ -pub mod satn; use crate::algebraic_value::de::{ValueDeserializeError, ValueDeserializer}; use crate::algebraic_value::ser::ValueSerializer; +use crate::meta_type::MetaType; use crate::{de::Deserialize, ser::Serialize}; -use crate::{ - AlgebraicType, AlgebraicTypeRef, AlgebraicValue, ArrayType, BuiltinType, ProductType, ProductTypeElement, - SumTypeVariant, -}; +use crate::{AlgebraicType, AlgebraicValue, ProductTypeElement, SumTypeVariant}; +/// A structural sum type. +/// +/// Unlike most languages, sums in SATS are *[structural]* and not nominal. +/// When checking whether two nominal types are the same, +/// their names and/or declaration sites (e.g., module / namespace) are considered. +/// Meanwhile, a structural type system would only check the structure of the type itself, +/// e.g., the names of its variants and their inner data types in the case of a sum. +/// +/// This is also known as a discriminated union (implementation) or disjoint union. +/// Another name is [coproduct (category theory)](https://ncatlab.org/nlab/show/coproduct). +/// +/// These structures are known as sum types because the number of possible values a sum +/// ```ignore +/// { N_0(T_0), N_1(T_1), ..., N_n(T_n) } +/// ``` +/// is: +/// ```ignore +/// Σ (i ∈ 0..n). values(T_i) +/// ``` +/// so for example, `values({ A(U64), B(Bool) }) = values(U64) + values(Bool)`. +/// +/// See also: https://ncatlab.org/nlab/show/sum+type. +/// +/// [structural]: https://en.wikipedia.org/wiki/Structural_type_system #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub struct SumType { + /// The possible variants of the sum type. + /// + /// The order is relevant as it defines the tags of the variants at runtime. pub variants: Vec, } impl SumType { - pub fn new(variants: Vec) -> Self { + /// Returns a sum type with these possible `variants`. + pub const fn new(variants: Vec) -> Self { Self { variants } } + /// Returns a sum type of unnamed variants taken from `types`. pub fn new_unnamed(types: Vec) -> Self { - let variants = types - .iter() - .map(|ty| SumTypeVariant::new(ty.clone(), None)) - .collect::>(); + let variants = types.into_iter().map(|ty| ty.into()).collect::>(); Self { variants } } + /// Returns whether this sum type looks like an option type. + /// + /// An option type has `some(T)` as its first variant and `none` as its second. + /// That is, `{ some(T), none }` or `some: T | none` depending on your notation. pub fn looks_like_option(&self) -> Option<&AlgebraicType> { match &*self.variants { [first, second] - if first.name.as_deref() == Some("some") - && second.name.as_deref() == Some("none") - && second.algebraic_type == AlgebraicType::UNIT_TYPE => + if second.algebraic_type == AlgebraicType::UNIT_TYPE + // ^-- Done first to avoid pointer indirection when it doesn't matter. + && first.has_name("some") + && second.has_name("none") => { Some(&first.algebraic_type) } @@ -40,29 +68,16 @@ impl SumType { } } -impl SumType { - pub fn make_meta_type() -> AlgebraicType { - let string = AlgebraicType::Builtin(BuiltinType::String); - let option = AlgebraicType::make_option_type(string); - let variant_type = AlgebraicType::Product(ProductType::new(vec![ - ProductTypeElement { - algebraic_type: option, - name: Some("name".into()), - }, - ProductTypeElement { - algebraic_type: AlgebraicType::Ref(AlgebraicTypeRef(0)), - name: Some("algebraic_type".into()), - }, - ])); - let array = AlgebraicType::Builtin(BuiltinType::Array(ArrayType { - elem_ty: Box::new(variant_type), - })); - AlgebraicType::Product(ProductType::new(vec![ProductTypeElement { - algebraic_type: array, - name: Some("variants".into()), - }])) +impl MetaType for SumType { + fn meta_type() -> AlgebraicType { + AlgebraicType::product(vec![ProductTypeElement::new_named( + AlgebraicType::array(SumTypeVariant::meta_type()), + "variants", + )]) } +} +impl SumType { pub fn as_value(&self) -> AlgebraicValue { self.serialize(ValueSerializer).unwrap_or_else(|x| match x {}) } diff --git a/crates/sats/src/sum_type/satn.rs b/crates/sats/src/sum_type/satn.rs deleted file mode 100644 index be2e7be379..0000000000 --- a/crates/sats/src/sum_type/satn.rs +++ /dev/null @@ -1,33 +0,0 @@ -use super::SumType; -use crate::algebraic_type; -use std::fmt::Display; - -pub struct Formatter<'a> { - ty: &'a SumType, -} - -impl<'a> Formatter<'a> { - pub fn new(ty: &'a SumType) -> Self { - Self { ty } - } -} - -impl<'a> Display for Formatter<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.ty.variants.is_empty() { - return write!(f, "(|)"); - } - write!(f, "(")?; - for (i, e) in self.ty.variants.iter().enumerate() { - if let Some(name) = &e.name { - write!(f, "{}", name)?; - write!(f, ": ")?; - } - write!(f, "{}", algebraic_type::satn::Formatter::new(&e.algebraic_type))?; - if i < self.ty.variants.len() - 1 { - write!(f, " | ")?; - } - } - write!(f, ")") - } -} diff --git a/crates/sats/src/sum_type_variant.rs b/crates/sats/src/sum_type_variant.rs index 37ad0b0e92..7d66b0f6ac 100644 --- a/crates/sats/src/sum_type_variant.rs +++ b/crates/sats/src/sum_type_variant.rs @@ -1,24 +1,63 @@ use crate::algebraic_type::AlgebraicType; +use crate::meta_type::MetaType; use crate::{de::Deserialize, ser::Serialize}; +use crate::{AlgebraicTypeRef, ProductTypeElement}; +/// A variant of a sum type. +/// /// NOTE: Each element has an implicit element tag based on its order. /// Uniquely identifies an element similarly to protobuf tags. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] #[sats(crate = crate)] pub struct SumTypeVariant { + /// The name of the variant, if any. pub name: Option, + /// The type of the variant. + /// + /// Unlike a language like Rust, + /// where we can have `enum _ { V1 { foo: A, bar: B, .. }, .. }`, + /// the product within the variant `V1`, i.e., `{ foo: A, bar: B, .. }` + /// is separated out in SATS into a separate product type. + /// So we would express this as `{ V1({ foo: A, bar: B, .. }), .. }`. pub algebraic_type: AlgebraicType, } impl SumTypeVariant { - pub fn new(algebraic_type: AlgebraicType, name: Option) -> Self { + /// Returns a sum type variant with an optional `name` and `algebraic_type`. + pub const fn new(algebraic_type: AlgebraicType, name: Option) -> Self { Self { algebraic_type, name } } + /// Returns a sum type variant with `name` and `algebraic_type`. pub fn new_named(algebraic_type: AlgebraicType, name: impl AsRef) -> Self { Self { algebraic_type, name: Some(name.as_ref().to_owned()), } } + + /// Returns the name of the variant. + pub fn name(&self) -> Option<&str> { + self.name.as_deref() + } + + /// Returns whether the variant has the given name. + pub fn has_name(&self, name: &str) -> bool { + self.name() == Some(name) + } +} + +impl MetaType for SumTypeVariant { + fn meta_type() -> AlgebraicType { + AlgebraicType::product(vec![ + ProductTypeElement::new_named(AlgebraicType::option(AlgebraicType::String), "name"), + ProductTypeElement::new_named(AlgebraicType::Ref(AlgebraicTypeRef(0)), "algebraic_type"), + ]) + } +} + +impl From for SumTypeVariant { + fn from(algebraic_type: AlgebraicType) -> Self { + Self::new(algebraic_type, None) + } } diff --git a/crates/sats/src/sum_value.rs b/crates/sats/src/sum_value.rs index 159d9cd509..00c68bb237 100644 --- a/crates/sats/src/sum_value.rs +++ b/crates/sats/src/sum_value.rs @@ -1,9 +1,13 @@ use crate::algebraic_value::AlgebraicValue; use crate::sum_type::SumType; +/// A value of a sum type chosing a specific variant of the type. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct SumValue { + /// A tag representing the choice of one variant of the sum type's variants. pub tag: u8, + /// Given a variant `Var(Ty)` in a sum type `{ Var(Ty), ... }`, + /// this provides the `value` for `Ty`. pub value: Box, } diff --git a/crates/sats/src/typespace.rs b/crates/sats/src/typespace.rs index d44311537e..44581aa704 100644 --- a/crates/sats/src/typespace.rs +++ b/crates/sats/src/typespace.rs @@ -3,12 +3,26 @@ use std::ops::{Index, IndexMut}; use crate::algebraic_type::AlgebraicType; use crate::algebraic_type_ref::AlgebraicTypeRef; +use crate::WithTypespace; use crate::{de::Deserialize, ser::Serialize}; -use crate::{ArrayType, BuiltinType, TypeInSpace}; +/// A `Typespace` represents the typing context in SATS. +/// +/// That is, this is the `Δ` or `Γ` you'll see in type theory litterature. +/// +/// We use (sort of) [deBrujin indices](https://en.wikipedia.org/wiki/De_Bruijn_index) +/// to represent our type variables, +/// but notably, these are given for the entire module +/// and there are no universal quantifiers (i.e., `Δ, α ⊢ τ | Δ ⊢ ∀ α. τ`) +/// nor are there type lambdas (i.e., `Λτ. v`). +/// +/// There are however recursive types in SATs, +/// e.g., `&0 = { Cons({ v: U8, t: &0 }), Nil }` represents a basic cons list +/// where `&0` is the type reference at index `0`. #[derive(Debug, Clone, Deserialize, Serialize)] #[sats(crate = crate)] pub struct Typespace { + /// The types in our typing context that can be referred to with [`AlgebraicTypeRef`]s. pub types: Vec, } @@ -32,12 +46,14 @@ impl IndexMut for Typespace { } impl Typespace { + /// Returns a context ([`Typespace`]) with the given `types`. pub const fn new(types: Vec) -> Self { Self { types } } + /// Returns the [`AlgebraicType`] referred to by `r` within this context. pub fn get(&self, r: AlgebraicTypeRef) -> Option<&AlgebraicType> { - self.types.get(r.0 as usize) + self.types.get(r.idx()) } /// Inserts an `AlgebraicType` into the typespace @@ -48,24 +64,39 @@ impl Typespace { /// /// You can also use this to later change the meaning of the returned `AlgebraicTypeRef` /// when you cannot provide the full definition of the type yet. + /// + /// Panics if the number of type references exceeds an `u32`. pub fn add(&mut self, ty: AlgebraicType) -> AlgebraicTypeRef { - let i = self.types.len(); + let index = self + .types + .len() + .try_into() + .expect("ran out of space for `AlgebraicTypeRef`s"); + self.types.push(ty); - AlgebraicTypeRef(i as u32) + AlgebraicTypeRef(index) } - pub fn with_type<'a, T: ?Sized>(&'a self, ty: &'a T) -> TypeInSpace<'a, T> { - TypeInSpace::new(self, ty) + /// Returns `ty` combined with the context `self`. + pub const fn with_type<'a, T: ?Sized>(&'a self, ty: &'a T) -> WithTypespace<'a, T> { + WithTypespace::new(self, ty) } } +/// A trait for types that can be represented as an `AlgebraicType` +/// provided a typing context `typespace`. pub trait SpacetimeType { + /// Returns an `AlgebraicType` representing the type for `Self` in SATS + /// and in the typing context in `typespace`. fn make_type(typespace: &mut S) -> AlgebraicType; } pub use spacetimedb_bindings_macro::SpacetimeType; +/// A trait for types that can build a [`Typespace`]. pub trait TypespaceBuilder { + /// Returns and adds a representation of type `T: 'static` as an `AlgebraicType` + /// with an optional `name` to the typing context in `self`. fn add( &mut self, typeid: TypeId, @@ -74,15 +105,35 @@ pub trait TypespaceBuilder { ) -> AlgebraicType; } +/// Implements [`SpacetimeType`] for a type in a simplified manner. +/// +/// An example: +/// ```ignore +/// struct Foo<'a, T>(&'a T, u8); +/// impl_st!( +/// // Type parameters Impl type +/// // v v +/// // -------------------- ---------- +/// ['a, T: SpacetimeType] Foo<'a, T>, +/// // The `make_type` implementation where `ts: impl TypespaceBuilder` +/// // and the expression right of `=>` is an `AlgebraicType`. +/// ts => AlgebraicType::product(vec![T::make_type(ts).into(), AlgebraicType::U8.into()]) +/// ); +/// ``` +#[macro_export] +macro_rules! impl_st { + ([ $($rgenerics:tt)* ] $rty:ty, $ts:ident => $stty:expr) => { + impl<$($rgenerics)*> $crate::SpacetimeType for $rty { + fn make_type($ts: &mut S) -> $crate::AlgebraicType { + $stty + } + } + }; +} + macro_rules! impl_primitives { ($($t:ty => $x:ident,)*) => { - $( - impl SpacetimeType for $t { - fn make_type(_ts: &mut S) -> AlgebraicType { - AlgebraicType::$x - } - } - )* + $(impl_st!([] $t, _ts => AlgebraicType::$x);)* }; } @@ -103,27 +154,7 @@ impl_primitives! { String => String, } -impl SpacetimeType for () { - fn make_type(_ts: &mut S) -> AlgebraicType { - AlgebraicType::UNIT_TYPE - } -} -impl SpacetimeType for &str { - fn make_type(_ts: &mut S) -> AlgebraicType { - AlgebraicType::String - } -} - -impl SpacetimeType for Vec { - fn make_type(typespace: &mut S) -> AlgebraicType { - AlgebraicType::Builtin(BuiltinType::Array(ArrayType { - elem_ty: Box::new(T::make_type(typespace)), - })) - } -} - -impl SpacetimeType for Option { - fn make_type(typespace: &mut S) -> AlgebraicType { - AlgebraicType::make_option_type(T::make_type(typespace)) - } -} +impl_st!([] (), _ts => AlgebraicType::UNIT_TYPE); +impl_st!([] &str, _ts => AlgebraicType::String); +impl_st!([T: SpacetimeType] Vec, ts => AlgebraicType::array(T::make_type(ts))); +impl_st!([T: SpacetimeType] Option, ts => AlgebraicType::option(T::make_type(ts))); diff --git a/crates/sats/tests/encoding_roundtrip.rs b/crates/sats/tests/encoding_roundtrip.rs index aa441ed11a..8f5487b468 100644 --- a/crates/sats/tests/encoding_roundtrip.rs +++ b/crates/sats/tests/encoding_roundtrip.rs @@ -3,12 +3,13 @@ use proptest::proptest; use spacetimedb_sats::buffer::DecodeError; use spacetimedb_sats::builtin_value::{F32, F64}; use spacetimedb_sats::{ - product, AlgebraicType, AlgebraicValue, BuiltinValue, ProductType, ProductTypeElement, ProductValue, + meta_type::MetaType, product, AlgebraicType, AlgebraicValue, BuiltinValue, ProductType, ProductTypeElement, + ProductValue, }; #[test] fn type_to_binary_equivalent() { - check_type(&AlgebraicType::make_meta_type()); + check_type(&AlgebraicType::meta_type()); } #[track_caller] diff --git a/crates/sqltest/src/space.rs b/crates/sqltest/src/space.rs index 4fba98ea59..a7b477fbd7 100644 --- a/crates/sqltest/src/space.rs +++ b/crates/sqltest/src/space.rs @@ -6,6 +6,7 @@ use spacetimedb::sql::compiler::compile_sql; use spacetimedb::sql::execute::execute_sql; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::relation::MemTable; +use spacetimedb_sats::meta_type::MetaType; use spacetimedb_sats::satn::Satn; use spacetimedb_sats::{AlgebraicType, AlgebraicValue, BuiltinType, BuiltinValue}; use sqllogictest::{AsyncDB, ColumnType, DBOutput}; @@ -23,7 +24,7 @@ impl ColumnType for Kind { 'T' => Some(Kind(AlgebraicType::String)), 'I' => Some(Kind(AlgebraicType::I64)), 'R' => Some(Kind(AlgebraicType::F32)), - _ => Some(Kind(AlgebraicType::make_meta_type())), + _ => Some(Kind(AlgebraicType::meta_type())), } } diff --git a/crates/sqltest/src/sqlite.rs b/crates/sqltest/src/sqlite.rs index 6606db5398..33588f1894 100644 --- a/crates/sqltest/src/sqlite.rs +++ b/crates/sqltest/src/sqlite.rs @@ -2,7 +2,7 @@ use crate::db::DBRunner; use crate::space::Kind; use async_trait::async_trait; use rusqlite::types::Value; -use spacetimedb_sats::AlgebraicType; +use spacetimedb_sats::{meta_type::MetaType, AlgebraicType}; use sqllogictest::{AsyncDB, DBOutput}; use std::path::PathBuf; use tempdir::TempDir; @@ -31,7 +31,7 @@ fn columns(stmt: &mut rusqlite::Statement) -> Vec<(String, AlgebraicType)> { stmt.columns() .iter() .map(|col| { - let kind = col.decl_type().map(kind).unwrap_or_else(AlgebraicType::make_meta_type); + let kind = col.decl_type().map(kind).unwrap_or_else(AlgebraicType::meta_type); (col.name().to_string(), kind) }) @@ -83,7 +83,7 @@ impl AsyncDB for Sqlite { let mut columns = columns(&mut stmt); let mut rows = stmt.query([])?; let mut data = Vec::new(); - let mut meta = AlgebraicType::make_meta_type(); + let mut meta = AlgebraicType::meta_type(); while let Some(row) = rows.next()? { let mut new = Vec::with_capacity(columns.len()); @@ -91,7 +91,7 @@ impl AsyncDB for Sqlite { for (name, dectype) in &mut columns { let value = row.get::<_, Value>(name.as_str())?; let (value, kind) = match value { - Value::Null => ("null".into(), AlgebraicType::make_never_type()), + Value::Null => ("null".into(), AlgebraicType::NEVER_TYPE), Value::Integer(x) => (x.to_string(), AlgebraicType::I64), Value::Real(x) => (format!("{:?}", x), AlgebraicType::F64), Value::Text(x) => (format!("'{}'", x), AlgebraicType::String), diff --git a/crates/vm/src/eval.rs b/crates/vm/src/eval.rs index a2a2d1435e..19ef273500 100644 --- a/crates/vm/src/eval.rs +++ b/crates/vm/src/eval.rs @@ -480,14 +480,14 @@ pub fn run_ast(p: &mut P, ast: Expr) -> Code { // Used internally for testing recursion #[doc(hidden)] pub fn fibo(input: u64) -> Expr { - let kind = AlgebraicType::Builtin(BuiltinType::U64); + let ty = AlgebraicType::U64; let less = |val: u64| bin_op(OpMath::Minus, var("n"), scalar(val)); let f = Function::new( "fib", - &[Param::new("n", kind.clone())], - kind, + &[Param::new("n", ty.clone())], + ty, &[if_( bin_op(OpCmp::Lt, var("n"), scalar(2u64)), var("n"), @@ -679,11 +679,11 @@ mod tests { #[test] fn test_fun() { let p = &mut Program::new(AuthCtx::for_testing()); - let kind = AlgebraicType::Builtin(BuiltinType::U64); + let ty = AlgebraicType::U64; let f = Function::new( "sum", - &[Param::new("a", kind.clone()), Param::new("b", kind.clone())], - kind, + &[Param::new("a", ty.clone()), Param::new("b", ty.clone())], + ty, &[bin_op(OpMath::Add, var("a"), var("b"))], ); diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index 844dc872fa..09707ce53d 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -11,7 +11,7 @@ use spacetimedb_lib::relation::{ use spacetimedb_sats::algebraic_type::AlgebraicType; use spacetimedb_sats::algebraic_value::AlgebraicValue; use spacetimedb_sats::satn::Satn; -use spacetimedb_sats::{ProductValue, TypeInSpace, Typespace}; +use spacetimedb_sats::{ProductValue, Typespace, WithTypespace}; use crate::errors::{ErrorKind, ErrorLang, ErrorType, ErrorVm}; use crate::functions::{FunDef, Param}; @@ -466,7 +466,7 @@ pub enum ExprOpt { pub(crate) fn fmt_value(ty: &AlgebraicType, val: &AlgebraicValue) -> String { let ts = Typespace::new(vec![]); - TypeInSpace::new(&ts, ty).with_value(val).to_satn() + WithTypespace::new(&ts, ty).with_value(val).to_satn() } impl fmt::Display for SourceExpr { diff --git a/crates/vm/src/typecheck.rs b/crates/vm/src/typecheck.rs index 1f921fabf0..fd2e186091 100644 --- a/crates/vm/src/typecheck.rs +++ b/crates/vm/src/typecheck.rs @@ -3,7 +3,6 @@ use crate::errors::ErrorType; use crate::expr::{CrudExprOpt, ExprOpt, SourceExprOpt}; use crate::types::Ty; use spacetimedb_sats::algebraic_type::AlgebraicType; -use spacetimedb_sats::builtin_type::BuiltinType; fn get_type<'a>(_env: &'a mut EnvTy, node: &'a ExprOpt) -> &'a Ty { match node { @@ -58,7 +57,7 @@ pub(crate) fn check_types(env: &mut EnvTy, ast: &ExprOpt) -> Result { if op.of.is_logical() { - return Ok(Ty::Val(AlgebraicType::Builtin(BuiltinType::Bool))); + return Ok(Ty::Val(AlgebraicType::Bool)); } let expects = match &op.ty { @@ -91,7 +90,7 @@ pub(crate) fn check_types(env: &mut EnvTy, ast: &ExprOpt) -> Result { let (test, if_true, if_false) = &**inner; - let expect = Ty::Val(AlgebraicType::Builtin(BuiltinType::Bool)); + let expect = Ty::Val(AlgebraicType::Bool); let found = check_types(env, test)?; if check_types(env, test)? == expect { let lhs = check_types(env, if_true)?; diff --git a/crates/vm/src/types.rs b/crates/vm/src/types.rs index 949bcd0e02..15948cb413 100644 --- a/crates/vm/src/types.rs +++ b/crates/vm/src/types.rs @@ -2,7 +2,7 @@ use std::fmt; use crate::operator::*; -use spacetimedb_sats::algebraic_type::map_notation::Formatter; +use spacetimedb_sats::algebraic_type::map_notation::fmt_algebraic_type; use spacetimedb_sats::algebraic_type::AlgebraicType; use spacetimedb_sats::algebraic_value::AlgebraicValue; use spacetimedb_sats::builtin_type::BuiltinType; @@ -20,19 +20,9 @@ pub enum Ty { impl fmt::Display for Ty { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Ty::Unknown => { - write!(f, "{self:?}") - } - Ty::Val(x) => { - let x = Formatter::new(x); - write!(f, "{x}") - } - Ty::Multi(options) => { - for x in options { - write!(f, "{x}")?; - } - Ok(()) - } + Ty::Unknown => write!(f, "{self:?}"), + Ty::Val(ty) => write!(f, "{}", fmt_algebraic_type(ty)), + Ty::Multi(options) => options.iter().try_for_each(|x| write!(f, "{x}")), Ty::Fun { params, result } => { write!(f, "(")?; for (pos, x) in params.iter().enumerate() {