From 6f385babd770dc2225f5aa94ab3fbf9d69096457 Mon Sep 17 00:00:00 2001 From: Jules Bertholet Date: Fri, 26 May 2023 11:22:07 -0400 Subject: [PATCH 1/4] Make `RwLockReadGuard` covariant And refactor `RwLock` implementation into non-generic core and generic interface --- src/mutex.rs | 13 ++ src/rwlock.rs | 553 ++++++++++++++-------------------------------- src/rwlock/raw.rs | 542 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 721 insertions(+), 387 deletions(-) create mode 100644 src/rwlock/raw.rs diff --git a/src/mutex.rs b/src/mutex.rs index 639ae23..df892de 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -163,6 +163,19 @@ impl Mutex { pub fn get_mut(&mut self) -> &mut T { unsafe { &mut *self.data.get() } } + + /// Unlocks the mutex directly. + /// + /// # Safety + /// + /// This function is intended to be used only in the case where the mutex is locked, + /// and the guard is subsequently forgotten. Calling this while you don't hold a lock + /// on the mutex will likely lead to UB. + pub(crate) unsafe fn unlock_unchecked(&self) { + // Remove the last bit and notify a waiting lock operation. + self.state.fetch_sub(1, Ordering::Release); + self.lock_ops.notify(1); + } } impl Mutex { diff --git a/src/rwlock.rs b/src/rwlock.rs index 7346367..0584043 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -1,20 +1,14 @@ use std::cell::UnsafeCell; use std::fmt; use std::future::Future; -use std::mem; +use std::mem::ManuallyDrop; use std::ops::{Deref, DerefMut}; use std::pin::Pin; -use std::process; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::{Context, Poll}; -use event_listener::{Event, EventListener}; +mod raw; -use crate::futures::Lock; -use crate::{Mutex, MutexGuard}; - -const WRITER_BIT: usize = 1; -const ONE_READER: usize = 2; +use raw::*; /// An async reader-writer lock. /// @@ -45,23 +39,8 @@ const ONE_READER: usize = 2; /// # }) /// ``` pub struct RwLock { - /// Acquired by the writer. - mutex: Mutex<()>, - - /// Event triggered when the last reader is dropped. - no_readers: Event, - - /// Event triggered when the writer is dropped. - no_writer: Event, - - /// Current state of the lock. - /// - /// The least significant bit (`WRITER_BIT`) is set to 1 when a writer is holding the lock or - /// trying to acquire it. - /// - /// The upper bits contain the number of currently active readers. Each active reader - /// increments the state by `ONE_READER`. - state: AtomicUsize, + /// The locking implementation. + raw: RawRwLock, /// The inner value. value: UnsafeCell, @@ -82,10 +61,7 @@ impl RwLock { /// ``` pub const fn new(t: T) -> RwLock { RwLock { - mutex: Mutex::new(()), - no_readers: Event::new(), - no_writer: Event::new(), - state: AtomicUsize::new(0), + raw: RawRwLock::new(), value: UnsafeCell::new(t), } } @@ -100,6 +76,7 @@ impl RwLock { /// let lock = RwLock::new(5); /// assert_eq!(lock.into_inner(), 5); /// ``` + #[inline] pub fn into_inner(self) -> T { self.value.into_inner() } @@ -125,31 +102,15 @@ impl RwLock { /// assert!(lock.try_read().is_some()); /// # }) /// ``` + #[inline] pub fn try_read(&self) -> Option> { - let mut state = self.state.load(Ordering::Acquire); - - loop { - // If there's a writer holding the lock or attempting to acquire it, we cannot acquire - // a read lock here. - if state & WRITER_BIT != 0 { - return None; - } - - // Make sure the number of readers doesn't overflow. - if state > std::isize::MAX as usize { - process::abort(); - } - - // Increment the number of readers. - match self.state.compare_exchange( - state, - state + ONE_READER, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => return Some(RwLockReadGuard(self)), - Err(s) => state = s, - } + if self.raw.try_read() { + Some(RwLockReadGuard { + lock: &self.raw, + value: self.value.get(), + }) + } else { + None } } @@ -174,11 +135,11 @@ impl RwLock { /// assert!(lock.try_read().is_some()); /// # }) /// ``` + #[inline] pub fn read(&self) -> Read<'_, T> { Read { - lock: self, - state: self.state.load(Ordering::Acquire), - listener: None, + raw: self.raw.read(), + value: self.value.get(), } } @@ -206,33 +167,15 @@ impl RwLock { /// *writer = 2; /// # }) /// ``` + #[inline] pub fn try_upgradable_read(&self) -> Option> { - // First try grabbing the mutex. - let lock = self.mutex.try_lock()?; - - let mut state = self.state.load(Ordering::Acquire); - - // Make sure the number of readers doesn't overflow. - if state > std::isize::MAX as usize { - process::abort(); - } - - // Increment the number of readers. - loop { - match self.state.compare_exchange( - state, - state + ONE_READER, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - return Some(RwLockUpgradableReadGuard { - reader: RwLockReadGuard(self), - reserved: lock, - }) - } - Err(s) => state = s, - } + if self.raw.try_upgradable_read() { + Some(RwLockUpgradableReadGuard { + lock: &self.raw, + value: self.value.get(), + }) + } else { + None } } @@ -262,10 +205,11 @@ impl RwLock { /// *writer = 2; /// # }) /// ``` + #[inline] pub fn upgradable_read(&self) -> UpgradableRead<'_, T> { UpgradableRead { - lock: self, - acquire: self.mutex.lock(), + raw: self.raw.upgradable_read(), + value: self.value.get(), } } @@ -288,18 +232,10 @@ impl RwLock { /// # }) /// ``` pub fn try_write(&self) -> Option> { - // First try grabbing the mutex. - let lock = self.mutex.try_lock()?; - - // If there are no readers, grab the write lock. - if self - .state - .compare_exchange(0, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - { + if self.raw.try_write() { Some(RwLockWriteGuard { - writer: RwLockWriteGuardInner(self), - reserved: lock, + lock: &self.raw, + value: self.value.get(), }) } else { None @@ -324,8 +260,8 @@ impl RwLock { /// ``` pub fn write(&self) -> Write<'_, T> { Write { - lock: self, - state: WriteState::Acquiring(self.mutex.lock()), + raw: self.raw.write(), + value: self.value.get(), } } @@ -381,16 +317,14 @@ impl Default for RwLock { /// The future returned by [`RwLock::read`]. pub struct Read<'a, T: ?Sized> { - /// The lock that is being acquired. - lock: &'a RwLock, - - /// The last-observed state of the lock. - state: usize, + raw: RawRead<'a>, - /// The listener for the "no writers" event. - listener: Option, + value: *const T, } +unsafe impl Send for Read<'_, T> {} +unsafe impl Sync for Read<'_, T> {} + impl fmt::Debug for Read<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("Read { .. }") @@ -402,66 +336,26 @@ impl Unpin for Read<'_, T> {} impl<'a, T: ?Sized> Future for Read<'a, T> { type Output = RwLockReadGuard<'a, T>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - loop { - if this.state & WRITER_BIT == 0 { - // Make sure the number of readers doesn't overflow. - if this.state > std::isize::MAX as usize { - process::abort(); - } - - // If nobody is holding a write lock or attempting to acquire it, increment the - // number of readers. - match this.lock.state.compare_exchange( - this.state, - this.state + ONE_READER, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => return Poll::Ready(RwLockReadGuard(this.lock)), - Err(s) => this.state = s, - } - } else { - // Start listening for "no writer" events. - let load_ordering = match &mut this.listener { - None => { - this.listener = Some(this.lock.no_writer.listen()); - - // Make sure there really is no writer. - Ordering::SeqCst - } - - Some(ref mut listener) => { - // Wait for the writer to finish. - ready!(Pin::new(listener).poll(cx)); - this.listener = None; - - // Notify the next reader waiting in list. - this.lock.no_writer.notify(1); - - // Check the state again. - Ordering::Acquire - } - }; - - // Reload the state. - this.state = this.lock.state.load(load_ordering); - } - } + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + ready!(Pin::new(&mut self.raw).poll(cx)); + + Poll::Ready(RwLockReadGuard { + lock: self.raw.lock, + value: self.value, + }) } } /// The future returned by [`RwLock::upgradable_read`]. pub struct UpgradableRead<'a, T: ?Sized> { - /// The lock that is being acquired. - lock: &'a RwLock, - - /// The mutex we are trying to acquire. - acquire: Lock<'a, ()>, + raw: RawUpgradableRead<'a>, + value: *mut T, } +unsafe impl Send for UpgradableRead<'_, T> {} +unsafe impl Sync for UpgradableRead<'_, T> {} + impl fmt::Debug for UpgradableRead<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("UpgradableRead { .. }") @@ -473,61 +367,25 @@ impl Unpin for UpgradableRead<'_, T> {} impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> { type Output = RwLockUpgradableReadGuard<'a, T>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + ready!(Pin::new(&mut self.raw).poll(cx)); - // Acquire the mutex. - let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx)); - - let mut state = this.lock.state.load(Ordering::Acquire); - - // Make sure the number of readers doesn't overflow. - if state > std::isize::MAX as usize { - process::abort(); - } - - // Increment the number of readers. - loop { - match this.lock.state.compare_exchange( - state, - state + ONE_READER, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - return Poll::Ready(RwLockUpgradableReadGuard { - reader: RwLockReadGuard(this.lock), - reserved: mutex_guard, - }); - } - Err(s) => state = s, - } - } + Poll::Ready(RwLockUpgradableReadGuard { + lock: self.raw.lock, + value: self.value, + }) } } /// The future returned by [`RwLock::write`]. pub struct Write<'a, T: ?Sized> { - /// The lock that is being acquired. - lock: &'a RwLock, - - /// Current state fof this future. - state: WriteState<'a, T>, + raw: RawWrite<'a>, + value: *mut T, } -enum WriteState<'a, T: ?Sized> { - /// We are currently acquiring the inner mutex. - Acquiring(Lock<'a, ()>), - - /// We are currently waiting for readers to finish. - WaitingReaders { - /// Our current write guard. - guard: Option>, - - /// The listener for the "no readers" event. - listener: Option, - }, -} +unsafe impl Send for Write<'_, T> {} +unsafe impl Sync for Write<'_, T> {} impl fmt::Debug for Write<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -540,82 +398,33 @@ impl Unpin for Write<'_, T> {} impl<'a, T: ?Sized> Future for Write<'a, T> { type Output = RwLockWriteGuard<'a, T>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - loop { - match &mut this.state { - WriteState::Acquiring(lock) => { - // First grab the mutex. - let mutex_guard = ready!(Pin::new(lock).poll(cx)); - - // Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled. - let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst); - let guard = RwLockWriteGuard { - writer: RwLockWriteGuardInner(this.lock), - reserved: mutex_guard, - }; - - // If we just acquired the writer lock, return it. - if new_state == WRITER_BIT { - return Poll::Ready(guard); - } - - // Start waiting for the readers to finish. - this.state = WriteState::WaitingReaders { - guard: Some(guard), - listener: Some(this.lock.no_readers.listen()), - }; - } - - WriteState::WaitingReaders { - guard, - ref mut listener, - } => { - let load_ordering = if listener.is_some() { - Ordering::Acquire - } else { - Ordering::SeqCst - }; - - // Check the state again. - if this.lock.state.load(load_ordering) == WRITER_BIT { - // We are the only ones holding the lock, return it. - return Poll::Ready(guard.take().unwrap()); - } - - // Wait for the readers to finish. - match listener { - None => { - // Register a listener. - *listener = Some(this.lock.no_readers.listen()); - } - - Some(ref mut evl) => { - // Wait for the readers to finish. - ready!(Pin::new(evl).poll(cx)); - *listener = None; - } - }; - } - } - } + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + ready!(Pin::new(&mut self.raw).poll(cx)); + + Poll::Ready(RwLockWriteGuard { + lock: self.raw.lock, + value: self.value, + }) } } /// A guard that releases the read lock when dropped. #[clippy::has_significant_drop] -pub struct RwLockReadGuard<'a, T: ?Sized>(&'a RwLock); +pub struct RwLockReadGuard<'a, T: ?Sized> { + lock: &'a RawRwLock, + value: *const T, +} unsafe impl Send for RwLockReadGuard<'_, T> {} unsafe impl Sync for RwLockReadGuard<'_, T> {} impl Drop for RwLockReadGuard<'_, T> { + #[inline] fn drop(&mut self) { - // Decrement the number of readers. - if self.0.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER { - // If this was the last reader, trigger the "no readers" event. - self.0.no_readers.notify(1); + // SAFETY: we are dropping a read guard. + unsafe { + self.lock.read_unlock(); } } } @@ -636,31 +445,32 @@ impl Deref for RwLockReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.0.value.get() } + unsafe { &*self.value } } } /// A guard that releases the upgradable read lock when dropped. #[clippy::has_significant_drop] pub struct RwLockUpgradableReadGuard<'a, T: ?Sized> { - reader: RwLockReadGuard<'a, T>, - reserved: MutexGuard<'a, ()>, + // The guard holds a lock on the mutex! + lock: &'a RawRwLock, + value: *mut T, +} + +impl<'a, T: ?Sized> Drop for RwLockUpgradableReadGuard<'a, T> { + #[inline] + fn drop(&mut self) { + // SAFETY: we are dropping an upgradable read guard. + unsafe { + self.lock.upgradable_read_unlock(); + } + } } unsafe impl Send for RwLockUpgradableReadGuard<'_, T> {} unsafe impl Sync for RwLockUpgradableReadGuard<'_, T> {} impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { - /// Converts this guard into a writer guard. - fn into_writer(self) -> RwLockWriteGuard<'a, T> { - let writer = RwLockWriteGuard { - writer: RwLockWriteGuardInner(self.reader.0), - reserved: self.reserved, - }; - mem::forget(self.reader); - writer - } - /// Downgrades into a regular reader guard. /// /// # Examples @@ -681,8 +491,19 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { /// assert!(lock.try_upgradable_read().is_some()); /// # }) /// ``` + #[inline] pub fn downgrade(guard: Self) -> RwLockReadGuard<'a, T> { - guard.reader + let upgradable = ManuallyDrop::new(guard); + + // SAFETY: `guard` is an upgradable read lock. + unsafe { + upgradable.lock.downgrade_upgradable_read(); + }; + + RwLockReadGuard { + lock: upgradable.lock, + value: upgradable.value, + } } /// Attempts to upgrade into a write lock. @@ -710,16 +531,17 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { /// let writer = RwLockUpgradableReadGuard::try_upgrade(reader).unwrap(); /// # }) /// ``` + #[inline] pub fn try_upgrade(guard: Self) -> Result, Self> { // If there are no readers, grab the write lock. - if guard - .reader - .0 - .state - .compare_exchange(ONE_READER, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - { - Ok(guard.into_writer()) + // SAFETY: `guard` is an upgradable read guard + if unsafe { guard.lock.try_upgrade() } { + let reader = ManuallyDrop::new(guard); + + Ok(RwLockWriteGuard { + lock: reader.lock, + value: reader.value, + }) } else { Err(guard) } @@ -742,20 +564,14 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { /// *writer = 2; /// # }) /// ``` + #[inline] pub fn upgrade(guard: Self) -> Upgrade<'a, T> { - // Set `WRITER_BIT` and decrement the number of readers at the same time. - guard - .reader - .0 - .state - .fetch_sub(ONE_READER - WRITER_BIT, Ordering::SeqCst); - - // Convert into a write guard that unsets `WRITER_BIT` in case this future is canceled. - let guard = guard.into_writer(); + let reader = ManuallyDrop::new(guard); Upgrade { - guard: Some(guard), - listener: None, + // SAFETY: `reader` is an upgradable read guard + raw: unsafe { reader.lock.upgrade() }, + value: reader.value, } } } @@ -776,17 +592,14 @@ impl Deref for RwLockUpgradableReadGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.reader.0.value.get() } + unsafe { &*self.value } } } /// The future returned by [`RwLockUpgradableReadGuard::upgrade`]. pub struct Upgrade<'a, T: ?Sized> { - /// The guard that we are upgrading to. - guard: Option>, - - /// The event listener we are waiting on. - listener: Option, + raw: RawUpgrade<'a>, + value: *mut T, } impl fmt::Debug for Upgrade<'_, T> { @@ -800,68 +613,38 @@ impl Unpin for Upgrade<'_, T> {} impl<'a, T: ?Sized> Future for Upgrade<'a, T> { type Output = RwLockWriteGuard<'a, T>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - let guard = this - .guard - .as_mut() - .expect("cannot poll future after completion"); - - // If there are readers, we need to wait for them to finish. - loop { - let load_ordering = if this.listener.is_some() { - Ordering::Acquire - } else { - Ordering::SeqCst - }; - - // See if the number of readers is zero. - let state = guard.writer.0.state.load(load_ordering); - if state == WRITER_BIT { - break; - } + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let lock = ready!(Pin::new(&mut self.raw).poll(cx)); - // If there are readers, wait for them to finish. - match &mut this.listener { - None => { - // Start listening for "no readers" events. - this.listener = Some(guard.writer.0.no_readers.listen()); - } - - Some(ref mut listener) => { - // Wait for the readers to finish. - ready!(Pin::new(listener).poll(cx)); - this.listener = None; - } - } - } - - // We are done. - Poll::Ready(this.guard.take().unwrap()) - } -} - -struct RwLockWriteGuardInner<'a, T: ?Sized>(&'a RwLock); - -impl Drop for RwLockWriteGuardInner<'_, T> { - fn drop(&mut self) { - // Unset `WRITER_BIT`. - self.0.state.fetch_and(!WRITER_BIT, Ordering::SeqCst); - // Trigger the "no writer" event. - self.0.no_writer.notify(1); + Poll::Ready(RwLockWriteGuard { + lock, + value: self.value, + }) } } /// A guard that releases the write lock when dropped. #[clippy::has_significant_drop] pub struct RwLockWriteGuard<'a, T: ?Sized> { - writer: RwLockWriteGuardInner<'a, T>, - reserved: MutexGuard<'a, ()>, + // The guard holds a lock on the mutex! + lock: &'a RawRwLock, + value: *mut T, } unsafe impl Send for RwLockWriteGuard<'_, T> {} unsafe impl Sync for RwLockWriteGuard<'_, T> {} +impl<'a, T: ?Sized> Drop for RwLockWriteGuard<'a, T> { + #[inline] + fn drop(&mut self) { + // SAFETY: we are dropping a write lock + unsafe { + self.lock.write_unlock(); + } + } +} + impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { /// Downgrades into a regular reader guard. /// @@ -884,21 +667,19 @@ impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { /// assert!(lock.try_read().is_some()); /// # }) /// ``` + #[inline] pub fn downgrade(guard: Self) -> RwLockReadGuard<'a, T> { - // Atomically downgrade state. - guard - .writer - .0 - .state - .fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst); + let write = ManuallyDrop::new(guard); - // Trigger the "no writer" event. - guard.writer.0.no_writer.notify(1); + // SAFETY: `write` is a write guard + unsafe { + write.lock.downgrade_write(); + } - // Convert into a read guard and return. - let new_guard = RwLockReadGuard(guard.writer.0); - mem::forget(guard.writer); // `RwLockWriteGuardInner::drop()` should not be called! - new_guard + RwLockReadGuard { + lock: write.lock, + value: write.value, + } } /// Downgrades into an upgradable reader guard. @@ -925,21 +706,19 @@ impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { /// assert!(RwLockUpgradableReadGuard::try_upgrade(reader).is_ok()) /// # }) /// ``` + #[inline] pub fn downgrade_to_upgradable(guard: Self) -> RwLockUpgradableReadGuard<'a, T> { - // Atomically downgrade state. - guard - .writer - .0 - .state - .fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst); - - // Convert into an upgradable read guard and return. - let new_guard = RwLockUpgradableReadGuard { - reader: RwLockReadGuard(guard.writer.0), - reserved: guard.reserved, - }; - mem::forget(guard.writer); // `RwLockWriteGuardInner::drop()` should not be called! - new_guard + let write = ManuallyDrop::new(guard); + + // SAFETY: `write` is a write guard + unsafe { + write.lock.downgrade_to_upgradable(); + } + + RwLockUpgradableReadGuard { + lock: write.lock, + value: write.value, + } } } @@ -959,12 +738,12 @@ impl Deref for RwLockWriteGuard<'_, T> { type Target = T; fn deref(&self) -> &T { - unsafe { &*self.writer.0.value.get() } + unsafe { &*self.value } } } impl DerefMut for RwLockWriteGuard<'_, T> { fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.writer.0.value.get() } + unsafe { &mut *self.value } } } diff --git a/src/rwlock/raw.rs b/src/rwlock/raw.rs new file mode 100644 index 0000000..53b831a --- /dev/null +++ b/src/rwlock/raw.rs @@ -0,0 +1,542 @@ +use std::future::Future; +use std::mem::forget; +use std::pin::Pin; +use std::process; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; + +use event_listener::{Event, EventListener}; + +use crate::futures::Lock; +use crate::Mutex; + +const WRITER_BIT: usize = 1; +const ONE_READER: usize = 2; + +/// A "raw" RwLock that doesn't hold any data. +pub(super) struct RawRwLock { + /// Acquired by the writer. + mutex: Mutex<()>, + + /// Event triggered when the last reader is dropped. + no_readers: Event, + + /// Event triggered when the writer is dropped. + no_writer: Event, + + /// Current state of the lock. + /// + /// The least significant bit (`WRITER_BIT`) is set to 1 when a writer is holding the lock or + /// trying to acquire it. + /// + /// The upper bits contain the number of currently active readers. Each active reader + /// increments the state by `ONE_READER`. + state: AtomicUsize, +} + +impl RawRwLock { + #[inline] + pub(super) const fn new() -> Self { + RawRwLock { + mutex: Mutex::new(()), + no_readers: Event::new(), + no_writer: Event::new(), + state: AtomicUsize::new(0), + } + } + + /// Returns `true` iff a read lock was successfully acquired. + + pub(super) fn try_read(&self) -> bool { + let mut state = self.state.load(Ordering::Acquire); + + loop { + // If there's a writer holding the lock or attempting to acquire it, we cannot acquire + // a read lock here. + if state & WRITER_BIT != 0 { + return false; + } + + // Make sure the number of readers doesn't overflow. + if state > std::isize::MAX as usize { + process::abort(); + } + + // Increment the number of readers. + match self.state.compare_exchange( + state, + state + ONE_READER, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return true, + Err(s) => state = s, + } + } + } + + #[inline] + + pub(super) fn read(&self) -> RawRead<'_> { + RawRead { + lock: self, + state: self.state.load(Ordering::Acquire), + listener: None, + } + } + + /// Returns `true` iff an upgradable read lock was successfully acquired. + + pub(super) fn try_upgradable_read(&self) -> bool { + // First try grabbing the mutex. + let lock = if let Some(lock) = self.mutex.try_lock() { + lock + } else { + return false; + }; + + forget(lock); + + let mut state = self.state.load(Ordering::Acquire); + + // Make sure the number of readers doesn't overflow. + if state > std::isize::MAX as usize { + process::abort(); + } + + // Increment the number of readers. + loop { + match self.state.compare_exchange( + state, + state + ONE_READER, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return true, + Err(s) => state = s, + } + } + } + + #[inline] + + pub(super) fn upgradable_read(&self) -> RawUpgradableRead<'_> { + RawUpgradableRead { + lock: self, + acquire: self.mutex.lock(), + } + } + + /// Returs `true` iff a write lock was successfully acquired. + + pub(super) fn try_write(&self) -> bool { + // First try grabbing the mutex. + let lock = if let Some(lock) = self.mutex.try_lock() { + lock + } else { + return false; + }; + + // If there are no readers, grab the write lock. + if self + .state + .compare_exchange(0, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + forget(lock); + true + } else { + drop(lock); + false + } + } + + #[inline] + + pub(super) fn write(&self) -> RawWrite<'_> { + RawWrite { + lock: self, + state: WriteState::Acquiring(self.mutex.lock()), + } + } + + /// Returns `true` iff a the upgradable read lock was successfully upgraded to a write lock. + /// + /// # Safety + /// + /// Caller must hold an upgradable read lock. + /// This will attempt to upgrade it to a write lock. + + pub(super) unsafe fn try_upgrade(&self) -> bool { + self.state + .compare_exchange(ONE_READER, WRITER_BIT, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + } + + /// # Safety + /// + /// Caller must hold an upgradable read lock. + /// This will upgrade it to a write lock. + + pub(super) unsafe fn upgrade(&self) -> RawUpgrade<'_> { + // Set `WRITER_BIT` and decrement the number of readers at the same time. + self.state + .fetch_sub(ONE_READER - WRITER_BIT, Ordering::SeqCst); + + RawUpgrade { + lock: Some(self), + listener: None, + } + } + + /// # Safety + /// + /// Caller must hold an upgradable read lock. + /// This will downgrade it to a stadard read lock. + #[inline] + + pub(super) unsafe fn downgrade_upgradable_read(&self) { + self.mutex.unlock_unchecked(); + } + + /// # Safety + /// + /// Caller must hold a write lock. + /// This will downgrade it to a read lock. + + pub(super) unsafe fn downgrade_write(&self) { + // Atomically downgrade state. + self.state + .fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst); + + // Release the writer mutex. + self.mutex.unlock_unchecked(); + + // Trigger the "no writer" event. + self.no_writer.notify(1); + } + + /// # Safety + /// + /// Caller must hold a write lock. + /// This will downgrade it to an upgradable read lock. + + pub(super) unsafe fn downgrade_to_upgradable(&self) { + // Atomically downgrade state. + self.state + .fetch_add(ONE_READER - WRITER_BIT, Ordering::SeqCst); + } + + /// # Safety + /// + /// Caller must hold a read lock . + /// This will unlock that lock. + + pub(super) unsafe fn read_unlock(&self) { + // Decrement the number of readers. + if self.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER { + // If this was the last reader, trigger the "no readers" event. + self.no_readers.notify(1); + } + } + + /// # Safety + /// + /// Caller must hold an upgradable read lock. + /// This will unlock that lock. + + pub(super) unsafe fn upgradable_read_unlock(&self) { + // Decrement the number of readers. + if self.state.fetch_sub(ONE_READER, Ordering::SeqCst) & !WRITER_BIT == ONE_READER { + // If this was the last reader, trigger the "no readers" event. + self.no_readers.notify(1); + } + + // SAFETY: upgradable read guards acquire the writer mutex upon creation. + self.mutex.unlock_unchecked(); + } + + /// # Safety + /// + /// Caller must hold a write lock. + /// This will unlock that lock. + + pub(super) unsafe fn write_unlock(&self) { + // Unset `WRITER_BIT`. + self.state.fetch_and(!WRITER_BIT, Ordering::SeqCst); + // Trigger the "no writer" event. + self.no_writer.notify(1); + + // Release the writer lock. + // SAFETY: `RwLockWriteGuard` always holds a lock on writer mutex. + self.mutex.unlock_unchecked(); + } +} + +/// The future returned by [`RawRwLock::read`]. + +pub(super) struct RawRead<'a> { + /// The lock that is being acquired. + pub(super) lock: &'a RawRwLock, + + /// The last-observed state of the lock. + state: usize, + + /// The listener for the "no writers" event. + listener: Option, +} + +impl<'a> Future for RawRead<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.get_mut(); + + loop { + if this.state & WRITER_BIT == 0 { + // Make sure the number of readers doesn't overflow. + if this.state > std::isize::MAX as usize { + process::abort(); + } + + // If nobody is holding a write lock or attempting to acquire it, increment the + // number of readers. + match this.lock.state.compare_exchange( + this.state, + this.state + ONE_READER, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Poll::Ready(()), + Err(s) => this.state = s, + } + } else { + // Start listening for "no writer" events. + let load_ordering = match &mut this.listener { + None => { + this.listener = Some(this.lock.no_writer.listen()); + + // Make sure there really is no writer. + Ordering::SeqCst + } + + Some(ref mut listener) => { + // Wait for the writer to finish. + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + + // Notify the next reader waiting in list. + this.lock.no_writer.notify(1); + + // Check the state again. + Ordering::Acquire + } + }; + + // Reload the state. + this.state = this.lock.state.load(load_ordering); + } + } + } +} + +/// The future returned by [`RawRwLock::upgradable_read`]. + +pub(super) struct RawUpgradableRead<'a> { + /// The lock that is being acquired. + pub(super) lock: &'a RawRwLock, + + /// The mutex we are trying to acquire. + acquire: Lock<'a, ()>, +} + +impl<'a> Future for RawUpgradableRead<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.get_mut(); + + // Acquire the mutex. + let mutex_guard = ready!(Pin::new(&mut this.acquire).poll(cx)); + forget(mutex_guard); + + let mut state = this.lock.state.load(Ordering::Acquire); + + // Make sure the number of readers doesn't overflow. + if state > std::isize::MAX as usize { + process::abort(); + } + + // Increment the number of readers. + loop { + match this.lock.state.compare_exchange( + state, + state + ONE_READER, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + return Poll::Ready(()); + } + Err(s) => state = s, + } + } + } +} + +/// The future returned by [`RawRwLock::write`]. + +pub(super) struct RawWrite<'a> { + /// The lock that is being acquired. + pub(super) lock: &'a RawRwLock, + + /// Current state fof this future. + state: WriteState<'a>, +} + +enum WriteState<'a> { + /// We are currently acquiring the inner mutex. + Acquiring(Lock<'a, ()>), + + /// We are currently waiting for readers to finish. + WaitingReaders { + /// The listener for the "no readers" event. + listener: Option, + }, + /// The future has completed. + Acquired, +} + +impl<'a> Future for RawWrite<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.get_mut(); + + loop { + match &mut this.state { + WriteState::Acquiring(lock) => { + // First grab the mutex. + let mutex_guard = ready!(Pin::new(lock).poll(cx)); + forget(mutex_guard); + + // Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled. + let new_state = this.lock.state.fetch_or(WRITER_BIT, Ordering::SeqCst); + + // If we just acquired the lock, return. + if new_state == WRITER_BIT { + this.state = WriteState::Acquired; + return Poll::Ready(()); + } + + // Start waiting for the readers to finish. + this.state = WriteState::WaitingReaders { + listener: Some(this.lock.no_readers.listen()), + }; + } + + WriteState::WaitingReaders { ref mut listener } => { + let load_ordering = if listener.is_some() { + Ordering::Acquire + } else { + Ordering::SeqCst + }; + + // Check the state again. + if this.lock.state.load(load_ordering) == WRITER_BIT { + // We are the only ones holding the lock, return `Ready`. + this.state = WriteState::Acquired; + return Poll::Ready(()); + } + + // Wait for the readers to finish. + match listener { + None => { + // Register a listener. + *listener = Some(this.lock.no_readers.listen()); + } + + Some(ref mut evl) => { + // Wait for the readers to finish. + ready!(Pin::new(evl).poll(cx)); + *listener = None; + } + }; + } + WriteState::Acquired => panic!("Write lock already acquired"), + } + } + } +} + +impl<'a> Drop for RawWrite<'a> { + fn drop(&mut self) { + if matches!(self.state, WriteState::WaitingReaders { .. }) { + // Safety: we hold a write lock, more or less. + unsafe { + self.lock.write_unlock(); + } + } + } +} + +/// The future returned by [`RawRwLock::upgrade`]. + +pub(super) struct RawUpgrade<'a> { + lock: Option<&'a RawRwLock>, + + /// The event listener we are waiting on. + listener: Option, +} + +impl<'a> Future for RawUpgrade<'a> { + type Output = &'a RawRwLock; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<&'a RawRwLock> { + let this = self.get_mut(); + let lock = this.lock.expect("cannot poll future after completion"); + + // If there are readers, we need to wait for them to finish. + loop { + let load_ordering = if this.listener.is_some() { + Ordering::Acquire + } else { + Ordering::SeqCst + }; + + // See if the number of readers is zero. + let state = lock.state.load(load_ordering); + if state == WRITER_BIT { + break; + } + + // If there are readers, wait for them to finish. + match &mut this.listener { + None => { + // Start listening for "no readers" events. + this.listener = Some(lock.no_readers.listen()); + } + + Some(ref mut listener) => { + // Wait for the readers to finish. + ready!(Pin::new(listener).poll(cx)); + this.listener = None; + } + } + } + + // We are done. + Poll::Ready(this.lock.take().unwrap()) + } +} + +impl<'a> Drop for RawUpgrade<'a> { + fn drop(&mut self) { + if let Some(lock) = self.lock { + // Unset `WRITER_BIT`. + lock.state.fetch_and(!WRITER_BIT, Ordering::SeqCst); + // Trigger the "no writer" event. + lock.no_writer.notify(1); + } + } +} From 8005eec100d58c47f793eaaa53e375a26344b940 Mon Sep 17 00:00:00 2001 From: Jules Bertholet Date: Fri, 26 May 2023 21:48:04 -0400 Subject: [PATCH 2/4] Add test for `RwLockReadGuard` covariance --- tests/rwlock.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/rwlock.rs b/tests/rwlock.rs index 737a8f5..8e48785 100644 --- a/tests/rwlock.rs +++ b/tests/rwlock.rs @@ -14,7 +14,7 @@ use std::thread; use futures_lite::future; -use async_lock::{RwLock, RwLockUpgradableReadGuard}; +use async_lock::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard}; #[cfg(target_family = "wasm")] use wasm_bindgen_test::wasm_bindgen_test as test; @@ -277,3 +277,8 @@ fn yields_when_contended() { RwLockUpgradableReadGuard::upgrade(upgradable), ); } + +// We are testing that this compiles. +fn _covariance_test<'g>(guard: RwLockReadGuard<'g, &'static ()>) { + let _: RwLockReadGuard<'g, &'g ()> = guard; +} From e954bfd06999421d1ef573ad6204a8142a78c434 Mon Sep 17 00:00:00 2001 From: Jules Bertholet Date: Fri, 26 May 2023 23:17:06 -0400 Subject: [PATCH 3/4] Address review comments --- src/mutex.rs | 16 ++++++++++------ src/rwlock.rs | 35 ++++++++++++++++++++++++++++++----- src/rwlock/raw.rs | 10 +++++++++- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/mutex.rs b/src/mutex.rs index df892de..fbf6bc1 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -575,10 +575,12 @@ impl<'a, T: ?Sized> MutexGuard<'a, T> { } impl Drop for MutexGuard<'_, T> { + #[inline] fn drop(&mut self) { - // Remove the last bit and notify a waiting lock operation. - self.0.state.fetch_sub(1, Ordering::Release); - self.0.lock_ops.notify(1); + // SAFETY: we are droppig the mutex guard, therefore unlocking the mutex. + unsafe { + self.0.unlock_unchecked(); + } } } @@ -636,10 +638,12 @@ impl MutexGuardArc { } impl Drop for MutexGuardArc { + #[inline] fn drop(&mut self) { - // Remove the last bit and notify a waiting lock operation. - self.0.state.fetch_sub(1, Ordering::Release); - self.0.lock_ops.notify(1); + // SAFETY: we are droppig the mutex guard, therefore unlocking the mutex. + unsafe { + self.0.unlock_unchecked(); + } } } diff --git a/src/rwlock.rs b/src/rwlock.rs index 0584043..2b6d54e 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -8,8 +8,7 @@ use std::task::{Context, Poll}; mod raw; -use raw::*; - +use self::raw::{RawRead, RawRwLock, RawUpgradableRead, RawUpgrade, RawWrite}; /// An async reader-writer lock. /// /// This type of lock allows multiple readers or one writer at any point in time. @@ -39,7 +38,8 @@ use raw::*; /// # }) /// ``` pub struct RwLock { - /// The locking implementation. + /// The underlying locking implementation. + /// Doesn't depend on `T`. raw: RawRwLock, /// The inner value. @@ -317,8 +317,10 @@ impl Default for RwLock { /// The future returned by [`RwLock::read`]. pub struct Read<'a, T: ?Sized> { + /// Raw read lock acquisition future, doesn't depend on `T`. raw: RawRead<'a>, + /// Pointer to the value protected by the lock. Covariant in `T`. value: *const T, } @@ -349,7 +351,11 @@ impl<'a, T: ?Sized> Future for Read<'a, T> { /// The future returned by [`RwLock::upgradable_read`]. pub struct UpgradableRead<'a, T: ?Sized> { + /// Raw upgradable read lock acquisition future, doesn't depend on `T`. raw: RawUpgradableRead<'a>, + + /// Pointer to the value protected by the lock. Invariant in `T` + /// as the upgradable lock could provide write access. value: *mut T, } @@ -380,7 +386,10 @@ impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> { /// The future returned by [`RwLock::write`]. pub struct Write<'a, T: ?Sized> { + /// Raw write lock acquisition future, doesn't depend on `T`. raw: RawWrite<'a>, + + /// Pointer to the value protected by the lock. Invariant in `T`. value: *mut T, } @@ -412,7 +421,11 @@ impl<'a, T: ?Sized> Future for Write<'a, T> { /// A guard that releases the read lock when dropped. #[clippy::has_significant_drop] pub struct RwLockReadGuard<'a, T: ?Sized> { + /// Reference to underlying locking implementation. + /// Doesn't depend on `T`. lock: &'a RawRwLock, + + /// Pointer to the value protected by the lock. Covariant in `T`. value: *const T, } @@ -452,8 +465,13 @@ impl Deref for RwLockReadGuard<'_, T> { /// A guard that releases the upgradable read lock when dropped. #[clippy::has_significant_drop] pub struct RwLockUpgradableReadGuard<'a, T: ?Sized> { - // The guard holds a lock on the mutex! + /// Reference to underlying locking implementation. + /// Doesn't depend on `T`. + /// This guard holds a lock on the witer mutex! lock: &'a RawRwLock, + + /// Pointer to the value protected by the lock. Invariant in `T` + /// as the upgradable lock could provide write access. value: *mut T, } @@ -598,7 +616,10 @@ impl Deref for RwLockUpgradableReadGuard<'_, T> { /// The future returned by [`RwLockUpgradableReadGuard::upgrade`]. pub struct Upgrade<'a, T: ?Sized> { + /// Raw read lock upgrade future, doesn't depend on `T`. raw: RawUpgrade<'a>, + + /// Pointer to the value protected by the lock. Invariant in `T`. value: *mut T, } @@ -627,8 +648,12 @@ impl<'a, T: ?Sized> Future for Upgrade<'a, T> { /// A guard that releases the write lock when dropped. #[clippy::has_significant_drop] pub struct RwLockWriteGuard<'a, T: ?Sized> { - // The guard holds a lock on the mutex! + /// Reference to underlying locking implementation. + /// Doesn't depend on `T`. + /// This guard holds a lock on the witer mutex! lock: &'a RawRwLock, + + /// Pointer to the value protected by the lock. Invariant in `T`. value: *mut T, } diff --git a/src/rwlock/raw.rs b/src/rwlock/raw.rs index 53b831a..0172e2f 100644 --- a/src/rwlock/raw.rs +++ b/src/rwlock/raw.rs @@ -1,3 +1,11 @@ +//! Raw, unsafe reader-writer locking implementation, +//! doesn't depend on the data protected by the lock. +//! [`RwLock`](super::RwLock) is implemented in terms of this. +//! +//! Splitting the implementation this way allows instantiating +//! the locking code only once, and also lets us make +//! [`RwLockReadGuard`](super::RwLockReadGuard) covariant in `T`. + use std::future::Future; use std::mem::forget; use std::pin::Pin; @@ -76,7 +84,6 @@ impl RawRwLock { } #[inline] - pub(super) fn read(&self) -> RawRead<'_> { RawRead { lock: self, @@ -403,6 +410,7 @@ enum WriteState<'a> { /// The listener for the "no readers" event. listener: Option, }, + /// The future has completed. Acquired, } From b48048ce17ef852a3be2252369d68f5f5f2c69b3 Mon Sep 17 00:00:00 2001 From: Jules Bertholet Date: Sat, 27 May 2023 00:31:49 -0400 Subject: [PATCH 4/4] Fix typo in comment --- src/mutex.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mutex.rs b/src/mutex.rs index fbf6bc1..1a360a6 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -577,7 +577,7 @@ impl<'a, T: ?Sized> MutexGuard<'a, T> { impl Drop for MutexGuard<'_, T> { #[inline] fn drop(&mut self) { - // SAFETY: we are droppig the mutex guard, therefore unlocking the mutex. + // SAFETY: we are dropping the mutex guard, therefore unlocking the mutex. unsafe { self.0.unlock_unchecked(); } @@ -640,7 +640,7 @@ impl MutexGuardArc { impl Drop for MutexGuardArc { #[inline] fn drop(&mut self) { - // SAFETY: we are droppig the mutex guard, therefore unlocking the mutex. + // SAFETY: we are dropping the mutex guard, therefore unlocking the mutex. unsafe { self.0.unlock_unchecked(); }