Skip to content

Commit

Permalink
await tasks to cancel (#6696)
Browse files Browse the repository at this point in the history
# Objective

- Fixes #6603

## Solution

- `Task`s will cancel when dropped, but wait until they return Pending before they actually get canceled. That means that if a task panics, it's possible for that error to get propagated to the scope and the scope gets dropped, while scoped tasks in other threads are still running. This is a big problem since scoped task can hold life-timed values that are dropped as the scope is dropped leading to UB.

---

## Changelog

- changed `Scope` to use `FallibleTask` and await the cancellation of all remaining tasks when it's dropped.
  • Loading branch information
hymm authored and cart committed Nov 30, 2022
1 parent c7ad98a commit 90e76a1
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use std::{
future::Future,
marker::PhantomData,
mem,
pin::Pin,
sync::Arc,
thread::{self, JoinHandle},
};

use async_task::FallibleTask;
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, pin, FutureExt};

Expand Down Expand Up @@ -248,8 +248,8 @@ impl TaskPool {
let task_scope_executor = &async_executor::Executor::default();
let task_scope_executor: &'env async_executor::Executor =
unsafe { mem::transmute(task_scope_executor) };
let spawned: ConcurrentQueue<async_executor::Task<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<async_executor::Task<T>> =
let spawned: ConcurrentQueue<FallibleTask<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<FallibleTask<T>> =
unsafe { mem::transmute(&spawned) };

let scope = Scope {
Expand All @@ -267,10 +267,10 @@ impl TaskPool {
if spawned.is_empty() {
Vec::new()
} else {
let get_results = async move {
let mut results = Vec::with_capacity(spawned.len());
while let Ok(task) = spawned.pop() {
results.push(task.await);
let get_results = async {
let mut results = Vec::with_capacity(spawned_ref.len());
while let Ok(task) = spawned_ref.pop() {
results.push(task.await.unwrap());
}

results
Expand All @@ -279,23 +279,8 @@ impl TaskPool {
// Pin the futures on the stack.
pin!(get_results);

// SAFETY: This function blocks until all futures complete, so we do not read/write
// the data from futures outside of the 'scope lifetime. However,
// rust has no way of knowing this so we must convert to 'static
// here to appease the compiler as it is unable to validate safety.
let get_results: Pin<&mut (dyn Future<Output = Vec<T>> + 'static + Send)> = get_results;
let get_results: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static + Send)> =
unsafe { mem::transmute(get_results) };

// The thread that calls scope() will participate in driving tasks in the pool
// forward until the tasks that are spawned by this scope() call
// complete. (If the caller of scope() happens to be a thread in
// this thread pool, and we only have one thread in the pool, then
// simply calling future::block_on(spawned) would deadlock.)
let mut spawned = task_scope_executor.spawn(get_results);

loop {
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
if let Some(result) = future::block_on(future::poll_once(&mut get_results)) {
break result;
};

Expand Down Expand Up @@ -378,7 +363,7 @@ impl Drop for TaskPool {
pub struct Scope<'scope, 'env: 'scope, T> {
executor: &'scope async_executor::Executor<'scope>,
task_scope_executor: &'scope async_executor::Executor<'scope>,
spawned: &'scope ConcurrentQueue<async_executor::Task<T>>,
spawned: &'scope ConcurrentQueue<FallibleTask<T>>,
// make `Scope` invariant over 'scope and 'env
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
Expand All @@ -394,7 +379,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
///
/// For more information, see [`TaskPool::scope`].
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self.executor.spawn(f);
let task = self.executor.spawn(f).fallible();
// ConcurrentQueue only errors when closed or full, but we never
// close and use an unbouded queue, so it is safe to unwrap
self.spawned.push(task).unwrap();
Expand All @@ -407,13 +392,26 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
///
/// For more information, see [`TaskPool::scope`].
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self.task_scope_executor.spawn(f);
let task = self.task_scope_executor.spawn(f).fallible();
// ConcurrentQueue only errors when closed or full, but we never
// close and use an unbouded queue, so it is safe to unwrap
self.spawned.push(task).unwrap();
}
}

impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
where
T: 'scope,
{
fn drop(&mut self) {
future::block_on(async {
while let Ok(task) = self.spawned.pop() {
task.cancel().await;
}
});
}
}

#[cfg(test)]
#[allow(clippy::disallowed_types)]
mod tests {
Expand Down

0 comments on commit 90e76a1

Please sign in to comment.