Skip to content

Commit

Permalink
refactor: Replace std thread spawn with tokio block_in_place for nest…
Browse files Browse the repository at this point in the history
…ed block_on
  • Loading branch information
nameexhaustion committed Apr 5, 2024
1 parent 273adf5 commit b709128
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
14 changes: 4 additions & 10 deletions crates/polars-io/src/parquet/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,22 +691,16 @@ impl BatchedParquetReader {
};

// Spawn the task and wait on it asynchronously.
let (dfs, rows_read, limit) = if POOL.current_thread_index().is_some() {
if POOL.current_thread_index().is_some() {
// We are a rayon thread, so we can't use POOL.spawn as it would mean we spawn a task and block until
// another rayon thread executes it - we would deadlock if all rayon threads did this.

// Activate another tokio thread to poll futures. There should be at least 1 tokio thread that is
// not a rayon thread.
let handle = tokio::spawn(async { rx.await.unwrap() });
// Now spawn the task onto rayon and participate in executing it. The current thread will no longer
// poll async futures until this rayon task is complete.
POOL.install(f);
handle.await.unwrap()
tokio::task::block_in_place(f);
} else {
POOL.spawn(f);
rx.await.unwrap()
};

let (dfs, rows_read, limit) = rx.await.unwrap();

self.rows_read = rows_read;
self.limit = limit;
dfs
Expand Down
39 changes: 26 additions & 13 deletions crates/polars-io/src/pl_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ use std::error::Error;
use std::future::Future;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
use std::sync::RwLock;
use std::thread::ThreadId;

use once_cell::sync::Lazy;
use polars_core::config::verbose;
use polars_core::POOL;
use polars_utils::aliases::PlHashSet;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::Semaphore;

Expand Down Expand Up @@ -220,7 +217,7 @@ where

pub struct RuntimeManager {
rt: Runtime,
blocking_threads: RwLock<PlHashSet<ThreadId>>,
blocking_markers: Vec<AtomicBool>,
}

impl RuntimeManager {
Expand All @@ -232,9 +229,13 @@ impl RuntimeManager {
.build()
.unwrap();

let blocking_markers = (0..POOL.current_num_threads())
.map(|_| AtomicBool::new(false))
.collect();

Self {
rt,
blocking_threads: Default::default(),
blocking_markers,
}
}

Expand All @@ -245,17 +246,29 @@ impl RuntimeManager {
pub fn block_on_potential_spawn<F>(&'static self, future: F) -> F::Output
where
F: Future + Send,
F::Output: Send,
F::Output: Send + 'static,
{
let thread_id = std::thread::current().id();
let rayon_idx = POOL.current_thread_index();

if self.blocking_threads.read().unwrap().contains(&thread_id) {
std::thread::scope(|s| s.spawn(|| self.rt.block_on(future)).join().unwrap())
if rayon_idx.is_none() {
self.rt.block_on(future)
} else {
self.blocking_threads.write().unwrap().insert(thread_id);
let out = self.rt.block_on(future);
self.blocking_threads.write().unwrap().remove(&thread_id);
out
let rayon_idx = rayon_idx.unwrap();
let tokio_entered = unsafe { self.blocking_markers.get_unchecked(rayon_idx) };

if tokio_entered.load(Ordering::Relaxed) {
tokio_entered.store(true, Ordering::Relaxed);
let out = self.rt.block_on(future);
debug_assert_eq!(
rayon_idx,
POOL.current_thread_index().unwrap(),
"execution after block_on should be by the same rayon thread"
);
tokio_entered.store(false, Ordering::Relaxed);
out
} else {
tokio::task::block_in_place(|| self.rt.block_on(future))
}
}
}

Expand Down

0 comments on commit b709128

Please sign in to comment.