diff --git a/strum_macros/src/helpers/metadata.rs b/strum_macros/src/helpers/metadata.rs index 5a44c78..f53d99b 100644 --- a/strum_macros/src/helpers/metadata.rs +++ b/strum_macros/src/helpers/metadata.rs @@ -18,6 +18,8 @@ pub mod kw { custom_keyword!(serialize_all); custom_keyword!(use_phf); custom_keyword!(prefix); + custom_keyword!(parse_err_ty); + custom_keyword!(parse_err_fn); // enum discriminant metadata custom_keyword!(derive); @@ -51,6 +53,14 @@ pub enum EnumMeta { kw: kw::prefix, prefix: LitStr, }, + ParseErrTy { + kw: kw::parse_err_ty, + path: Path, + }, + ParseErrFn { + kw: kw::parse_err_fn, + path: Path, + }, } impl Parse for EnumMeta { @@ -80,6 +90,20 @@ impl Parse for EnumMeta { input.parse::()?; let prefix = input.parse()?; Ok(EnumMeta::Prefix { kw, prefix }) + } else if lookahead.peek(kw::parse_err_ty) { + let kw = input.parse::()?; + input.parse::()?; + let path_str: LitStr = input.parse()?; + let path_tokens = parse_str(&path_str.value())?; + let path = parse2(path_tokens)?; + Ok(EnumMeta::ParseErrTy { kw, path }) + } else if lookahead.peek(kw::parse_err_fn) { + let kw = input.parse::()?; + input.parse::()?; + let path_str: LitStr = input.parse()?; + let path_tokens = parse_str(&path_str.value())?; + let path = parse2(path_tokens)?; + Ok(EnumMeta::ParseErrFn { kw, path }) } else { Err(lookahead.error()) } diff --git a/strum_macros/src/helpers/mod.rs b/strum_macros/src/helpers/mod.rs index 23d60b5..a00e8e4 100644 --- a/strum_macros/src/helpers/mod.rs +++ b/strum_macros/src/helpers/mod.rs @@ -13,6 +13,13 @@ use proc_macro2::Span; use quote::ToTokens; use syn::spanned::Spanned; +pub fn missing_parse_err_attr_error() -> syn::Error { + syn::Error::new( + Span::call_site(), + "`parse_err_ty` and `parse_err_fn` attribute is both required.", + ) +} + pub fn non_enum_error() -> syn::Error { syn::Error::new(Span::call_site(), "This macro only supports enums.") } diff --git a/strum_macros/src/helpers/type_props.rs b/strum_macros/src/helpers/type_props.rs index cdc7a8c..60ce244 100644 --- a/strum_macros/src/helpers/type_props.rs +++ b/strum_macros/src/helpers/type_props.rs @@ -13,6 +13,8 @@ pub trait HasTypeProperties { #[derive(Clone, Default)] pub struct StrumTypeProperties { + pub parse_err_ty: Option, + pub parse_err_fn: Option, pub case_style: Option, pub ascii_case_insensitive: bool, pub crate_module_path: Option, @@ -32,6 +34,8 @@ impl HasTypeProperties for DeriveInput { let strum_meta = self.get_metadata()?; let discriminants_meta = self.get_discriminants_metadata()?; + let mut parse_err_ty_kw = None; + let mut parse_err_fn_kw = None; let mut serialize_all_kw = None; let mut ascii_case_insensitive_kw = None; let mut use_phf_kw = None; @@ -82,6 +86,22 @@ impl HasTypeProperties for DeriveInput { prefix_kw = Some(kw); output.prefix = Some(prefix); } + EnumMeta::ParseErrTy { path, kw } => { + if let Some(fst_kw) = parse_err_ty_kw { + return Err(occurrence_error(fst_kw, kw, "parse_err_ty")); + } + + parse_err_ty_kw = Some(kw); + output.parse_err_ty = Some(path); + } + EnumMeta::ParseErrFn { path, kw } => { + if let Some(fst_kw) = parse_err_fn_kw { + return Err(occurrence_error(fst_kw, kw, "parse_err_fn")); + } + + parse_err_fn_kw = Some(kw); + output.parse_err_fn = Some(path); + } } } diff --git a/strum_macros/src/macros/strings/from_string.rs b/strum_macros/src/macros/strings/from_string.rs index 21307df..9b95567 100644 --- a/strum_macros/src/macros/strings/from_string.rs +++ b/strum_macros/src/macros/strings/from_string.rs @@ -1,10 +1,10 @@ use proc_macro2::TokenStream; use quote::quote; -use syn::{Data, DeriveInput, Fields}; +use syn::{parse_quote, Data, DeriveInput, Fields, Path}; use crate::helpers::{ - non_enum_error, occurrence_error, HasInnerVariantProperties, HasStrumVariantProperties, - HasTypeProperties, + missing_parse_err_attr_error, non_enum_error, occurrence_error, HasInnerVariantProperties, + HasStrumVariantProperties, HasTypeProperties, }; pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { @@ -19,9 +19,25 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { let strum_module_path = type_properties.crate_module_path(); let mut default_kw = None; - let mut default = - quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }; - + let (mut default_err_ty, mut default) = match ( + type_properties.parse_err_ty, + type_properties.parse_err_fn, + ) { + (None, None) => ( + quote! { #strum_module_path::ParseError }, + quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }, + ), + (Some(ty), Some(f)) => { + let ty_path: Path = parse_quote!(#ty); + let fn_path: Path = parse_quote!(#f); + + ( + quote! { #ty_path }, + quote! { ::core::result::Result::Err(#fn_path(s)) }, + ) + } + _ => return Err(missing_parse_err_attr_error()), + }; let mut phf_exact_match_arms = Vec::new(); let mut standard_match_arms = Vec::new(); for variant in variants { @@ -47,6 +63,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { } } default_kw = Some(kw); + default_err_ty = quote! { #strum_module_path::ParseError }; default = quote! { ::core::result::Result::Ok(#name::#ident(s.into())) }; @@ -146,7 +163,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { let from_str = quote! { #[allow(clippy::use_self)] impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause { - type Err = #strum_module_path::ParseError; + type Err = #default_err_ty; #[inline] fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , ::Err> { @@ -160,7 +177,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { &impl_generics, &ty_generics, where_clause, - &strum_module_path, + &default_err_ty, ); Ok(quote! { @@ -186,12 +203,12 @@ fn try_from_str( impl_generics: &syn::ImplGenerics, ty_generics: &syn::TypeGenerics, where_clause: Option<&syn::WhereClause>, - strum_module_path: &syn::Path, + default_err_ty: &TokenStream, ) -> TokenStream { quote! { #[allow(clippy::use_self)] impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause { - type Error = #strum_module_path::ParseError; + type Error = #default_err_ty; #[inline] fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , >::Error> { diff --git a/strum_tests/tests/from_str.rs b/strum_tests/tests/from_str.rs index 734282b..ff4ffeb 100644 --- a/strum_tests/tests/from_str.rs +++ b/strum_tests/tests/from_str.rs @@ -229,3 +229,36 @@ fn color_default_with_white() { } } } + +#[derive(Debug, EnumString)] +#[strum( + parse_err_fn = "some_enum_not_found_err", + parse_err_ty = "CaseCustomParseErrorNotFoundError" +)] +enum CaseCustomParseErrorEnum { + #[strum(serialize = "red")] + Red, + #[strum(serialize = "blue")] + Blue, +} +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +struct CaseCustomParseErrorNotFoundError(String); +impl std::fmt::Display for CaseCustomParseErrorNotFoundError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "not found `{}`", self.0) + } +} +impl std::error::Error for CaseCustomParseErrorNotFoundError {} +fn some_enum_not_found_err(s: &str) -> CaseCustomParseErrorNotFoundError { + CaseCustomParseErrorNotFoundError(s.to_string()) +} + +#[test] +fn case_custom_parse_error() { + let r = "yellow".parse::(); + assert!(r.is_err()); + assert_eq!( + CaseCustomParseErrorNotFoundError("yellow".to_string()), + r.unwrap_err() + ); +}