diff --git a/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs b/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs index 71a8e6cc6a950..169f32f06500e 100644 --- a/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs +++ b/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs @@ -18,6 +18,8 @@ pub struct Benchmark(World, Box>); impl Benchmark { pub fn new() -> Self { + ComputeTaskPool::init(TaskPool::default); + let mut world = World::default(); world.spawn_batch((0..1000).map(|_| { @@ -39,7 +41,6 @@ impl Benchmark { }); } - world.insert_resource(ComputeTaskPool(TaskPool::default())); let mut system = IntoSystem::into_system(sys); system.initialize(&mut world); system.update_archetype_component_access(&world); diff --git a/crates/bevy_app/src/app.rs b/crates/bevy_app/src/app.rs index d587344d6549b..cdf1984c8b01e 100644 --- a/crates/bevy_app/src/app.rs +++ b/crates/bevy_app/src/app.rs @@ -10,7 +10,6 @@ use bevy_ecs::{ system::Resource, world::World, }; -use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; use bevy_utils::{tracing::debug, HashMap}; use std::fmt::Debug; @@ -863,18 +862,9 @@ impl App { pub fn add_sub_app( &mut self, label: impl AppLabel, - mut app: App, + app: App, sub_app_runner: impl Fn(&mut World, &mut App) + 'static, ) -> &mut Self { - if let Some(pool) = self.world.get_resource::() { - app.world.insert_resource(pool.clone()); - } - if let Some(pool) = self.world.get_resource::() { - app.world.insert_resource(pool.clone()); - } - if let Some(pool) = self.world.get_resource::() { - app.world.insert_resource(pool.clone()); - } self.sub_apps.insert( Box::new(label), SubApp { diff --git a/crates/bevy_asset/src/asset_server.rs b/crates/bevy_asset/src/asset_server.rs index 1612b406b4f73..bc7defc361c52 100644 --- a/crates/bevy_asset/src/asset_server.rs +++ b/crates/bevy_asset/src/asset_server.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use bevy_ecs::system::{Res, ResMut}; use bevy_log::warn; -use bevy_tasks::TaskPool; +use bevy_tasks::IoTaskPool; use bevy_utils::{Entry, HashMap, Uuid}; use crossbeam_channel::TryRecvError; use parking_lot::{Mutex, RwLock}; @@ -56,7 +56,6 @@ pub struct AssetServerInternal { loaders: RwLock>>, extension_to_loader_index: RwLock>, handle_to_path: Arc>>>, - task_pool: TaskPool, } /// Loads assets from the filesystem on background threads @@ -66,11 +65,11 @@ pub struct AssetServer { } impl AssetServer { - pub fn new(source_io: T, task_pool: TaskPool) -> Self { - Self::with_boxed_io(Box::new(source_io), task_pool) + pub fn new(source_io: T) -> Self { + Self::with_boxed_io(Box::new(source_io)) } - pub fn with_boxed_io(asset_io: Box, task_pool: TaskPool) -> Self { + pub fn with_boxed_io(asset_io: Box) -> Self { AssetServer { server: Arc::new(AssetServerInternal { loaders: Default::default(), @@ -79,7 +78,6 @@ impl AssetServer { asset_ref_counter: Default::default(), handle_to_path: Default::default(), asset_lifecycles: Default::default(), - task_pool, asset_io, }), } @@ -315,7 +313,6 @@ impl AssetServer { &self.server.asset_ref_counter.channel, self.asset_io(), version, - &self.server.task_pool, ); if let Err(err) = asset_loader @@ -377,8 +374,7 @@ impl AssetServer { pub(crate) fn load_untracked(&self, asset_path: AssetPath<'_>, force: bool) -> HandleId { let server = self.clone(); let owned_path = asset_path.to_owned(); - self.server - .task_pool + IoTaskPool::get() .spawn(async move { if let Err(err) = server.load_async(owned_path, force).await { warn!("{}", err); @@ -620,8 +616,8 @@ mod test { fn setup(asset_path: impl AsRef) -> AssetServer { use crate::FileAssetIo; - - AssetServer::new(FileAssetIo::new(asset_path, false), Default::default()) + IoTaskPool::init(Default::default); + AssetServer::new(FileAssetIo::new(asset_path, false)) } #[test] diff --git a/crates/bevy_asset/src/debug_asset_server.rs b/crates/bevy_asset/src/debug_asset_server.rs index e53646de95b9f..b7e3c71c85c89 100644 --- a/crates/bevy_asset/src/debug_asset_server.rs +++ b/crates/bevy_asset/src/debug_asset_server.rs @@ -58,14 +58,14 @@ impl Default for HandleMap { impl Plugin for DebugAssetServerPlugin { fn build(&self, app: &mut bevy_app::App) { + IoTaskPool::init(|| { + TaskPoolBuilder::default() + .num_threads(2) + .thread_name("Debug Asset Server IO Task Pool".to_string()) + .build() + }); let mut debug_asset_app = App::new(); debug_asset_app - .insert_resource(IoTaskPool( - TaskPoolBuilder::default() - .num_threads(2) - .thread_name("Debug Asset Server IO Task Pool".to_string()) - .build(), - )) .insert_resource(AssetServerSettings { asset_folder: "crates".to_string(), watch_for_changes: true, diff --git a/crates/bevy_asset/src/lib.rs b/crates/bevy_asset/src/lib.rs index 870f100d10306..b5ba1a1854d02 100644 --- a/crates/bevy_asset/src/lib.rs +++ b/crates/bevy_asset/src/lib.rs @@ -30,7 +30,6 @@ pub use path::*; use bevy_app::{prelude::Plugin, App}; use bevy_ecs::schedule::{StageLabel, SystemStage}; -use bevy_tasks::IoTaskPool; /// The names of asset stages in an App Schedule #[derive(Debug, Hash, PartialEq, Eq, Clone, StageLabel)] @@ -82,12 +81,8 @@ pub fn create_platform_default_asset_io(app: &mut App) -> Box { impl Plugin for AssetPlugin { fn build(&self, app: &mut App) { if !app.world.contains_resource::() { - let task_pool = app.world.resource::().0.clone(); - let source = create_platform_default_asset_io(app); - - let asset_server = AssetServer::with_boxed_io(source, task_pool); - + let asset_server = AssetServer::with_boxed_io(source); app.insert_resource(asset_server); } diff --git a/crates/bevy_asset/src/loader.rs b/crates/bevy_asset/src/loader.rs index 5a5de9b8c11eb..5d6b87d8388ba 100644 --- a/crates/bevy_asset/src/loader.rs +++ b/crates/bevy_asset/src/loader.rs @@ -5,7 +5,6 @@ use crate::{ use anyhow::Result; use bevy_ecs::system::{Res, ResMut}; use bevy_reflect::{TypeUuid, TypeUuidDynamic}; -use bevy_tasks::TaskPool; use bevy_utils::{BoxedFuture, HashMap}; use crossbeam_channel::{Receiver, Sender}; use downcast_rs::{impl_downcast, Downcast}; @@ -84,7 +83,6 @@ pub struct LoadContext<'a> { pub(crate) labeled_assets: HashMap, BoxedLoadedAsset>, pub(crate) path: &'a Path, pub(crate) version: usize, - pub(crate) task_pool: &'a TaskPool, } impl<'a> LoadContext<'a> { @@ -93,7 +91,6 @@ impl<'a> LoadContext<'a> { ref_change_channel: &'a RefChangeChannel, asset_io: &'a dyn AssetIo, version: usize, - task_pool: &'a TaskPool, ) -> Self { Self { ref_change_channel, @@ -101,7 +98,6 @@ impl<'a> LoadContext<'a> { labeled_assets: Default::default(), version, path, - task_pool, } } @@ -144,10 +140,6 @@ impl<'a> LoadContext<'a> { asset_metas } - pub fn task_pool(&self) -> &TaskPool { - self.task_pool - } - pub fn asset_io(&self) -> &dyn AssetIo { self.asset_io } diff --git a/crates/bevy_core/src/lib.rs b/crates/bevy_core/src/lib.rs index f4630bc668b90..7c7d193564f87 100644 --- a/crates/bevy_core/src/lib.rs +++ b/crates/bevy_core/src/lib.rs @@ -30,7 +30,7 @@ impl Plugin for CorePlugin { .get_resource::() .cloned() .unwrap_or_default() - .create_default_pools(&mut app.world); + .create_default_pools(); app.register_type::().register_type::(); diff --git a/crates/bevy_core/src/task_pool_options.rs b/crates/bevy_core/src/task_pool_options.rs index 19c9dad5bfe2f..152489b7cf7ae 100644 --- a/crates/bevy_core/src/task_pool_options.rs +++ b/crates/bevy_core/src/task_pool_options.rs @@ -1,4 +1,3 @@ -use bevy_ecs::world::World; use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder}; use bevy_utils::tracing::trace; @@ -93,14 +92,14 @@ impl DefaultTaskPoolOptions { } /// Inserts the default thread pools into the given resource map based on the configured values - pub fn create_default_pools(&self, world: &mut World) { + pub fn create_default_pools(&self) { let total_threads = bevy_tasks::logical_core_count().clamp(self.min_total_threads, self.max_total_threads); trace!("Assigning {} cores to default task pools", total_threads); let mut remaining_threads = total_threads; - if !world.contains_resource::() { + { // Determine the number of IO threads we will use let io_threads = self .io @@ -109,15 +108,15 @@ impl DefaultTaskPoolOptions { trace!("IO Threads: {}", io_threads); remaining_threads = remaining_threads.saturating_sub(io_threads); - world.insert_resource(IoTaskPool( + IoTaskPool::init(|| { TaskPoolBuilder::default() .num_threads(io_threads) .thread_name("IO Task Pool".to_string()) - .build(), - )); + .build() + }); } - if !world.contains_resource::() { + { // Determine the number of async compute threads we will use let async_compute_threads = self .async_compute @@ -126,15 +125,15 @@ impl DefaultTaskPoolOptions { trace!("Async Compute Threads: {}", async_compute_threads); remaining_threads = remaining_threads.saturating_sub(async_compute_threads); - world.insert_resource(AsyncComputeTaskPool( + AsyncComputeTaskPool::init(|| { TaskPoolBuilder::default() .num_threads(async_compute_threads) .thread_name("Async Compute Task Pool".to_string()) - .build(), - )); + .build() + }); } - if !world.contains_resource::() { + { // Determine the number of compute threads we will use // This is intentionally last so that an end user can specify 1.0 as the percent let compute_threads = self @@ -142,12 +141,13 @@ impl DefaultTaskPoolOptions { .get_number_of_threads(remaining_threads, total_threads); trace!("Compute Threads: {}", compute_threads); - world.insert_resource(ComputeTaskPool( + + ComputeTaskPool::init(|| { TaskPoolBuilder::default() .num_threads(compute_threads) .thread_name("Compute Task Pool".to_string()) - .build(), - )); + .build() + }); } } } diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 56ac5db9acd94..30a373a9d5ae0 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -375,8 +375,8 @@ mod tests { #[test] fn par_for_each_dense() { + ComputeTaskPool::init(TaskPool::default); let mut world = World::new(); - world.insert_resource(ComputeTaskPool(TaskPool::default())); let e1 = world.spawn().insert(A(1)).id(); let e2 = world.spawn().insert(A(2)).id(); let e3 = world.spawn().insert(A(3)).id(); @@ -397,8 +397,8 @@ mod tests { #[test] fn par_for_each_sparse() { + ComputeTaskPool::init(TaskPool::default); let mut world = World::new(); - world.insert_resource(ComputeTaskPool(TaskPool::default())); let e1 = world.spawn().insert(SparseStored(1)).id(); let e2 = world.spawn().insert(SparseStored(2)).id(); let e3 = world.spawn().insert(SparseStored(3)).id(); diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 64a80313c69f1..a01fc4019de4f 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -10,18 +10,17 @@ use crate::{ storage::TableId, world::{World, WorldId}, }; -use bevy_tasks::{ComputeTaskPool, TaskPool}; +use bevy_tasks::ComputeTaskPool; #[cfg(feature = "trace")] use bevy_utils::tracing::Instrument; use fixedbitset::FixedBitSet; -use std::{borrow::Borrow, fmt, ops::Deref}; +use std::{borrow::Borrow, fmt}; use super::{QueryFetch, QueryItem, QueryManyIter, ROQueryFetch, ROQueryItem}; /// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter. pub struct QueryState { world_id: WorldId, - task_pool: Option, pub(crate) archetype_generation: ArchetypeGeneration, pub(crate) matched_tables: FixedBitSet, pub(crate) matched_archetypes: FixedBitSet, @@ -62,9 +61,6 @@ impl QueryState { let mut state = Self { world_id: world.id(), - task_pool: world - .get_resource::() - .map(|task_pool| task_pool.deref().clone()), archetype_generation: ArchetypeGeneration::initial(), matched_table_ids: Vec::new(), matched_archetype_ids: Vec::new(), @@ -754,8 +750,8 @@ impl QueryState { /// write-queries. /// /// # Panics - /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. #[inline] pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, @@ -779,8 +775,8 @@ impl QueryState { /// Runs `func` on each query result in parallel. /// /// # Panics - /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. #[inline] pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, @@ -806,8 +802,8 @@ impl QueryState { /// This can only be called for read-only queries. /// /// # Panics - /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety /// @@ -922,8 +918,8 @@ impl QueryState { /// iter() method, but cannot be chained like a normal [`Iterator`]. /// /// # Panics - /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety /// @@ -945,106 +941,95 @@ impl QueryState { ) { // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: // QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::many_for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual - self.task_pool - .as_ref() - .expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.") - .scope(|scope| { - if QF::IS_DENSE && >::IS_DENSE { - let tables = &world.storages().tables; - for table_id in &self.matched_table_ids { - let table = &tables[*table_id]; - let mut offset = 0; - while offset < table.len() { - let func = func.clone(); - let len = batch_size.min(table.len() - offset); - let task = async move { - let mut fetch = QF::init( - world, - &self.fetch_state, - last_change_tick, - change_tick, - ); - let mut filter = as Fetch>::init( - world, - &self.filter_state, - last_change_tick, - change_tick, - ); - let tables = &world.storages().tables; - let table = &tables[*table_id]; - fetch.set_table(&self.fetch_state, table); - filter.set_table(&self.filter_state, table); - for table_index in offset..offset + len { - if !filter.table_filter_fetch(table_index) { - continue; - } - let item = fetch.table_fetch(table_index); - func(item); - } - }; - #[cfg(feature = "trace")] - let span = bevy_utils::tracing::info_span!( - "par_for_each", - query = std::any::type_name::(), - filter = std::any::type_name::(), - count = len, + ComputeTaskPool::get().scope(|scope| { + if QF::IS_DENSE && >::IS_DENSE { + let tables = &world.storages().tables; + for table_id in &self.matched_table_ids { + let table = &tables[*table_id]; + let mut offset = 0; + while offset < table.len() { + let func = func.clone(); + let len = batch_size.min(table.len() - offset); + let task = async move { + let mut fetch = + QF::init(world, &self.fetch_state, last_change_tick, change_tick); + let mut filter = as Fetch>::init( + world, + &self.filter_state, + last_change_tick, + change_tick, ); - #[cfg(feature = "trace")] - let task = task.instrument(span); - scope.spawn(task); - offset += batch_size; - } - } - } else { - let archetypes = &world.archetypes; - for archetype_id in &self.matched_archetype_ids { - let mut offset = 0; - let archetype = &archetypes[*archetype_id]; - while offset < archetype.len() { - let func = func.clone(); - let len = batch_size.min(archetype.len() - offset); - let task = async move { - let mut fetch = QF::init( - world, - &self.fetch_state, - last_change_tick, - change_tick, - ); - let mut filter = as Fetch>::init( - world, - &self.filter_state, - last_change_tick, - change_tick, - ); - let tables = &world.storages().tables; - let archetype = &world.archetypes[*archetype_id]; - fetch.set_archetype(&self.fetch_state, archetype, tables); - filter.set_archetype(&self.filter_state, archetype, tables); - - for archetype_index in offset..offset + len { - if !filter.archetype_filter_fetch(archetype_index) { - continue; - } - func(fetch.archetype_fetch(archetype_index)); + let tables = &world.storages().tables; + let table = &tables[*table_id]; + fetch.set_table(&self.fetch_state, table); + filter.set_table(&self.filter_state, table); + for table_index in offset..offset + len { + if !filter.table_filter_fetch(table_index) { + continue; } - }; - - #[cfg(feature = "trace")] - let span = bevy_utils::tracing::info_span!( - "par_for_each", - query = std::any::type_name::(), - filter = std::any::type_name::(), - count = len, + let item = fetch.table_fetch(table_index); + func(item); + } + }; + #[cfg(feature = "trace")] + let span = bevy_utils::tracing::info_span!( + "par_for_each", + query = std::any::type_name::(), + filter = std::any::type_name::(), + count = len, + ); + #[cfg(feature = "trace")] + let task = task.instrument(span); + scope.spawn(task); + offset += batch_size; + } + } + } else { + let archetypes = &world.archetypes; + for archetype_id in &self.matched_archetype_ids { + let mut offset = 0; + let archetype = &archetypes[*archetype_id]; + while offset < archetype.len() { + let func = func.clone(); + let len = batch_size.min(archetype.len() - offset); + let task = async move { + let mut fetch = + QF::init(world, &self.fetch_state, last_change_tick, change_tick); + let mut filter = as Fetch>::init( + world, + &self.filter_state, + last_change_tick, + change_tick, ); - #[cfg(feature = "trace")] - let task = task.instrument(span); - - scope.spawn(task); - offset += batch_size; - } + let tables = &world.storages().tables; + let archetype = &world.archetypes[*archetype_id]; + fetch.set_archetype(&self.fetch_state, archetype, tables); + filter.set_archetype(&self.filter_state, archetype, tables); + + for archetype_index in offset..offset + len { + if !filter.archetype_filter_fetch(archetype_index) { + continue; + } + func(fetch.archetype_fetch(archetype_index)); + } + }; + + #[cfg(feature = "trace")] + let span = bevy_utils::tracing::info_span!( + "par_for_each", + query = std::any::type_name::(), + filter = std::any::type_name::(), + count = len, + ); + #[cfg(feature = "trace")] + let task = task.instrument(span); + + scope.spawn(task); + offset += batch_size; } } - }); + } + }); } /// Runs `func` on each query result for the given [`World`] and list of [`Entity`]'s, where the last change and diff --git a/crates/bevy_ecs/src/schedule/executor_parallel.rs b/crates/bevy_ecs/src/schedule/executor_parallel.rs index 149d8d02bc2df..c82924b0e276c 100644 --- a/crates/bevy_ecs/src/schedule/executor_parallel.rs +++ b/crates/bevy_ecs/src/schedule/executor_parallel.rs @@ -123,10 +123,7 @@ impl ParallelSystemExecutor for ParallelExecutor { } } - let compute_pool = world - .get_resource_or_insert_with(|| ComputeTaskPool(TaskPool::default())) - .clone(); - compute_pool.scope(|scope| { + ComputeTaskPool::init(TaskPool::default).scope(|scope| { self.prepare_systems(scope, systems, world); let parallel_executor = async { // All systems have been ran if there are no queued or running systems. diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index 484d119e4111d..0bfe2711dc630 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -587,8 +587,8 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { ///* `f` - The function to run on each item in the query /// /// # Panics - /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool #[inline] @@ -615,8 +615,8 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { /// See [`Self::par_for_each`] for more details. /// /// # Panics - /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool #[inline] diff --git a/crates/bevy_gltf/Cargo.toml b/crates/bevy_gltf/Cargo.toml index c35ac483c044a..5d9ab53ad21c6 100644 --- a/crates/bevy_gltf/Cargo.toml +++ b/crates/bevy_gltf/Cargo.toml @@ -24,6 +24,7 @@ bevy_reflect = { path = "../bevy_reflect", version = "0.8.0-dev", features = ["b bevy_render = { path = "../bevy_render", version = "0.8.0-dev" } bevy_scene = { path = "../bevy_scene", version = "0.8.0-dev" } bevy_transform = { path = "../bevy_transform", version = "0.8.0-dev" } +bevy_tasks = { path = "../bevy_tasks", version = "0.8.0-dev" } bevy_utils = { path = "../bevy_utils", version = "0.8.0-dev" } # other diff --git a/crates/bevy_gltf/src/loader.rs b/crates/bevy_gltf/src/loader.rs index c699a58945ad8..a5976530e74df 100644 --- a/crates/bevy_gltf/src/loader.rs +++ b/crates/bevy_gltf/src/loader.rs @@ -29,6 +29,7 @@ use bevy_render::{ view::VisibleEntities, }; use bevy_scene::Scene; +use bevy_tasks::IoTaskPool; use bevy_transform::{components::Transform, TransformBundle}; use bevy_utils::{HashMap, HashSet}; @@ -410,8 +411,7 @@ async fn load_gltf<'a, 'b>( } } else { #[cfg(not(target_arch = "wasm32"))] - load_context - .task_pool() + IoTaskPool::get() .scope(|scope| { gltf.textures().for_each(|gltf_texture| { let linear_textures = &linear_textures; diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 7b83b9bc344bd..06a0da456931a 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -13,7 +13,8 @@ futures-lite = "1.4.0" event-listener = "2.5.2" async-executor = "1.3.0" async-channel = "1.4.2" -num_cpus = "1.0.1" +num_cpus = "1" +once_cell = "1.7" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index ebd6ba6b41f4c..1d0f86e7cb5ed 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -60,29 +60,9 @@ impl TaskPoolBuilder { } } -#[derive(Debug)] -struct TaskPoolInner { - threads: Vec>, - shutdown_tx: async_channel::Sender<()>, -} - -impl Drop for TaskPoolInner { - fn drop(&mut self) { - self.shutdown_tx.close(); - - let panicking = thread::panicking(); - for join_handle in self.threads.drain(..) { - let res = join_handle.join(); - if !panicking { - res.expect("Task thread panicked while executing."); - } - } - } -} - /// A thread pool for executing tasks. Tasks are futures that are being automatically driven by /// the pool on threads owned by the pool. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TaskPool { /// The executor for the pool /// @@ -92,7 +72,8 @@ pub struct TaskPool { executor: Arc>, /// Inner state of the pool - inner: Arc, + threads: Vec>, + shutdown_tx: async_channel::Sender<()>, } impl TaskPool { @@ -155,16 +136,14 @@ impl TaskPool { Self { executor, - inner: Arc::new(TaskPoolInner { - threads, - shutdown_tx, - }), + threads, + shutdown_tx, } } /// Return the number of threads owned by the task pool pub fn thread_num(&self) -> usize { - self.inner.threads.len() + self.threads.len() } /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback, @@ -268,6 +247,20 @@ impl Default for TaskPool { } } +impl Drop for TaskPool { + fn drop(&mut self) { + self.shutdown_tx.close(); + + let panicking = thread::panicking(); + for join_handle in self.threads.drain(..) { + let res = join_handle.join(); + if !panicking { + res.expect("Task thread panicked while executing."); + } + } + } +} + /// A `TaskPool` scope for running one or more non-`'static` futures. /// /// For more information, see [`TaskPool::scope`]. diff --git a/crates/bevy_tasks/src/usages.rs b/crates/bevy_tasks/src/usages.rs index 923c1a7eb4eab..419d842f47168 100644 --- a/crates/bevy_tasks/src/usages.rs +++ b/crates/bevy_tasks/src/usages.rs @@ -11,12 +11,35 @@ //! for consumption. (likely via channels) use super::TaskPool; +use once_cell::sync::OnceCell; use std::ops::Deref; +static COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); +static ASYNC_COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); +static IO_TASK_POOL: OnceCell = OnceCell::new(); + /// A newtype for a task pool for CPU-intensive work that must be completed to deliver the next /// frame -#[derive(Clone, Debug)] -pub struct ComputeTaskPool(pub TaskPool); +#[derive(Debug)] +pub struct ComputeTaskPool(TaskPool); + +impl ComputeTaskPool { + /// Initializes the global [`ComputeTaskPool`] instance. + pub fn init(f: impl FnOnce() -> TaskPool) -> &'static Self { + COMPUTE_TASK_POOL.get_or_init(|| Self(f())) + } + + /// Gets the global [`ComputeTaskPool`] instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + COMPUTE_TASK_POOL.get().expect( + "A ComputeTaskPool has not been initialized yet. Please call \ + ComputeTaskPool::init beforehand.", + ) + } +} impl Deref for ComputeTaskPool { type Target = TaskPool; @@ -27,8 +50,26 @@ impl Deref for ComputeTaskPool { } /// A newtype for a task pool for CPU-intensive work that may span across multiple frames -#[derive(Clone, Debug)] -pub struct AsyncComputeTaskPool(pub TaskPool); +#[derive(Debug)] +pub struct AsyncComputeTaskPool(TaskPool); + +impl AsyncComputeTaskPool { + /// Initializes the global [`AsyncComputeTaskPool`] instance. + pub fn init(f: impl FnOnce() -> TaskPool) -> &'static Self { + ASYNC_COMPUTE_TASK_POOL.get_or_init(|| Self(f())) + } + + /// Gets the global [`AsyncComputeTaskPool`] instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + ASYNC_COMPUTE_TASK_POOL.get().expect( + "A AsyncComputeTaskPool has not been initialized yet. Please call \ + AsyncComputeTaskPool::init beforehand.", + ) + } +} impl Deref for AsyncComputeTaskPool { type Target = TaskPool; @@ -40,8 +81,26 @@ impl Deref for AsyncComputeTaskPool { /// A newtype for a task pool for IO-intensive work (i.e. tasks that spend very little time in a /// "woken" state) -#[derive(Clone, Debug)] -pub struct IoTaskPool(pub TaskPool); +#[derive(Debug)] +pub struct IoTaskPool(TaskPool); + +impl IoTaskPool { + /// Initializes the global [`IoTaskPool`] instance. + pub fn init(f: impl FnOnce() -> TaskPool) -> &'static Self { + IO_TASK_POOL.get_or_init(|| Self(f())) + } + + /// Gets the global [`IoTaskPool`] instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + IO_TASK_POOL.get().expect( + "A IoTaskPool has not been initialized yet. Please call \ + IoTaskPool::init beforehand.", + ) + } +} impl Deref for IoTaskPool { type Target = TaskPool; diff --git a/examples/asset/custom_asset_io.rs b/examples/asset/custom_asset_io.rs index f601ebad241f5..7a58750a1cbfa 100644 --- a/examples/asset/custom_asset_io.rs +++ b/examples/asset/custom_asset_io.rs @@ -51,10 +51,6 @@ struct CustomAssetIoPlugin; impl Plugin for CustomAssetIoPlugin { fn build(&self, app: &mut App) { - // must get a hold of the task pool in order to create the asset server - - let task_pool = app.world.resource::().0.clone(); - let asset_io = { // the platform default asset io requires a reference to the app // builder to find its configuration @@ -68,7 +64,7 @@ impl Plugin for CustomAssetIoPlugin { // the asset server is constructed and added the resource manager - app.insert_resource(AssetServer::new(asset_io, task_pool)); + app.insert_resource(AssetServer::new(asset_io)); } } diff --git a/examples/async_tasks/async_compute.rs b/examples/async_tasks/async_compute.rs index e50e3ad982ccf..e01ebd4c9bcd7 100644 --- a/examples/async_tasks/async_compute.rs +++ b/examples/async_tasks/async_compute.rs @@ -50,7 +50,8 @@ struct ComputeTransform(Task); /// work that potentially spans multiple frames/ticks. A separate /// system, `handle_tasks`, will poll the spawned tasks on subsequent /// frames/ticks, and use the results to spawn cubes -fn spawn_tasks(mut commands: Commands, thread_pool: Res) { +fn spawn_tasks(mut commands: Commands) { + let thread_pool = AsyncComputeTaskPool::get(); for x in 0..NUM_CUBES { for y in 0..NUM_CUBES { for z in 0..NUM_CUBES {