diff --git a/benches/Cargo.toml b/benches/Cargo.toml index 1eea2e04489..c581055cf65 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -26,6 +26,11 @@ name = "spawn" path = "spawn.rs" harness = false +[[bench]] +name = "sync_broadcast" +path = "sync_broadcast.rs" +harness = false + [[bench]] name = "sync_mpsc" path = "sync_mpsc.rs" diff --git a/benches/sync_broadcast.rs b/benches/sync_broadcast.rs new file mode 100644 index 00000000000..38a2141387b --- /dev/null +++ b/benches/sync_broadcast.rs @@ -0,0 +1,82 @@ +use rand::{Rng, RngCore, SeedableRng}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::{broadcast, Notify}; + +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion}; + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(6) + .build() + .unwrap() +} + +fn do_work(rng: &mut impl RngCore) -> u32 { + use std::fmt::Write; + let mut message = String::new(); + for i in 1..=10 { + let _ = write!(&mut message, " {i}={}", rng.gen::()); + } + message + .as_bytes() + .iter() + .map(|&c| c as u32) + .fold(0, u32::wrapping_add) +} + +fn contention_impl(g: &mut BenchmarkGroup) { + let rt = rt(); + + let (tx, _rx) = broadcast::channel::(1000); + let wg = Arc::new((AtomicUsize::new(0), Notify::new())); + + for n in 0..N_TASKS { + let wg = wg.clone(); + let mut rx = tx.subscribe(); + let mut rng = rand::rngs::StdRng::seed_from_u64(n as u64); + rt.spawn(async move { + while let Ok(_) = rx.recv().await { + let r = do_work(&mut rng); + let _ = black_box(r); + if wg.0.fetch_sub(1, Ordering::Relaxed) == 1 { + wg.1.notify_one(); + } + } + }); + } + + const N_ITERS: usize = 100; + + g.bench_function(N_TASKS.to_string(), |b| { + b.iter(|| { + rt.block_on({ + let wg = wg.clone(); + let tx = tx.clone(); + async move { + for i in 0..N_ITERS { + assert_eq!(wg.0.fetch_add(N_TASKS, Ordering::Relaxed), 0); + tx.send(i).unwrap(); + while wg.0.load(Ordering::Relaxed) > 0 { + wg.1.notified().await; + } + } + } + }) + }) + }); +} + +fn bench_contention(c: &mut Criterion) { + let mut group = c.benchmark_group("contention"); + contention_impl::<10>(&mut group); + contention_impl::<100>(&mut group); + contention_impl::<500>(&mut group); + contention_impl::<1000>(&mut group); + group.finish(); +} + +criterion_group!(contention, bench_contention); + +criterion_main!(contention); diff --git a/tokio/src/loom/mocked.rs b/tokio/src/loom/mocked.rs index d40e2c1f8ea..5d96ab71f7d 100644 --- a/tokio/src/loom/mocked.rs +++ b/tokio/src/loom/mocked.rs @@ -1,12 +1,25 @@ pub(crate) use loom::*; pub(crate) mod sync { + use std::{ + ops::{Deref, DerefMut}, + sync::{LockResult, PoisonError}, + }; pub(crate) use loom::sync::MutexGuard; #[derive(Debug)] pub(crate) struct Mutex(loom::sync::Mutex); + #[derive(Debug)] + pub(crate) struct RwLock(loom::sync::RwLock); + + #[derive(Debug)] + pub(crate) struct RwLockReadGuard<'a, T>(loom::sync::RwLockReadGuard<'a, T>); + + #[derive(Debug)] + pub(crate) struct RwLockWriteGuard<'a, T>(loom::sync::RwLockWriteGuard<'a, T>); + #[allow(dead_code)] impl Mutex { #[inline] @@ -25,6 +38,63 @@ pub(crate) mod sync { self.0.try_lock().ok() } } + + #[allow(dead_code)] + impl RwLock { + #[inline] + pub(crate) fn new(t: T) -> RwLock { + RwLock(loom::sync::RwLock::new(t)) + } + + #[inline] + pub(crate) fn read(&self) -> LockResult> { + match self.0.read() { + Ok(inner) => Ok(RwLockReadGuard(inner)), + Err(err) => Err(PoisonError::new(RwLockReadGuard(err.into_inner()))), + } + } + + #[inline] + pub(crate) fn write(&self) -> LockResult> { + match self.0.write() { + Ok(inner) => Ok(RwLockWriteGuard(inner)), + Err(err) => Err(PoisonError::new(RwLockWriteGuard(err.into_inner()))), + } + } + } + + impl<'a, T> Deref for RwLockReadGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.0.deref() + } + } + + #[allow(dead_code)] + impl<'a, T> RwLockWriteGuard<'a, T> { + pub(crate) fn downgrade( + s: Self, + rwlock: &'a RwLock, + ) -> LockResult> { + // Std rwlock does not support downgrading. + drop(s); + rwlock.read() + } + } + + impl<'a, T> Deref for RwLockWriteGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.0.deref() + } + } + + impl<'a, T> DerefMut for RwLockWriteGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.0.deref_mut() + } + } + pub(crate) use loom::sync::*; pub(crate) mod atomic { diff --git a/tokio/src/loom/std/mod.rs b/tokio/src/loom/std/mod.rs index 0c611af162a..7c68ea3ef3e 100644 --- a/tokio/src/loom/std/mod.rs +++ b/tokio/src/loom/std/mod.rs @@ -8,6 +8,7 @@ mod barrier; mod mutex; #[cfg(all(feature = "parking_lot", not(miri)))] mod parking_lot; +mod rwlock; mod unsafe_cell; pub(crate) mod cell { @@ -59,15 +60,18 @@ pub(crate) mod sync { #[cfg(all(feature = "parking_lot", not(miri)))] #[allow(unused_imports)] pub(crate) use crate::loom::std::parking_lot::{ - Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult, + Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult, }; #[cfg(not(all(feature = "parking_lot", not(miri))))] #[allow(unused_imports)] - pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult}; + pub(crate) use std::sync::{Condvar, MutexGuard, WaitTimeoutResult}; #[cfg(not(all(feature = "parking_lot", not(miri))))] - pub(crate) use crate::loom::std::mutex::Mutex; + pub(crate) use crate::loom::std::{ + mutex::Mutex, + rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}, + }; pub(crate) mod atomic { pub(crate) use crate::loom::std::atomic_u16::AtomicU16; diff --git a/tokio/src/loom/std/parking_lot.rs b/tokio/src/loom/std/parking_lot.rs index 9b9a81d35b0..41de799709f 100644 --- a/tokio/src/loom/std/parking_lot.rs +++ b/tokio/src/loom/std/parking_lot.rs @@ -112,6 +112,16 @@ impl<'a, T: ?Sized> Deref for RwLockReadGuard<'a, T> { } } +impl<'a, T> RwLockWriteGuard<'a, T> { + // The corresponding std method requires the rwlock. + pub(crate) fn downgrade(s: Self, _rwlock: &'a RwLock) -> LockResult> { + Ok(RwLockReadGuard( + PhantomData, + parking_lot::RwLockWriteGuard::downgrade(s.1), + )) + } +} + impl<'a, T: ?Sized> Deref for RwLockWriteGuard<'a, T> { type Target = T; fn deref(&self) -> &T { diff --git a/tokio/src/loom/std/rwlock.rs b/tokio/src/loom/std/rwlock.rs new file mode 100644 index 00000000000..f14f8ecea2b --- /dev/null +++ b/tokio/src/loom/std/rwlock.rs @@ -0,0 +1,67 @@ +use std::{ + ops::{Deref, DerefMut}, + sync::{self, LockResult, PoisonError}, +}; + +/// Adapter for std::RwLock that adds `downgrade` method. +#[derive(Debug)] +pub(crate) struct RwLock(sync::RwLock); + +#[derive(Debug)] +pub(crate) struct RwLockReadGuard<'a, T>(sync::RwLockReadGuard<'a, T>); + +#[derive(Debug)] +pub(crate) struct RwLockWriteGuard<'a, T>(sync::RwLockWriteGuard<'a, T>); + +#[allow(dead_code)] +impl RwLock { + #[inline] + pub(crate) fn new(t: T) -> RwLock { + RwLock(sync::RwLock::new(t)) + } + + #[inline] + pub(crate) fn read(&self) -> LockResult> { + match self.0.read() { + Ok(inner) => Ok(RwLockReadGuard(inner)), + Err(err) => Err(PoisonError::new(RwLockReadGuard(err.into_inner()))), + } + } + + #[inline] + pub(crate) fn write(&self) -> LockResult> { + match self.0.write() { + Ok(inner) => Ok(RwLockWriteGuard(inner)), + Err(err) => Err(PoisonError::new(RwLockWriteGuard(err.into_inner()))), + } + } +} + +impl<'a, T> Deref for RwLockReadGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.0.deref() + } +} + +#[allow(dead_code)] +impl<'a, T> RwLockWriteGuard<'a, T> { + pub(crate) fn downgrade(s: Self, rwlock: &'a RwLock) -> LockResult> { + // Std rwlock does not support downgrading. + drop(s); + rwlock.read() + } +} + +impl<'a, T> Deref for RwLockWriteGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.0.deref() + } +} + +impl<'a, T> DerefMut for RwLockWriteGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.0.deref_mut() + } +} diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 568a50bd59b..dab58ad49a5 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -118,8 +118,8 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; -use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; +use crate::loom::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crate::util::linked_list::{self, ConcurrentPushLinkedList, GuardedLinkedList}; use crate::util::WakeList; use std::fmt; @@ -127,7 +127,8 @@ use std::future::Future; use std::marker::PhantomPinned; use std::pin::Pin; use std::ptr::NonNull; -use std::sync::atomic::Ordering::SeqCst; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst}; use std::task::{Context, Poll, Waker}; use std::usize; @@ -310,7 +311,7 @@ struct Shared { mask: usize, /// Tail of the queue. Includes the rx wait list. - tail: Mutex, + tail: RwLock, /// Number of outstanding Sender handles. num_tx: AtomicUsize, @@ -328,7 +329,7 @@ struct Tail { closed: bool, /// Receivers waiting for a value. - waiters: LinkedList::Target>, + waiters: ConcurrentPushLinkedList::Target>, } /// Slot in the buffer. @@ -354,7 +355,7 @@ struct Slot { /// An entry in the wait queue. struct Waiter { /// True if queued. - queued: bool, + queued: AtomicBool, /// Task waiting on the broadcast channel. waker: Option, @@ -369,7 +370,7 @@ struct Waiter { impl Waiter { fn new() -> Self { Self { - queued: false, + queued: AtomicBool::new(false), waker: None, pointers: linked_list::Pointers::new(), _p: PhantomPinned, @@ -521,11 +522,11 @@ impl Sender { let shared = Arc::new(Shared { buffer: buffer.into_boxed_slice(), mask: capacity - 1, - tail: Mutex::new(Tail { + tail: RwLock::new(Tail { pos: 0, rx_cnt: receiver_count, closed: false, - waiters: LinkedList::new(), + waiters: ConcurrentPushLinkedList::new(), }), num_tx: AtomicUsize::new(1), }); @@ -585,7 +586,7 @@ impl Sender { /// } /// ``` pub fn send(&self, value: T) -> Result> { - let mut tail = self.shared.tail.lock(); + let mut tail = self.shared.tail.write().unwrap(); if tail.rx_cnt == 0 { return Err(SendError(value)); @@ -688,7 +689,7 @@ impl Sender { /// } /// ``` pub fn len(&self) -> usize { - let tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); let base_idx = (tail.pos & self.shared.mask as u64) as usize; let mut low = 0; @@ -735,7 +736,7 @@ impl Sender { /// } /// ``` pub fn is_empty(&self) -> bool { - let tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize; self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 @@ -778,7 +779,7 @@ impl Sender { /// } /// ``` pub fn receiver_count(&self) -> usize { - let tail = self.shared.tail.lock(); + let tail = self.shared.tail.read().unwrap(); tail.rx_cnt } @@ -806,7 +807,7 @@ impl Sender { } fn close_channel(&self) { - let mut tail = self.shared.tail.lock(); + let mut tail = self.shared.tail.write().unwrap(); tail.closed = true; self.shared.notify_rx(tail); @@ -815,7 +816,7 @@ impl Sender { /// Create a new `Receiver` which reads starting from the tail. fn new_receiver(shared: Arc>) -> Receiver { - let mut tail = shared.tail.lock(); + let mut tail = shared.tail.write().unwrap(); assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers"); @@ -842,7 +843,7 @@ impl<'a, T> Drop for WaitersList<'a, T> { // If the list is not empty, we unlink all waiters from it. // We do not wake the waiters to avoid double panics. if !self.is_empty { - let _lock_guard = self.shared.tail.lock(); + let _lock_guard = self.shared.tail.write().unwrap(); while self.list.pop_back().is_some() {} } } @@ -850,12 +851,12 @@ impl<'a, T> Drop for WaitersList<'a, T> { impl<'a, T> WaitersList<'a, T> { fn new( - unguarded_list: LinkedList::Target>, + unguarded_list: ConcurrentPushLinkedList::Target>, guard: Pin<&'a Waiter>, shared: &'a Shared, ) -> Self { let guard_ptr = NonNull::from(guard.get_ref()); - let list = unguarded_list.into_guarded(guard_ptr); + let list = unguarded_list.into_list().into_guarded(guard_ptr); WaitersList { list, is_empty: false, @@ -864,8 +865,8 @@ impl<'a, T> WaitersList<'a, T> { } /// Removes the last element from the guarded list. Modifying this list - /// requires an exclusive access to the main list in `Notify`. - fn pop_back_locked(&mut self, _tail: &mut Tail) -> Option> { + /// requires a read lock on the main list. + fn pop_back_locked(&mut self, _tail: &Tail) -> Option> { let result = self.list.pop_back(); if result.is_none() { // Save information about emptiness to avoid waiting for lock @@ -877,7 +878,7 @@ impl<'a, T> WaitersList<'a, T> { } impl Shared { - fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: MutexGuard<'a, Tail>) { + fn notify_rx<'a, 'b: 'a>(&'b self, mut tail: RwLockWriteGuard<'a, Tail>) { // It is critical for `GuardedLinkedList` safety that the guard node is // pinned in memory and is not dropped until the guarded list is dropped. let guard = Waiter::new(); @@ -887,26 +888,45 @@ impl Shared { // underneath to allow every waiter to safely remove itself from it. // // * This list will be still guarded by the `waiters` lock. - // `NotifyWaitersList` wrapper makes sure we hold the lock to modify it. + // `WaitersList` wrapper makes sure we hold a read lock to modify it. // * This wrapper will empty the list on drop. It is critical for safety // that we will not leave any list entry with a pointer to the local // guard node after this function returns / panics. let mut list = WaitersList::new(std::mem::take(&mut tail.waiters), guard.as_ref(), self); + // From now on, read lock suffices: we own our own copy of waiters list, + // and we only need to guard against concurrent waiter removals. + // Except us, waiter removals are done by `Recv::drop` and it takes + // a write lock to do it. + let mut tail = RwLockWriteGuard::downgrade(tail, &self.tail).unwrap(); + let mut wakers = WakeList::new(); 'outer: loop { while wakers.can_push() { - match list.pop_back_locked(&mut tail) { + match list.pop_back_locked(&tail) { Some(mut waiter) => { - // Safety: `tail` lock is still held. - let waiter = unsafe { waiter.as_mut() }; - - assert!(waiter.queued); - waiter.queued = false; - - if let Some(waker) = waiter.waker.take() { + // Safety: except us, `waiter.waker` is accessed only + // by `Receiver::recv_ref`. As this waiter is already + // queued, `Receiver::recv_ref` would take a write lock. + let waker = unsafe { waiter.as_mut().waker.take() }; + if let Some(waker) = waker { wakers.push(waker); } + + // Mark the waiter as not queued. + // It is critical to do it **after** the waker was extracted, + // otherwise we might data race with `Receiver::recv_ref`. + // + // Safety: + // - Read lock on tail is held, so `waiter` cannot + // be concurrently removed, + // - `waiter.queued` is atomic, so read lock suffices. + let queued = unsafe { &(*waiter.as_ptr()).queued }; + // Release memory order is required to establish a happens-before + // relationship between us writing to `waiter.waker` and + // `Receiver::recv_ref`/`Recv::drop` accessing it. + let prev_queued = queued.swap(false, Release); + assert!(prev_queued); } None => { break 'outer; @@ -925,7 +945,7 @@ impl Shared { wakers.wake_all(); // Acquire the lock again. - tail = self.tail.lock(); + tail = self.tail.read().unwrap(); } // Release the lock before waking. @@ -987,7 +1007,7 @@ impl Receiver { /// } /// ``` pub fn len(&self) -> usize { - let next_send_pos = self.shared.tail.lock().pos; + let next_send_pos = self.shared.tail.read().unwrap().pos; (next_send_pos - self.next) as usize } @@ -1055,24 +1075,46 @@ impl Receiver { if slot.pos != self.next { // Release the `slot` lock before attempting to acquire the `tail` - // lock. This is required because `send2` acquires the tail lock + // lock. This is required because `send` acquires the tail lock // first followed by the slot lock. Acquiring the locks in reverse // order here would result in a potential deadlock: `recv_ref` // acquires the `slot` lock and attempts to acquire the `tail` lock - // while `send2` acquired the `tail` lock and attempts to acquire + // while `send` acquired the `tail` lock and attempts to acquire // the slot lock. drop(slot); let mut old_waker = None; - let mut tail = self.shared.tail.lock(); + let queued = waiter + .map(|(waiter, _)| { + waiter.with(|ptr| { + // Safety: waiter.queued is atomic. + // Acquire is needed to synchronize with `Shared::notify_rx`. + unsafe { (*ptr).queued.load(Acquire) } + }) + }) + .unwrap_or(false); + + // If `queued` is false, then we are the sole owner if the waiter, + // so read lock on tail suffices. + // If `queued` is true, the waiter might be accessed concurrently, + // so we need a write lock. + let mut tail_read = None; + let mut tail_write = None; + let tail = if queued { + tail_write = Some(self.shared.tail.write().unwrap()); + tail_write.as_deref().unwrap() + } else { + tail_read = Some(self.shared.tail.read().unwrap()); + tail_read.as_deref().unwrap() + }; // Acquire slot lock again slot = self.shared.buffer[idx].read().unwrap(); // Make sure the position did not change. This could happen in the // unlikely event that the buffer is wrapped between dropping the - // read lock and acquiring the tail lock. + // slot lock and acquiring the tail lock. if slot.pos != self.next { let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64); @@ -1086,7 +1128,11 @@ impl Receiver { // Store the waker if let Some((waiter, waker)) = waiter { - // Safety: called while locked. + // Safety: called while holding a lock on tail. + // If waiter is not queued, then we hold a read lock + // on tail and can safely mutate `waiter` since we + // are the only owner. + // If waiter is queued, then we hold a write lock on tail. unsafe { // Only queue if not already queued waiter.with_mut(|ptr| { @@ -1104,8 +1150,16 @@ impl Receiver { } } - if !(*ptr).queued { - (*ptr).queued = true; + // If the waiter is not already queued, enqueue it. + // Relaxed memory order suffices because, if `queued` + // if `false`, then we are the sole owner of the waiter, + // and all future accesses will happen with tail lock held. + if !(*ptr).queued.swap(true, Relaxed) { + // Safety: + // - `waiter` is not already queued, + // - calling `recv_ref` with a waiter implies ownership + // of it's `Recv`. As such, this waiter cannot be pushed + // concurrently by some other thread. tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); } }); @@ -1114,7 +1168,8 @@ impl Receiver { // Drop the old waker after releasing the locks. drop(slot); - drop(tail); + drop(tail_read); + drop(tail_write); drop(old_waker); return Err(TryRecvError::Empty); @@ -1129,7 +1184,8 @@ impl Receiver { let missed = next.wrapping_sub(self.next); - drop(tail); + drop(tail_read); + drop(tail_write); // The receiver is slow but no values have been missed if missed == 0 { @@ -1331,7 +1387,7 @@ impl Receiver { impl Drop for Receiver { fn drop(&mut self) { - let mut tail = self.shared.tail.lock(); + let mut tail = self.shared.tail.write().unwrap(); tail.rx_cnt -= 1; let until = tail.pos; @@ -1357,7 +1413,7 @@ impl<'a, T> Recv<'a, T> { Recv { receiver, waiter: UnsafeCell::new(Waiter { - queued: false, + queued: AtomicBool::new(false), waker: None, pointers: linked_list::Pointers::new(), _p: PhantomPinned, @@ -1402,22 +1458,38 @@ where impl<'a, T> Drop for Recv<'a, T> { fn drop(&mut self) { - // Acquire the tail lock. This is required for safety before accessing - // the waiter node. - let mut tail = self.receiver.shared.tail.lock(); - - // safety: tail lock is held - let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); - + // Safety: `waiter.queued` is atomic. + // Acquire ordering is required to synchronize with + // `Shared::notify_rx` before we drop the object. + let queued = self + .waiter + .with(|ptr| unsafe { (*ptr).queued.load(Acquire) }); + + // If `queued` is false, it cannot become true again + // (concurrent calls to `Shared::recv_ref` with this + // waiter are impossible as they imply an exclusive + // reference to this `Recv`, which we now have). + // If `queued` is true, we need to take a write lock + // and check again. if queued { - // Remove the node - // - // safety: tail lock is held and the wait node is verified to be in - // the list. - unsafe { - self.waiter.with_mut(|ptr| { - tail.waiters.remove((&mut *ptr).into()); - }); + let mut tail = self.receiver.shared.tail.write().unwrap(); + + // Safety: `waiter.queued` is atomic. + // Relaxed order suffices because we hold a write lock on tail. + let queued = self + .waiter + .with(|ptr| unsafe { (*ptr).queued.load(Relaxed) }); + + if queued { + // Remove the node. + // + // Safety: tail write lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } } } } diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 0ed2b616456..400c42573ca 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -6,6 +6,11 @@ //! structure's APIs are `unsafe` as they require the caller to ensure the //! specified node is actually contained by the list. +#[cfg(feature = "sync")] +mod concurrent_push; +#[cfg(feature = "sync")] +pub(crate) use self::concurrent_push::ConcurrentPushLinkedList; + use core::cell::UnsafeCell; use core::fmt; use core::marker::{PhantomData, PhantomPinned}; @@ -108,6 +113,52 @@ struct PointersInner { unsafe impl Send for Pointers {} unsafe impl Sync for Pointers {} +// ===== LinkedListBase ===== + +// Common methods between LinkedList and ConcurrentPushLinkedList. +trait LinkedListBase { + // NB: exclusive reference is important for AtomicLinkedList safety guarantees. + fn head(&mut self) -> Option>; + fn tail(&mut self) -> Option>; + + fn set_head(&mut self, node: Option>); + fn set_tail(&mut self, node: Option>); + + unsafe fn remove(&mut self, node: NonNull) -> Option { + if let Some(prev) = L::pointers(node).as_ref().get_prev() { + debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); + L::pointers(prev) + .as_mut() + .set_next(L::pointers(node).as_ref().get_next()); + } else { + if self.head() != Some(node) { + return None; + } + + self.set_head(L::pointers(node).as_ref().get_next()); + } + + if let Some(next) = L::pointers(node).as_ref().get_next() { + debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); + L::pointers(next) + .as_mut() + .set_prev(L::pointers(node).as_ref().get_prev()); + } else { + // This might be the last item in the list + if self.tail() != Some(node) { + return None; + } + + self.set_tail(L::pointers(node).as_ref().get_prev()); + } + + L::pointers(node).as_mut().set_next(None); + L::pointers(node).as_mut().set_prev(None); + + Some(L::from_raw(node)) + } +} + // ===== impl LinkedList ===== impl LinkedList { @@ -121,6 +172,24 @@ impl LinkedList { } } +impl LinkedListBase for LinkedList { + fn head(&mut self) -> Option::Target>> { + self.head + } + + fn tail(&mut self) -> Option::Target>> { + self.tail + } + + fn set_head(&mut self, node: Option::Target>>) { + self.head = node; + } + + fn set_tail(&mut self, node: Option::Target>>) { + self.tail = node; + } +} + impl LinkedList { /// Adds an element first in the list. pub(crate) fn push_front(&mut self, val: L::Handle) { @@ -185,37 +254,7 @@ impl LinkedList { /// the caller has an exclusive access to that list. This condition is /// used by the linked list in `sync::Notify`. pub(crate) unsafe fn remove(&mut self, node: NonNull) -> Option { - if let Some(prev) = L::pointers(node).as_ref().get_prev() { - debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); - L::pointers(prev) - .as_mut() - .set_next(L::pointers(node).as_ref().get_next()); - } else { - if self.head != Some(node) { - return None; - } - - self.head = L::pointers(node).as_ref().get_next(); - } - - if let Some(next) = L::pointers(node).as_ref().get_next() { - debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); - L::pointers(next) - .as_mut() - .set_prev(L::pointers(node).as_ref().get_prev()); - } else { - // This might be the last item in the list - if self.tail != Some(node) { - return None; - } - - self.tail = L::pointers(node).as_ref().get_prev(); - } - - L::pointers(node).as_mut().set_next(None); - L::pointers(node).as_mut().set_prev(None); - - Some(L::from_raw(node)) + LinkedListBase::remove(self, node) } } @@ -474,9 +513,9 @@ pub(crate) mod tests { #[derive(Debug)] #[repr(C)] - struct Entry { + pub(crate) struct Entry { pointers: Pointers, - val: i32, + pub(crate) val: i32, } unsafe impl<'a> Link for &'a Entry { @@ -496,7 +535,7 @@ pub(crate) mod tests { } } - fn entry(val: i32) -> Pin> { + pub(crate) fn entry(val: i32) -> Pin> { Box::pin(Entry { pointers: Pointers::new(), val, diff --git a/tokio/src/util/linked_list/concurrent_push.rs b/tokio/src/util/linked_list/concurrent_push.rs new file mode 100644 index 00000000000..74fe56bae5c --- /dev/null +++ b/tokio/src/util/linked_list/concurrent_push.rs @@ -0,0 +1,200 @@ +use super::{Link, LinkedList, LinkedListBase}; + +use core::cell::UnsafeCell; +use core::marker::PhantomData; +use core::mem::ManuallyDrop; +use core::ptr::NonNull; +use core::sync::atomic::{ + AtomicPtr, + Ordering::{AcqRel, Relaxed}, +}; + +/// A linked list that supports adding new nodes concurrently. +/// Note that all other operations, e.g. node removals, +/// require external synchronization. +/// The simplest way to achieve it is to use RwLock: +/// pushing nodes only requires a read lock, +/// while removing nodes requires a write lock. +pub(crate) struct ConcurrentPushLinkedList { + /// Linked list head. + head: AtomicPtr, + + /// Linked list tail. + tail: UnsafeCell>>, + + /// Node type marker. + _marker: PhantomData<*const L>, +} + +unsafe impl Send for ConcurrentPushLinkedList where L::Target: Send {} +unsafe impl Sync for ConcurrentPushLinkedList where L::Target: Sync {} + +impl Default for ConcurrentPushLinkedList { + fn default() -> Self { + Self::new() + } +} + +impl ConcurrentPushLinkedList { + /// Creates an empty concurrent push linked list. + pub(crate) const fn new() -> ConcurrentPushLinkedList { + ConcurrentPushLinkedList { + head: AtomicPtr::new(core::ptr::null_mut()), + tail: UnsafeCell::new(None), + _marker: PhantomData, + } + } + + /// Convert a concurrent push LL into a regular LL. + pub(crate) fn into_list(mut self) -> LinkedList { + LinkedList { + head: NonNull::new(*self.head.get_mut()), + tail: *self.tail.get_mut(), + _marker: PhantomData, + } + } +} + +impl LinkedListBase for ConcurrentPushLinkedList { + fn head(&mut self) -> Option> { + NonNull::new(*self.head.get_mut()) + } + + fn tail(&mut self) -> Option> { + *self.tail.get_mut() + } + + fn set_head(&mut self, node: Option>) { + *self.head.get_mut() = match node { + Some(ptr) => ptr.as_ptr(), + None => core::ptr::null_mut(), + }; + } + + fn set_tail(&mut self, node: Option>) { + *self.tail.get_mut() = node; + } +} + +impl ConcurrentPushLinkedList { + /// Atomically adds an element first in the list. + /// This method can be called concurrently by multiple threads. + /// + /// # Safety + /// + /// The caller must ensure that: + /// - `val` is not pushed concurrently by muptiple threads, + /// - `val` is not already part of some list. + pub(crate) unsafe fn push_front(&self, val: L::Handle) { + // Note that removing nodes from the list still requires + // an exclusive reference, so we need not worry about + // concurrent node removals. + + // The value should not be dropped, it is being inserted into the list. + let val = ManuallyDrop::new(val); + let new_head = L::as_raw(&val); + + // Safety: due to the function contract, no concurrent `push_front` + // is called on this particular element, so we are safe to assume + // ownership. + L::pointers(new_head).as_mut().set_prev(None); + + let mut old_head = self.head.load(Relaxed); + loop { + // Safety: due to the function contract, no concurrent `push_front` + // is called on this particular element, and we have not + // inserted it into the list, so we can still assume ownership. + L::pointers(new_head) + .as_mut() + .set_next(NonNull::new(old_head)); + + if let Err(actual_head) = + self.head + .compare_exchange_weak(old_head, new_head.as_ptr(), AcqRel, Relaxed) + { + old_head = actual_head; + } else { + break; + }; + } + + if old_head.is_null() { + // Safety: only the thread that successfully inserted the first + // element is granted the right to update tail. + *self.tail.get() = Some(new_head); + } else { + // Safety: + // 1. Only the thread that successfully inserted the new element + // is granted the right to update the previous head's `prev`, + // 2. Upon successfull insertion, we have synchronized with all the + // previous insertions (due to `AcqRel` memory ordering), so all + // the previous `set_prev` calls on `old_head` happen-before this call, + // 3. Due the `push_front` contract, we can assume that `old_head` + // is not pushed concurrently by another thread, as it is already + // in the list. Thus, no data race on `set_prev` can happen. + L::pointers(NonNull::new_unchecked(old_head)) + .as_mut() + .set_prev(Some(new_head)); + } + } + + /// See [LinkedList::remove]. + /// + /// Note that `&mut self` implies that this call is somehow + /// synchronized with `push_front` (e.g. with RwLock). + /// In terms of memory model, there has to be an established + /// happens-before relationship between any given `push_front` + /// and any given `remove`. The relation can go either way. + pub(crate) unsafe fn remove(&mut self, node: NonNull) -> Option { + LinkedListBase::remove(self, node) + } +} + +#[cfg(test)] +#[cfg(not(loom))] +#[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads +pub(crate) mod tests { + use super::super::tests::*; + use super::*; + + use std::sync::Arc; + + #[test] + fn concurrent_push_front() { + let atomic_list = + Arc::new(ConcurrentPushLinkedList::<&Entry, <&Entry as Link>::Target>::new()); + + let _entries = [5, 7] + .into_iter() + .map(|x| { + std::thread::spawn({ + let atomic_list = atomic_list.clone(); + move || { + let list_entry = entry(x); + unsafe { + atomic_list.push_front(list_entry.as_ref()); + } + list_entry + } + }) + }) + .collect::>() + .into_iter() + .map(|handle| handle.join().unwrap()) + .collect::>(); + + let mut list = Arc::into_inner(atomic_list).unwrap().into_list(); + + assert!(!list.is_empty()); + + let first = list.pop_back().unwrap(); + assert!(first.val == 5 || first.val == 7); + + let second = list.pop_back().unwrap(); + assert!(second.val == 5 || second.val == 7); + assert_ne!(first.val, second.val); + + assert!(list.is_empty()); + assert!(list.pop_back().is_none()); + } +}