Skip to content

Commit

Permalink
feat: make bytes type customizable
Browse files Browse the repository at this point in the history
Add a new option to `stef-build` that allows to choose between the
default `Vec<u8>` and `bytes::Bytes` as type used for the STEF `bytes`
type in Rust.
  • Loading branch information
dnaka91 committed Nov 5, 2023
1 parent 5042fc1 commit eaa06ae
Show file tree
Hide file tree
Showing 13 changed files with 288 additions and 176 deletions.
91 changes: 55 additions & 36 deletions crates/stef-build/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use stef_parser::{
DataType, Enum, Fields, Generics, NamedField, Struct, Type, UnnamedField, Variant,
};

pub fn compile_struct(
use crate::{BytesType, Opts};

pub(super) fn compile_struct(
opts: &Opts,
Struct {
comment: _,
attributes: _,
Expand All @@ -15,8 +18,8 @@ pub fn compile_struct(
) -> TokenStream {
let name = Ident::new(name.get(), Span::call_site());
let (generics, generics_where) = compile_generics(generics);
let field_vars = compile_field_vars(fields);
let field_matches = compile_field_matches(fields);
let field_vars = compile_field_vars(opts, fields);
let field_matches = compile_field_matches(opts, fields);
let field_assigns = compile_field_assigns(fields);

let body = if matches!(fields, Fields::Unit) {
Expand Down Expand Up @@ -48,7 +51,8 @@ pub fn compile_struct(
}
}

pub fn compile_enum(
pub(super) fn compile_enum(
opts: &Opts,
Enum {
comment: _,
attributes: _,
Expand All @@ -59,7 +63,7 @@ pub fn compile_enum(
) -> TokenStream {
let name = Ident::new(name.get(), Span::call_site());
let (generics, generics_where) = compile_generics(generics);
let variants = variants.iter().map(compile_variant);
let variants = variants.iter().map(|v| compile_variant(opts, v));

quote! {
#[automatically_derived]
Expand All @@ -76,6 +80,7 @@ pub fn compile_enum(
}

fn compile_variant(
opts: &Opts,
Variant {
comment: _,
name,
Expand All @@ -86,8 +91,8 @@ fn compile_variant(
) -> TokenStream {
let id = proc_macro2::Literal::u32_unsuffixed(id.get());
let name = Ident::new(name.get(), Span::call_site());
let field_vars = compile_field_vars(fields);
let field_matches = compile_field_matches(fields);
let field_vars = compile_field_vars(opts, fields);
let field_matches = compile_field_matches(opts, fields);
let field_assigns = compile_field_assigns(fields);

if matches!(fields, Fields::Unit) {
Expand All @@ -111,7 +116,7 @@ fn compile_variant(
}
}

fn compile_field_vars(fields: &Fields<'_>) -> TokenStream {
fn compile_field_vars(opts: &Opts, fields: &Fields<'_>) -> TokenStream {
let vars: Box<dyn Iterator<Item = _>> = match fields {
Fields::Named(named) => Box::new(named.iter().map(|named| {
let name = Ident::new(named.name.get(), Span::call_site());
Expand All @@ -125,7 +130,7 @@ fn compile_field_vars(fields: &Fields<'_>) -> TokenStream {
};

let vars = vars.map(|(name, ty)| {
let ty_ident = super::definition::compile_data_type(ty);
let ty_ident = super::definition::compile_data_type(opts, ty);

if matches!(ty.value, DataType::Option(_)) {
quote! { let mut #name: #ty_ident = None; }
Expand All @@ -137,7 +142,7 @@ fn compile_field_vars(fields: &Fields<'_>) -> TokenStream {
quote! { #(#vars)* }
}

fn compile_field_matches(fields: &Fields<'_>) -> TokenStream {
fn compile_field_matches(opts: &Opts, fields: &Fields<'_>) -> TokenStream {
match fields {
Fields::Named(named) => {
let calls = named.iter().map(
Expand All @@ -150,11 +155,14 @@ fn compile_field_matches(fields: &Fields<'_>) -> TokenStream {
}| {
let id = proc_macro2::Literal::u32_unsuffixed(id.get());
let name = proc_macro2::Ident::new(name.get(), Span::call_site());
let ty = compile_data_type(if let DataType::Option(ty) = &ty.value {
ty
} else {
ty
});
let ty = compile_data_type(
opts,
if let DataType::Option(ty) = &ty.value {
ty
} else {
ty
},
);

quote! { #id => #name = Some(#ty?) }
},
Expand All @@ -169,11 +177,14 @@ fn compile_field_matches(fields: &Fields<'_>) -> TokenStream {
.map(|(idx, UnnamedField { ty, id, .. })| {
let id = proc_macro2::Literal::u32_unsuffixed(id.get());
let name = Ident::new(&format!("n{idx}"), Span::call_site());
let ty = compile_data_type(if let DataType::Option(ty) = &ty.value {
ty
} else {
ty
});
let ty = compile_data_type(
opts,
if let DataType::Option(ty) = &ty.value {
ty
} else {
ty
},
);

quote! { #id => #name = Some(#ty?) }
});
Expand Down Expand Up @@ -242,7 +253,7 @@ fn compile_generics(Generics(types): &Generics<'_>) -> (TokenStream, TokenStream
}

#[allow(clippy::needless_pass_by_value)]
fn compile_data_type(ty: &Type<'_>) -> TokenStream {
fn compile_data_type(opts: &Opts, ty: &Type<'_>) -> TokenStream {
match &ty.value {
DataType::Bool => quote! { ::stef::buf::decode_bool(r) },
DataType::U8 => quote! { ::stef::buf::decode_u8(r) },
Expand All @@ -258,22 +269,25 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream {
DataType::F32 => quote! { ::stef::buf::decode_f32(r) },
DataType::F64 => quote! { ::stef::buf::decode_f64(r) },
DataType::String | DataType::StringRef => quote! { ::stef::buf::decode_string(r) },
DataType::Bytes | DataType::BytesRef => quote! { ::stef::buf::decode_bytes(r) },
DataType::Bytes | DataType::BytesRef => match opts.bytes_type {
BytesType::VecU8 => quote! { ::stef::buf::decode_bytes_std(r) },
BytesType::Bytes => quote! { ::stef::buf::decode_bytes_bytes(r) },
},
DataType::Vec(ty) => {
let ty = compile_data_type(ty);
let ty = compile_data_type(opts, ty);
quote! { ::stef::buf::decode_vec(r, |r| { #ty }) }
}
DataType::HashMap(kv) => {
let ty_k = compile_data_type(&kv.0);
let ty_v = compile_data_type(&kv.1);
let ty_k = compile_data_type(opts, &kv.0);
let ty_v = compile_data_type(opts, &kv.1);
quote! { ::stef::buf::decode_hash_map(r, |r| { #ty_k }, |r| { #ty_v }) }
}
DataType::HashSet(ty) => {
let ty = compile_data_type(ty);
let ty = compile_data_type(opts, ty);
quote! { ::stef::buf::decode_hash_set(r, |r| { #ty }) }
}
DataType::Option(ty) => {
let ty = compile_data_type(ty);
let ty = compile_data_type(opts, ty);
quote! { ::stef::buf::decode_option(r, |r| { #ty }) }
}
DataType::NonZero(ty) => match &ty.value {
Expand All @@ -290,20 +304,25 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream {
DataType::String | DataType::StringRef => {
quote! { ::stef::buf::decode_non_zero_string(r) }
}
DataType::Bytes | DataType::BytesRef => {
quote! { ::stef::buf::decode_non_zero_bytes(r) }
}
DataType::Bytes | DataType::BytesRef => match opts.bytes_type {
BytesType::VecU8 => {
quote! { ::stef::buf::decode_non_zero_bytes_std(r) }
}
BytesType::Bytes => {
quote! { ::stef::buf::decode_non_zero_bytes_bytes(r) }
}
},
DataType::Vec(ty) => {
let ty = compile_data_type(ty);
let ty = compile_data_type(opts, ty);
quote! { ::stef::buf::decode_non_zero_vec(r, |r| { #ty }) }
}
DataType::HashMap(kv) => {
let ty_k = compile_data_type(&kv.0);
let ty_v = compile_data_type(&kv.1);
let ty_k = compile_data_type(opts, &kv.0);
let ty_v = compile_data_type(opts, &kv.1);
quote! { ::stef::buf::decode_non_zero_hash_map(r, |r| { #ty_k }, |r| { #ty_v }) }
}
DataType::HashSet(ty) => {
let ty = compile_data_type(ty);
let ty = compile_data_type(opts, ty);
quote! { ::stef::buf::decode_non_zero_hash_set(r, |r| { #ty }) }
}
ty => todo!("compiler should catch invalid {ty:?} type"),
Expand All @@ -312,15 +331,15 @@ fn compile_data_type(ty: &Type<'_>) -> TokenStream {
DataType::BoxBytes => quote! { Box::<[u8]>::decode(r) },
DataType::Tuple(types) => match types.len() {
2..=12 => {
let types = types.iter().map(|ty| compile_data_type(ty));
let types = types.iter().map(|ty| compile_data_type(opts, ty));
quote! { { Ok::<_, ::stef::buf::Error>((#(#types?,)*)) } }
}
0 => panic!("tuple with zero elements"),
1 => panic!("tuple with single element"),
_ => panic!("tuple with more than 12 elements"),
},
DataType::Array(ty, _size) => {
let ty = compile_data_type(ty);
let ty = compile_data_type(opts, ty);
quote! { ::stef::buf::decode_array(r, |r| { #ty }) }
}
DataType::External(ty) => {
Expand Down
Loading

0 comments on commit eaa06ae

Please sign in to comment.