Skip to content

Commit

Permalink
use atomic bool another
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Apr 5, 2024
1 parent 7ad141a commit adfd055
Showing 1 changed file with 8 additions and 28 deletions.
36 changes: 8 additions & 28 deletions crates/polars-io/src/pl_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,9 @@ where
callable().await
}

struct BlockingMarkers(
*mut AtomicBool,
// Own the underlying vec to keep it alive
Vec<AtomicBool>,
);
unsafe impl Send for BlockingMarkers {}
unsafe impl Sync for BlockingMarkers {}

impl BlockingMarkers {
#[allow(clippy::mut_from_ref)]
unsafe fn get_unchecked_mut(&self, index: usize) -> &mut bool {
debug_assert!(index < self.1.len());
(&mut *self.0.add(index)).get_mut()
}
}

pub struct RuntimeManager {
rt: Runtime,
blocking_markers: BlockingMarkers,
blocking_markers: Vec<AtomicBool>,
}

impl RuntimeManager {
Expand All @@ -245,13 +229,9 @@ impl RuntimeManager {
.build()
.unwrap();

let n_threads = POOL.current_num_threads();

let mut blocking_markers = Vec::<AtomicBool>::with_capacity(n_threads);
for _ in 0..n_threads {
blocking_markers.push(AtomicBool::new(false));
}
let blocking_markers = BlockingMarkers(blocking_markers.as_mut_ptr(), blocking_markers);
let blocking_markers = (0..POOL.current_num_threads())
.map(|_| AtomicBool::new(false))
.collect();

Self {
rt,
Expand All @@ -274,17 +254,17 @@ impl RuntimeManager {
self.rt.block_on(future)
} else {
let thread_idx = thread_idx.unwrap();
let tokio_entered = unsafe { self.blocking_markers.get_unchecked_mut(thread_idx) };
let tokio_entered = unsafe { self.blocking_markers.get_unchecked(thread_idx) };

if !*tokio_entered {
*tokio_entered = true;
if tokio_entered.load(Ordering::Relaxed) {
tokio_entered.store(true, Ordering::Relaxed);
let out = self.rt.block_on(future);
debug_assert_eq!(
thread_idx,
POOL.current_thread_index().unwrap(),
"execution after block_on should be by the same rayon thread"
);
*tokio_entered = false;
tokio_entered.store(false, Ordering::Relaxed);
out
} else {
// Safety: The tokio runtime flavor is multi-threaded
Expand Down

0 comments on commit adfd055

Please sign in to comment.