From 1ac8dff213937088616dc84de9adc92b4b68c49a Mon Sep 17 00:00:00 2001 From: Rafael Bachmann Date: Tue, 20 Aug 2024 10:58:04 +0200 Subject: [PATCH] task: add `AbortOnDropHandle` type (#6786) --- tokio-util/src/task/abort_on_drop.rs | 63 ++++++++++++++++++++++++++++ tokio-util/src/task/mod.rs | 3 ++ tokio-util/tests/abort_on_drop.rs | 27 ++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 tokio-util/src/task/abort_on_drop.rs create mode 100644 tokio-util/tests/abort_on_drop.rs diff --git a/tokio-util/src/task/abort_on_drop.rs b/tokio-util/src/task/abort_on_drop.rs new file mode 100644 index 00000000000..3739bbfa9fa --- /dev/null +++ b/tokio-util/src/task/abort_on_drop.rs @@ -0,0 +1,63 @@ +//! An [`AbortOnDropHandle`] is like a [`JoinHandle`], except that it +//! will abort the task as soon as it is dropped. + +use tokio::task::{AbortHandle, JoinError, JoinHandle}; + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +/// A wrapper around a [`tokio::task::JoinHandle`], +/// which [aborts] the task when it is dropped. +/// +/// [aborts]: tokio::task::JoinHandle::abort +#[must_use = "Dropping the handle aborts the task immediately"] +#[derive(Debug)] +pub struct AbortOnDropHandle(JoinHandle); + +impl Drop for AbortOnDropHandle { + fn drop(&mut self) { + self.0.abort() + } +} + +impl AbortOnDropHandle { + /// Create an [`AbortOnDropHandle`] from a [`JoinHandle`]. + pub fn new(handle: JoinHandle) -> Self { + Self(handle) + } + + /// Abort the task associated with this handle, + /// equivalent to [`JoinHandle::abort`]. + pub fn abort(&self) { + self.0.abort() + } + + /// Checks if the task associated with this handle is finished, + /// equivalent to [`JoinHandle::is_finished`]. + pub fn is_finished(&self) -> bool { + self.0.is_finished() + } + + /// Returns a new [`AbortHandle`] that can be used to remotely abort this task, + /// equivalent to [`JoinHandle::abort_handle`]. + pub fn abort_handle(&self) -> AbortHandle { + self.0.abort_handle() + } +} + +impl Future for AbortOnDropHandle { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl AsRef> for AbortOnDropHandle { + fn as_ref(&self) -> &JoinHandle { + &self.0 + } +} diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs index 1ab3ff13dbe..6d0c379fe20 100644 --- a/tokio-util/src/task/mod.rs +++ b/tokio-util/src/task/mod.rs @@ -11,3 +11,6 @@ pub use join_map::{JoinMap, JoinMapKeys}; pub mod task_tracker; pub use task_tracker::TaskTracker; + +mod abort_on_drop; +pub use abort_on_drop::AbortOnDropHandle; diff --git a/tokio-util/tests/abort_on_drop.rs b/tokio-util/tests/abort_on_drop.rs new file mode 100644 index 00000000000..c7dcee35aac --- /dev/null +++ b/tokio-util/tests/abort_on_drop.rs @@ -0,0 +1,27 @@ +use tokio::sync::oneshot; +use tokio_util::task::AbortOnDropHandle; + +#[tokio::test] +async fn aborts_task_on_drop() { + let (mut tx, rx) = oneshot::channel::(); + let handle = tokio::spawn(async move { + let _ = rx.await; + }); + let handle = AbortOnDropHandle::new(handle); + drop(handle); + tx.closed().await; + assert!(tx.is_closed()); +} + +#[tokio::test] +async fn aborts_task_directly() { + let (mut tx, rx) = oneshot::channel::(); + let handle = tokio::spawn(async move { + let _ = rx.await; + }); + let handle = AbortOnDropHandle::new(handle); + handle.abort(); + tx.closed().await; + assert!(tx.is_closed()); + assert!(handle.is_finished()); +}