Skip to content

Commit

Permalink
Handle TaskPool panicking threads
Browse files Browse the repository at this point in the history
  • Loading branch information
SarthakSingh31 committed Jun 15, 2022
1 parent 32cd989 commit fba1605
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 26 deletions.
70 changes: 48 additions & 22 deletions crates/bevy_core/src/task_pool_options.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_tasks::{
AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder, TaskPoolThreadPanicPolicy,
};
use bevy_utils::tracing::trace;

/// Defines a simple way to determine how many threads to use given the number of remaining cores
Expand Down Expand Up @@ -30,6 +32,15 @@ impl TaskPoolThreadAssignmentPolicy {
}
}

/// The set of policies describing how the according task pool behaves
#[derive(Clone)]
pub struct TaskPoolPolicies {
/// Used to determine number of threads to allocate
pub assignment_policy: TaskPoolThreadAssignmentPolicy,
/// Used to determine the panic policy of the task pool
pub panic_policy: TaskPoolThreadPanicPolicy,
}

/// Helper for configuring and creating the default task pools. For end-users who want full control,
/// insert the default task pools into the resource map manually. If the pools are already inserted,
/// this helper will do nothing.
Expand All @@ -42,12 +53,12 @@ pub struct DefaultTaskPoolOptions {
/// max_total_threads
pub max_total_threads: usize,

/// Used to determine number of IO threads to allocate
pub io: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of async compute threads to allocate
pub async_compute: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of compute threads to allocate
pub compute: TaskPoolThreadAssignmentPolicy,
/// Used to configure the IOTaskPool's inner policies
pub io: TaskPoolPolicies,
/// Used to configure the AsyncTaskPool's inner policies
pub async_compute: TaskPoolPolicies,
/// Used to configure the ComputeTaskPool's inner policies
pub compute: TaskPoolPolicies,
}

impl Default for DefaultTaskPoolOptions {
Expand All @@ -57,25 +68,34 @@ impl Default for DefaultTaskPoolOptions {
min_total_threads: 1,
max_total_threads: std::usize::MAX,

// Use 25% of cores for IO, at least 1, no more than 4
io: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
io: TaskPoolPolicies {
// Use 25% of cores for IO, at least 1, no more than 4
assignment_policy: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
},
panic_policy: TaskPoolThreadPanicPolicy::Restart,
},

// Use 25% of cores for async compute, at least 1, no more than 4
async_compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
async_compute: TaskPoolPolicies {
// Use 25% of cores for async compute, at least 1, no more than 4
assignment_policy: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
},
panic_policy: TaskPoolThreadPanicPolicy::Propagate,
},

// Use all remaining cores for compute (at least 1)
compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: std::usize::MAX,
percent: 1.0, // This 1.0 here means "whatever is left over"
compute: TaskPoolPolicies {
// Use all remaining cores for compute (at least 1)
assignment_policy: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: std::usize::MAX,
percent: 1.0, // This 1.0 here means "whatever is left over"
},
panic_policy: TaskPoolThreadPanicPolicy::Propagate,
},
}
}
Expand Down Expand Up @@ -103,6 +123,7 @@ impl DefaultTaskPoolOptions {
// Determine the number of IO threads we will use
let io_threads = self
.io
.assignment_policy
.get_number_of_threads(remaining_threads, total_threads);

trace!("IO Threads: {}", io_threads);
Expand All @@ -112,6 +133,7 @@ impl DefaultTaskPoolOptions {
TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string())
.panic_policy(self.io.panic_policy)
.build()
});
}
Expand All @@ -120,6 +142,7 @@ impl DefaultTaskPoolOptions {
// Determine the number of async compute threads we will use
let async_compute_threads = self
.async_compute
.assignment_policy
.get_number_of_threads(remaining_threads, total_threads);

trace!("Async Compute Threads: {}", async_compute_threads);
Expand All @@ -129,6 +152,7 @@ impl DefaultTaskPoolOptions {
TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string())
.panic_policy(self.async_compute.panic_policy)
.build()
});
}
Expand All @@ -138,6 +162,7 @@ impl DefaultTaskPoolOptions {
// This is intentionally last so that an end user can specify 1.0 as the percent
let compute_threads = self
.compute
.assignment_policy
.get_number_of_threads(remaining_threads, total_threads);

trace!("Compute Threads: {}", compute_threads);
Expand All @@ -146,6 +171,7 @@ impl DefaultTaskPoolOptions {
TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string())
.panic_policy(self.compute.panic_policy)
.build()
});
}
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use task::Task;
#[cfg(not(target_arch = "wasm32"))]
mod task_pool;
#[cfg(not(target_arch = "wasm32"))]
pub use task_pool::{Scope, TaskPool, TaskPoolBuilder};
pub use task_pool::{Scope, TaskPool, TaskPoolBuilder, TaskPoolThreadPanicPolicy};

#[cfg(target_arch = "wasm32")]
mod single_threaded_task_pool;
Expand Down
94 changes: 91 additions & 3 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct TaskPoolBuilder {
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
thread_name: Option<String>,
/// Used to determine the panic policy of the task pool
panic_policy: TaskPoolThreadPanicPolicy,
}

impl TaskPoolBuilder {
Expand Down Expand Up @@ -50,16 +52,40 @@ impl TaskPoolBuilder {
self
}

/// Override the panic policy of the task pool
pub fn panic_policy(mut self, panic_policy: TaskPoolThreadPanicPolicy) -> Self {
self.panic_policy = panic_policy;
self
}

/// Creates a new [`TaskPool`] based on the current options.
pub fn build(self) -> TaskPool {
TaskPool::new_internal(
self.num_threads,
self.stack_size,
self.thread_name.as_deref(),
self.panic_policy,
)
}
}

/// The policy used when a task pool's internal thread panics
#[derive(Copy, Clone, Debug)]
pub enum TaskPoolThreadPanicPolicy {
/// Propagate the panic to the main thread, causing the main
/// thread to panic as well.
Propagate,
/// Restart the thread by joining the panicked thread and
/// spawning another one in it's place.
Restart,
}

impl Default for TaskPoolThreadPanicPolicy {
fn default() -> Self {
TaskPoolThreadPanicPolicy::Propagate
}
}

/// A thread pool for executing tasks. Tasks are futures that are being automatically driven by
/// the pool on threads owned by the pool.
#[derive(Debug)]
Expand Down Expand Up @@ -90,6 +116,7 @@ impl TaskPool {
num_threads: Option<usize>,
stack_size: Option<usize>,
thread_name: Option<&str>,
panic_policy: TaskPoolThreadPanicPolicy,
) -> Self {
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();

Expand Down Expand Up @@ -126,9 +153,22 @@ impl TaskPool {

thread_builder
.spawn(move || {
let shutdown_future = ex.run(shutdown_rx.recv());
// Use unwrap_err because we expect a Closed error
future::block_on(shutdown_future).unwrap_err();
let wait_for_shutdown = || {
let shutdown_future = ex.run(shutdown_rx.recv());
// Use unwrap_err because we expect a Closed error
future::block_on(shutdown_future).unwrap_err();
};

match panic_policy {
TaskPoolThreadPanicPolicy::Propagate => wait_for_shutdown(),
TaskPoolThreadPanicPolicy::Restart => loop {
let res = std::panic::catch_unwind(wait_for_shutdown);

if res.is_ok() {
break;
}
},
}
})
.expect("Failed to spawn thread.")
})
Expand Down Expand Up @@ -418,4 +458,52 @@ mod tests {
assert!(!thread_check_failed.load(Ordering::Acquire));
assert_eq!(count.load(Ordering::Acquire), 200);
}

#[test]
#[should_panic]
fn test_propogate_panic_policy_handling() {
const COUNT: usize = 100;

let pool = TaskPoolBuilder::default()
.panic_policy(TaskPoolThreadPanicPolicy::Propagate)
.build();

for i in 0..COUNT {
pool.spawn(async move {
if i % 2 == 0 {
panic!("Half of the tasks panic");
}
})
.detach();
}

drop(pool);
}

#[test]
fn test_restart_panic_policy_handling() {
const COUNT: usize = 100;

let pool = TaskPoolBuilder::default()
.panic_policy(TaskPoolThreadPanicPolicy::Restart)
.build();

let (tx, rx) = std::sync::mpsc::sync_channel::<usize>(COUNT / 2);

for i in 0..COUNT {
let tx = tx.clone();
pool.spawn(async move {
if i % 2 == 0 {
panic!("Half of the tasks panic");
}

tx.send(i).unwrap();
})
.detach();
}

for _ in 0..(COUNT / 2) {
assert!(rx.recv().unwrap() % 2 == 1);
}
}
}

0 comments on commit fba1605

Please sign in to comment.