diff --git a/crates/stef-build/src/decode.rs b/crates/stef-build/src/decode.rs index 00b562c..7420a19 100644 --- a/crates/stef-build/src/decode.rs +++ b/crates/stef-build/src/decode.rs @@ -4,7 +4,10 @@ use stef_parser::{ DataType, Enum, Fields, Generics, NamedField, Struct, Type, UnnamedField, Variant, }; -pub fn compile_struct( +use crate::{BytesType, Opts}; + +pub(super) fn compile_struct( + opts: &Opts, Struct { comment: _, attributes: _, @@ -15,8 +18,8 @@ pub fn compile_struct( ) -> TokenStream { let name = Ident::new(name.get(), Span::call_site()); let (generics, generics_where) = compile_generics(generics); - let field_vars = compile_field_vars(fields); - let field_matches = compile_field_matches(fields); + let field_vars = compile_field_vars(opts, fields); + let field_matches = compile_field_matches(opts, fields); let field_assigns = compile_field_assigns(fields); let body = if matches!(fields, Fields::Unit) { @@ -48,7 +51,8 @@ pub fn compile_struct( } } -pub fn compile_enum( +pub(super) fn compile_enum( + opts: &Opts, Enum { comment: _, attributes: _, @@ -59,7 +63,7 @@ pub fn compile_enum( ) -> TokenStream { let name = Ident::new(name.get(), Span::call_site()); let (generics, generics_where) = compile_generics(generics); - let variants = variants.iter().map(compile_variant); + let variants = variants.iter().map(|v| compile_variant(opts, v)); quote! { #[automatically_derived] @@ -76,6 +80,7 @@ pub fn compile_enum( } fn compile_variant( + opts: &Opts, Variant { comment: _, name, @@ -86,8 +91,8 @@ fn compile_variant( ) -> TokenStream { let id = proc_macro2::Literal::u32_unsuffixed(id.get()); let name = Ident::new(name.get(), Span::call_site()); - let field_vars = compile_field_vars(fields); - let field_matches = compile_field_matches(fields); + let field_vars = compile_field_vars(opts, fields); + let field_matches = compile_field_matches(opts, fields); let field_assigns = compile_field_assigns(fields); if matches!(fields, Fields::Unit) { @@ -111,7 +116,7 @@ fn compile_variant( } } -fn compile_field_vars(fields: &Fields<'_>) -> TokenStream { +fn compile_field_vars(opts: &Opts, fields: &Fields<'_>) -> TokenStream { let vars: Box> = match fields { Fields::Named(named) => Box::new(named.iter().map(|named| { let name = Ident::new(named.name.get(), Span::call_site()); @@ -125,7 +130,7 @@ fn compile_field_vars(fields: &Fields<'_>) -> TokenStream { }; let vars = vars.map(|(name, ty)| { - let ty_ident = super::definition::compile_data_type(ty); + let ty_ident = super::definition::compile_data_type(opts, ty); if matches!(ty.value, DataType::Option(_)) { quote! { let mut #name: #ty_ident = None; } @@ -137,7 +142,7 @@ fn compile_field_vars(fields: &Fields<'_>) -> TokenStream { quote! { #(#vars)* } } -fn compile_field_matches(fields: &Fields<'_>) -> TokenStream { +fn compile_field_matches(opts: &Opts, fields: &Fields<'_>) -> TokenStream { match fields { Fields::Named(named) => { let calls = named.iter().map( @@ -150,11 +155,14 @@ fn compile_field_matches(fields: &Fields<'_>) -> TokenStream { }| { let id = proc_macro2::Literal::u32_unsuffixed(id.get()); let name = proc_macro2::Ident::new(name.get(), Span::call_site()); - let ty = compile_data_type(if let DataType::Option(ty) = &ty.value { - ty - } else { - ty - }); + let ty = compile_data_type( + opts, + if let DataType::Option(ty) = &ty.value { + ty + } else { + ty + }, + ); quote! { #id => #name = Some(#ty?) } }, @@ -169,11 +177,14 @@ fn compile_field_matches(fields: &Fields<'_>) -> TokenStream { .map(|(idx, UnnamedField { ty, id, .. })| { let id = proc_macro2::Literal::u32_unsuffixed(id.get()); let name = Ident::new(&format!("n{idx}"), Span::call_site()); - let ty = compile_data_type(if let DataType::Option(ty) = &ty.value { - ty - } else { - ty - }); + let ty = compile_data_type( + opts, + if let DataType::Option(ty) = &ty.value { + ty + } else { + ty + }, + ); quote! { #id => #name = Some(#ty?) } }); @@ -242,7 +253,7 @@ fn compile_generics(Generics(types): &Generics<'_>) -> (TokenStream, TokenStream } #[allow(clippy::needless_pass_by_value)] -fn compile_data_type(ty: &Type<'_>) -> TokenStream { +fn compile_data_type(opts: &Opts, ty: &Type<'_>) -> TokenStream { match &ty.value { DataType::Bool => quote! { ::stef::buf::decode_bool(r) }, DataType::U8 => quote! { ::stef::buf::decode_u8(r) }, @@ -258,22 +269,25 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream { DataType::F32 => quote! { ::stef::buf::decode_f32(r) }, DataType::F64 => quote! { ::stef::buf::decode_f64(r) }, DataType::String | DataType::StringRef => quote! { ::stef::buf::decode_string(r) }, - DataType::Bytes | DataType::BytesRef => quote! { ::stef::buf::decode_bytes(r) }, + DataType::Bytes | DataType::BytesRef => match opts.bytes_type { + BytesType::VecU8 => quote! { ::stef::buf::decode_bytes_std(r) }, + BytesType::Bytes => quote! { ::stef::buf::decode_bytes_bytes(r) }, + }, DataType::Vec(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::buf::decode_vec(r, |r| { #ty }) } } DataType::HashMap(kv) => { - let ty_k = compile_data_type(&kv.0); - let ty_v = compile_data_type(&kv.1); + let ty_k = compile_data_type(opts, &kv.0); + let ty_v = compile_data_type(opts, &kv.1); quote! { ::stef::buf::decode_hash_map(r, |r| { #ty_k }, |r| { #ty_v }) } } DataType::HashSet(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::buf::decode_hash_set(r, |r| { #ty }) } } DataType::Option(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::buf::decode_option(r, |r| { #ty }) } } DataType::NonZero(ty) => match &ty.value { @@ -290,20 +304,25 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream { DataType::String | DataType::StringRef => { quote! { ::stef::buf::decode_non_zero_string(r) } } - DataType::Bytes | DataType::BytesRef => { - quote! { ::stef::buf::decode_non_zero_bytes(r) } - } + DataType::Bytes | DataType::BytesRef => match opts.bytes_type { + BytesType::VecU8 => { + quote! { ::stef::buf::decode_non_zero_bytes_std(r) } + } + BytesType::Bytes => { + quote! { ::stef::buf::decode_non_zero_bytes_bytes(r) } + } + }, DataType::Vec(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::buf::decode_non_zero_vec(r, |r| { #ty }) } } DataType::HashMap(kv) => { - let ty_k = compile_data_type(&kv.0); - let ty_v = compile_data_type(&kv.1); + let ty_k = compile_data_type(opts, &kv.0); + let ty_v = compile_data_type(opts, &kv.1); quote! { ::stef::buf::decode_non_zero_hash_map(r, |r| { #ty_k }, |r| { #ty_v }) } } DataType::HashSet(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::buf::decode_non_zero_hash_set(r, |r| { #ty }) } } ty => todo!("compiler should catch invalid {ty:?} type"), @@ -312,7 +331,7 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream { DataType::BoxBytes => quote! { Box::<[u8]>::decode(r) }, DataType::Tuple(types) => match types.len() { 2..=12 => { - let types = types.iter().map(|ty| compile_data_type(ty)); + let types = types.iter().map(|ty| compile_data_type(opts, ty)); quote! { { Ok::<_, ::stef::buf::Error>((#(#types?,)*)) } } } 0 => panic!("tuple with zero elements"), @@ -320,7 +339,7 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream { _ => panic!("tuple with more than 12 elements"), }, DataType::Array(ty, _size) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::buf::decode_array(r, |r| { #ty }) } } DataType::External(ty) => { diff --git a/crates/stef-build/src/definition.rs b/crates/stef-build/src/definition.rs index c289c2c..af8bdb8 100644 --- a/crates/stef-build/src/definition.rs +++ b/crates/stef-build/src/definition.rs @@ -6,9 +6,11 @@ use stef_parser::{ }; use super::{decode, encode}; +use crate::{BytesType, Opts}; -pub fn compile_schema(Schema { definitions, .. }: &Schema<'_>) -> TokenStream { - let definitions = definitions.iter().map(compile_definition); +#[must_use] +pub fn compile_schema(opts: &Opts, Schema { definitions, .. }: &Schema<'_>) -> TokenStream { + let definitions = definitions.iter().map(|def| compile_definition(opts, def)); quote! { #[allow(unused_imports)] @@ -18,13 +20,13 @@ pub fn compile_schema(Schema { definitions, .. }: &Schema<'_>) -> TokenStream { } } -fn compile_definition(definition: &Definition<'_>) -> TokenStream { +fn compile_definition(opts: &Opts, definition: &Definition<'_>) -> TokenStream { let definition = match definition { - Definition::Module(m) => compile_module(m), + Definition::Module(m) => compile_module(opts, m), Definition::Struct(s) => { - let def = compile_struct(s); - let encode = encode::compile_struct(s); - let decode = decode::compile_struct(s); + let def = compile_struct(opts, s); + let encode = encode::compile_struct(opts, s); + let decode = decode::compile_struct(opts, s); quote! { #def @@ -33,9 +35,9 @@ fn compile_definition(definition: &Definition<'_>) -> TokenStream { } } Definition::Enum(e) => { - let def = compile_enum(e); - let encode = encode::compile_enum(e); - let decode = decode::compile_enum(e); + let def = compile_enum(opts, e); + let encode = encode::compile_enum(opts, e); + let decode = decode::compile_enum(opts, e); quote! { #def @@ -43,7 +45,7 @@ fn compile_definition(definition: &Definition<'_>) -> TokenStream { #decode } } - Definition::TypeAlias(a) => compile_alias(a), + Definition::TypeAlias(a) => compile_alias(opts, a), Definition::Const(c) => compile_const(c), Definition::Import(i) => compile_import(i), }; @@ -52,6 +54,7 @@ fn compile_definition(definition: &Definition<'_>) -> TokenStream { } fn compile_module( + opts: &Opts, Module { comment, name, @@ -60,7 +63,7 @@ fn compile_module( ) -> TokenStream { let comment = compile_comment(comment); let name = Ident::new(name.get(), Span::call_site()); - let definitions = definitions.iter().map(compile_definition); + let definitions = definitions.iter().map(|def| compile_definition(opts, def)); quote! { #comment @@ -74,6 +77,7 @@ fn compile_module( } fn compile_struct( + opts: &Opts, Struct { comment, attributes: _, @@ -85,7 +89,7 @@ fn compile_struct( let comment = compile_comment(comment); let name = Ident::new(name.get(), Span::call_site()); let generics = compile_generics(generics); - let fields = compile_fields(fields, true); + let fields = compile_fields(opts, fields, true); quote! { #comment @@ -96,6 +100,7 @@ fn compile_struct( } fn compile_enum( + opts: &Opts, Enum { comment, attributes: _, @@ -107,7 +112,7 @@ fn compile_enum( let comment = compile_comment(comment); let name = Ident::new(name.get(), Span::call_site()); let generics = compile_generics(generics); - let variants = variants.iter().map(compile_variant); + let variants = variants.iter().map(|v| compile_variant(opts, v)); quote! { #comment @@ -120,6 +125,7 @@ fn compile_enum( } fn compile_variant( + opts: &Opts, Variant { comment, name, @@ -129,7 +135,7 @@ fn compile_variant( ) -> TokenStream { let comment = compile_comment(comment); let name = Ident::new(name.get(), Span::call_site()); - let fields = compile_fields(fields, false); + let fields = compile_fields(opts, fields, false); quote! { #comment @@ -138,6 +144,7 @@ fn compile_variant( } fn compile_alias( + opts: &Opts, TypeAlias { comment, name, @@ -148,7 +155,7 @@ fn compile_alias( let comment = compile_comment(comment); let name = Ident::new(name.get(), Span::call_site()); let generics = compile_generics(generics); - let target = compile_data_type(target); + let target = compile_data_type(opts, target); quote! { #comment @@ -210,7 +217,7 @@ fn compile_generics(Generics(types): &Generics<'_>) -> Option { }) } -fn compile_fields(fields: &Fields<'_>, for_struct: bool) -> TokenStream { +fn compile_fields(opts: &Opts, fields: &Fields<'_>, for_struct: bool) -> TokenStream { match fields { Fields::Named(named) => { let fields = named.iter().map( @@ -220,7 +227,7 @@ fn compile_fields(fields: &Fields<'_>, for_struct: bool) -> TokenStream { let comment = compile_comment(comment); let public = for_struct.then(|| quote! { pub }); let name = Ident::new(name.get(), Span::call_site()); - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { #comment #public #name: #ty @@ -235,7 +242,7 @@ fn compile_fields(fields: &Fields<'_>, for_struct: bool) -> TokenStream { Fields::Unnamed(unnamed) => { let fields = unnamed.iter().map(|UnnamedField { ty, .. }| { let public = for_struct.then(|| quote! { pub }); - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { #public #ty } }); @@ -255,7 +262,7 @@ fn compile_fields(fields: &Fields<'_>, for_struct: bool) -> TokenStream { } } -pub(super) fn compile_data_type(ty: &Type<'_>) -> TokenStream { +pub(super) fn compile_data_type(opts: &Opts, ty: &Type<'_>) -> TokenStream { match &ty.value { DataType::Bool => quote! { bool }, DataType::U8 => quote! { u8 }, @@ -271,22 +278,25 @@ pub(super) fn compile_data_type(ty: &Type<'_>) -> TokenStream { DataType::F32 => quote! { f32 }, DataType::F64 => quote! { f64 }, DataType::String | DataType::StringRef => quote! { String }, - DataType::Bytes | DataType::BytesRef => quote! { Vec }, + DataType::Bytes | DataType::BytesRef => match opts.bytes_type { + BytesType::VecU8 => quote! { Vec }, + BytesType::Bytes => quote! { ::stef::buf::Bytes }, + }, DataType::Vec(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { Vec<#ty> } } DataType::HashMap(kv) => { - let k = compile_data_type(&kv.0); - let v = compile_data_type(&kv.1); + let k = compile_data_type(opts, &kv.0); + let v = compile_data_type(opts, &kv.1); quote! { ::std::collections::HashMap<#k, #v> } } DataType::HashSet(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::std::collections::HashSet<#ty> } } DataType::Option(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { Option<#ty> } } DataType::NonZero(ty) => match &ty.value { @@ -301,18 +311,21 @@ pub(super) fn compile_data_type(ty: &Type<'_>) -> TokenStream { DataType::I64 => quote! { ::std::num::NonZeroI64 }, DataType::I128 => quote! { ::std::num::NonZeroI128 }, DataType::String | DataType::StringRef => quote! { ::stef::NonZeroString }, - DataType::Bytes | DataType::BytesRef => quote! { ::stef::NonZeroBytes }, + DataType::Bytes | DataType::BytesRef => match opts.bytes_type { + BytesType::VecU8 => quote! { ::stef::NonZeroBytes }, + BytesType::Bytes => quote! { ::stef::NonZero<::stef::buf::Bytes> }, + }, DataType::Vec(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::NonZeroVec<#ty> } } DataType::HashMap(kv) => { - let k = compile_data_type(&kv.0); - let v = compile_data_type(&kv.1); + let k = compile_data_type(opts, &kv.0); + let v = compile_data_type(opts, &kv.1); quote! { ::stef::NonZeroHashMap<#k, #v> } } DataType::HashSet(ty) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); quote! { ::stef::NonZeroHashSet<#ty> } } ty => todo!("compiler should catch invalid {ty:?} type"), @@ -320,11 +333,11 @@ pub(super) fn compile_data_type(ty: &Type<'_>) -> TokenStream { DataType::BoxString => quote! { Box }, DataType::BoxBytes => quote! { Box<[u8]> }, DataType::Tuple(types) => { - let types = types.iter().map(compile_data_type); + let types = types.iter().map(|ty| compile_data_type(opts, ty)); quote! { (#(#types,)*) } } DataType::Array(ty, size) => { - let ty = compile_data_type(ty); + let ty = compile_data_type(opts, ty); let size = proc_macro2::Literal::u32_unsuffixed(*size); quote! { [#ty; #size] } } @@ -336,7 +349,7 @@ pub(super) fn compile_data_type(ty: &Type<'_>) -> TokenStream { let path = path.iter().map(Name::get); let name = Ident::new(name.get(), Span::call_site()); let generics = (!generics.is_empty()).then(|| { - let types = generics.iter().map(compile_data_type); + let types = generics.iter().map(|ty| compile_data_type(opts, ty)); quote! { <#(#types,)*> } }); diff --git a/crates/stef-build/src/encode.rs b/crates/stef-build/src/encode.rs index e8e555e..bd74915 100644 --- a/crates/stef-build/src/encode.rs +++ b/crates/stef-build/src/encode.rs @@ -4,7 +4,10 @@ use stef_parser::{ DataType, Enum, Fields, Generics, NamedField, Struct, Type, UnnamedField, Variant, }; -pub fn compile_struct( +use crate::{BytesType, Opts}; + +pub(super) fn compile_struct( + opts: &Opts, Struct { comment: _, attributes: _, @@ -15,7 +18,7 @@ pub fn compile_struct( ) -> TokenStream { let name = Ident::new(name.get(), Span::call_site()); let (generics, generics_where) = compile_generics(generics); - let fields = compile_struct_fields(fields); + let fields = compile_struct_fields(opts, fields); quote! { #[automatically_derived] @@ -33,7 +36,7 @@ pub fn compile_struct( } } -fn compile_struct_fields(fields: &Fields<'_>) -> TokenStream { +fn compile_struct_fields(opts: &Opts, fields: &Fields<'_>) -> TokenStream { match fields { Fields::Named(named) => { let calls = named.iter().map( @@ -48,14 +51,14 @@ fn compile_struct_fields(fields: &Fields<'_>) -> TokenStream { let name = proc_macro2::Ident::new(name.get(), Span::call_site()); if let DataType::Option(ty) = &ty.value { - let ty = compile_data_type(ty, if is_copy(&ty.value) { + let ty = compile_data_type(opts, ty, if is_copy(&ty.value) { quote! { *v } } else { quote! { v } }); quote! { ::stef::buf::encode_field_option(w, #id, &self.#name, |w, v| { #ty; }); } } else { - let ty = compile_data_type(ty, quote! { self.#name }); + let ty = compile_data_type(opts, ty, quote! { self.#name }); quote! { ::stef::buf::encode_field(w, #id, |w| { #ty; }); } } }, @@ -73,7 +76,7 @@ fn compile_struct_fields(fields: &Fields<'_>) -> TokenStream { .map(|(idx, UnnamedField { ty, id, .. })| { let id = proc_macro2::Literal::u32_unsuffixed(id.get()); let idx = proc_macro2::Literal::usize_unsuffixed(idx); - let ty = compile_data_type(ty, quote! { self.#idx }); + let ty = compile_data_type(opts, ty, quote! { self.#idx }); quote! { ::stef::buf::encode_field(w, #id, |w| { #ty }); } }); @@ -87,7 +90,8 @@ fn compile_struct_fields(fields: &Fields<'_>) -> TokenStream { } } -pub fn compile_enum( +pub(super) fn compile_enum( + opts: &Opts, Enum { comment: _, attributes: _, @@ -98,7 +102,7 @@ pub fn compile_enum( ) -> TokenStream { let name = Ident::new(name.get(), Span::call_site()); let (generics, generics_where) = compile_generics(generics); - let variants = variants.iter().map(compile_variant); + let variants = variants.iter().map(|v| compile_variant(opts, v)); quote! { #[automatically_derived] @@ -118,6 +122,7 @@ pub fn compile_enum( } fn compile_variant( + opts: &Opts, Variant { comment: _, name, @@ -128,7 +133,7 @@ fn compile_variant( ) -> TokenStream { let id = proc_macro2::Literal::u32_unsuffixed(id.get()); let name = Ident::new(name.get(), Span::call_site()); - let fields_body = compile_variant_fields(fields); + let fields_body = compile_variant_fields(opts, fields); match fields { Fields::Named(named) => { @@ -165,7 +170,7 @@ fn compile_variant( } } -fn compile_variant_fields(fields: &Fields<'_>) -> TokenStream { +fn compile_variant_fields(opts: &Opts, fields: &Fields<'_>) -> TokenStream { match fields { Fields::Named(named) => { let calls = named.iter().map( @@ -182,7 +187,7 @@ fn compile_variant_fields(fields: &Fields<'_>) -> TokenStream { if matches!(ty.value, DataType::Option(_)) { quote! { ::stef::buf::encode_field_option(w, #id, &#name); } } else { - let ty = compile_data_type(ty, quote! { *#name }); + let ty = compile_data_type(opts, ty, quote! { *#name }); quote! { ::stef::buf::encode_field(w, #id, |w| { #ty }); } } }, @@ -200,7 +205,7 @@ fn compile_variant_fields(fields: &Fields<'_>) -> TokenStream { .map(|(idx, UnnamedField { ty, id, .. })| { let id = proc_macro2::Literal::u32_unsuffixed(id.get()); let name = Ident::new(&format!("n{idx}"), Span::call_site()); - let ty = compile_data_type(ty, quote! { *#name }); + let ty = compile_data_type(opts, ty, quote! { *#name }); quote! { ::stef::buf::encode_field(w, #id, |w| { #ty }); } }); @@ -250,7 +255,7 @@ fn is_copy(ty: &DataType<'_>) -> bool { } #[allow(clippy::needless_pass_by_value, clippy::too_many_lines)] -fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { +fn compile_data_type(opts: &Opts, ty: &Type<'_>, name: TokenStream) -> TokenStream { match &ty.value { DataType::Bool => quote! { ::stef::buf::encode_bool(w, #name) }, DataType::U8 => quote! { ::stef::buf::encode_u8(w, #name) }, @@ -266,9 +271,13 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { DataType::F32 => quote! { ::stef::buf::encode_f32(w, #name) }, DataType::F64 => quote! { ::stef::buf::encode_f64(w, #name) }, DataType::String | DataType::StringRef => quote! { ::stef::buf::encode_string(w, &#name) }, - DataType::Bytes | DataType::BytesRef => quote! { ::stef::buf::encode_bytes(w, &#name) }, + DataType::Bytes | DataType::BytesRef => match opts.bytes_type { + BytesType::VecU8 => quote! { ::stef::buf::encode_bytes_std(w, &#name) }, + BytesType::Bytes => quote! { ::stef::buf::encode_bytes_bytes(w, &#name) }, + }, DataType::Vec(ty) => { let ty = compile_data_type( + opts, ty, if is_copy(&ty.value) { quote! { *v } @@ -280,6 +289,7 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { } DataType::HashMap(kv) => { let ty_k = compile_data_type( + opts, &kv.0, if is_copy(&kv.0.value) { quote! { *k } @@ -288,6 +298,7 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { }, ); let ty_v = compile_data_type( + opts, &kv.1, if is_copy(&kv.1.value) { quote! { *v } @@ -299,6 +310,7 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { } DataType::HashSet(ty) => { let ty = compile_data_type( + opts, ty, if is_copy(&ty.value) { quote! { *v } @@ -310,6 +322,7 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { } DataType::Option(ty) => { let ty = compile_data_type( + opts, ty, if is_copy(&ty.value) { quote! { *v } @@ -336,17 +349,18 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { | DataType::BytesRef | DataType::Vec(_) | DataType::HashMap(_) - | DataType::HashSet(_) => compile_data_type(ty, quote! { #name.get() }), + | DataType::HashSet(_) => compile_data_type(opts, ty, quote! { #name.get() }), ty => todo!("compiler should catch invalid {ty:?} type"), }, DataType::BoxString => quote! { ::stef::buf::encode_string(w, &*#name) }, - DataType::BoxBytes => quote! { ::stef::buf::encode_bytes(w, &*#name) }, + DataType::BoxBytes => quote! { ::stef::buf::encode_bytes_std(w, &*#name) }, DataType::Tuple(types) => match types.len() { 2..=12 => { let types = types.iter().enumerate().map(|(idx, ty)| { let idx = proc_macro2::Literal::usize_unsuffixed(idx); compile_data_type( + opts, ty, if is_copy(&ty.value) { quote! { #name.#idx } @@ -363,6 +377,7 @@ fn compile_data_type(ty: &Type<'_>, name: TokenStream) -> TokenStream { }, DataType::Array(ty, _size) => { let ty = compile_data_type( + opts, ty, if is_copy(&ty.value) { quote! { *v } diff --git a/crates/stef-build/src/lib.rs b/crates/stef-build/src/lib.rs index 8dffbed..cf4089f 100644 --- a/crates/stef-build/src/lib.rs +++ b/crates/stef-build/src/lib.rs @@ -3,11 +3,7 @@ #![warn(clippy::pedantic)] #![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)] -use std::{ - convert::AsRef, - fmt::Debug, - path::{Path, PathBuf}, -}; +use std::{convert::AsRef, fmt::Debug, path::PathBuf}; use miette::Report; use stef_parser::Schema; @@ -52,73 +48,106 @@ impl Debug for Error { } } -pub fn compile(schemas: &[impl AsRef], _includes: &[impl AsRef]) -> Result<()> { - miette::set_hook(Box::new(|_| { - Box::new( - miette::MietteHandlerOpts::new() - .color(true) - .context_lines(3) - .force_graphical(true) - .terminal_links(true) - .build(), - ) - })) - .ok(); +#[derive(Default)] +pub struct Compiler { + bytes_type: BytesType, +} - let out_dir = PathBuf::from(std::env::var_os("OUT_DIR").unwrap()); - let mut inputs = Vec::new(); - let mut validated = Vec::new(); +#[derive(Clone, Copy, Default)] +pub enum BytesType { + #[default] + VecU8, + Bytes, +} + +#[derive(Default)] +pub struct Opts { + bytes_type: BytesType, +} + +impl Compiler { + #[must_use] + pub fn with_bytes_type(mut self, value: BytesType) -> Self { + self.bytes_type = value; + self + } + + pub fn compile(&self, schemas: &[impl AsRef]) -> Result<()> { + init_miette(); - for schema in schemas.iter().map(AsRef::as_ref) { - for schema in glob::glob(schema).map_err(|source| Error::Pattern { - source, - glob: schema.to_owned(), - })? { - let path = schema.map_err(|e| Error::Glob { source: e })?; + let out_dir = PathBuf::from(std::env::var_os("OUT_DIR").unwrap()); + let mut inputs = Vec::new(); + let mut validated = Vec::new(); - let input = std::fs::read_to_string(&path).map_err(|source| Error::Read { + for schema in schemas.iter().map(AsRef::as_ref) { + for schema in glob::glob(schema).map_err(|source| Error::Pattern { source, + glob: schema.to_owned(), + })? { + let path = schema.map_err(|e| Error::Glob { source: e })?; + + let input = std::fs::read_to_string(&path).map_err(|source| Error::Read { + source, + file: path.clone(), + })?; + + inputs.push((path, input)); + } + } + + for (path, input) in &inputs { + let stem = path.file_stem().unwrap().to_str().unwrap(); + + let schema = Schema::parse(input, Some(path)).map_err(|e| Error::Parse { + report: Report::new(e), file: path.clone(), })?; - inputs.push((path, input)); + stef_compiler::validate_schema(&schema).map_err(|e| Error::Compile { + report: Report::new(e), + file: path.clone(), + })?; + + validated.push((stem, schema)); } - } - for (path, input) in &inputs { - let stem = path.file_stem().unwrap().to_str().unwrap(); + let validated = validated + .iter() + .map(|(name, schema)| (*name, schema)) + .collect::>(); - let schema = Schema::parse(input, Some(path)).map_err(|e| Error::Parse { + stef_compiler::resolve_schemas(&validated).map_err(|e| Error::Compile { report: Report::new(e), - file: path.clone(), + file: PathBuf::new(), })?; - stef_compiler::validate_schema(&schema).map_err(|e| Error::Compile { - report: Report::new(e), - file: path.clone(), - })?; + let opts = Opts { + bytes_type: self.bytes_type, + }; - validated.push((stem, schema)); - } - - let validated = validated - .iter() - .map(|(name, schema)| (*name, schema)) - .collect::>(); - - stef_compiler::resolve_schemas(&validated).map_err(|e| Error::Compile { - report: Report::new(e), - file: PathBuf::new(), - })?; + for (stem, schema) in validated { + let code = definition::compile_schema(&opts, schema); + let code = prettyplease::unparse(&syn::parse2(code).unwrap()); - for (stem, schema) in validated { - let code = definition::compile_schema(schema); - let code = prettyplease::unparse(&syn::parse2(code).unwrap()); + let out_file = out_dir.join(format!("{stem}.rs",)); - let out_file = out_dir.join(format!("{stem}.rs",)); + std::fs::write(out_file, code).unwrap(); + } - std::fs::write(out_file, code).unwrap(); + Ok(()) } +} - Ok(()) +fn init_miette() { + miette::set_hook(Box::new(|_| { + Box::new( + miette::MietteHandlerOpts::new() + .color(true) + .context_lines(3) + .force_graphical(true) + .terminal_links(true) + .build(), + ) + })) + .ok(); } diff --git a/crates/stef-build/tests/compiler.rs b/crates/stef-build/tests/compiler.rs index c868bf6..3fe2ff0 100644 --- a/crates/stef-build/tests/compiler.rs +++ b/crates/stef-build/tests/compiler.rs @@ -4,6 +4,7 @@ use std::{ }; use insta::{assert_snapshot, glob, with_settings}; +use stef_build::Opts; use stef_parser::Schema; fn strip_path(path: &Path) -> PathBuf { @@ -18,7 +19,7 @@ fn compile_schema() { glob!("inputs/*.stef", |path| { let input = fs::read_to_string(path).unwrap(); let value = Schema::parse(input.as_str(), Some(&strip_path(path))).unwrap(); - let value = stef_build::compile_schema(&value); + let value = stef_build::compile_schema(&Opts::default(), &value); let value = prettyplease::unparse(&syn::parse2(value.clone()).unwrap()); with_settings!({ @@ -35,7 +36,7 @@ fn compile_schema_extra() { glob!("inputs_extra/*.stef", |path| { let input = fs::read_to_string(path).unwrap(); let value = Schema::parse(input.as_str(), Some(&strip_path(path))).unwrap(); - let value = stef_build::compile_schema(&value); + let value = stef_build::compile_schema(&Opts::default(), &value); let value = prettyplease::unparse(&syn::parse2(value.clone()).unwrap()); with_settings!({ diff --git a/crates/stef-build/tests/snapshots/compiler__compile@types_basic.stef.snap b/crates/stef-build/tests/snapshots/compiler__compile@types_basic.stef.snap index 55571ec..a4bff2c 100644 --- a/crates/stef-build/tests/snapshots/compiler__compile@types_basic.stef.snap +++ b/crates/stef-build/tests/snapshots/compiler__compile@types_basic.stef.snap @@ -148,14 +148,14 @@ impl ::stef::Encode for Sample { w, 16, |w| { - ::stef::buf::encode_bytes(w, &self.f16); + ::stef::buf::encode_bytes_std(w, &self.f16); }, ); ::stef::buf::encode_field( w, 17, |w| { - ::stef::buf::encode_bytes(w, &self.f17); + ::stef::buf::encode_bytes_std(w, &self.f17); }, ); ::stef::buf::encode_field( @@ -169,7 +169,7 @@ impl ::stef::Encode for Sample { w, 19, |w| { - ::stef::buf::encode_bytes(w, &*self.f19); + ::stef::buf::encode_bytes_std(w, &*self.f19); }, ); ::stef::buf::encode_field( @@ -240,8 +240,8 @@ impl ::stef::Decode for Sample { 13 => f13 = Some(::stef::buf::decode_f64(r)?), 14 => f14 = Some(::stef::buf::decode_string(r)?), 15 => f15 = Some(::stef::buf::decode_string(r)?), - 16 => f16 = Some(::stef::buf::decode_bytes(r)?), - 17 => f17 = Some(::stef::buf::decode_bytes(r)?), + 16 => f16 = Some(::stef::buf::decode_bytes_std(r)?), + 17 => f17 = Some(::stef::buf::decode_bytes_std(r)?), 18 => f18 = Some(Box::::decode(r)?), 19 => f19 = Some(Box::<[u8]>::decode(r)?), 20 => { diff --git a/crates/stef-build/tests/snapshots/compiler__compile@types_non_zero.stef.snap b/crates/stef-build/tests/snapshots/compiler__compile@types_non_zero.stef.snap index 40db0f8..08f93a3 100644 --- a/crates/stef-build/tests/snapshots/compiler__compile@types_non_zero.stef.snap +++ b/crates/stef-build/tests/snapshots/compiler__compile@types_non_zero.stef.snap @@ -114,7 +114,7 @@ impl ::stef::Encode for Sample { w, 12, |w| { - ::stef::buf::encode_bytes(w, &self.f12.get()); + ::stef::buf::encode_bytes_std(w, &self.f12.get()); }, ); ::stef::buf::encode_field( @@ -141,7 +141,7 @@ impl ::stef::Encode for Sample { ::stef::buf::encode_string(w, &k); }, |w, v| { - ::stef::buf::encode_bytes(w, &v); + ::stef::buf::encode_bytes_std(w, &v); }, ); }, @@ -195,7 +195,7 @@ impl ::stef::Decode for Sample { 9 => f09 = Some(::stef::buf::decode_non_zero_i64(r)?), 10 => f10 = Some(::stef::buf::decode_non_zero_i128(r)?), 11 => f11 = Some(::stef::buf::decode_non_zero_string(r)?), - 12 => f12 = Some(::stef::buf::decode_non_zero_bytes(r)?), + 12 => f12 = Some(::stef::buf::decode_non_zero_bytes_std(r)?), 13 => { f13 = Some( ::stef::buf::decode_non_zero_vec( @@ -209,7 +209,7 @@ impl ::stef::Decode for Sample { ::stef::buf::decode_non_zero_hash_map( r, |r| { ::stef::buf::decode_string(r) }, - |r| { ::stef::buf::decode_bytes(r) }, + |r| { ::stef::buf::decode_bytes_std(r) }, )?, ); } diff --git a/crates/stef-build/tests/snapshots/compiler__compile_extra@struct.stef.snap b/crates/stef-build/tests/snapshots/compiler__compile_extra@struct.stef.snap index 976b632..e857444 100644 --- a/crates/stef-build/tests/snapshots/compiler__compile_extra@struct.stef.snap +++ b/crates/stef-build/tests/snapshots/compiler__compile_extra@struct.stef.snap @@ -33,7 +33,7 @@ impl ::stef::Encode for Sample { w, 2, |w| { - ::stef::buf::encode_bytes(w, &self.field2); + ::stef::buf::encode_bytes_std(w, &self.field2); }, ); ::stef::buf::encode_field( @@ -64,7 +64,7 @@ impl ::stef::Decode for Sample { match ::stef::buf::decode_id(r)? { ::stef::buf::END_MARKER => break, 1 => field1 = Some(::stef::buf::decode_u32(r)?), - 2 => field2 = Some(::stef::buf::decode_bytes(r)?), + 2 => field2 = Some(::stef::buf::decode_bytes_std(r)?), 3 => { field3 = Some( { diff --git a/crates/stef-playground/build.rs b/crates/stef-playground/build.rs index 687d864..e2c6d97 100644 --- a/crates/stef-playground/build.rs +++ b/crates/stef-playground/build.rs @@ -1,8 +1,6 @@ fn main() -> stef_build::Result<()> { - stef_build::compile(&["src/sample.stef"], &["src/"])?; - stef_build::compile( - &["schemas/*.stef", "src/other.stef", "src/second.stef"], - &["schemas/"], - )?; + let compiler = stef_build::Compiler::default(); + compiler.compile(&["src/sample.stef"])?; + compiler.compile(&["schemas/*.stef", "src/other.stef", "src/second.stef"])?; Ok(()) } diff --git a/crates/stef/src/buf/decode.rs b/crates/stef/src/buf/decode.rs index 97f8a1f..45e212b 100644 --- a/crates/stef/src/buf/decode.rs +++ b/crates/stef/src/buf/decode.rs @@ -6,7 +6,7 @@ use std::{ hash::Hash, }; -pub use bytes::Buf; +pub use bytes::{Buf, Bytes}; use crate::{varint, NonZero}; @@ -88,16 +88,23 @@ pub fn decode_f64(r: &mut impl Buf) -> Result { } pub fn decode_string(r: &mut impl Buf) -> Result { - String::from_utf8(decode_bytes(r)?).map_err(Into::into) + String::from_utf8(decode_bytes_std(r)?).map_err(Into::into) } -pub fn decode_bytes(r: &mut impl Buf) -> Result> { +pub fn decode_bytes_std(r: &mut impl Buf) -> Result> { let len = decode_u64(r)?; ensure_size!(r, len as usize); Ok(r.copy_to_bytes(len as usize).to_vec()) } +pub fn decode_bytes_bytes(r: &mut impl Buf) -> Result { + let len = decode_u64(r)?; + ensure_size!(r, len as usize); + + Ok(r.copy_to_bytes(len as usize)) +} + pub fn decode_vec(r: &mut R, decode: D) -> Result> where R: Buf, @@ -199,12 +206,12 @@ decode_non_zero_int!(u8, u16, u32, u64, u128); decode_non_zero_int!(i8, i16, i32, i64, i128); pub fn decode_non_zero_string(r: &mut impl Buf) -> Result> { - String::from_utf8(decode_non_zero_bytes(r)?.into_inner()) + String::from_utf8(decode_non_zero_bytes_std(r)?.into_inner()) .map(|v| NonZero::::new(v).unwrap()) .map_err(Into::into) } -pub fn decode_non_zero_bytes(r: &mut impl Buf) -> Result>> { +pub fn decode_non_zero_bytes_std(r: &mut impl Buf) -> Result>> { let len = decode_u64(r)?; ensure_not_empty!(len); ensure_size!(r, len as usize); @@ -212,6 +219,14 @@ pub fn decode_non_zero_bytes(r: &mut impl Buf) -> Result>> { Ok(NonZero::>::new(r.copy_to_bytes(len as usize).to_vec()).unwrap()) } +pub fn decode_non_zero_bytes_bytes(r: &mut impl Buf) -> Result> { + let len = decode_u64(r)?; + ensure_not_empty!(len); + ensure_size!(r, len as usize); + + Ok(NonZero::::new(r.copy_to_bytes(len as usize)).unwrap()) +} + pub fn decode_non_zero_vec(r: &mut R, decode: D) -> Result>> where R: Buf, @@ -308,7 +323,7 @@ impl Decode for Box { impl Decode for Box<[u8]> { #[inline(always)] fn decode(r: &mut impl Buf) -> Result { - decode_bytes(r).map(Vec::into_boxed_slice) + decode_bytes_std(r).map(Vec::into_boxed_slice) } } diff --git a/crates/stef/src/buf/encode.rs b/crates/stef/src/buf/encode.rs index 0c9f51f..1af64a2 100644 --- a/crates/stef/src/buf/encode.rs +++ b/crates/stef/src/buf/encode.rs @@ -1,6 +1,6 @@ use std::collections::{HashMap, HashSet}; -pub use bytes::BufMut; +pub use bytes::{BufMut, Bytes}; use crate::{varint, NonZero}; @@ -42,14 +42,18 @@ pub fn encode_f64(w: &mut impl BufMut, value: f64) { } pub fn encode_string(w: &mut impl BufMut, value: &str) { - encode_bytes(w, value.as_bytes()); + encode_bytes_std(w, value.as_bytes()); } -pub fn encode_bytes(w: &mut impl BufMut, value: &[u8]) { +pub fn encode_bytes_std(w: &mut impl BufMut, value: &[u8]) { encode_u64(w, value.len() as u64); w.put(value); } +pub fn encode_bytes_bytes(w: &mut impl BufMut, value: &Bytes) { + encode_bytes_std(w, value); +} + pub fn encode_vec(w: &mut W, vec: &[T], encode: E) where W: BufMut, @@ -186,7 +190,7 @@ impl Encode for Box { impl Encode for Box<[u8]> { #[inline(always)] fn encode(&self, w: &mut impl BufMut) { - encode_bytes(w, self); + encode_bytes_std(w, self); } } diff --git a/crates/stef/src/buf/mod.rs b/crates/stef/src/buf/mod.rs index 911a5e6..ad1e0d7 100644 --- a/crates/stef/src/buf/mod.rs +++ b/crates/stef/src/buf/mod.rs @@ -30,18 +30,35 @@ mod tests { } #[test] - fn non_zero_bytes_valid() { + fn non_zero_bytes_std_valid() { let mut buf = Vec::new(); - encode_bytes(&mut buf, &[1, 2, 3]); - assert!(decode_non_zero_bytes(&mut &*buf).is_ok()); + encode_bytes_std(&mut buf, &[1, 2, 3]); + assert!(decode_non_zero_bytes_std(&mut &*buf).is_ok()); } #[test] - fn non_zero_bytes_invalid() { + fn non_zero_bytes_std_invalid() { let mut buf = Vec::new(); - encode_bytes(&mut buf, &[]); + encode_bytes_std(&mut buf, &[]); assert!(matches!( - decode_non_zero_bytes(&mut &*buf), + decode_non_zero_bytes_std(&mut &*buf), + Err(Error::Zero), + )); + } + + #[test] + fn non_zero_bytes_bytes_valid() { + let mut buf = Vec::new(); + encode_bytes_bytes(&mut buf, &Bytes::from_static(&[1, 2, 3])); + assert!(decode_non_zero_bytes_bytes(&mut &*buf).is_ok()); + } + + #[test] + fn non_zero_bytes_bytes_invalid() { + let mut buf = Vec::new(); + encode_bytes_bytes(&mut buf, &Bytes::from_static(&[])); + assert!(matches!( + decode_non_zero_bytes_bytes(&mut &*buf), Err(Error::Zero), )); } diff --git a/crates/stef/src/lib.rs b/crates/stef/src/lib.rs index dc4a49a..743a353 100644 --- a/crates/stef/src/lib.rs +++ b/crates/stef/src/lib.rs @@ -15,7 +15,7 @@ use std::{ ops::Deref, }; -pub use buf::{Buf, BufMut, Decode, Encode}; +pub use buf::{Buf, BufMut, Bytes, Decode, Encode}; pub mod buf; pub mod varint; @@ -60,6 +60,7 @@ macro_rules! non_zero_collection { non_zero_collection!(String); non_zero_collection!(Vec); +non_zero_collection!(Bytes); non_zero_collection!(HashMap); non_zero_collection!(HashSet);