diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 12c3c813656..86db89167af 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -43,6 +43,7 @@ futures-util = { version = "0.3.0", optional = true } pin-project-lite = "0.2.11" slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } +libc = "0.2" [target.'cfg(tokio_unstable)'.dependencies] hashbrown = { version = "0.14.0", optional = true } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 78e6bf50d62..f4eeb775984 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -318,6 +318,17 @@ impl Builder { } } + /// Returns true if kind is "CurrentThread" of this [`Builder`]. False otherwise. + pub fn is_current_threaded(&self) -> bool { + match &self.kind { + Kind::CurrentThread => true, + #[cfg(all(feature = "rt-multi-thread", not(target_os = "wasi")))] + Kind::MultiThread => false, + #[cfg(all(tokio_unstable, feature = "rt-multi-thread", not(target_os = "wasi")))] + Kind::MultiThreadAlt => false, + } + } + /// Enables both I/O and time drivers. /// /// Doing this is a shorthand for calling `enable_io` and `enable_time` diff --git a/tokio/src/runtime/lrtd.rs b/tokio/src/runtime/lrtd.rs new file mode 100644 index 00000000000..b8e2fb55304 --- /dev/null +++ b/tokio/src/runtime/lrtd.rs @@ -0,0 +1,233 @@ +//! Utility to help with "really nice to add a warning for tasks that might be blocking" +use libc; +use std::collections::HashSet; +use std::sync::mpsc; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use std::{env, thread}; + +use crate::runtime::{Builder, Runtime}; +use crate::util::rand::FastRand; + +const PANIC_WORKER_BLOCK_DURATION_DEFAULT: Duration = Duration::from_secs(60); + +fn get_panic_worker_block_duration() -> Duration { + let duration_str = env::var("MY_DURATION_ENV").unwrap_or_else(|_| "60".to_string()); + duration_str + .parse::() + .map(Duration::from_secs) + .unwrap_or(PANIC_WORKER_BLOCK_DURATION_DEFAULT) +} + +fn get_thread_id() -> libc::pthread_t { + unsafe { libc::pthread_self() } +} + +/// A trait for handling actions when blocking is detected. +/// +/// This trait provides a method for handling the detection of a blocking action. +pub trait BlockingActionHandler: Send + Sync { + /// Called when a blocking action is detected and prior to thread signaling. + /// + /// # Arguments + /// + /// * `workers` - The list of thread IDs of the tokio runtime worker threads. /// # Returns + /// + fn blocking_detected(&self, workers: &[libc::pthread_t]); +} + +struct StdErrBlockingActionHandler; + +/// BlockingActionHandler implementation that writes blocker details to standard error. +impl BlockingActionHandler for StdErrBlockingActionHandler { + fn blocking_detected(&self, workers: &[libc::pthread_t]) { + eprintln!("Detected blocking in worker threads: {:?}", workers); + } +} + +#[derive(Debug)] +struct WorkerSet { + inner: Mutex>, +} + +impl WorkerSet { + fn new() -> Self { + WorkerSet { + inner: Mutex::new(HashSet::new()), + } + } + + fn add(&self, pid: libc::pthread_t) { + let mut set = self.inner.lock().unwrap(); + set.insert(pid); + } + + fn remove(&self, pid: libc::pthread_t) { + let mut set = self.inner.lock().unwrap(); + set.remove(&pid); + } + + fn get_all(&self) -> Vec { + let set = self.inner.lock().unwrap(); + set.iter().cloned().collect() + } +} + +/// Utility to help with "really nice to add a warning for tasks that might be blocking" +#[derive(Debug)] +pub struct LongRunningTaskDetector { + interval: Duration, + detection_time: Duration, + stop_flag: Arc>, + workers: Arc, +} + +async fn do_nothing(tx: mpsc::Sender<()>) { + // signal I am done + tx.send(()).unwrap(); +} + +fn probe( + tokio_runtime: &Arc, + detection_time: Duration, + workers: &Arc, + action: &Arc, +) { + let (tx, rx) = mpsc::channel(); + let _nothing_handle = tokio_runtime.spawn(do_nothing(tx)); + let is_probe_success = match rx.recv_timeout(detection_time) { + Ok(_result) => true, + Err(_) => false, + }; + if !is_probe_success { + let targets = workers.get_all(); + action.blocking_detected(&targets); + rx.recv_timeout(get_panic_worker_block_duration()).unwrap(); + } +} + +/// Utility to help with "really nice to add a warning for tasks that might be blocking" +/// Example use: +/// ``` +/// use std::sync::Arc; +/// use tokio::runtime::lrtd::LongRunningTaskDetector; +/// +/// let mut builder = tokio::runtime::Builder::new_multi_thread(); +/// let mutable_builder = builder.worker_threads(2); +/// let lrtd = LongRunningTaskDetector::new( +/// std::time::Duration::from_millis(10), +/// std::time::Duration::from_millis(100), +/// mutable_builder, +/// ); +/// let runtime = builder.enable_all().build().unwrap(); +/// let arc_runtime = Arc::new(runtime); +/// let arc_runtime2 = arc_runtime.clone(); +/// lrtd.start(arc_runtime); +/// arc_runtime2.block_on(async { +/// print!("my async code") +/// }); +/// +/// ``` +/// +/// The above will allow you to get details on what is blocking your tokio worker threads for longer that 100ms. +/// The detail will look like: +/// +/// ```text +/// Detected blocking in worker threads: [123145318232064, 123145320341504] +/// ``` +/// +/// To get more details(like stack traces) start LongRunningTaskDetector with start_with_custom_action and provide a custom handler. +/// +impl LongRunningTaskDetector { + /// Creates a new `LongRunningTaskDetector` instance. + /// + /// # Arguments + /// + /// * `interval` - The interval between probes. This interval is randomized. + /// * `detection_time` - The maximum time allowed for a probe to succeed. + /// A probe running for longer indicates something is blocking the worker threads. + /// * `runtime_builder` - A mutable reference to a `tokio::runtime::Builder`. + /// + /// # Returns + /// + /// Returns a new `LongRunningTaskDetector` instance. + pub fn new( + interval: Duration, + detection_time: Duration, + runtime_builder: &mut Builder, + ) -> Self { + let workers = Arc::new(WorkerSet::new()); + if runtime_builder.is_current_threaded() { + workers.add(get_thread_id()); + } else { + let workers_clone = Arc::clone(&workers); + let workers_clone2 = Arc::clone(&workers); + runtime_builder + .on_thread_start(move || { + let pid = get_thread_id(); + workers_clone.add(pid); + }) + .on_thread_stop(move || { + let pid = get_thread_id(); + workers_clone2.remove(pid); + }); + } + LongRunningTaskDetector { + interval, + detection_time, + stop_flag: Arc::new(Mutex::new(true)), + workers, + } + } + + /// Starts the monitoring thread with default action handlers (that write details to std err). + /// + /// # Arguments + /// + /// * `runtime` - An `Arc` reference to a `tokio::runtime::Runtime`. + pub fn start(&self, runtime: Arc) { + self.start_with_custom_action(runtime, Arc::new(StdErrBlockingActionHandler)) + } + + /// Starts the monitoring process with custom action handlers that + /// allow you to customize what happens when blocking is detected. + /// + /// # Arguments + /// + /// * `runtime` - An `Arc` reference to a `tokio::runtime::Runtime`. + /// * `action` - An `Arc` reference to a custom `BlockingActionHandler`. + /// * `thread_action` - An `Arc` reference to a custom `ThreadStateHandler`. + pub fn start_with_custom_action( + &self, + runtime: Arc, + action: Arc, + ) { + *self.stop_flag.lock().unwrap() = false; + let stop_flag = Arc::clone(&self.stop_flag); + let detection_time = self.detection_time; + let interval = self.interval; + let workers = Arc::clone(&self.workers); + thread::spawn(move || { + let mut rnd = FastRand::new(); + let max: u32 = >::try_into(interval.as_micros()).unwrap() - 10; + while !*stop_flag.lock().unwrap() { + probe(&runtime, detection_time, &workers, &action); + thread::sleep(Duration::from_micros(rnd.fastrand_n(max).into())); + } + }); + } + + /// Stops the monitoring thread. Does nothing if LRTD is already stopped. + pub fn stop(&self) { + let mut sf = self.stop_flag.lock().unwrap(); + if !(*sf) { + *sf = true; + } + } +} + +impl Drop for LongRunningTaskDetector { + fn drop(&mut self) { + self.stop(); + } +} diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 3d333960f3d..2cdaae5bb8d 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -379,6 +379,9 @@ cfg_rt! { pub use dump::Dump; } + #[cfg(unix)] + pub mod lrtd; + mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/tests/lrtd.rs b/tokio/tests/lrtd.rs new file mode 100644 index 00000000000..7b765409956 --- /dev/null +++ b/tokio/tests/lrtd.rs @@ -0,0 +1,197 @@ +#![cfg(unix)] +mod lrtd_tests { + use std::backtrace::Backtrace; + use std::collections::HashMap; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{Arc, Mutex}; + use std::thread; + use std::time::{Duration, Instant}; + use tokio::runtime::lrtd::{BlockingActionHandler, LongRunningTaskDetector}; + + async fn run_blocking_stuff() { + println!("slow start"); + thread::sleep(Duration::from_secs(1)); + println!("slow done"); + } + + #[test] + fn test_blocking_detection_multi() { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + let mutable_builder = builder.worker_threads(2); + let lrtd = LongRunningTaskDetector::new( + Duration::from_millis(10), + Duration::from_millis(100), + mutable_builder, + ); + let runtime = builder.enable_all().build().unwrap(); + let arc_runtime = Arc::new(runtime); + let arc_runtime2 = arc_runtime.clone(); + lrtd.start(arc_runtime); + arc_runtime2.spawn(run_blocking_stuff()); + arc_runtime2.spawn(run_blocking_stuff()); + arc_runtime2.block_on(async { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + println!("Done"); + }); + } + + #[test] + fn test_blocking_detection_current() { + let mut builder = tokio::runtime::Builder::new_current_thread(); + let mutable_builder = builder.enable_all(); + let lrtd = LongRunningTaskDetector::new( + Duration::from_millis(10), + Duration::from_millis(100), + mutable_builder, + ); + let runtime = mutable_builder.build().unwrap(); + let arc_runtime = Arc::new(runtime); + let arc_runtime2 = arc_runtime.clone(); + lrtd.start(arc_runtime); + arc_runtime2.block_on(async { + run_blocking_stuff().await; + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + println!("Done"); + }); + } + + #[test] + fn test_blocking_detection_stop_unstarted() { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + let mutable_builder = builder.worker_threads(2); + let _lrtd = LongRunningTaskDetector::new( + Duration::from_millis(10), + Duration::from_millis(100), + mutable_builder, + ); + } + + fn get_thread_id() -> libc::pthread_t { + unsafe { libc::pthread_self() } + } + + static SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0); + + static THREAD_DUMPS: Mutex>> = Mutex::new(None); + + extern "C" fn signal_handler(_: i32) { + // not signal safe, this needs to be rewritten to avoid mem allocations and use a pre-allocated buffer. + let backtrace = Backtrace::force_capture(); + let name = thread::current() + .name() + .map(|n| format!(" for thread \"{}\"", n)) + .unwrap_or_else(|| "".to_owned()); + let tid = get_thread_id(); + let detail = format!("Stack trace{}:{}\n{}", name, tid, backtrace); + let mut omap = THREAD_DUMPS.lock().unwrap(); + let map = omap.as_mut().unwrap(); + (*map).insert(tid, detail); + SIGNAL_COUNTER.fetch_sub(1, Ordering::SeqCst); + } + + fn install_thread_stack_stace_handler(signal: libc::c_int) { + unsafe { + libc::signal(signal, signal_handler as libc::sighandler_t); + } + } + + static GTI_MUTEX: Mutex<()> = Mutex::new(()); + + /// A naive stack trace capture implementation for threads for DEMO/TEST only purposes. + fn get_thread_info( + signal: libc::c_int, + targets: &[libc::pthread_t], + ) -> HashMap { + let _lock = GTI_MUTEX.lock(); + { + let mut omap = THREAD_DUMPS.lock().unwrap(); + *omap = Some(HashMap::new()); + SIGNAL_COUNTER.store(targets.len(), Ordering::SeqCst); + } + for thread_id in targets { + let result = unsafe { libc::pthread_kill(*thread_id, signal) }; + if result != 0 { + eprintln!("Error sending signal: {:?}", result); + } + } + let time_limit = Duration::from_secs(1); + let start_time = Instant::now(); + loop { + let signal_count = SIGNAL_COUNTER.load(Ordering::SeqCst); + if signal_count == 0 { + break; + } + if Instant::now() - start_time >= time_limit { + break; + } + std::thread::sleep(std::time::Duration::from_micros(10)); + } + { + let omap = THREAD_DUMPS.lock().unwrap(); + omap.clone().unwrap() + } + } + + struct DetailedCaptureBlockingActionHandler { + inner: Mutex>>, + } + + impl DetailedCaptureBlockingActionHandler { + fn new() -> Self { + DetailedCaptureBlockingActionHandler { + inner: Mutex::new(None), + } + } + + fn contains_symbol(&self, symbol_name: &str) -> bool { + // Iterate over the frames in the backtrace + let omap = self.inner.lock().unwrap(); + match omap.as_ref() { + Some(map) => { + if map.is_empty() { + false + } else { + let bt_str = map.values().next().unwrap(); + bt_str.contains(symbol_name) + } + } + None => false, + } + } + } + + impl BlockingActionHandler for DetailedCaptureBlockingActionHandler { + fn blocking_detected(&self, workers: &[libc::pthread_t]) { + let mut map = self.inner.lock().unwrap(); + let tinfo = get_thread_info(libc::SIGUSR1, workers); + eprintln!("Blocking detected with details: {:?}", tinfo); + *map = Some(tinfo); + } + } + + #[test] + fn test_blocking_detection_multi_capture() { + install_thread_stack_stace_handler(libc::SIGUSR1); + let mut builder = tokio::runtime::Builder::new_multi_thread(); + let mutable_builder = builder.worker_threads(2); + let lrtd = LongRunningTaskDetector::new( + Duration::from_millis(10), + Duration::from_millis(100), + mutable_builder, + ); + let runtime = builder.enable_all().build().unwrap(); + let arc_runtime = Arc::new(runtime); + let arc_runtime2 = arc_runtime.clone(); + let blocking_action = Arc::new(DetailedCaptureBlockingActionHandler::new()); + let to_assert_blocking = blocking_action.clone(); + lrtd.start_with_custom_action(arc_runtime, blocking_action); + arc_runtime2.spawn(run_blocking_stuff()); + arc_runtime2.spawn(run_blocking_stuff()); + arc_runtime2.block_on(async { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + println!("Hello world"); + }); + assert!(to_assert_blocking.contains_symbol("std::thread::sleep")); + lrtd.stop() + } +}