Skip to content

Commit

Permalink
Replace custom_where attribute with where
Browse files Browse the repository at this point in the history
  • Loading branch information
MrGVSV committed Jan 28, 2024
1 parent 76cb513 commit 6ea40de
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 56 deletions.
2 changes: 1 addition & 1 deletion crates/bevy_asset/src/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl std::fmt::Debug for StrongHandle {
/// [`Handle::Strong`] also provides access to useful [`Asset`] metadata, such as the [`AssetPath`] (if it exists).
#[derive(Component, Reflect)]
#[reflect(Component)]
#[reflect(custom_where(A: Asset))]
#[reflect(where A: Asset)]
pub enum Handle<A: Asset> {
/// A "strong" reference to a live (or loading) [`Asset`]. If a [`Handle`] is [`Handle::Strong`], the [`Asset`] will be kept
/// alive until the [`Handle`] is dropped. Strong handles also provide access to additional asset metadata.
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_asset/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use thiserror::Error;
///
/// For an "untyped" / "generic-less" id, see [`UntypedAssetId`].
#[derive(Reflect)]
#[reflect(custom_where(A: Asset))]
#[reflect(where A: Asset)]
pub enum AssetId<A: Asset> {
/// A small / efficient runtime identifier that can be used to efficiently look up an asset stored in [`Assets`]. This is
/// the "default" identifier used for assets. The alternative(s) (ex: [`AssetId::Uuid`]) will only be used if assets are
Expand Down
45 changes: 28 additions & 17 deletions crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

use crate::utility;
use bevy_macro_utils::fq_std::{FQAny, FQOption};
use proc_macro2::{Ident, Span};
use proc_macro2::{Ident, Span, TokenTree};
use quote::quote_spanned;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{Expr, LitBool, Meta, Path, Token, WherePredicate};
use syn::{Expr, LitBool, Meta, MetaList, Path, WhereClause};

// The "special" trait idents that are used internally for reflection.
// Received via attributes like `#[reflect(PartialEq, Hash, ...)]`
Expand All @@ -31,9 +31,6 @@ const FROM_REFLECT_ATTR: &str = "from_reflect";
// Attributes for `TypePath` implementation
const TYPE_PATH_ATTR: &str = "type_path";

// Attributes for `Reflect` implementation
const CUSTOM_WHERE_ATTR: &str = "custom_where";

// The error message to show when a trait/type is specified multiple times
const CONFLICTING_TYPE_DATA_MESSAGE: &str = "conflicting type data registration";

Expand Down Expand Up @@ -214,12 +211,30 @@ pub(crate) struct ReflectTraits {
partial_eq: TraitImpl,
from_reflect_attrs: FromReflectAttrs,
type_path_attrs: TypePathAttrs,
custom_where: Option<Punctuated<WherePredicate, Token![,]>>,
custom_where: Option<WhereClause>,
idents: Vec<Ident>,
}

impl ReflectTraits {
pub fn from_metas(
pub fn from_meta_list(
meta: &MetaList,
is_from_reflect_derive: bool,
) -> Result<Self, syn::Error> {
match meta.tokens.clone().into_iter().next() {
// Handles `#[reflect(where T: Trait, U::Assoc: Trait)]`
Some(TokenTree::Ident(ident)) if ident == "where" => {
let mut traits = ReflectTraits::default();
traits.custom_where = Some(meta.parse_args::<WhereClause>()?);
Ok(traits)
}
_ => Self::from_metas(
meta.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)?,
is_from_reflect_derive,
),
}
}

fn from_metas(
metas: Punctuated<Meta, Comma>,
is_from_reflect_derive: bool,
) -> Result<Self, syn::Error> {
Expand Down Expand Up @@ -257,12 +272,6 @@ impl ReflectTraits {
}
}
}
// Handles `#[reflect(custom_where(T: Trait, U::Assoc: Trait))]`
Meta::List(list) if list.path.is_ident(CUSTOM_WHERE_ATTR) => {
let predicate: Punctuated<WherePredicate, Token![,]> =
list.parse_args_with(Punctuated::parse_terminated)?;
traits.merge_custom_where(Some(predicate));
}
// Handles `#[reflect( Debug(custom_debug_fn) )]`
Meta::List(list) if list.path.is_ident(DEBUG_ATTR) => {
let ident = list.path.get_ident().unwrap();
Expand Down Expand Up @@ -290,7 +299,9 @@ impl ReflectTraits {
Meta::List(list) => {
return Err(syn::Error::new_spanned(
list,
format!("expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}, {CUSTOM_WHERE_ATTR:?}]")
format!(
"expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}]"
),
));
}
Meta::NameValue(pair) => {
Expand Down Expand Up @@ -408,7 +419,7 @@ impl ReflectTraits {
}
}

pub fn custom_where(&self) -> Option<&Punctuated<WherePredicate, Token![,]>> {
pub fn custom_where(&self) -> Option<&WhereClause> {
self.custom_where.as_ref()
}

Expand All @@ -430,10 +441,10 @@ impl ReflectTraits {
Ok(())
}

fn merge_custom_where(&mut self, other: Option<Punctuated<WherePredicate, Token![,]>>) {
fn merge_custom_where(&mut self, other: Option<WhereClause>) {
match (&mut self.custom_where, other) {
(Some(this), Some(other)) => {
this.extend(other);
this.predicates.extend(other.predicates);
}
(None, Some(other)) => {
self.custom_where = Some(other);
Expand Down
12 changes: 4 additions & 8 deletions crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,8 @@ impl<'a> ReflectDerive<'a> {
}

reflect_mode = Some(ReflectMode::Normal);
let new_traits = ReflectTraits::from_metas(
meta_list.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)?,
is_from_reflect_derive,
)?;
let new_traits =
ReflectTraits::from_meta_list(meta_list, is_from_reflect_derive)?;
traits.merge(new_traits)?;
}
Meta::List(meta_list) if meta_list.path.is_ident(REFLECT_VALUE_ATTRIBUTE_NAME) => {
Expand All @@ -182,10 +180,8 @@ impl<'a> ReflectDerive<'a> {
}

reflect_mode = Some(ReflectMode::Value);
let new_traits = ReflectTraits::from_metas(
meta_list.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)?,
is_from_reflect_derive,
)?;
let new_traits =
ReflectTraits::from_meta_list(meta_list, is_from_reflect_derive)?;
traits.merge(new_traits)?;
}
Meta::Path(path) if path.is_ident(REFLECT_VALUE_ATTRIBUTE_NAME) => {
Expand Down
8 changes: 4 additions & 4 deletions crates/bevy_reflect/bevy_reflect_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
/// This is useful for when a type can't or shouldn't implement `TypePath`,
/// or if a manual implementation is desired.
///
/// ## `#[reflect(custom_where(T: Trait, U::Assoc: Trait, ...))]`
/// ## `#[reflect(where T: Trait, U::Assoc: Trait, ...)]`
///
/// By default, the derive macro will automatically add certain trait bounds to all generic type parameters
/// in order to make them compatible with reflection without the user needing to add them manually.
Expand All @@ -147,7 +147,7 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
/// in general.
///
/// This means that if you want to opt-out of the default bounds for _all_ type parameters,
/// you can add `#[reflect(custom_where())]` to the container item to indicate
/// you can add `#[reflect(where)]` to the container item to indicate
/// that an empty `where` clause should be used.
///
/// ### Example
Expand All @@ -158,7 +158,7 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
/// }
///
/// #[derive(Reflect)]
/// #[reflect(custom_where(T::Assoc: FromReflect))]
/// #[reflect(where T::Assoc: FromReflect)]
/// struct Foo<T: Trait> where T::Assoc: Default {
/// value: T::Assoc,
/// }
Expand Down Expand Up @@ -192,7 +192,7 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name";
/// or allowing the use of types that do not implement `Reflect` within the container.
///
/// If the field contains a generic type parameter, you will likely need to add a
/// [`#[reflect(custom_where())]`](#reflectcustom_wheret-trait-uassoc-trait-)
/// [`#[reflect(where)]`](#reflectwheret-trait-uassoc-trait-)
/// attribute to the container in order to avoid the default bounds being applied to the type parameter.
///
/// ## `#[reflect(skip_serializing)]`
Expand Down
14 changes: 9 additions & 5 deletions crates/bevy_reflect/bevy_reflect_derive/src/utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<'a, 'b> WhereClauseOptions<'a, 'b> {
///
/// This will only add bounds for generic type parameters.
///
/// If the container has a `#[reflect(custom_where(...))]` attribute,
/// If the container has a `#[reflect(where)]` attribute,
/// this method will extend the type parameters with the _required_ bounds.
/// If the attribute is not present, it will extend the type parameters with the _additional_ bounds.
///
Expand All @@ -138,7 +138,7 @@ impl<'a, 'b> WhereClauseOptions<'a, 'b> {
///
/// It has type parameters `T` and `U`.
///
/// Since there is no `#[reflect(custom_where(...))]` attribute, this method will extend the type parameters
/// Since there is no `#[reflect(where)]` attribute, this method will extend the type parameters
/// with the additional bounds:
///
/// ```ignore (bevy_reflect is not accessible from this crate)
Expand All @@ -150,15 +150,15 @@ impl<'a, 'b> WhereClauseOptions<'a, 'b> {
/// If we had this struct:
/// ```ignore (bevy_reflect is not accessible from this crate)
/// #[derive(Reflect)]
/// #[reflect(custom_where(T: FromReflect + Default))]
/// #[reflect(where T: FromReflect + Default)]
/// struct Foo<T, U> {
/// a: T,
/// #[reflect(ignore)]
/// b: U
/// }
/// ```
///
/// Since there is a `#[reflect(custom_where(...))]` attribute, this method will extend the type parameters
/// Since there is a `#[reflect(where)]` attribute, this method will extend the type parameters
/// with _just_ the required bounds along with the predicates specified in the attribute:
///
/// ```ignore (bevy_reflect is not accessible from this crate)
Expand All @@ -181,7 +181,11 @@ impl<'a, 'b> WhereClauseOptions<'a, 'b> {

// Add additional reflection trait bounds
let types = self.type_param_idents();
let custom_where = self.meta.traits().custom_where();
let custom_where = self
.meta
.traits()
.custom_where()
.map(|clause| &clause.predicates);
let trait_bounds = self.trait_bounds();

generic_where_clause.extend(quote! {
Expand Down
37 changes: 17 additions & 20 deletions crates/bevy_reflect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ mod tests {
ser::{to_string_pretty, PrettyConfig},
Deserializer,
};
use static_assertions::{assert_impl_all, assert_not_impl_all};
use std::{
any::TypeId,
borrow::Cow,
Expand Down Expand Up @@ -1866,47 +1867,46 @@ bevy_reflect::tests::Test {
#[test]
fn should_allow_custom_where() {
#[derive(Reflect)]
#[reflect(custom_where(T: Default))]
#[reflect(where T: Default)]
struct Foo<T>(String, #[reflect(ignore)] PhantomData<T>);

#[derive(Default, TypePath)]
struct Bar;

#[derive(Reflect)]
struct Baz {
a: Foo<Bar>,
b: Foo<usize>,
}
#[derive(TypePath)]
struct Baz;

assert_impl_all!(Foo<Bar>: Reflect);
assert_not_impl_all!(Foo<Baz>: Reflect);
}

#[test]
fn should_allow_empty_custom_where() {
#[derive(Reflect)]
#[reflect(custom_where())]
#[reflect(where)]
struct Foo<T>(String, #[reflect(ignore)] PhantomData<T>);

#[derive(TypePath)]
struct Bar;

#[derive(Reflect)]
struct Baz {
a: Foo<Bar>,
b: Foo<usize>,
}
assert_impl_all!(Foo<Bar>: Reflect);
}

#[test]
fn should_allow_multiple_custom_where() {
#[derive(Reflect)]
#[reflect(custom_where(T: Default + FromReflect))]
#[reflect(custom_where(U: std::ops::Add<T> + FromReflect))]
#[reflect(where T: Default + FromReflect)]
#[reflect(where U: std::ops::Add<T> + FromReflect)]
struct Foo<T, U>(T, U);

#[derive(Reflect)]
struct Baz {
a: Foo<i32, i32>,
b: Foo<u32, u32>,
}

assert_impl_all!(Foo<i32, i32>: Reflect);
assert_not_impl_all!(Foo<i32, usize>: Reflect);
}

#[test]
Expand All @@ -1917,7 +1917,7 @@ bevy_reflect::tests::Test {

// We don't need `T` to be `Reflect` since we only care about `T::Assoc`
#[derive(Reflect)]
#[reflect(custom_where(T::Assoc: FromReflect))]
#[reflect(where T::Assoc: FromReflect)]
struct Foo<T: Trait>(T::Assoc);

#[derive(TypePath)]
Expand All @@ -1927,10 +1927,7 @@ bevy_reflect::tests::Test {
type Assoc = usize;
}

#[derive(Reflect)]
struct Baz {
a: Foo<Bar>,
}
assert_impl_all!(Foo<Bar>: Reflect);
}

#[test]
Expand Down Expand Up @@ -1969,7 +1966,7 @@ bevy_reflect::tests::Test {
fn can_opt_out_type_path() {
#[derive(Reflect)]
#[reflect(type_path = false)]
#[reflect(custom_where())]
#[reflect(where)]
struct Foo<T> {
#[reflect(ignore)]
_marker: PhantomData<T>,
Expand Down

0 comments on commit 6ea40de

Please sign in to comment.