diff --git a/strum_macros/src/lib.rs b/strum_macros/src/lib.rs index cc1c2a1d..a37ccdcc 100644 --- a/strum_macros/src/lib.rs +++ b/strum_macros/src/lib.rs @@ -384,6 +384,31 @@ pub fn enum_iter(input: proc_macro::TokenStream) -> proc_macro::TokenStream { toks.into() } +/// Generated `is_*()` methods for each variant. +/// E.g. `Color.is_red()`. +/// +/// ``` +/// +/// use strum_macros::EnumIs; +/// +/// #[derive(EnumIs, Debug)] +/// enum Color { +/// Red, +/// Green { range: usize }, +/// } +/// +/// assert!(Color::Red.is_red()); +/// assert!(Color::Green{range: 0}.is_green()); +/// ``` +#[proc_macro_derive(EnumIs, attributes(strum))] +pub fn enum_is(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ast = syn::parse_macro_input!(input as DeriveInput); + + let toks = macros::enum_is::enum_is_inner(&ast).unwrap_or_else(|err| err.to_compile_error()); + debug_print_generated(&ast, &toks); + toks.into() +} + /// Add a function to enum that allows accessing variants by its discriminant /// /// This macro adds a standalone function to obtain an enum variant by its discriminant. The macro adds diff --git a/strum_macros/src/macros/enum_is.rs b/strum_macros/src/macros/enum_is.rs new file mode 100644 index 00000000..bde38519 --- /dev/null +++ b/strum_macros/src/macros/enum_is.rs @@ -0,0 +1,61 @@ +use crate::helpers::{non_enum_error, HasStrumVariantProperties}; +use heck::ToSnakeCase; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{Data, DeriveInput}; + +pub fn enum_is_inner(ast: &DeriveInput) -> syn::Result { + let variants = match &ast.data { + Data::Enum(v) => &v.variants, + _ => return Err(non_enum_error()), + }; + + let enum_name = &ast.ident; + + let variants: Vec<_> = variants + .iter() + .filter_map(|variant| { + if variant.get_variant_properties().ok()?.disabled.is_some() { + return None; + } + + let variant_name = &variant.ident; + let fn_name = format_ident!("is_{}", snakify(&variant_name.to_string())); + + Some(quote! { + #[must_use] + #[inline] + pub const fn #fn_name(&self) -> bool { + match self { + &#enum_name::#variant_name { .. } => true, + _ => false + } + } + }) + }) + .collect(); + + Ok(quote! { + impl #enum_name { + #(#variants)* + } + } + .into()) +} + +/// heck doesn't treat numbers as new words, but this function does. +/// E.g. for input `Hello2You`, heck would output `hello2_you`, and snakify would output `hello_2_you`. +fn snakify(s: &str) -> String { + let mut output: Vec = s.to_string().to_snake_case().chars().collect(); + let mut num_starts = vec![]; + for (pos, c) in output.iter().enumerate() { + if c.is_digit(10) && pos != 0 && !output[pos - 1].is_digit(10) { + num_starts.push(pos); + } + } + // need to do in reverse, because after inserting, all chars after the point of insertion are off + for i in num_starts.into_iter().rev() { + output.insert(i, '_') + } + output.into_iter().collect() +} diff --git a/strum_macros/src/macros/mod.rs b/strum_macros/src/macros/mod.rs index b4129697..a44be083 100644 --- a/strum_macros/src/macros/mod.rs +++ b/strum_macros/src/macros/mod.rs @@ -1,5 +1,6 @@ pub mod enum_count; pub mod enum_discriminants; +pub mod enum_is; pub mod enum_iter; pub mod enum_messages; pub mod enum_properties; diff --git a/strum_tests/tests/enum_is.rs b/strum_tests/tests/enum_is.rs new file mode 100644 index 00000000..186cb23c --- /dev/null +++ b/strum_tests/tests/enum_is.rs @@ -0,0 +1,70 @@ +use strum::EnumIs; + +#[derive(EnumIs)] +enum Foo { + Unit, + Named0 {}, + Named1 { _a: char }, + Named2 { _a: u32, _b: String }, + Unnamed0(), + Unnamed1(Option), + Unnamed2(bool, u8), + MultiWordName, + #[strum(disabled)] + #[allow(dead_code)] + Disabled, +} + +#[test] +fn simple_test() { + assert!(Foo::Unit.is_unit()); +} + +#[test] +fn named_0() { + assert!(Foo::Named0 {}.is_named_0()); +} + +#[test] +fn named_1() { + let foo = Foo::Named1 { + _a: Default::default(), + }; + assert!(foo.is_named_1()); +} + +#[test] +fn named_2() { + let foo = Foo::Named2 { + _a: Default::default(), + _b: Default::default(), + }; + assert!(foo.is_named_2()); +} + +#[test] +fn unnamed_0() { + assert!(Foo::Unnamed0().is_unnamed_0()); +} + +#[test] +fn unnamed_1() { + let foo = Foo::Unnamed1(Default::default()); + assert!(foo.is_unnamed_1()); +} + +#[test] +fn unnamed_2() { + let foo = Foo::Unnamed2(Default::default(), Default::default()); + assert!(foo.is_unnamed_2()); +} + +#[test] +fn multi_word() { + assert!(Foo::MultiWordName.is_multi_word_name()); +} + +#[test] +fn doesnt_match_other_variations() { + assert!(!Foo::Unit.is_multi_word_name()); +}