From be9c28339c10d1a3313517f448613e880e563bd4 Mon Sep 17 00:00:00 2001 From: Abutalib Aghayev Date: Sat, 5 Nov 2022 19:52:07 -0400 Subject: [PATCH] rt: add a method to retrieve task id --- tokio/src/runtime/context.rs | 15 ++++++++ tokio/src/runtime/task/harness.rs | 20 +++++++++- tokio/src/runtime/task/mod.rs | 17 ++++++++- tokio/src/task/mod.rs | 1 + tokio/tests/task_local.rs | 61 +++++++++++++++++++++++++++++++ tokio/tests/task_panic.rs | 13 +++++++ 6 files changed, 124 insertions(+), 3 deletions(-) diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 4427c8a2efc..295222f8645 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -1,4 +1,5 @@ use crate::runtime::coop; +use crate::runtime::task::Id; use std::cell::Cell; @@ -17,6 +18,7 @@ struct Context { /// Handle to the runtime scheduler running on the current thread. #[cfg(feature = "rt")] scheduler: RefCell>, + current_task_id: Cell>, #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand, @@ -31,6 +33,7 @@ tokio_thread_local! { Context { #[cfg(feature = "rt")] scheduler: RefCell::new(None), + current_task_id: Cell::new(None), #[cfg(any(feature = "rt", feature = "macros"))] rng: FastRand::new(RngSeed::new()), @@ -85,6 +88,18 @@ cfg_rt! { pub(crate) struct DisallowBlockInPlaceGuard(bool); + pub(crate) fn set_current_task_id(id: Option) { + CONTEXT.with(|ctx| ctx.current_task_id.replace(id)); + } + + #[track_caller] + pub(crate) fn current_task_id() -> Id { + match CONTEXT.try_with(|ctx| ctx.current_task_id.get()) { + Ok(Some(id)) => id, + _ => panic!("can't get a task id when not inside a task"), + } + } + pub(crate) fn try_current() -> Result { match CONTEXT.try_with(|ctx| ctx.scheduler.borrow().clone()) { Ok(Some(handle)) => Ok(handle), diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 206cdf2695a..65844167fa8 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -1,8 +1,9 @@ use crate::future::Future; +use crate::runtime::context; use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; use crate::runtime::task::state::Snapshot; use crate::runtime::task::waker::waker_ref; -use crate::runtime::task::{JoinError, Notified, Schedule, Task}; +use crate::runtime::task::{Id, JoinError, Notified, Schedule, Task}; use std::mem; use std::mem::ManuallyDrop; @@ -439,10 +440,26 @@ enum PollFuture { Dealloc, } +/// Guard that sets and clears the task id in the context during task execution +/// and cancellation. +struct TaskIdGuard {} +impl TaskIdGuard { + fn new(id: Id) -> Self { + context::set_current_task_id(Some(id)); + TaskIdGuard {} + } +} +impl Drop for TaskIdGuard { + fn drop(&mut self) { + context::set_current_task_id(None); + } +} + /// Cancels the task and store the appropriate error in the stage field. fn cancel_task(stage: &CoreStage, id: super::Id) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + let _task_id_guard = TaskIdGuard::new(id); stage.drop_future_or_output(); })); @@ -476,6 +493,7 @@ fn poll_future( self.core.drop_future_or_output(); } } + let _task_id_guard = TaskIdGuard::new(id); let guard = Guard { core }; let res = guard.core.poll(cx); mem::forget(guard); diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 3d5b1cbf373..9a1d06075d6 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -201,10 +201,23 @@ use std::{fmt, mem}; /// [unstable]: crate#unstable-features #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] -// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well... -#[derive(Clone, Debug, Hash, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub struct Id(u64); +/// Returns the `Id` of the task. +/// +/// # Panics +/// +/// This function panics if called from outside a task or if called from a +/// future passed to `block_on` call. +/// +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[track_caller] +pub fn id() -> Id { + use crate::runtime::context; + context::current_task_id() +} + /// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task { diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index f1683f7e07f..a8d69060706 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -319,6 +319,7 @@ cfg_rt! { cfg_unstable! { pub use crate::runtime::task::Id; + pub use crate::runtime::task::id; } cfg_trace! { diff --git a/tokio/tests/task_local.rs b/tokio/tests/task_local.rs index a1fab08950d..b1d633e46b5 100644 --- a/tokio/tests/task_local.rs +++ b/tokio/tests/task_local.rs @@ -116,3 +116,64 @@ async fn task_local_available_on_completion_drop() { assert_eq!(rx.await.unwrap(), 42); h.await.unwrap(); } + +#[tokio::test(flavor = "current_thread")] +async fn task_id() { + use tokio::task; + + let handle = tokio::spawn(async { println!("task id: {}", task::id()) }); + + handle.await.unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "multi_thread")] +async fn task_id_collision_multi_thread() { + use tokio::task; + + let handle1 = tokio::spawn(async { task::id() }); + let handle2 = tokio::spawn(async { task::id() }); + + let (id1, id2) = tokio::join!(handle1, handle2); + assert_ne!(id1.unwrap(), id2.unwrap()); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_id_collision_current_thread() { + use tokio::task; + + let handle1 = tokio::spawn(async { task::id() }); + let handle2 = tokio::spawn(async { task::id() }); + + let (id1, id2) = tokio::join!(handle1, handle2); + assert_ne!(id1.unwrap(), id2.unwrap()); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn task_ids_match_current_thread() { + use tokio::{sync::oneshot, task}; + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async { + let id = rx.await.unwrap(); + assert_eq!(id, task::id()); + }); + tx.send(handle.id()).unwrap(); + handle.await.unwrap(); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "multi_thread")] +async fn task_ids_match_multi_thread() { + use tokio::{sync::oneshot, task}; + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async { + let id = rx.await.unwrap(); + assert_eq!(id, task::id()); + }); + tx.send(handle.id()).unwrap(); + handle.await.unwrap(); +} diff --git a/tokio/tests/task_panic.rs b/tokio/tests/task_panic.rs index 126195222e5..377e40ceadc 100644 --- a/tokio/tests/task_panic.rs +++ b/tokio/tests/task_panic.rs @@ -120,3 +120,16 @@ fn local_key_get_panic_caller() -> Result<(), Box> { Ok(()) } + +#[cfg(tokio_unstable)] +#[test] +fn task_id_handle_panic_caller() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let _ = task::id(); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +}