Skip to content

Commit

Permalink
Simplify trait hierarchy for SystemParam (#6865)
Browse files Browse the repository at this point in the history
# Objective

* Implementing a custom `SystemParam` by hand requires implementing three traits -- four if it is read-only.
* The trait `SystemParamFetch<'w, 's>` is a workaround from before we had generic associated types, and is no longer necessary.

## Solution

* Combine the trait `SystemParamFetch` with `SystemParamState`.
    * I decided to remove the `Fetch` name and keep the `State` name, since the former was consistently conflated with the latter.
* Replace the trait `ReadOnlySystemParamFetch` with `ReadOnlySystemParam`, which simplifies trait bounds in generic code.

---

## Changelog

- Removed the trait `SystemParamFetch`, moving its functionality to `SystemParamState`.
- Replaced the trait `ReadOnlySystemParamFetch` with `ReadOnlySystemParam`.

## Migration Guide

The trait `SystemParamFetch` has been removed, and its functionality has been transferred to `SystemParamState`.

```rust
// Before
impl SystemParamState for MyParamState {
    fn init(world: &mut World, system_meta: &mut SystemMeta) -> Self { ... }
}
impl<'w, 's> SystemParamFetch<'w, 's> for MyParamState {
    type Item = MyParam<'w, 's>;
    fn get_param(...) -> Self::Item;
}

// After
impl SystemParamState for MyParamState {
    type Item<'w, 's> = MyParam<'w, 's>; // Generic associated types!
    fn init(world: &mut World, system_meta: &mut SystemMeta) -> Self { ... }
    fn get_param<'w, 's>(...) -> Self::Item<'w, 's>;
}
```

The trait `ReadOnlySystemParamFetch` has been replaced with `ReadOnlySystemParam`.

```rust
// Before
unsafe impl ReadOnlySystemParamFetch for MyParamState {}

// After
unsafe impl<'w, 's> ReadOnlySystemParam for MyParam<'w, 's> {}
```
  • Loading branch information
JoJoJet committed Dec 11, 2022
1 parent c16791c commit 1af7362
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 386 deletions.
78 changes: 42 additions & 36 deletions crates/bevy_ecs/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,54 +209,56 @@ pub fn impl_param_set(_input: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new();
let max_params = 8;
let params = get_idents(|i| format!("P{i}"), max_params);
let params_fetch = get_idents(|i| format!("PF{i}"), max_params);
let params_state = get_idents(|i| format!("PF{i}"), max_params);
let metas = get_idents(|i| format!("m{i}"), max_params);
let mut param_fn_muts = Vec::new();
for (i, param) in params.iter().enumerate() {
let fn_name = Ident::new(&format!("p{i}"), Span::call_site());
let index = Index::from(i);
param_fn_muts.push(quote! {
pub fn #fn_name<'a>(&'a mut self) -> <#param::Fetch as SystemParamFetch<'a, 'a>>::Item {
pub fn #fn_name<'a>(&'a mut self) -> SystemParamItem<'a, 'a, #param> {
// SAFETY: systems run without conflicts with other systems.
// Conflicting params in ParamSet are not accessible at the same time
// ParamSets are guaranteed to not conflict with other SystemParams
unsafe {
<#param::Fetch as SystemParamFetch<'a, 'a>>::get_param(&mut self.param_states.#index, &self.system_meta, self.world, self.change_tick)
<#param::State as SystemParamState>::get_param(&mut self.param_states.#index, &self.system_meta, self.world, self.change_tick)
}
}
});
}

for param_count in 1..=max_params {
let param = &params[0..param_count];
let param_fetch = &params_fetch[0..param_count];
let param_state = &params_state[0..param_count];
let meta = &metas[0..param_count];
let param_fn_mut = &param_fn_muts[0..param_count];
tokens.extend(TokenStream::from(quote! {
impl<'w, 's, #(#param: SystemParam,)*> SystemParam for ParamSet<'w, 's, (#(#param,)*)>
{
type Fetch = ParamSetState<(#(#param::Fetch,)*)>;
type State = ParamSetState<(#(#param::State,)*)>;
}

// SAFETY: All parameters are constrained to ReadOnlyFetch, so World is only read
// SAFETY: All parameters are constrained to ReadOnlyState, so World is only read

unsafe impl<#(#param_fetch: for<'w1, 's1> SystemParamFetch<'w1, 's1>,)*> ReadOnlySystemParamFetch for ParamSetState<(#(#param_fetch,)*)>
where #(#param_fetch: ReadOnlySystemParamFetch,)*
unsafe impl<'w, 's, #(#param,)*> ReadOnlySystemParam for ParamSet<'w, 's, (#(#param,)*)>
where #(#param: ReadOnlySystemParam,)*
{ }

// SAFETY: Relevant parameter ComponentId and ArchetypeComponentId access is applied to SystemMeta. If any ParamState conflicts
// with any prior access, a panic will occur.

unsafe impl<#(#param_fetch: for<'w1, 's1> SystemParamFetch<'w1, 's1>,)*> SystemParamState for ParamSetState<(#(#param_fetch,)*)>
unsafe impl<#(#param_state: SystemParamState,)*> SystemParamState for ParamSetState<(#(#param_state,)*)>
{
type Item<'w, 's> = ParamSet<'w, 's, (#(<#param_state as SystemParamState>::Item::<'w, 's>,)*)>;

fn init(world: &mut World, system_meta: &mut SystemMeta) -> Self {
#(
// Pretend to add each param to the system alone, see if it conflicts
let mut #meta = system_meta.clone();
#meta.component_access_set.clear();
#meta.archetype_component_access.clear();
#param_fetch::init(world, &mut #meta);
let #param = #param_fetch::init(world, &mut system_meta.clone());
#param_state::init(world, &mut #meta);
let #param = #param_state::init(world, &mut system_meta.clone());
)*
#(
system_meta
Expand All @@ -279,21 +281,14 @@ pub fn impl_param_set(_input: TokenStream) -> TokenStream {
fn apply(&mut self, world: &mut World) {
self.0.apply(world)
}
}



impl<'w, 's, #(#param_fetch: for<'w1, 's1> SystemParamFetch<'w1, 's1>,)*> SystemParamFetch<'w, 's> for ParamSetState<(#(#param_fetch,)*)>
{
type Item = ParamSet<'w, 's, (#(<#param_fetch as SystemParamFetch<'w, 's>>::Item,)*)>;

#[inline]
unsafe fn get_param(
unsafe fn get_param<'w, 's>(
state: &'s mut Self,
system_meta: &SystemMeta,
world: &'w World,
change_tick: u32,
) -> Self::Item {
) -> Self::Item<'w, 's> {
ParamSet {
param_states: &mut state.0,
system_meta: system_meta.clone(),
Expand Down Expand Up @@ -414,34 +409,48 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
_ => unreachable!(),
}));

// Create a where clause for the `ReadOnlySystemParam` impl.
// Ensure that each field implements `ReadOnlySystemParam`.
let mut read_only_generics = generics.clone();
let read_only_where_clause = read_only_generics.make_where_clause();
for field_type in &field_types {
read_only_where_clause
.predicates
.push(syn::parse_quote!(#field_type: #path::system::ReadOnlySystemParam));
}

let struct_name = &ast.ident;
let fetch_struct_visibility = &ast.vis;
let state_struct_visibility = &ast.vis;

TokenStream::from(quote! {
// We define the FetchState struct in an anonymous scope to avoid polluting the user namespace.
// The struct can still be accessed via SystemParam::Fetch, e.g. EventReaderState can be accessed via
// <EventReader<'static, 'static, T> as SystemParam>::Fetch
// The struct can still be accessed via SystemParam::State, e.g. EventReaderState can be accessed via
// <EventReader<'static, 'static, T> as SystemParam>::State
const _: () = {
impl<'w, 's, #punctuated_generics> #path::system::SystemParam for #struct_name #ty_generics #where_clause {
type Fetch = State<'w, 's, #punctuated_generic_idents>;
type State = State<'w, 's, #punctuated_generic_idents>;
}

#[doc(hidden)]
type State<'w, 's, #punctuated_generic_idents> = FetchState<
(#(<#field_types as #path::system::SystemParam>::Fetch,)*),
(#(<#field_types as #path::system::SystemParam>::State,)*),
#punctuated_generic_idents
>;

#[doc(hidden)]
#fetch_struct_visibility struct FetchState <TSystemParamState, #punctuated_generic_idents> {
#state_struct_visibility struct FetchState <TSystemParamState, #punctuated_generic_idents> {
state: TSystemParamState,
marker: std::marker::PhantomData<fn()->(#punctuated_generic_idents)>
}

unsafe impl<TSystemParamState: #path::system::SystemParamState, #punctuated_generics> #path::system::SystemParamState for FetchState <TSystemParamState, #punctuated_generic_idents> #where_clause {
unsafe impl<'__w, '__s, #punctuated_generics> #path::system::SystemParamState for
State<'__w, '__s, #punctuated_generic_idents>
#where_clause {
type Item<'w, 's> = #struct_name #ty_generics;

fn init(world: &mut #path::world::World, system_meta: &mut #path::system::SystemMeta) -> Self {
Self {
state: TSystemParamState::init(world, system_meta),
state: #path::system::SystemParamState::init(world, system_meta),
marker: std::marker::PhantomData,
}
}
Expand All @@ -453,25 +462,22 @@ pub fn derive_system_param(input: TokenStream) -> TokenStream {
fn apply(&mut self, world: &mut #path::world::World) {
self.state.apply(world)
}
}

impl<'w, 's, #punctuated_generics> #path::system::SystemParamFetch<'w, 's> for State<'w, 's, #punctuated_generic_idents> #where_clause {
type Item = #struct_name #ty_generics;
unsafe fn get_param(
unsafe fn get_param<'w, 's>(
state: &'s mut Self,
system_meta: &#path::system::SystemMeta,
world: &'w #path::world::World,
change_tick: u32,
) -> Self::Item {
) -> Self::Item<'w, 's> {
#struct_name {
#(#fields: <<#field_types as #path::system::SystemParam>::Fetch as #path::system::SystemParamFetch>::get_param(&mut state.state.#field_indices, system_meta, world, change_tick),)*
#(#fields: <<#field_types as #path::system::SystemParam>::State as #path::system::SystemParamState>::get_param(&mut state.state.#field_indices, system_meta, world, change_tick),)*
#(#ignored_fields: <#ignored_field_types>::default(),)*
}
}
}

// Safety: The `ParamState` is `ReadOnlySystemParamFetch`, so this can only read from the `World`
unsafe impl<TSystemParamState: #path::system::SystemParamState + #path::system::ReadOnlySystemParamFetch, #punctuated_generics> #path::system::ReadOnlySystemParamFetch for FetchState <TSystemParamState, #punctuated_generic_idents> #where_clause {}
// Safety: Each field is `ReadOnlySystemParam`, so this can only read from the `World`
unsafe impl<'w, 's, #punctuated_generics> #path::system::ReadOnlySystemParam for #struct_name #ty_generics #read_only_where_clause {}
};
})
}
Expand Down
34 changes: 16 additions & 18 deletions crates/bevy_ecs/src/system/commands/parallel_scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use thread_local::ThreadLocal;
use crate::{
entity::Entities,
prelude::World,
system::{SystemParam, SystemParamFetch, SystemParamState},
system::{SystemParam, SystemParamState},
};

use super::{CommandQueue, Commands};
Expand Down Expand Up @@ -49,27 +49,13 @@ pub struct ParallelCommands<'w, 's> {
}

impl SystemParam for ParallelCommands<'_, '_> {
type Fetch = ParallelCommandsState;
}

impl<'w, 's> SystemParamFetch<'w, 's> for ParallelCommandsState {
type Item = ParallelCommands<'w, 's>;

unsafe fn get_param(
state: &'s mut Self,
_: &crate::system::SystemMeta,
world: &'w World,
_: u32,
) -> Self::Item {
ParallelCommands {
state,
entities: world.entities(),
}
}
type State = ParallelCommandsState;
}

// SAFETY: no component or resource access to report
unsafe impl SystemParamState for ParallelCommandsState {
type Item<'w, 's> = ParallelCommands<'w, 's>;

fn init(_: &mut World, _: &mut crate::system::SystemMeta) -> Self {
Self::default()
}
Expand All @@ -79,6 +65,18 @@ unsafe impl SystemParamState for ParallelCommandsState {
cq.get_mut().apply(world);
}
}

unsafe fn get_param<'w, 's>(
state: &'s mut Self,
_: &crate::system::SystemMeta,
world: &'w World,
_: u32,
) -> Self::Item<'w, 's> {
ParallelCommands {
state,
entities: world.entities(),
}
}
}

impl<'w, 's> ParallelCommands<'w, 's> {
Expand Down
11 changes: 5 additions & 6 deletions crates/bevy_ecs/src/system/exclusive_function_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ use crate::{
query::Access,
schedule::{SystemLabel, SystemLabelId},
system::{
check_system_change_tick, AsSystemLabel, ExclusiveSystemParam, ExclusiveSystemParamFetch,
ExclusiveSystemParamItem, ExclusiveSystemParamState, IntoSystem, System, SystemMeta,
SystemTypeIdLabel,
check_system_change_tick, AsSystemLabel, ExclusiveSystemParam, ExclusiveSystemParamItem,
ExclusiveSystemParamState, IntoSystem, System, SystemMeta, SystemTypeIdLabel,
},
world::{World, WorldId},
};
Expand All @@ -25,7 +24,7 @@ where
Param: ExclusiveSystemParam,
{
func: F,
param_state: Option<Param::Fetch>,
param_state: Option<Param::State>,
system_meta: SystemMeta,
world_id: Option<WorldId>,
// NOTE: PhantomData<fn()-> T> gives this safe Send/Sync impls
Expand Down Expand Up @@ -95,7 +94,7 @@ where
let saved_last_tick = world.last_change_tick;
world.last_change_tick = self.system_meta.last_change_tick;

let params = <Param as ExclusiveSystemParam>::Fetch::get_param(
let params = <Param as ExclusiveSystemParam>::State::get_param(
self.param_state.as_mut().expect(PARAM_MESSAGE),
&self.system_meta,
);
Expand Down Expand Up @@ -130,7 +129,7 @@ where
fn initialize(&mut self, world: &mut World) {
self.world_id = Some(world.id());
self.system_meta.last_change_tick = world.change_tick().wrapping_sub(MAX_CHANGE_AGE);
self.param_state = Some(<Param::Fetch as ExclusiveSystemParamState>::init(
self.param_state = Some(<Param::State as ExclusiveSystemParamState>::init(
world,
&mut self.system_meta,
));
Expand Down
Loading

0 comments on commit 1af7362

Please sign in to comment.