Skip to content

Commit

Permalink
Replace ignore_params with custom_where
Browse files Browse the repository at this point in the history
  • Loading branch information
MrGVSV committed Aug 1, 2023
1 parent b5f8c15 commit 37eafe9
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 63 deletions.
2 changes: 1 addition & 1 deletion crates/bevy_asset/src/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl HandleId {
///
#[derive(Component, Reflect)]
#[reflect(Component, Default)]
#[reflect(ignore_params(T))]
#[reflect(custom_where(T: Asset))]
pub struct Handle<T>
where
T: Asset,
Expand Down
61 changes: 21 additions & 40 deletions crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ use crate::fq_std::{FQAny, FQOption};
use crate::utility;
use proc_macro2::{Ident, Span};
use quote::quote_spanned;
use std::collections::HashSet;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{Generics, LitBool, Meta, Path};
use syn::{LitBool, Meta, Path, Token, WherePredicate};

// The "special" trait idents that are used internally for reflection.
// Received via attributes like `#[reflect(PartialEq, Hash, ...)]`
Expand All @@ -30,7 +29,7 @@ pub(crate) const REFLECT_DEFAULT: &str = "ReflectDefault";
const FROM_REFLECT_ATTR: &str = "from_reflect";

// Attributes for `Reflect` implementation
const IGNORE_PARAMS_ATTR: &str = "ignore_params";
const CUSTOM_WHERE_ATTR: &str = "custom_where";

// The error message to show when a trait/type is specified multiple times
const CONFLICTING_TYPE_DATA_MESSAGE: &str = "conflicting type data registration";
Expand Down Expand Up @@ -170,7 +169,7 @@ pub(crate) struct ReflectTraits {
hash: TraitImpl,
partial_eq: TraitImpl,
from_reflect: FromReflectAttrs,
ignored_params: HashSet<Ident>,
custom_where: Option<Punctuated<WherePredicate, Token![,]>>,
idents: Vec<Ident>,
}

Expand Down Expand Up @@ -211,11 +210,11 @@ impl ReflectTraits {
}
}
}
// Handles `#[reflect(ignore_params(T, U))]`
Meta::List(list) if list.path.is_ident(IGNORE_PARAMS_ATTR) => {
let params: Punctuated<Ident, Comma> =
// Handles `#[reflect(custom_where(T: Trait, U::Assoc: Trait))]`
Meta::List(list) if list.path.is_ident(CUSTOM_WHERE_ATTR) => {
let predicate: Punctuated<WherePredicate, Token![,]> =
list.parse_args_with(Punctuated::parse_separated_nonempty)?;
traits.ignored_params.extend(params);
traits.merge_custom_where(Some(predicate));
}
// Handles `#[reflect( Debug(custom_debug_fn) )]`
Meta::List(list) if list.path.is_ident(DEBUG_ATTR) => {
Expand Down Expand Up @@ -244,7 +243,7 @@ impl ReflectTraits {
Meta::List(list) => {
return Err(syn::Error::new_spanned(
list,
format!("expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}, {IGNORE_PARAMS_ATTR:?}]")
format!("expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}, {CUSTOM_WHERE_ATTR:?}]")
));
}
Meta::NameValue(pair) => {
Expand Down Expand Up @@ -364,8 +363,8 @@ impl ReflectTraits {
}
}

pub fn ignore_param(&self, param: &Ident) -> bool {
self.ignored_params.contains(param)
pub fn custom_where(&self) -> Option<&Punctuated<WherePredicate, Token![,]>> {
self.custom_where.as_ref()
}

/// Merges the trait implementations of this [`ReflectTraits`] with another one.
Expand All @@ -376,43 +375,25 @@ impl ReflectTraits {
self.hash.merge(other.hash)?;
self.partial_eq.merge(other.partial_eq)?;
self.from_reflect.merge(other.from_reflect)?;
self.ignored_params.extend(other.ignored_params);

self.merge_custom_where(other.custom_where);

for ident in other.idents {
add_unique_ident(&mut self.idents, ident)?;
}
Ok(())
}

/// Validates that any ignored type parameters are valid for the given set of generics.
pub fn validate_ignored_params(&self, generics: &Generics) -> Result<(), syn::Error> {
if self.ignored_params.is_empty() {
return Ok(());
}

let mut params = self.ignored_params.clone();

for param in generics.type_params() {
params.remove(&param.ident);
}

if params.is_empty() {
return Ok(());
}

let mut errors: Option<syn::Error> = None;
for param in params {
let err = syn::Error::new_spanned(
&param,
format!("`{}` is not a valid type parameter", param),
);
if let Some(error) = &mut errors {
error.combine(err);
} else {
errors = Some(err);
fn merge_custom_where(&mut self, other: Option<Punctuated<WherePredicate, Token![,]>>) {
match (&mut self.custom_where, other) {
(Some(this), Some(other)) => {
this.extend(other);
}
(None, Some(other)) => {
self.custom_where = Some(other);
}
_ => {}
}

Err(errors.unwrap())
}
}

Expand Down
2 changes: 0 additions & 2 deletions crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ impl<'a> ReflectDerive<'a> {
_ => (),
}

traits.validate_ignored_params(&input.generics)?;

let type_path = ReflectTypePath::Internal {
ident: &input.ident,
custom_path,
Expand Down
46 changes: 35 additions & 11 deletions crates/bevy_reflect/bevy_reflect_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,6 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
/// This is often used with traits that have been marked by the [`#[reflect_trait]`](macro@reflect_trait)
/// macro in order to register the type's implementation of that trait.
///
/// ## `#[reflect(ignore_params(T, U, ...))]`
///
/// This derive macro will automatically add the necessary bounds to any generic type parameters
/// in order to make them compatible with reflection.
/// However, this may not always be desired and some type paramters are not meant to be reflected
/// (i.e. their usages in fields are ignored or they're only used for their associated types).
///
/// Using this attribute, type parameters can opt out of receiving these automatic bounds.
/// Note that they will receive the bounds that are still considered absolutely necessary,
/// such as `Send`, `Sync`, `Any`, and `TypePath`.
///
/// ### Default Registrations
///
/// The following types are automatically registered when deriving `Reflect`:
Expand Down Expand Up @@ -139,6 +128,37 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
///
/// Note that in the latter case, `ReflectFromReflect` will no longer be automatically registered.
///
/// ## `#[reflect(custom_where(T: Trait, U::Assoc: Trait, ...))]`
///
/// By default, the derive macro will automatically add the necessary bounds to any generic type parameters
/// in order to make them compatible with reflection.
/// However, this may not always be desired, and some type paramaters can't or shouldn't require those bounds
/// (i.e. their usages in fields are ignored or they're only used for their associated types).
///
/// With this attribute, you can specify a custom `where` clause to be used instead of the default.
/// Any parameter used in the attribute will not be given the default bounds,
/// and use the ones defined in the attribute instead.
/// Type parameters not used in the attribute will still receive the default bounds.
///
/// Note that all type parameters will receive the bounds that are still considered absolutely necessary,
/// such as `Send`, `Sync`, `Any`, and `TypePath`.
///
/// ### Example
///
/// ```ignore
/// trait Trait {
/// type Assoc;
/// }
///
/// #[derive(Reflect)]
/// // Note: We add the `T: 'static` bound here to fully opt-out of the automatic bounds on `T`.
/// // We could use almost any trait here, but `'static` is the easiest as it's always available.
/// #[reflect(custom_where(T: 'static, T::Assoc: FromReflect))]
/// struct Foo<T: Trait> {
/// value: T::Assoc,
/// }
/// ```
///
/// # Field Attributes
///
/// Along with the container attributes, this macro comes with some attributes that may be applied
Expand All @@ -152,6 +172,10 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
/// which may be useful for maintaining invariants, keeping certain data private,
/// or allowing the use of types that do not implement `Reflect` within the container.
///
/// If the field contains a generic type parameter, you will likely need to add a
/// [`#[reflect(custom_where(...))]`](#reflectcustom_wheret-trait-uassoc-trait-)
/// attribute to the container in order to avoid the default bounds being applied to the type parameter.
///
/// ## `#[reflect(skip_serializing)]`
///
/// This works similar to `#[reflect(ignore)]`, but rather than opting out of _all_ of reflection,
Expand Down
34 changes: 30 additions & 4 deletions crates/bevy_reflect/bevy_reflect_derive/src/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ use bevy_macro_utils::BevyManifest;
use bit_set::BitSet;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use syn::{spanned::Spanned, LitStr, Member, Path, TypeParam, WhereClause};
use std::collections::HashSet;
use syn::punctuated::Punctuated;
use syn::{
spanned::Spanned, LitStr, Member, Path, Token, Type, TypeParam, WhereClause, WherePredicate,
};

/// Returns the correct path for `bevy_reflect`.
pub(crate) fn get_bevy_reflect_path() -> Path {
Expand Down Expand Up @@ -70,6 +74,7 @@ pub(crate) struct WhereClauseOptions {
ignored_types: Vec<Ident>,
/// Trait bounds to add to the ignored types
ignored_trait_bounds: Vec<proc_macro2::TokenStream>,
custom_where: Option<Punctuated<WherePredicate, Token![,]>>,
}

impl Default for WhereClauseOptions {
Expand All @@ -80,6 +85,7 @@ impl Default for WhereClauseOptions {
ignored_types: Vec::new(),
active_trait_bounds: Vec::new(),
ignored_trait_bounds: Vec::new(),
custom_where: None,
}
}
}
Expand Down Expand Up @@ -125,9 +131,24 @@ impl WhereClauseOptions {
) -> Self {
let mut options = WhereClauseOptions::default();

let skip_params = if let Some(custom_where) = meta.traits().custom_where() {
custom_where
.iter()
.filter_map(|predicate| match predicate {
WherePredicate::Type(predicate_ty) => match &predicate_ty.bounded_ty {
Type::Path(ty_path) => ty_path.path.get_ident().cloned(),
_ => None,
},
_ => None,
})
.collect()
} else {
HashSet::new()
};

for param in meta.type_path().generics().type_params() {
let ident = param.ident.clone();
let ignored = meta.traits().ignore_param(&ident);
let ignored = skip_params.contains(&ident);

if ignored {
let bounds = ignored_bounds(param).unwrap_or_default();
Expand All @@ -142,6 +163,8 @@ impl WhereClauseOptions {
}
}

options.custom_where = meta.traits().custom_where().cloned();

options
}
}
Expand All @@ -150,8 +173,8 @@ impl WhereClauseOptions {
///
/// This is mostly used to add additional bounds to reflected objects with generic types.
/// For reflection purposes, we usually have:
/// * active_trait_bounds: `Reflect + TypePath` or `FromReflect + TypePath`
/// * ignored_trait_bounds: `TypePath + Any + Send + Sync`
/// * `active_trait_bounds`: `Reflect + TypePath` or `FromReflect + TypePath`
/// * `ignored_trait_bounds`: `TypePath + Any + Send + Sync`
///
/// # Arguments
///
Expand Down Expand Up @@ -193,9 +216,12 @@ pub(crate) fn extend_where_clause(
quote!(where Self: 'static,)
};

let custom_where = &where_clause_options.custom_where;

generic_where_clause.extend(quote! {
#(#active_types: #active_trait_bounds,)*
#(#ignored_types: #ignored_trait_bounds,)*
#custom_where
});
generic_where_clause
}
Expand Down
24 changes: 19 additions & 5 deletions crates/bevy_reflect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1878,9 +1878,9 @@ bevy_reflect::tests::should_reflect_debug::Test {
}

#[test]
fn should_allow_ignored_params() {
fn should_allow_custom_where() {
#[derive(Reflect)]
#[reflect(ignore_params(T))]
#[reflect(custom_where(T: 'static))]
struct Foo<T>(String, #[reflect(ignore)] PhantomData<T>);

#[derive(TypePath)]
Expand All @@ -1894,14 +1894,28 @@ bevy_reflect::tests::should_reflect_debug::Test {
}

#[test]
fn should_allow_ignored_params_wtih_assoc_type() {
fn should_allow_multiple_custom_where() {
#[derive(Reflect)]
#[reflect(custom_where(T: Default + FromReflect))]
#[reflect(custom_where(U: std::ops::Add<T> + FromReflect))]
struct Foo<T, U>(T, U);

#[derive(Reflect)]
struct Baz {
a: Foo<i32, i32>,
b: Foo<u32, u32>,
}
}

#[test]
fn should_allow_custom_where_wtih_assoc_type() {
trait Trait {
type Assoc: Reflect + FromReflect;
type Assoc;
}

// We don't need `T` to be `Reflect` since we only care about `T::Assoc`
#[derive(Reflect)]
#[reflect(ignore_params(T))]
#[reflect(custom_where(T: 'static, T::Assoc: FromReflect))]
struct Foo<T: Trait>(T::Assoc);

#[derive(TypePath)]
Expand Down

0 comments on commit 37eafe9

Please sign in to comment.