Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Remove task_pool parameter from par_for_each(_mut) #4705

Closed
wants to merge 11 commits into from
8 changes: 4 additions & 4 deletions benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bevy_ecs::prelude::*;
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
use glam::*;

#[derive(Component, Copy, Clone)]
Expand Down Expand Up @@ -29,8 +29,8 @@ impl Benchmark {
)
}));

fn sys(task_pool: Res<TaskPool>, mut query: Query<(&mut Position, &mut Transform)>) {
query.par_for_each_mut(&task_pool, 128, |(mut pos, mut mat)| {
fn sys(mut query: Query<(&mut Position, &mut Transform)>) {
query.par_for_each_mut(128, |(mut pos, mut mat)| {
for _ in 0..100 {
mat.0 = mat.0.inverse();
}
Expand All @@ -39,7 +39,7 @@ impl Benchmark {
});
}

world.insert_resource(TaskPool::default());
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let mut system = IntoSystem::into_system(sys);
system.initialize(&mut world);
system.update_archetype_component_access(&world);
Expand Down
10 changes: 4 additions & 6 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod tests {
query::{Added, ChangeTrackers, Changed, FilteredAccess, With, Without, WorldQuery},
world::{Mut, World},
};
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
use std::{
any::TypeId,
sync::{
Expand Down Expand Up @@ -376,7 +376,7 @@ mod tests {
#[test]
fn par_for_each_dense() {
let mut world = World::new();
let task_pool = TaskPool::default();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(A(1)).id();
let e2 = world.spawn().insert(A(2)).id();
let e3 = world.spawn().insert(A(3)).id();
Expand All @@ -385,7 +385,7 @@ mod tests {
let results = Arc::new(Mutex::new(Vec::new()));
world
.query::<(Entity, &A)>()
.par_for_each(&world, &task_pool, 2, |(e, &A(i))| {
.par_for_each(&world, 2, |(e, &A(i))| {
results.lock().unwrap().push((e, i));
});
results.lock().unwrap().sort();
Expand All @@ -398,8 +398,7 @@ mod tests {
#[test]
fn par_for_each_sparse() {
let mut world = World::new();

let task_pool = TaskPool::default();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(SparseStored(1)).id();
let e2 = world.spawn().insert(SparseStored(2)).id();
let e3 = world.spawn().insert(SparseStored(3)).id();
Expand All @@ -408,7 +407,6 @@ mod tests {
let results = Arc::new(Mutex::new(Vec::new()));
world.query::<(Entity, &SparseStored)>().par_for_each(
&world,
&task_pool,
2,
|(e, &SparseStored(i))| results.lock().unwrap().push((e, i)),
);
Expand Down
216 changes: 120 additions & 96 deletions crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ use crate::{
storage::TableId,
world::{World, WorldId},
};
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
#[cfg(feature = "trace")]
use bevy_utils::tracing::Instrument;
use fixedbitset::FixedBitSet;
use std::fmt;
use std::{fmt, ops::Deref};

use super::{QueryFetch, QueryItem, ROQueryFetch, ROQueryItem};

/// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter.
pub struct QueryState<Q: WorldQuery, F: WorldQuery = ()> {
world_id: WorldId,
task_pool: Option<TaskPool>,
pub(crate) archetype_generation: ArchetypeGeneration,
pub(crate) matched_tables: FixedBitSet,
pub(crate) matched_archetypes: FixedBitSet,
Expand Down Expand Up @@ -61,6 +62,9 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {

let mut state = Self {
world_id: world.id(),
task_pool: world
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clone is going to cause a subtle bug if anyone tries to replace the compute pool during running. The par_for_each is going ignore the pool being replaced and continue to use the original compute pool.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about this, the less such an operation makes sense. Even more reason to make TaskPools global and initialized once instead of allowing arbitrary reconfiguration mid execution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, but since this is one of the objections against global task pools, I'm not sure we can merge this without more clarity there.

Copy link
Member Author

@james7132 james7132 May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking on this a bit more, this isn't too different from the status quo. You can already spawn a task, embed a clone of the TaskPool it's running on, and have it run forever. Replacing the resource would spawn a new TaskPool, but the former pool wouldn't be cleaned up.

Is this a problem? Yes. Is it one solvable within this PR? No. Does integrating this into ECS exacerbate the issue? Probably. Is it worth blocking this change on it? I don't think so. However, in turn, we should prioritize reviewing and merging #2250 or some variant of it to address this issue.

.get_resource::<ComputeTaskPool>()
.map(|task_pool| task_pool.deref().clone()),
james7132 marked this conversation as resolved.
Show resolved Hide resolved
archetype_generation: ArchetypeGeneration::initial(),
matched_table_ids: Vec::new(),
matched_archetype_ids: Vec::new(),
Expand Down Expand Up @@ -685,15 +689,18 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
);
}

/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries, see [`Self::par_for_each_mut`] for
/// write-queries.
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
#[inline]
pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
Expand All @@ -702,7 +709,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<ROQueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
Expand All @@ -711,12 +717,15 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
}
}

/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
#[inline]
pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w mut World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
Expand All @@ -725,7 +734,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
Expand All @@ -734,10 +742,14 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
}
}

/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries.
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
Expand All @@ -746,14 +758,12 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
pub unsafe fn par_for_each_unchecked<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
Expand Down Expand Up @@ -829,6 +839,10 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
/// the current change tick are given. This is faster than the equivalent
/// iter() method, but cannot be chained like a normal [`Iterator`].
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
Expand All @@ -842,103 +856,113 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
>(
&self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
last_change_tick: u32,
change_tick: u32,
) {
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
task_pool.scope(|scope| {
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
let tables = &world.storages().tables;
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
let len = batch_size.min(table.len() - offset);
let task = async move {
let mut fetch =
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
self.task_pool
.as_ref()
.expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.")
.scope(|scope| {
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
let tables = &world.storages().tables;
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
let len = batch_size.min(table.len() - offset);
let task = async move {
let mut fetch = QF::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
}
let item = fetch.table_fetch(table_index);
func(item);
}
let item = fetch.table_fetch(table_index);
func(item);
}
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let task = async move {
let mut fetch =
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);

for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let task = async move {
let mut fetch = QF::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);

for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
}
func(fetch.archetype_fetch(archetype_index));
}
func(fetch.archetype_fetch(archetype_index));
}
};

#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);

scope.spawn(task);
offset += batch_size;
};

#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);

scope.spawn(task);
offset += batch_size;
}
}
}
}
});
});
}
}

Expand Down
Loading