diff --git a/crates/bevy_asset/src/handle.rs b/crates/bevy_asset/src/handle.rs index f124bbde02762..ea534c6d89d25 100644 --- a/crates/bevy_asset/src/handle.rs +++ b/crates/bevy_asset/src/handle.rs @@ -97,7 +97,7 @@ impl HandleId { /// #[derive(Component, Reflect)] #[reflect(Component, Default)] -#[reflect(ignore_params(T))] +#[reflect(custom_where(T: Asset))] pub struct Handle where T: Asset, diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs b/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs index 0401a4720948d..e34688a177ab7 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs @@ -9,12 +9,11 @@ use crate::fq_std::{FQAny, FQOption}; use crate::utility; use proc_macro2::{Ident, Span}; use quote::quote_spanned; -use std::collections::HashSet; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::token::Comma; -use syn::{Generics, LitBool, Meta, Path}; +use syn::{LitBool, Meta, Path, Token, WherePredicate}; // The "special" trait idents that are used internally for reflection. // Received via attributes like `#[reflect(PartialEq, Hash, ...)]` @@ -30,7 +29,7 @@ pub(crate) const REFLECT_DEFAULT: &str = "ReflectDefault"; const FROM_REFLECT_ATTR: &str = "from_reflect"; // Attributes for `Reflect` implementation -const IGNORE_PARAMS_ATTR: &str = "ignore_params"; +const CUSTOM_WHERE_ATTR: &str = "custom_where"; // The error message to show when a trait/type is specified multiple times const CONFLICTING_TYPE_DATA_MESSAGE: &str = "conflicting type data registration"; @@ -170,7 +169,7 @@ pub(crate) struct ReflectTraits { hash: TraitImpl, partial_eq: TraitImpl, from_reflect: FromReflectAttrs, - ignored_params: HashSet, + custom_where: Option>, idents: Vec, } @@ -211,11 +210,11 @@ impl ReflectTraits { } } } - // Handles `#[reflect(ignore_params(T, U))]` - Meta::List(list) if list.path.is_ident(IGNORE_PARAMS_ATTR) => { - let params: Punctuated = + // Handles `#[reflect(custom_where(T: Trait, U::Assoc: Trait))]` + Meta::List(list) if list.path.is_ident(CUSTOM_WHERE_ATTR) => { + let predicate: Punctuated = list.parse_args_with(Punctuated::parse_separated_nonempty)?; - traits.ignored_params.extend(params); + traits.merge_custom_where(Some(predicate)); } // Handles `#[reflect( Debug(custom_debug_fn) )]` Meta::List(list) if list.path.is_ident(DEBUG_ATTR) => { @@ -244,7 +243,7 @@ impl ReflectTraits { Meta::List(list) => { return Err(syn::Error::new_spanned( list, - format!("expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}, {IGNORE_PARAMS_ATTR:?}]") + format!("expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}, {CUSTOM_WHERE_ATTR:?}]") )); } Meta::NameValue(pair) => { @@ -364,8 +363,8 @@ impl ReflectTraits { } } - pub fn ignore_param(&self, param: &Ident) -> bool { - self.ignored_params.contains(param) + pub fn custom_where(&self) -> Option<&Punctuated> { + self.custom_where.as_ref() } /// Merges the trait implementations of this [`ReflectTraits`] with another one. @@ -376,43 +375,25 @@ impl ReflectTraits { self.hash.merge(other.hash)?; self.partial_eq.merge(other.partial_eq)?; self.from_reflect.merge(other.from_reflect)?; - self.ignored_params.extend(other.ignored_params); + + self.merge_custom_where(other.custom_where); + for ident in other.idents { add_unique_ident(&mut self.idents, ident)?; } Ok(()) } - /// Validates that any ignored type parameters are valid for the given set of generics. - pub fn validate_ignored_params(&self, generics: &Generics) -> Result<(), syn::Error> { - if self.ignored_params.is_empty() { - return Ok(()); - } - - let mut params = self.ignored_params.clone(); - - for param in generics.type_params() { - params.remove(¶m.ident); - } - - if params.is_empty() { - return Ok(()); - } - - let mut errors: Option = None; - for param in params { - let err = syn::Error::new_spanned( - ¶m, - format!("`{}` is not a valid type parameter", param), - ); - if let Some(error) = &mut errors { - error.combine(err); - } else { - errors = Some(err); + fn merge_custom_where(&mut self, other: Option>) { + match (&mut self.custom_where, other) { + (Some(this), Some(other)) => { + this.extend(other); + } + (None, Some(other)) => { + self.custom_where = Some(other); } + _ => {} } - - Err(errors.unwrap()) } } diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs b/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs index 3cfc3cde476b3..8421402556f70 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs @@ -247,8 +247,6 @@ impl<'a> ReflectDerive<'a> { _ => (), } - traits.validate_ignored_params(&input.generics)?; - let type_path = ReflectTypePath::Internal { ident: &input.ident, custom_path, diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs index fc77d484159f1..323fc7c340689 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs @@ -75,17 +75,6 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name"; /// This is often used with traits that have been marked by the [`#[reflect_trait]`](macro@reflect_trait) /// macro in order to register the type's implementation of that trait. /// -/// ## `#[reflect(ignore_params(T, U, ...))]` -/// -/// This derive macro will automatically add the necessary bounds to any generic type parameters -/// in order to make them compatible with reflection. -/// However, this may not always be desired and some type paramters are not meant to be reflected -/// (i.e. their usages in fields are ignored or they're only used for their associated types). -/// -/// Using this attribute, type parameters can opt out of receiving these automatic bounds. -/// Note that they will receive the bounds that are still considered absolutely necessary, -/// such as `Send`, `Sync`, `Any`, and `TypePath`. -/// /// ### Default Registrations /// /// The following types are automatically registered when deriving `Reflect`: @@ -139,6 +128,37 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name"; /// /// Note that in the latter case, `ReflectFromReflect` will no longer be automatically registered. /// +/// ## `#[reflect(custom_where(T: Trait, U::Assoc: Trait, ...))]` +/// +/// By default, the derive macro will automatically add the necessary bounds to any generic type parameters +/// in order to make them compatible with reflection. +/// However, this may not always be desired, and some type paramaters can't or shouldn't require those bounds +/// (i.e. their usages in fields are ignored or they're only used for their associated types). +/// +/// With this attribute, you can specify a custom `where` clause to be used instead of the default. +/// Any parameter used in the attribute will not be given the default bounds, +/// and use the ones defined in the attribute instead. +/// Type parameters not used in the attribute will still receive the default bounds. +/// +/// Note that all type parameters will receive the bounds that are still considered absolutely necessary, +/// such as `Send`, `Sync`, `Any`, and `TypePath`. +/// +/// ### Example +/// +/// ```ignore +/// trait Trait { +/// type Assoc; +/// } +/// +/// #[derive(Reflect)] +/// // Note: We add the `T: 'static` bound here to fully opt-out of the automatic bounds on `T`. +/// // We could use almost any trait here, but `'static` is the easiest as it's always available. +/// #[reflect(custom_where(T: 'static, T::Assoc: FromReflect))] +/// struct Foo { +/// value: T::Assoc, +/// } +/// ``` +/// /// # Field Attributes /// /// Along with the container attributes, this macro comes with some attributes that may be applied @@ -152,6 +172,10 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name"; /// which may be useful for maintaining invariants, keeping certain data private, /// or allowing the use of types that do not implement `Reflect` within the container. /// +/// If the field contains a generic type parameter, you will likely need to add a +/// [`#[reflect(custom_where(...))]`](#reflectcustom_wheret-trait-uassoc-trait-) +/// attribute to the container in order to avoid the default bounds being applied to the type parameter. +/// /// ## `#[reflect(skip_serializing)]` /// /// This works similar to `#[reflect(ignore)]`, but rather than opting out of _all_ of reflection, diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs b/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs index 4314510c76ac7..e1d85bb10947c 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs @@ -7,7 +7,11 @@ use bevy_macro_utils::BevyManifest; use bit_set::BitSet; use proc_macro2::{Ident, Span}; use quote::{quote, ToTokens}; -use syn::{spanned::Spanned, LitStr, Member, Path, TypeParam, WhereClause}; +use std::collections::HashSet; +use syn::punctuated::Punctuated; +use syn::{ + spanned::Spanned, LitStr, Member, Path, Token, Type, TypeParam, WhereClause, WherePredicate, +}; /// Returns the correct path for `bevy_reflect`. pub(crate) fn get_bevy_reflect_path() -> Path { @@ -70,6 +74,7 @@ pub(crate) struct WhereClauseOptions { ignored_types: Vec, /// Trait bounds to add to the ignored types ignored_trait_bounds: Vec, + custom_where: Option>, } impl Default for WhereClauseOptions { @@ -80,6 +85,7 @@ impl Default for WhereClauseOptions { ignored_types: Vec::new(), active_trait_bounds: Vec::new(), ignored_trait_bounds: Vec::new(), + custom_where: None, } } } @@ -125,9 +131,24 @@ impl WhereClauseOptions { ) -> Self { let mut options = WhereClauseOptions::default(); + let skip_params = if let Some(custom_where) = meta.traits().custom_where() { + custom_where + .iter() + .filter_map(|predicate| match predicate { + WherePredicate::Type(predicate_ty) => match &predicate_ty.bounded_ty { + Type::Path(ty_path) => ty_path.path.get_ident().cloned(), + _ => None, + }, + _ => None, + }) + .collect() + } else { + HashSet::new() + }; + for param in meta.type_path().generics().type_params() { let ident = param.ident.clone(); - let ignored = meta.traits().ignore_param(&ident); + let ignored = skip_params.contains(&ident); if ignored { let bounds = ignored_bounds(param).unwrap_or_default(); @@ -142,6 +163,8 @@ impl WhereClauseOptions { } } + options.custom_where = meta.traits().custom_where().cloned(); + options } } @@ -150,8 +173,8 @@ impl WhereClauseOptions { /// /// This is mostly used to add additional bounds to reflected objects with generic types. /// For reflection purposes, we usually have: -/// * active_trait_bounds: `Reflect + TypePath` or `FromReflect + TypePath` -/// * ignored_trait_bounds: `TypePath + Any + Send + Sync` +/// * `active_trait_bounds`: `Reflect + TypePath` or `FromReflect + TypePath` +/// * `ignored_trait_bounds`: `TypePath + Any + Send + Sync` /// /// # Arguments /// @@ -193,9 +216,12 @@ pub(crate) fn extend_where_clause( quote!(where Self: 'static,) }; + let custom_where = &where_clause_options.custom_where; + generic_where_clause.extend(quote! { #(#active_types: #active_trait_bounds,)* #(#ignored_types: #ignored_trait_bounds,)* + #custom_where }); generic_where_clause } diff --git a/crates/bevy_reflect/src/lib.rs b/crates/bevy_reflect/src/lib.rs index 5326a2ca811a8..27f772e1987ed 100644 --- a/crates/bevy_reflect/src/lib.rs +++ b/crates/bevy_reflect/src/lib.rs @@ -1878,9 +1878,9 @@ bevy_reflect::tests::should_reflect_debug::Test { } #[test] - fn should_allow_ignored_params() { + fn should_allow_custom_where() { #[derive(Reflect)] - #[reflect(ignore_params(T))] + #[reflect(custom_where(T: 'static))] struct Foo(String, #[reflect(ignore)] PhantomData); #[derive(TypePath)] @@ -1894,14 +1894,28 @@ bevy_reflect::tests::should_reflect_debug::Test { } #[test] - fn should_allow_ignored_params_wtih_assoc_type() { + fn should_allow_multiple_custom_where() { + #[derive(Reflect)] + #[reflect(custom_where(T: Default + FromReflect))] + #[reflect(custom_where(U: std::ops::Add + FromReflect))] + struct Foo(T, U); + + #[derive(Reflect)] + struct Baz { + a: Foo, + b: Foo, + } + } + + #[test] + fn should_allow_custom_where_wtih_assoc_type() { trait Trait { - type Assoc: Reflect + FromReflect; + type Assoc; } // We don't need `T` to be `Reflect` since we only care about `T::Assoc` #[derive(Reflect)] - #[reflect(ignore_params(T))] + #[reflect(custom_where(T: 'static, T::Assoc: FromReflect))] struct Foo(T::Assoc); #[derive(TypePath)]