Skip to content

Commit

Permalink
Use a robust abort handling mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
AzureMarker committed Dec 21, 2021
1 parent 8e2aed5 commit c4fa5c9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 56 deletions.
2 changes: 1 addition & 1 deletion tokio-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ codec = []
time = ["tokio/time","slab"]
io = []
io-util = ["io", "tokio/rt", "tokio/io-util"]
rt = ["tokio/rt", "tokio/sync"]
rt = ["tokio/rt", "tokio/sync", "futures-util"]

__docs_rs = ["futures-util"]

Expand Down
83 changes: 28 additions & 55 deletions tokio-util/src/task/spawn_pinned.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use futures_util::future::{AbortHandle, Abortable};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::runtime::Builder;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
Expand Down Expand Up @@ -96,7 +95,7 @@ impl LocalPool {
let (sender, receiver) = oneshot::channel();

let worker = self.find_and_incr_least_burdened_worker();
let job_guard = JobGuard(Arc::clone(&worker.task_count));
let job_guard = JobCountGuard(Arc::clone(&worker.task_count));
let worker_spawner = worker.spawner.clone();

// Spawn a future onto the worker's runtime so can immediately return
Expand All @@ -105,12 +104,19 @@ impl LocalPool {
// Move the job guard into the task
let _job_guard = job_guard;

// Propagate aborts via Abortable/AbortHandle
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let _abort_guard = AbortGuard(abort_handle);

// Inside the future we can't run spawn_local yet because we're not
// in the context of a LocalSet. We need to send create_task to the
// LocalSet task for spawning.
let spawn_task = Box::new(move || {
// Once we're in the LocalSet context we can call spawn_local
let join_handle = spawn_local(async move { create_task().await });
let join_handle =
spawn_local(
async move { Abortable::new(create_task(), abort_registration).await },
);

// Send the join handle back to the spawner. If sending fails,
// we assume the parent task was canceled, so cancel this task
Expand All @@ -126,9 +132,8 @@ impl LocalPool {
panic!("Failed to send job to worker: {}", e);
}

// Wait for the task's join handle. Forward task cancellation in
// case this task gets canceled (via ReceiverCancelGuard).
let join_handle = match ReceiverCancelGuard(receiver).await {
// Wait for the task's join handle
let join_handle = match receiver.await {
Ok(handle) => handle,
Err(e) => {
// We sent the task successfully, but failed to get its
Expand All @@ -139,12 +144,19 @@ impl LocalPool {
}
};

// Wait for the task to complete. Forward task cancellation in case
// this task gets canceled.
let join_result = JoinHandleCancelGuard(join_handle).await;
// Wait for the task to complete
let join_result = join_handle.await;

match join_result {
Ok(output) => output,
Ok(Ok(output)) => output,
Ok(Err(_)) => {
// Pinned task was aborted. But that only happens if this
// task is aborted. So this is an impossible branch.
unreachable!(
"Reaching this branch means this task was previously \
aborted but it continued running anyways"
)
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
Expand Down Expand Up @@ -196,63 +208,24 @@ impl LocalPool {

/// Automatically decrements a worker's job count when a job finishes (when
/// this gets dropped).
struct JobGuard(Arc<AtomicUsize>);
struct JobCountGuard(Arc<AtomicUsize>);

impl Drop for JobGuard {
impl Drop for JobCountGuard {
fn drop(&mut self) {
// Decrement the job count
self.0.fetch_sub(1, Ordering::SeqCst);
}
}

/// Automatically abort/cancel the task when this guard gets dropped. This will
/// forward a cancellation from one task to another.
///
/// This implements Future by polling the join handle, so just await it.
struct JoinHandleCancelGuard<T>(JoinHandle<T>);

impl<T> Future for JoinHandleCancelGuard<T> {
type Output = <JoinHandle<T> as Future>::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let join_handle = Pin::new(&mut self.0);
join_handle.poll(cx)
}
}
/// Calls abort on the handle when dropped.
struct AbortGuard(AbortHandle);

impl<T> Drop for JoinHandleCancelGuard<T> {
impl Drop for AbortGuard {
fn drop(&mut self) {
// Attempt to abort the task. This does nothing if the task has already
// completed.
self.0.abort();
}
}

/// If the task is canceled while waiting for the join handle, this guard will
/// check if the join handle was sent (in-transit so it wasn't aborted on the
/// worker side) and abort it if so.
struct ReceiverCancelGuard<T>(oneshot::Receiver<JoinHandle<T>>);

impl<T> Future for ReceiverCancelGuard<T> {
type Output = <oneshot::Receiver<JoinHandle<T>> as Future>::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let receiver = Pin::new(&mut self.0);
receiver.poll(cx)
}
}

impl<T> Drop for ReceiverCancelGuard<T> {
fn drop(&mut self) {
// If task is canceled while waiting for the join handle, and the join
// handle was already "sent" by the worker, then it's in a limbo state
// and needs to be manually canceled here.
if let Ok(join_handle) = self.0.try_recv() {
join_handle.abort();
}
}
}

type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;

struct LocalWorkerHandle {
Expand Down

0 comments on commit c4fa5c9

Please sign in to comment.