Skip to content

Commit

Permalink
Add proc_macro to generate BonusType methods
Browse files Browse the repository at this point in the history
  • Loading branch information
nmeylan committed Mar 24, 2024
1 parent 39b071d commit 087aeff
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 466 deletions.
165 changes: 150 additions & 15 deletions lib/enum_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use proc_macro::TokenStream;
use proc_macro2::{Ident as Ident2, TokenStream as TokenStream2};
use quote::quote;
use syn::Data::Enum;
use syn::{parse_macro_input, DeriveInput, Variant, Ident, Meta, Type};
use syn::{parse_macro_input, DeriveInput, Variant, Ident, Meta, Type, Field};
use syn::__private::Span;

#[proc_macro_derive(WithNumberValue, attributes(value))]
pub fn with_number_value(input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -422,7 +423,6 @@ pub fn with_wq(input: TokenStream) -> TokenStream {
let variant_offset = if let Some(attribute) = variant.attrs.iter().find(|attr| {
attr.path().is_ident("value_comparison_offset")
}) {
println!("get number for variant_offset");
get_number_value(variant, "value_comparison_offset").unwrap_or(0)
} else {
0
Expand All @@ -433,21 +433,9 @@ pub fn with_wq(input: TokenStream) -> TokenStream {
if args1.len() > 1 {
args1[variant_offset] = quote! {variant1};
args2[variant_offset] = quote! {variant2};
match &fields.unnamed[variant_offset].ty {
Type::Path(p) => {
let is_numeric = match p.path.get_ident().unwrap().to_string().as_str() {
"u8" | "i8" | "u16" | "i16" | "u16" | "i32" | "u32" | "i64" | "u64" | "i128" | "u128" => true,
_ => false
};
if is_numeric {
should_deref = true;
}
}
_ => {}
}
should_deref = is_numeric(&fields.unnamed[variant_offset]);
}
// let field_types = fields.unnamed.iter().map(|field| &field.ty);
println!("field len {}, variant_offset {}", fields.unnamed.len(), variant_offset);
if fields.unnamed.len() > 1 {
if should_deref {
quote! {(#enum_name::#variant_name(#(#args1,)*), #enum_name::#variant_name(#(#args2,)*)) => *variant1 == *variant2, }
Expand Down Expand Up @@ -483,3 +471,150 @@ pub fn with_wq(input: TokenStream) -> TokenStream {
TokenStream::from(quote! {})
}
}

fn field_type(field: &Field) -> Option<String> {
match &field.ty {
Type::Path(p) => {
Some(p.path.get_ident().unwrap().to_string())
}
_ => None
}
}

fn is_numeric(field: &Field) -> bool {
match &field.ty {
Type::Path(p) => {
let is_numeric = match field_type(field).unwrap().as_str() {
"u8" | "i8" | "u16" | "i16" | "i32" | "u32" | "i64" | "u64" | "i128" | "u128" => true,
_ => false
};
return is_numeric;
}
_ => false
}
}

#[proc_macro_derive(WithStackable, attributes(skip_sum, value_offset))]
pub fn stackable_enum(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let enum_name = &input.ident;

if let Enum(enum_data) = &input.data {
let get_value_sum_arms = enum_data.variants.iter().filter(|variant| matches!(&variant.fields, syn::Fields::Unnamed(..))).map(|variant| {
let variant_name = variant.ident.clone();
if let syn::Fields::Unnamed(fields) = &variant.fields {
let value_offset = if let Some(_) = variant.attrs.iter().find(|attr| {
attr.path().is_ident("value_offset")
}) {
get_number_value(variant, "value_offset").unwrap_or(1)
} else {
if fields.unnamed.len() == 1 {
0
} else {
1
}
};
if is_numeric(&fields.unnamed[value_offset]) {
let mut args1 = fields.unnamed.iter().map(|_| quote! {_}).collect::<Vec<TokenStream2>>();
args1[value_offset] = quote! {val};
quote! {#enum_name::#variant_name(#(#args1,)*) => Some(*val as f32), }
} else {
quote! {}
}
} else {
panic!("patterns `Fields::Named(_)` and `Fields::Unit` not covered")
}
});
let get_value_sum_return_arms = enum_data.variants.iter().map(|variant| {
let variant_name = variant.ident.clone();
if let syn::Fields::Unnamed(fields) = &variant.fields {
let value_offset = if let Some(_) = variant.attrs.iter().find(|attr| {
attr.path().is_ident("value_offset")
}) {
get_number_value(variant, "value_offset").unwrap_or(1)
} else {
if fields.unnamed.len() == 1 {
0
} else {
1
}
};
if is_numeric(&fields.unnamed[value_offset]) {
let mut args1 = fields.unnamed.iter().enumerate().map(|(i, _)| {
let v = Ident::new(format!("value{}", i).as_str(), Span::call_site());
quote! {#v}
}).collect::<Vec<TokenStream2>>();
let mut args2 = args1.clone();
let val_type = Ident::new(field_type(&fields.unnamed[value_offset]).unwrap().as_str(), Span::call_site());
args1[value_offset] = quote! {_};
args2[value_offset] = quote! {&val as #val_type};
quote! {#enum_name::#variant_name(#(#args1,)*) => #enum_name::#variant_name(#(*#args2,)*), }
} else {
let mut args1 = fields.unnamed.iter().enumerate().map(|(i, _)| {
let v = Ident::new(format!("value{}", i).as_str(), Span::call_site());
quote! {#v}
}).collect::<Vec<TokenStream2>>();
quote! {#enum_name::#variant_name(#(#args1,)*) => #enum_name::#variant_name(#(*#args1,)*), }
}

} else {
quote! {#enum_name::#variant_name => #enum_name::#variant_name, }
}
});
let get_enum_value = enum_data.variants.iter().filter(|variant| matches!(&variant.fields, syn::Fields::Unnamed(..))).map(|variant| {
let variant_name = variant.ident.clone();
if let syn::Fields::Unnamed(fields) = &variant.fields {
let value_offset = if let Some(_) = variant.attrs.iter().find(|attr| {
attr.path().is_ident("value_offset")
}) {
get_number_value(variant, "value_offset").unwrap_or(1)
} else {
if fields.unnamed.len() == 1 {
0
} else {
1
}
};
if is_numeric(&fields.unnamed[value_offset]) {
let mut args1 = fields.unnamed.iter().enumerate().map(|(i, _)| {
quote! {_}
}).collect::<Vec<TokenStream2>>();
args1[value_offset] = quote! {val};
quote! {#enum_name::#variant_name(#(#args1,)*) => *val as f32, }
} else {
quote!{}
}

} else {
panic!("patterns `Fields::Named(_)` and `Fields::Unit` not covered")
}
});
TokenStream::from(quote! {
impl EnumStackable<#enum_name> for #enum_name {
fn get_value_sum(single_enum: &#enum_name, enums: &Vec<#enum_name>) -> #enum_name {
let val: f32 = enums.into_iter().filter_map(|e|
if e == single_enum {
match e {
#(#get_value_sum_arms)*
_ => None
}
} else {
None
}
).sum();
match single_enum {
#(#get_value_sum_return_arms)*
}
}
fn get_enum_value<'a>(single_enum: &#enum_name, enums: &'a Vec<&#enum_name>) -> Option<f32> {
Self::get_enum(single_enum, enums).map(|b| match b {
#(#get_enum_value)*
_ => 0.0
})
}
}
})
} else {
TokenStream::from(quote! {})
}
}
Loading

0 comments on commit 087aeff

Please sign in to comment.