diff --git a/crates/bindings/src/rt.rs b/crates/bindings/src/rt.rs index 6d0288009c..55060def87 100644 --- a/crates/bindings/src/rt.rs +++ b/crates/bindings/src/rt.rs @@ -81,8 +81,7 @@ pub fn invoke_connection_func( /// Creates a reducer context from the given `sender` and `timestamp`. fn assemble_context(sender: Buffer, timestamp: u64) -> ReducerContext { - let sender = sender.read_array::<32>(); - let sender = Identity { data: sender }; + let sender = Identity::from_byte_array(sender.read_array::<32>()); let timestamp = Timestamp::UNIX_EPOCH + Duration::from_micros(timestamp); diff --git a/crates/cli/src/subcommands/generate/csharp.rs b/crates/cli/src/subcommands/generate/csharp.rs index d15ba15d23..dbb272df0b 100644 --- a/crates/cli/src/subcommands/generate/csharp.rs +++ b/crates/cli/src/subcommands/generate/csharp.rs @@ -44,8 +44,8 @@ fn ty_fmt<'a>(ctx: &'a GenCtx, ty: &'a AlgebraicType, namespace: &'a str) -> imp fmt_fn(move |f| match ty { AlgebraicType::Sum(sum_type) => { // This better be an option type - if is_option_type(sum_type) { - match &sum_type.variants[0].algebraic_type { + if let Some(inner_ty) = sum_type.as_option() { + match inner_ty { Builtin(b) => match b { BuiltinType::Bool | BuiltinType::I8 @@ -60,22 +60,29 @@ fn ty_fmt<'a>(ctx: &'a GenCtx, ty: &'a AlgebraicType, namespace: &'a str) -> imp | BuiltinType::U128 | BuiltinType::F32 | BuiltinType::F64 => { - // This has to be a nullable type - write!(f, "{}?", ty_fmt(ctx, &sum_type.variants[0].algebraic_type, namespace)) + // This has to be a nullable type. + write!(f, "{}?", ty_fmt(ctx, inner_ty, namespace)) } _ => { - write!(f, "{}", ty_fmt(ctx, &sum_type.variants[0].algebraic_type, namespace)) + write!(f, "{}", ty_fmt(ctx, inner_ty, namespace)) } }, _ => { - write!(f, "{}", ty_fmt(ctx, &sum_type.variants[0].algebraic_type, namespace)) + write!(f, "{}", ty_fmt(ctx, inner_ty, namespace)) } } } else { unimplemented!() } } - AlgebraicType::Product(_) => unimplemented!(), + AlgebraicType::Product(prod) => { + // The only type that is allowed here is the identity type. All other types should fail. + if prod.is_identity() { + write!(f, "SpacetimeDB.Identity") + } else { + unimplemented!() + } + } AlgebraicType::Builtin(b) => match maybe_primitive(b) { MaybePrimitive::Primitive(p) => f.write_str(p), MaybePrimitive::Array(ArrayType { elem_ty }) if **elem_ty == AlgebraicType::U8 => f.write_str("byte[]"), @@ -169,10 +176,20 @@ fn convert_type<'a>( namespace: &'a str, ) -> impl fmt::Display + 'a { fmt_fn(move |f| match ty { - AlgebraicType::Product(_) => unimplemented!(), + AlgebraicType::Product(product) => { + if product.is_identity() { + write!( + f, + "SpacetimeDB.Identity.From({}.AsProductValue().elements[0].AsBytes())", + value + ) + } else { + unimplemented!() + } + } AlgebraicType::Sum(sum_type) => { - if is_option_type(sum_type) { - match &sum_type.variants[0].algebraic_type { + if let Some(inner_ty) = sum_type.as_option() { + match inner_ty { Builtin(ty) => match ty { BuiltinType::Bool | BuiltinType::I8 @@ -190,15 +207,15 @@ fn convert_type<'a>( f, "{}.AsSumValue().tag == 1 ? null : new {}?({}.AsSumValue().value{})", value, - ty_fmt(ctx, &sum_type.variants[0].algebraic_type, namespace), + ty_fmt(ctx, inner_ty, namespace), value, - &convert_type(ctx, vecnest, &sum_type.variants[0].algebraic_type, "", namespace), + &convert_type(ctx, vecnest, inner_ty, "", namespace), ), _ => fmt::Display::fmt( &convert_type( ctx, vecnest, - &sum_type.variants[0].algebraic_type, + inner_ty, format_args!("{}.AsSumValue().tag == 1 ? null : {}.AsSumValue().value", value, value), namespace, ), @@ -209,7 +226,7 @@ fn convert_type<'a>( &convert_type( ctx, vecnest, - &sum_type.variants[0].algebraic_type, + inner_ty, format_args!("{}.AsSumValue().tag == 1 ? null : {}.AsSumValue().value", value, value), namespace, ), @@ -261,24 +278,6 @@ fn csharp_typename(ctx: &GenCtx, typeref: AlgebraicTypeRef) -> &str { ctx.names[typeref.idx()].as_deref().expect("tuples should have names") } -fn is_option_type(ty: &SumType) -> bool { - if ty.variants.len() != 2 { - return false; - } - - if ty.variants[0].name.clone().expect("Variants should have names!") != "some" - || ty.variants[1].name.clone().expect("Variants should have names!") != "none" - { - return false; - } - - if let AlgebraicType::Product(none_type) = &ty.variants[1].algebraic_type { - none_type.elements.is_empty() - } else { - false - } -} - macro_rules! indent_scope { ($x:ident) => { let mut $x = $x.indented(1); @@ -605,7 +604,7 @@ fn autogen_csharp_product_table_common( } } AlgebraicType::Sum(sum) => { - if is_option_type(sum) { + if sum.as_option().is_some() { writeln!(output, "[SpacetimeDB.Some]").unwrap(); } else { unimplemented!() @@ -656,7 +655,7 @@ fn autogen_csharp_product_table_common( output, "private static Dictionary<{type_name}, {name}> {field_name}_Index = new Dictionary<{type_name}, {name}>(16{comparer});" ) - .unwrap(); + .unwrap(); } writeln!(output).unwrap(); // OnInsert method for updating indexes @@ -772,7 +771,7 @@ fn autogen_csharp_product_table_common( output, "OnUpdate?.Invoke(({name})oldValue,({name})newValue,(ReducerEvent)dbEvent?.FunctionCall.CallInfo);" ) - .unwrap(); + .unwrap(); } writeln!(output, "}}").unwrap(); writeln!(output).unwrap(); @@ -815,7 +814,7 @@ fn autogen_csharp_product_table_common( output, "public static void OnRowUpdateEvent(SpacetimeDBClient.TableOp op, object oldValue, object newValue, ClientApi.Event dbEvent)" ) - .unwrap(); + .unwrap(); writeln!(output, "{{").unwrap(); { indent_scope!(output); @@ -823,7 +822,7 @@ fn autogen_csharp_product_table_common( output, "OnRowUpdate?.Invoke(op, ({name})oldValue,({name})newValue,(ReducerEvent)dbEvent?.FunctionCall.CallInfo);" ) - .unwrap(); + .unwrap(); } writeln!(output, "}}").unwrap(); } @@ -877,7 +876,7 @@ fn autogen_csharp_product_value_to_struct( 0, field_type, format_args!("productValue.elements[{idx}]"), - namespace + namespace, ) ) .unwrap(); @@ -950,12 +949,18 @@ fn autogen_csharp_access_funcs_for_struct( let csharp_field_name_pascal = field_name.replace("r#", "").to_case(Case::Pascal); let (field_type, csharp_field_type) = match field_type { - AlgebraicType::Product(_) | AlgebraicType::Ref(_) => { - // TODO: We don't allow filtering on tuples right now, its possible we may consider it for the future. - continue; + AlgebraicType::Product(product) => { + if product.is_identity() { + ("Identity".into(), "SpacetimeDB.Identity") + } else { + // TODO: We don't allow filtering on tuples right now, + // it's possible we may consider it for the future. + continue; + } } - AlgebraicType::Sum(_) => { - // TODO: We don't allow filtering on enums right now, its possible we may consider it for the future. + AlgebraicType::Ref(_) | AlgebraicType::Sum(_) => { + // TODO: We don't allow filtering on enums or tuples right now; + // it's possible we may consider it for the future. continue; } AlgebraicType::Builtin(b) => match maybe_primitive(b) { @@ -1012,12 +1017,21 @@ fn autogen_csharp_access_funcs_for_struct( { indent_scope!(output); writeln!(output, "var productValue = entry.Item1.AsProductValue();").unwrap(); - writeln!( - output, - "var compareValue = ({})productValue.elements[{}].As{}();", - csharp_field_type, col_i, field_type - ) - .unwrap(); + if field_type == "Identity" { + writeln!( + output, + "var compareValue = Identity.From(productValue.elements[{}].AsProductValue().elements[0].AsBytes());", + col_i + ) + .unwrap(); + } else { + writeln!( + output, + "var compareValue = ({})productValue.elements[{}].As{}();", + csharp_field_type, col_i, field_type + ) + .unwrap(); + } if csharp_field_type == "byte[]" { writeln!( output, @@ -1076,7 +1090,7 @@ fn autogen_csharp_access_funcs_for_struct( output, "public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2)" ) - .unwrap(); + .unwrap(); writeln!(output, "{{").unwrap(); { indent_scope!(output); @@ -1096,7 +1110,7 @@ fn autogen_csharp_access_funcs_for_struct( output, "return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2);" ) - .unwrap(); + .unwrap(); } writeln!(output, "}}").unwrap(); } else { @@ -1104,7 +1118,7 @@ fn autogen_csharp_access_funcs_for_struct( output, "public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)" ) - .unwrap(); + .unwrap(); writeln!(output, "{{").unwrap(); { indent_scope!(output); @@ -1180,10 +1194,10 @@ pub fn autogen_csharp_reducer(ctx: &GenCtx, reducer: &ReducerDef, namespace: &st match &arg.algebraic_type { AlgebraicType::Sum(sum_type) => { - if is_option_type(sum_type) { - json_args.push_str(format!("new SomeWrapper({})", arg_name).as_str()); + if sum_type.as_option().is_some() { + json_args.push_str(&format!("new SomeWrapper({})", arg_name)); } else { - json_args.push_str(arg_name.as_str()); + json_args.push_str(&arg_name); } } AlgebraicType::Product(_) => { diff --git a/crates/cli/src/subcommands/generate/python.rs b/crates/cli/src/subcommands/generate/python.rs index f7b0f5bbe3..f120376a8a 100644 --- a/crates/cli/src/subcommands/generate/python.rs +++ b/crates/cli/src/subcommands/generate/python.rs @@ -67,20 +67,14 @@ fn convert_type<'a>( ) -> impl fmt::Display + 'a { fmt_fn(move |f| match ty { AlgebraicType::Product(_) => unreachable!(), - AlgebraicType::Sum(sum_type) if is_option_type(sum_type) => { - write!( + AlgebraicType::Sum(sum_type) => match sum_type.as_option() { + Some(inner_ty) => write!( f, "{} if '0' in {value} else None", - convert_type( - ctx, - vecnest, - &sum_type.variants[0].algebraic_type, - format!("{value}['0']"), - ref_prefix - ) - ) - } - AlgebraicType::Sum(_sum_type) => unimplemented!(), + convert_type(ctx, vecnest, inner_ty, format!("{value}['0']"), ref_prefix), + ), + None => unimplemented!(), + }, AlgebraicType::Builtin(b) => fmt::Display::fmt(&convert_builtintype(ctx, vecnest, b, &value, ref_prefix), f), AlgebraicType::Ref(r) => { let name = python_typename(ctx, *r); @@ -139,24 +133,6 @@ fn python_filename(ctx: &GenCtx, typeref: AlgebraicTypeRef) -> String { .to_case(Case::Snake) } -fn is_option_type(ty: &SumType) -> bool { - if ty.variants.len() != 2 { - return false; - } - - if ty.variants[0].name.clone().expect("Variants should have names!") != "some" - || ty.variants[1].name.clone().expect("Variants should have names!") != "none" - { - return false; - } - - if let AlgebraicType::Product(none_type) = &ty.variants[1].algebraic_type { - none_type.elements.is_empty() - } else { - false - } -} - pub fn autogen_python_table(ctx: &GenCtx, table: &TableDef) -> String { let tuple = ctx.typespace[table.data].as_product().unwrap(); autogen_python_product_table_common(ctx, &table.name, tuple, Some(&table.column_attrs)) @@ -176,10 +152,12 @@ fn _generate_imports(ctx: &GenCtx, ty: &AlgebraicType, imports: &mut Vec _generate_imports(ctx, &map_type.key_ty, imports); _generate_imports(ctx, &map_type.ty, imports); } - _ => (), + _ => {} }, - AlgebraicType::Sum(sum_type) if is_option_type(sum_type) => { - _generate_imports(ctx, &sum_type.variants[0].algebraic_type, imports); + AlgebraicType::Sum(sum_type) => { + if let Some(inner_ty) = sum_type.as_option() { + _generate_imports(ctx, inner_ty, imports) + } } AlgebraicType::Ref(r) => { let class_name = python_typename(ctx, *r).to_string(); @@ -188,7 +166,7 @@ fn _generate_imports(ctx: &GenCtx, ty: &AlgebraicType, imports: &mut Vec let import = format!("from .{filename} import {class_name}"); imports.push(import); } - _ => (), + _ => {} } } @@ -302,7 +280,7 @@ fn autogen_python_product_table_common( continue; } AlgebraicType::Sum(ty) => { - if !is_option_type(ty) { + if ty.as_option().is_none() { // TODO: We don't allow filtering on enums right now, its possible we may consider it for the future. continue; } @@ -381,7 +359,7 @@ fn autogen_python_product_table_common( let python_field_name = field_name.to_string().replace("r#", ""); match &field.algebraic_type { - AlgebraicType::Sum(sum_type) if is_option_type(sum_type) => { + AlgebraicType::Sum(sum_type) if sum_type.as_option().is_some() => { reducer_args.push(format!("{{'0': [self.{}]}}", python_field_name)) } AlgebraicType::Sum(_) => unimplemented!(), @@ -494,20 +472,14 @@ pub fn encode_type<'a>( ) -> impl fmt::Display + 'a { fmt_fn(move |f| match ty { AlgebraicType::Product(_) => unreachable!(), - AlgebraicType::Sum(sum_type) if is_option_type(sum_type) => { - write!( + AlgebraicType::Sum(sum_type) => match sum_type.as_option() { + Some(inner_ty) => write!( f, "{{'0': {}}} if value is not None else {{}}", - encode_type( - ctx, - vecnest, - &sum_type.variants[0].algebraic_type, - format!("{value}"), - ref_prefix - ) - ) - } - AlgebraicType::Sum(_sum_type) => unimplemented!(), + encode_type(ctx, vecnest, inner_ty, format!("{value}"), ref_prefix), + ), + None => unimplemented!(), + }, AlgebraicType::Builtin(b) => fmt::Display::fmt(&encode_builtintype(ctx, vecnest, b, &value, ref_prefix), f), AlgebraicType::Ref(r) => { let algebraic_type = &ctx.typespace.types[r.idx()]; diff --git a/crates/cli/src/subcommands/generate/rust.rs b/crates/cli/src/subcommands/generate/rust.rs index 7773de3fe9..4216fbc3b5 100644 --- a/crates/cli/src/subcommands/generate/rust.rs +++ b/crates/cli/src/subcommands/generate/rust.rs @@ -41,26 +41,6 @@ fn maybe_primitive(b: &BuiltinType) -> MaybePrimitive { }) } -fn is_empty_product(ty: &AlgebraicType) -> bool { - if let AlgebraicType::Product(none_type) = ty { - none_type.elements.is_empty() - } else { - false - } -} - -// This function is duplicated in [typescript.rs] and [csharp.rs], and should maybe be -// lifted into a module, or be a part of SATS itself. -fn is_option_type(ty: &SumType) -> bool { - let name_is = |variant: &SumTypeVariant, name| variant.name.as_ref().expect("Variants should have names!") == name; - matches!( - &ty.variants[..], - [a, b] if name_is(a, "some") - && name_is(b, "none") - && is_empty_product(&b.algebraic_type) - ) -} - fn write_type_ctx(ctx: &GenCtx, out: &mut Indenter, ty: &AlgebraicType) { write_type(&|r| type_name(ctx, r), out, ty) } @@ -68,9 +48,9 @@ fn write_type_ctx(ctx: &GenCtx, out: &mut Indenter, ty: &AlgebraicType) { pub fn write_type(ctx: &impl Fn(AlgebraicTypeRef) -> String, out: &mut W, ty: &AlgebraicType) { match ty { AlgebraicType::Sum(sum_type) => { - if is_option_type(sum_type) { + if let Some(inner_ty) = sum_type.as_option() { write!(out, "Option::<").unwrap(); - write_type(ctx, out, &sum_type.variants[0].algebraic_type); + write_type(ctx, out, inner_ty); write!(out, ">").unwrap(); } else { write!(out, "enum ").unwrap(); @@ -82,6 +62,9 @@ pub fn write_type(ctx: &impl Fn(AlgebraicTypeRef) -> String, out: &mut }); } } + AlgebraicType::Product(p) if p.is_identity() => { + write!(out, "Identity").unwrap(); + } AlgebraicType::Product(ProductType { elements }) => { print_comma_sep_braced(out, elements, |out: &mut W, elem: &ProductTypeElement| { if let Some(name) = &elem.name { diff --git a/crates/cli/src/subcommands/generate/typescript.rs b/crates/cli/src/subcommands/generate/typescript.rs index baf98a658d..022c369308 100644 --- a/crates/cli/src/subcommands/generate/typescript.rs +++ b/crates/cli/src/subcommands/generate/typescript.rs @@ -21,51 +21,28 @@ enum MaybePrimitive<'a> { fn maybe_primitive(b: &BuiltinType) -> MaybePrimitive { MaybePrimitive::Primitive(match b { BuiltinType::Bool => "boolean", - BuiltinType::I8 => "number", - BuiltinType::U8 => "number", - BuiltinType::I16 => "number", - BuiltinType::U16 => "number", - BuiltinType::I32 => "number", - BuiltinType::U32 => "number", - BuiltinType::I64 => "number", - BuiltinType::U64 => "number", - BuiltinType::I128 => "BigInt", - BuiltinType::U128 => "BigInt", + BuiltinType::I8 + | BuiltinType::U8 + | BuiltinType::I16 + | BuiltinType::U16 + | BuiltinType::I32 + | BuiltinType::U32 + | BuiltinType::I64 + | BuiltinType::U64 + | BuiltinType::F32 + | BuiltinType::F64 => "number", + BuiltinType::I128 | BuiltinType::U128 => "BigInt", BuiltinType::String => "string", - BuiltinType::F32 => "number", - BuiltinType::F64 => "number", BuiltinType::Array(ty) => return MaybePrimitive::Array(ty), BuiltinType::Map(m) => return MaybePrimitive::Map(m), }) } -fn is_option_type(ty: &SumType) -> bool { - if ty.variants.len() != 2 { - return false; - } - - if ty.variants[0].name.clone().expect("Variants should have names!") != "some" - || ty.variants[1].name.clone().expect("Variants should have names!") != "none" - { - return false; - } - - if let AlgebraicType::Product(none_type) = &ty.variants[1].algebraic_type { - none_type.elements.is_empty() - } else { - false - } -} - fn ty_fmt<'a>(ctx: &'a GenCtx, ty: &'a AlgebraicType, ref_prefix: &'a str) -> impl fmt::Display + 'a { fmt_fn(move |f| match ty { AlgebraicType::Sum(sum_type) => { - if is_option_type(sum_type) { - write!( - f, - "{} | null", - ty_fmt(ctx, &sum_type.variants[0].algebraic_type, ref_prefix) - ) + if let Some(inner_ty) = sum_type.as_option() { + write!(f, "{} | null", ty_fmt(ctx, inner_ty, ref_prefix)) } else { unimplemented!() } @@ -146,8 +123,8 @@ fn convert_type<'a>( fmt_fn(move |f| match ty { AlgebraicType::Product(_) => unreachable!(), AlgebraicType::Sum(sum_type) => { - if is_option_type(sum_type) { - match &sum_type.variants[0].algebraic_type { + if let Some(inner_ty) = sum_type.as_option() { + match inner_ty { Builtin(ty) => match ty { BuiltinType::Bool | BuiltinType::I8 @@ -166,13 +143,13 @@ fn convert_type<'a>( "{}.asSumValue().tag == 1 ? null : {}.asSumValue().value{}", value, value, - &convert_type(ctx, vecnest, &sum_type.variants[0].algebraic_type, "", ref_prefix), + &convert_type(ctx, vecnest, inner_ty, "", ref_prefix), ), _ => fmt::Display::fmt( &convert_type( ctx, vecnest, - &sum_type.variants[0].algebraic_type, + inner_ty, format_args!("{value}.asSumValue().tag == 1 ? null : {value}.asSumValue().value"), ref_prefix ), @@ -187,7 +164,7 @@ fn convert_type<'a>( convert_type( ctx, vecnest, - &sum_type.variants[0].algebraic_type, + inner_ty, "value", ref_prefix ) @@ -198,7 +175,7 @@ fn convert_type<'a>( &convert_type( ctx, vecnest, - &sum_type.variants[0].algebraic_type, + inner_ty, format_args!("{value}.asSumValue().tag == 1 ? null : {value}.asSumValue().value"), ref_prefix ), @@ -294,21 +271,6 @@ fn convert_sum_type<'a>(ctx: &'a GenCtx, sum_type: &'a SumType, ref_prefix: &'a }) } -pub fn is_enum(sum_type: &SumType) -> bool { - for variant in sum_type.clone().variants { - match variant.algebraic_type { - AlgebraicType::Product(product) => { - if product.elements.is_empty() { - continue; - } - } - _ => return false, - } - } - - true -} - fn serialize_type<'a>( ctx: &'a GenCtx, ty: &'a AlgebraicType, @@ -318,11 +280,11 @@ fn serialize_type<'a>( fmt_fn(move |f| match ty { AlgebraicType::Product(_) => unreachable!(), AlgebraicType::Sum(sum_type) => { - if is_option_type(sum_type) { + if let Some(inner_ty) = sum_type.as_option() { write!( f, "{value} ? {{ \"some\": {} }} : {{ \"none\": [] }}", - serialize_type(ctx, &sum_type.variants[0].algebraic_type, value, prefix) + serialize_type(ctx, inner_ty, value, prefix) ) } else { unimplemented!() @@ -392,7 +354,7 @@ pub fn autogen_typescript_sum(ctx: &GenCtx, name: &str, sum_type: &SumType) -> S { indent_scope!(output); - if is_enum(sum_type) { + if sum_type.is_simple_enum() { // for a simple enum we can simplify the fromValue function writeln!(output, "const result: {{[key: string]: any}} = {{}};").unwrap(); writeln!(output, "result[value.tag] = [];").unwrap(); @@ -467,7 +429,7 @@ pub fn autogen_typescript_sum(ctx: &GenCtx, name: &str, sum_type: &SumType) -> S writeln!(output, "let sumValue = value.asSumValue();").unwrap(); - if is_enum(sum_type) { + if sum_type.is_simple_enum() { // for a simple enum we can simplify the fromValue function writeln!(output, "let tag = sumValue.tag;").unwrap(); writeln!( diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 63c18487b0..b2cbd2196e 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -843,7 +843,7 @@ pub async fn publish( let op = match control_ctx_find_database(&*ctx, &db_address).await? { Some(db) => { - if Identity::from_slice(db.identity.as_slice()) != auth.identity { + if db.identity != auth.identity { return Err((StatusCode::BAD_REQUEST, "Identity does not own this database.").into()); } diff --git a/crates/client-api/src/routes/identity.rs b/crates/client-api/src/routes/identity.rs index c5cfa42e65..0b96744107 100644 --- a/crates/client-api/src/routes/identity.rs +++ b/crates/client-api/src/routes/identity.rs @@ -5,6 +5,7 @@ use axum::response::IntoResponse; use http::StatusCode; use serde::{Deserialize, Serialize}; use spacetimedb::auth::identity::encode_token_with_expiry; +use spacetimedb_lib::de::serde::DeserializeWrapper; use spacetimedb_lib::Identity; use crate::auth::{SpacetimeAuth, SpacetimeAuthHeader}; @@ -87,9 +88,35 @@ pub async fn get_identity( Ok(axum::Json(identity_response)) } +/// A version of `Identity` appropriate for URL de/encoding. +/// +/// Because `Identity` is represented in SATS as a `ProductValue`, +/// its serialized format is somewhat gnarly. +/// When URL-encoding identities, we want to use only the hex string, +/// without wrapping it in a `ProductValue`. +/// This keeps our routes pretty, like `/identity/<64 hex chars>/set-email`. +/// +/// This newtype around `Identity` implements `Deserialize` +/// directly from the inner identity bytes, +/// without the enclosing `ProductValue` wrapper. +pub struct IdentityForUrl(Identity); + +impl From for Identity { + /// Consumes `self` returning the backing `Identity`. + fn from(IdentityForUrl(id): IdentityForUrl) -> Identity { + id + } +} + +impl<'de> serde::Deserialize<'de> for IdentityForUrl { + fn deserialize>(de: D) -> Result { + <_>::deserialize(de).map(|DeserializeWrapper(b)| IdentityForUrl(Identity::from_byte_array(b))) + } +} + #[derive(Deserialize)] pub struct SetEmailParams { - identity: Identity, + identity: IdentityForUrl, } #[derive(Deserialize)] @@ -103,6 +130,7 @@ pub async fn set_email( Query(SetEmailQueryParams { email }): Query, auth: SpacetimeAuthHeader, ) -> axum::response::Result { + let identity = identity.into(); let auth = auth.get().ok_or(StatusCode::BAD_REQUEST)?; if auth.identity != identity { @@ -119,7 +147,7 @@ pub async fn set_email( #[derive(Deserialize)] pub struct GetDatabasesParams { - identity: Identity, + identity: IdentityForUrl, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -131,6 +159,7 @@ pub async fn get_databases( State(ctx): State>, Path(GetDatabasesParams { identity }): Path, ) -> axum::response::Result { + let identity = identity.into(); // Linear scan for all databases that have this identity, and return their addresses let all_dbs = ctx.control_db().get_databases().await.map_err(|e| { log::error!("Failure when retrieving databases for search: {}", e); diff --git a/crates/client-api/src/routes/tracelog.rs b/crates/client-api/src/routes/tracelog.rs index 3194c17ef4..fda646ab53 100644 --- a/crates/client-api/src/routes/tracelog.rs +++ b/crates/client-api/src/routes/tracelog.rs @@ -74,10 +74,8 @@ pub async fn perform_tracelog_replay(body: Bytes) -> axum::response::Result Message { Message { r#type: Some(message::Type::IdentityToken(IdentityToken { - identity: self.identity.as_slice().to_vec(), + identity: self.identity.as_bytes().to_vec(), token: self.identity_token, })), } @@ -86,7 +86,7 @@ impl ServerMessage for TransactionUpdateMessage<'_> { let event = Event { timestamp: event.timestamp.0, status: status.into(), - caller_identity: event.caller_identity.data.to_vec(), + caller_identity: event.caller_identity.to_vec(), function_call: Some(FunctionCall { reducer: event.function_call.reducer.to_owned(), arg_bytes: event.function_call.args.get_bsatn().clone().into(), diff --git a/crates/core/src/control_db.rs b/crates/core/src/control_db.rs index 38480962b4..d2f26cd79e 100644 --- a/crates/core/src/control_db.rs +++ b/crates/core/src/control_db.rs @@ -174,7 +174,7 @@ impl ControlDb { } } None => { - tree.insert(tld.as_lowercase().as_bytes(), owner_identity.as_slice())?; + tree.insert(tld.as_lowercase().as_bytes(), owner_identity.as_bytes())?; Ok(RegisterTldResult::Success { domain: tld }) } } @@ -251,7 +251,7 @@ impl ControlDb { let name = b"clockworklabs:"; let bytes = [name, bytes].concat(); let hash = hash_bytes(bytes); - let address = Address::from_slice(&hash.as_slice()[0..16]); + let address = Address::from_slice(&hash.as_slice()[..16]); Ok(address) } @@ -262,7 +262,7 @@ impl ControlDb { let tree = self.db.open_tree("email")?; let identity_email = IdentityEmail { identity, email }; let buf = bsatn::to_vec(&identity_email).unwrap(); - tree.insert(identity.as_slice(), buf)?; + tree.insert(identity.as_bytes(), buf)?; Ok(()) } diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 1214925c2c..0198ad80ea 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -753,14 +753,14 @@ impl WasmInstanceActor { arg_bytes, } => self .instance - .call_reducer(id, budget, &sender.data, timestamp, arg_bytes), + .call_reducer(id, budget, sender.as_bytes(), timestamp, arg_bytes), InstanceOp::ConnDisconn { conn, sender, timestamp, } => self .instance - .call_connect_disconnect(conn, budget, &sender.data, timestamp), + .call_connect_disconnect(conn, budget, sender.as_bytes(), timestamp), }); let ExecuteResult { diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 5ebdf7ceeb..13ed7fe789 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -286,7 +286,7 @@ mod tests { &db, &mut tx, &q, - AuthCtx::new(Identity::__dummy(), Identity::from_arr(&[1u8; 32])), + AuthCtx::new(Identity::__dummy(), Identity::from_byte_array([1u8; 32])), ) { Ok(_) => { panic!("it allows to execute against private table") diff --git a/crates/lib/src/hash.rs b/crates/lib/src/hash.rs index b6c0b91c91..9af1e8c53f 100644 --- a/crates/lib/src/hash.rs +++ b/crates/lib/src/hash.rs @@ -1,8 +1,4 @@ -#[cfg(feature = "serde")] -use crate::{de, ser}; - use core::fmt; - use sha3::{Digest, Keccak256}; use spacetimedb_sats::{impl_deserialize, impl_serialize, impl_st, AlgebraicType}; @@ -80,19 +76,13 @@ impl hex::FromHex for Hash { #[cfg(feature = "serde")] impl serde::Serialize for Hash { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - ser::serde::serialize_to(self, serializer) + fn serialize(&self, serializer: S) -> Result { + spacetimedb_sats::ser::serde::serialize_to(self, serializer) } } #[cfg(feature = "serde")] impl<'de> serde::Deserialize<'de> for Hash { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - de::serde::deserialize_from(deserializer) + fn deserialize>(deserializer: D) -> Result { + spacetimedb_sats::de::serde::deserialize_from(deserializer) } } diff --git a/crates/lib/src/identity.rs b/crates/lib/src/identity.rs index 7ce598a0f3..ff7dbd9b13 100644 --- a/crates/lib/src/identity.rs +++ b/crates/lib/src/identity.rs @@ -1,12 +1,7 @@ -#[cfg(feature = "serde")] -use crate::{de, ser}; - +use spacetimedb_bindings_macro::{Deserialize, Serialize}; +use spacetimedb_sats::{impl_st, AlgebraicType, ProductTypeElement}; use std::fmt; -use sats::{impl_deserialize, impl_serialize, impl_st}; - -use crate::sats; - #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct AuthCtx { pub owner: Identity, @@ -30,63 +25,64 @@ impl AuthCtx { } } -#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)] +#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash, Serialize, Deserialize)] pub struct Identity { - pub data: [u8; 32], -} - -impl Identity { - #[doc(hidden)] - pub fn __dummy() -> Self { - Self { data: [0; 32] } - } + __identity_bytes: [u8; 32], } -impl_st!([] Identity, _ts => sats::AlgebraicType::bytes()); -impl_serialize!([] Identity, (self, ser) => self.data.serialize(ser)); -impl_deserialize!([] Identity, de => Ok(Self { data: <_>::deserialize(de)? })); +impl_st!([] Identity, _ts => AlgebraicType::product(vec![ + ProductTypeElement::new_named(AlgebraicType::bytes(), "__identity_bytes") +])); impl Identity { const ABBREVIATION_LEN: usize = 16; - pub fn from_arr(arr: &[u8; 32]) -> Self { - Self { data: *arr } + /// Returns an `Identity` defined as the given `bytes` byte array. + pub fn from_byte_array(bytes: [u8; 32]) -> Self { + Self { + __identity_bytes: bytes, + } } + /// Returns an `Identity` defined as the given byte `slice`. pub fn from_slice(slice: &[u8]) -> Self { - Self { - data: slice.try_into().unwrap(), - } + Self::from_byte_array(slice.try_into().unwrap()) + } + + #[doc(hidden)] + pub fn __dummy() -> Self { + Self::from_byte_array([0; 32]) } + + /// Returns a borrowed view of the byte array defining this `Identity`. + pub fn as_bytes(&self) -> &[u8; 32] { + &self.__identity_bytes + } + pub fn to_vec(&self) -> Vec { - self.data.to_vec() + self.__identity_bytes.to_vec() } pub fn to_hex(&self) -> String { - hex::encode(self.data) + hex::encode(self.__identity_bytes) } pub fn to_abbreviated_hex(&self) -> String { self.to_hex()[0..Identity::ABBREVIATION_LEN].to_owned() } - pub fn as_slice(&self) -> &[u8] { - self.data.as_slice() - } - pub fn from_hex(hex: impl AsRef<[u8]>) -> Result { hex::FromHex::from_hex(hex) } pub fn from_hashing_bytes(bytes: impl AsRef<[u8]>) -> Self { - let hash = crate::hash::hash_bytes(bytes); - Identity { data: hash.data } + Identity::from_byte_array(crate::hash::hash_bytes(bytes).data) } } impl fmt::Display for Identity { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&hex::encode(self.data)) + f.write_str(&hex::encode(self.__identity_bytes)) } } @@ -101,25 +97,20 @@ impl hex::FromHex for Identity { fn from_hex>(hex: T) -> Result { let data = hex::FromHex::from_hex(hex)?; - Ok(Identity { data }) + Ok(Identity { __identity_bytes: data }) } } #[cfg(feature = "serde")] impl serde::Serialize for Identity { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - ser::serde::serialize_to(self, serializer) + fn serialize(&self, serializer: S) -> Result { + spacetimedb_sats::ser::serde::serialize_to(self, serializer) } } + #[cfg(feature = "serde")] impl<'de> serde::Deserialize<'de> for Identity { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - de::serde::deserialize_from(deserializer) + fn deserialize>(deserializer: D) -> Result { + spacetimedb_sats::de::serde::deserialize_from(deserializer) } } diff --git a/crates/replay/src/main.rs b/crates/replay/src/main.rs index 0b5cbc9d5f..d709a4fb99 100644 --- a/crates/replay/src/main.rs +++ b/crates/replay/src/main.rs @@ -24,10 +24,8 @@ pub fn main() { let logger_path = tmp_dir.path(); let scheduler_path = tmp_dir.path().join("scheduler"); - let identity = Identity { - data: hash_bytes(b"This is a fake identity.").data, - }; - let address = Address::from_slice(&identity.as_slice()[0..16]); + let identity = Identity::from_byte_array(hash_bytes(b"This is a fake identity.").data); + let address = Address::from_slice(&identity.as_bytes()[..16]); let dbic = DatabaseInstanceContext::new(0, 0, false, identity, address, db_path.to_path_buf(), logger_path); diff --git a/crates/sats/src/algebraic_type.rs b/crates/sats/src/algebraic_type.rs index 5d43ffafcd..c232f7c0b4 100644 --- a/crates/sats/src/algebraic_type.rs +++ b/crates/sats/src/algebraic_type.rs @@ -164,11 +164,6 @@ impl AlgebraicType { /// The canonical 0-variant "never" / "absurd" / "void" type. pub const NEVER_TYPE: Self = Self::sum(Vec::new()); - - /// A type representing an array of `U8`s. - pub fn bytes() -> Self { - Self::array(Self::U8) - } } impl MetaType for AlgebraicType { @@ -188,6 +183,18 @@ impl MetaType for AlgebraicType { } impl AlgebraicType { + /// A type representing an array of `U8`s. + pub fn bytes() -> Self { + Self::array(Self::U8) + } + + /// Returns whether this type is `AlgebraicType::bytes()`. + pub fn is_bytes(&self) -> bool { + matches!(self, AlgebraicType::Builtin(BuiltinType::Array(ArrayType { elem_ty })) + if **elem_ty == AlgebraicType::U8 + ) + } + /// Returns a sum type with the given `variants`. pub const fn sum(variants: Vec) -> Self { AlgebraicType::Sum(SumType { variants }) diff --git a/crates/sats/src/de/impls.rs b/crates/sats/src/de/impls.rs index bb19160299..b6d5569c02 100644 --- a/crates/sats/src/de/impls.rs +++ b/crates/sats/src/de/impls.rs @@ -296,7 +296,7 @@ impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> { } fn is_option(&self) -> bool { - self.ty().looks_like_option().is_some() + self.ty().as_option().is_some() } fn visit_sum>(self, data: A) -> Result { diff --git a/crates/sats/src/product_type.rs b/crates/sats/src/product_type.rs index 68826a9fed..a63589098b 100644 --- a/crates/sats/src/product_type.rs +++ b/crates/sats/src/product_type.rs @@ -42,6 +42,17 @@ impl ProductType { pub const fn new(elements: Vec) -> Self { Self { elements } } + + /// Returns whether this is the special case of `spacetimedb_lib::Identity`. + pub fn is_identity(&self) -> bool { + match &*self.elements { + [ProductTypeElement { + name: Some(name), + algebraic_type, + }] => name == "__identity_bytes" && algebraic_type.is_bytes(), + _ => false, + } + } } impl> FromIterator for ProductType { diff --git a/crates/sats/src/sum_type.rs b/crates/sats/src/sum_type.rs index 47bd6faa7b..f4f3f443e8 100644 --- a/crates/sats/src/sum_type.rs +++ b/crates/sats/src/sum_type.rs @@ -53,11 +53,10 @@ impl SumType { /// /// An option type has `some(T)` as its first variant and `none` as its second. /// That is, `{ some(T), none }` or `some: T | none` depending on your notation. - pub fn looks_like_option(&self) -> Option<&AlgebraicType> { + pub fn as_option(&self) -> Option<&AlgebraicType> { match &*self.variants { [first, second] - if second.algebraic_type == AlgebraicType::UNIT_TYPE - // ^-- Done first to avoid pointer indirection when it doesn't matter. + if second.is_unit() // Done first to avoid pointer indirection when it doesn't matter. && first.has_name("some") && second.has_name("none") => { @@ -66,6 +65,11 @@ impl SumType { _ => None, } } + + /// Returns whether this sum type is like on in C without data attached to the variants. + pub fn is_simple_enum(&self) -> bool { + self.variants.iter().all(SumTypeVariant::is_unit) + } } impl MetaType for SumType { diff --git a/crates/sats/src/sum_type_variant.rs b/crates/sats/src/sum_type_variant.rs index 7d66b0f6ac..9713c718eb 100644 --- a/crates/sats/src/sum_type_variant.rs +++ b/crates/sats/src/sum_type_variant.rs @@ -45,6 +45,11 @@ impl SumTypeVariant { pub fn has_name(&self, name: &str) -> bool { self.name() == Some(name) } + + /// Returns whether this is a unit variant. + pub fn is_unit(&self) -> bool { + self.algebraic_type == AlgebraicType::UNIT_TYPE + } } impl MetaType for SumTypeVariant { diff --git a/crates/sdk/src/callbacks.rs b/crates/sdk/src/callbacks.rs index 752015210b..3e19f53c98 100644 --- a/crates/sdk/src/callbacks.rs +++ b/crates/sdk/src/callbacks.rs @@ -715,7 +715,7 @@ impl ReducerCallbacks { log::warn!("Received Event with function_call of None"); return None; }; - let identity = Identity { bytes: caller_identity }; + let identity = Identity::from_bytes(caller_identity); let Some(status) = parse_status(status, message) else { log::warn!("Received Event with unknown status {:?}", status); return None; @@ -855,7 +855,7 @@ impl CredentialStore { } let creds = Credentials { - identity: Identity { bytes: identity }, + identity: Identity::from_bytes(identity), token: Token { string: token }, }; diff --git a/crates/sdk/src/identity.rs b/crates/sdk/src/identity.rs index fe000639c3..f9eebf80ec 100644 --- a/crates/sdk/src/identity.rs +++ b/crates/sdk/src/identity.rs @@ -7,10 +7,10 @@ use spacetimedb_sats::bsatn; // TODO: impl ser/de for `Identity`, `Token`, `Credentials` so that clients can stash them // to disk and use them to re-connect. -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] /// A unique public identifier for a client connected to a database. pub struct Identity { - pub(crate) bytes: Vec, + __identity_bytes: Vec, } impl Identity { @@ -27,14 +27,16 @@ impl Identity { /// As such, it is necessary to do e.g. /// `MyTable::filter_by_identity(some_identity.bytes().to_owned())`. pub fn bytes(&self) -> &[u8] { - &self.bytes + &self.__identity_bytes } /// Construct an `Identity` containing the `bytes`. /// /// This method does not verify that `bytes` represents a valid identity. pub fn from_bytes(bytes: Vec) -> Self { - Identity { bytes } + Self { + __identity_bytes: bytes, + } } } diff --git a/test/tests/filtering.sh b/test/tests/filtering.sh index 314d3d6bf7..dcb8695395 100644 --- a/test/tests/filtering.sh +++ b/test/tests/filtering.sh @@ -115,9 +115,9 @@ struct IdentifiedPerson { } fn identify(id_number: u64) -> Identity { - let mut identity = Identity { data: [0u8; 32] }; - identity.data[0..8].clone_from_slice(&id_number.to_le_bytes()); - identity + let mut bytes = [0u8; 32]; + bytes[..8].clone_from_slice(&id_number.to_le_bytes()); + Identity::from_byte_array(bytes) } #[spacetimedb(reducer)] @@ -278,4 +278,3 @@ run_test cargo run logs "$IDENT" 100 run_test cargo run call "$IDENT" insert_person_twice '[23, "Alice", "al"]' run_test cargo run logs "$IDENT" 100 [ ' UNIQUE CONSTRAINT VIOLATION ERROR: id 23: Alice' == "$(grep 'UNIQUE CONSTRAINT VIOLATION ERROR: id 23: Alice' "$TEST_OUT" | tail -n 4 | cut -d: -f4-)" ] -