diff --git a/mea/Cargo.toml b/mea/Cargo.toml index 4cc3f47..7f46bb1 100644 --- a/mea/Cargo.toml +++ b/mea/Cargo.toml @@ -32,12 +32,7 @@ rust-version.workspace = true all-features = true rustdoc-args = ["--cfg", "docsrs"] -[features] -default = ["std"] -std = [] - [dependencies] -atomic-wait = { version = "1.1.0", optional = true } slab = { version = "0.4.9" } [dev-dependencies] diff --git a/mea/src/internal/mod.rs b/mea/src/internal/mod.rs index 4a6e3f7..a5f0ac6 100644 --- a/mea/src/internal/mod.rs +++ b/mea/src/internal/mod.rs @@ -18,5 +18,11 @@ pub(crate) use countdown::*; mod mutex; pub(crate) use mutex::*; +mod semaphore; +pub(crate) use semaphore::*; + +mod waitlist; +pub(crate) use waitlist::*; + mod waitset; pub(crate) use waitset::*; diff --git a/mea/src/internal/semaphore.rs b/mea/src/internal/semaphore.rs new file mode 100644 index 0000000..ec18807 --- /dev/null +++ b/mea/src/internal/semaphore.rs @@ -0,0 +1,279 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering; +use std::sync::MutexGuard; +use std::task::Context; +use std::task::Poll; +use std::task::Waker; + +use slab::Slab; + +use crate::internal::Mutex; +use crate::internal::WaitList; + +/// The internal semaphore that provides low-level async primitives. +#[derive(Debug)] +pub(crate) struct Semaphore { + /// The current number of available permits in the semaphore. + permits: AtomicU32, + waiters: Mutex>, +} + +#[derive(Debug)] +struct WaitNode { + permits: u32, + waker: Option, +} + +impl Semaphore { + pub(crate) fn new(permits: u32) -> Self { + Self { + permits: AtomicU32::new(permits), + waiters: Mutex::new(WaitList::new()), + } + } + + /// Returns the current number of available permits. + pub(crate) fn available_permits(&self) -> u32 { + self.permits.load(Ordering::Acquire) + } + + /// Tries to acquire `n` permits from the semaphore. + /// + /// Returns `true` if the permits were acquired, `false` otherwise. + pub(crate) fn try_acquire(&self, n: u32) -> bool { + let mut current = self.permits.load(Ordering::Acquire); + loop { + if current < n { + return false; + } + + let next = current - n; + match self + .permits + .compare_exchange(current, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => return true, + Err(actual) => current = actual, + } + } + } + + /// Decrease a semaphore's permits by a maximum of `n`. + /// + /// Return the number of permits that were actually reduced. + pub(crate) fn forget(&self, n: u32) -> u32 { + if n == 0 { + return 0; + } + + let mut current = self.permits.load(Ordering::Acquire); + loop { + let new = current.saturating_sub(n); + match self.permits.compare_exchange_weak( + current, + new, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return n.min(current), + Err(actual) => current = actual, + } + } + } + + /// Acquires `n` permits from the semaphore. + pub(crate) fn acquire(&self, n: u32) -> Acquire<'_> { + Acquire { + permits: n, + index: None, + semaphore: self, + } + } + + /// Adds `n` new permits to the semaphore. + pub(crate) fn release(&self, n: u32) { + if n != 0 { + self.insert_permits_with_lock(n, self.waiters.lock()); + } + } + + fn insert_permits_with_lock(&self, mut rem: u32, waiters: MutexGuard<'_, WaitList>) { + const NUM_WAKER: usize = 32; + let mut wakers = Slab::with_capacity(NUM_WAKER); + + let mut lock = Some(waiters); + while rem > 0 { + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); + while wakers.len() < NUM_WAKER { + match waiters.remove_first_waiter(|node| { + if node.permits <= rem { + rem -= node.permits; + node.permits = 0; + true + } else { + node.permits -= rem; + rem = 0; + false + } + }) { + None => break, + Some(waiter) => { + if let Some(waker) = waiter.waker.take() { + wakers.insert(waker); + } + } + } + } + + if rem > 0 && waiters.is_empty() { + let permits = rem; + let prev = self.permits.fetch_add(permits, Ordering::Release); + // never happens; permits is always zero and permits is no more than u32::MAX + debug_assert!(prev.checked_add(permits).is_some()); + rem = 0; + } + + drop(waiters); + for w in wakers.drain() { + w.wake(); + } + } + } +} + +#[derive(Debug)] +pub(crate) struct Acquire<'a> { + permits: u32, + index: Option, + semaphore: &'a Semaphore, +} + +impl Drop for Acquire<'_> { + fn drop(&mut self) { + if let Some(index) = self.index { + let mut waiters = self.semaphore.waiters.lock(); + let mut acquired = 0; + waiters.remove_waiter(index, |node| { + acquired = node.permits; + node.permits = 0; + true + }); + waiters.with_mut(index, |_| true); // drop + if acquired > 0 { + self.semaphore.insert_permits_with_lock(acquired, waiters); + } + } + } +} + +impl Future for Acquire<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let Self { + permits, + index, + semaphore, + } = self.get_mut(); + + match index { + Some(idx) => { + let mut waiters = semaphore.waiters.lock(); + let mut ready = false; + waiters.with_mut(*idx, |node| { + if node.permits > 0 { + let update_waker = node + .waker + .as_ref() + .map_or(true, |w| !w.will_wake(cx.waker())); + if update_waker { + node.waker = Some(cx.waker().clone()); + } + false + } else { + ready = true; + true + } + }); + + if ready { + *index = None; + return Poll::Ready(()); + } + } + None => { + // not yet enqueued + let needed = *permits; + + let mut acquired = 0; + let mut current = semaphore.permits.load(Ordering::Acquire); + let mut lock = None; + + let mut waiters = loop { + let mut remaining = 0; + let total = current.checked_add(acquired).expect("permits overflow"); + let (next, acq) = if total >= needed { + let next = current - (needed - acquired); + (next, needed - acquired) + } else { + remaining = (needed - acquired) - current; + (0, current) + }; + + if remaining > 0 && lock.is_none() { + // No permits were immediately available, so this permit will + // (probably) need to wait. We'll need to acquire a lock on the + // wait queue before continuing. We need to do this _before_ the + // CAS that sets the new value of the semaphore's `permits` + // counter. Otherwise, if we subtract the permits and then + // acquire the lock, we might miss additional permits being + // added while waiting for the lock. + lock = Some(semaphore.waiters.lock()); + } + + match semaphore.permits.compare_exchange( + current, + next, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + acquired += acq; + if remaining == 0 { + return Poll::Ready(()); + } + break lock.expect("lock not acquired"); + } + Err(actual) => current = actual, + } + }; + + waiters.register_waiter(index, |node| match node { + None => Some(WaitNode { + permits: needed - acquired, + waker: Some(cx.waker().clone()), + }), + Some(node) => unreachable!("unexpected node: {:?}", node), + }); + } + }; + + Poll::Pending + } +} diff --git a/mea/src/internal/waitlist.rs b/mea/src/internal/waitlist.rs new file mode 100644 index 0000000..cde7f4c --- /dev/null +++ b/mea/src/internal/waitlist.rs @@ -0,0 +1,120 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use slab::Slab; + +/// A guarded linked list. +/// +/// * `guard`'s `next` points to the first node (regular head). +/// * `guard`'s `prev` points to the last node (regular tail). +#[derive(Debug)] +pub(crate) struct WaitList { + guard: usize, + nodes: Slab>, +} + +#[derive(Debug)] +struct Node { + prev: usize, + next: usize, + stat: Option, +} + +impl WaitList { + pub(crate) fn new() -> Self { + let mut nodes = Slab::new(); + let first = nodes.vacant_entry(); + let guard = first.key(); + first.insert(Node { + prev: guard, + next: guard, + stat: None, + }); + Self { guard, nodes } + } + + /// Registers a waiter to the tail of the wait list. + pub(crate) fn register_waiter( + &mut self, + idx: &mut Option, + f: impl FnOnce(Option<&T>) -> Option, + ) { + match *idx { + None => { + let stat = f(None); + let prev_tail = self.nodes[self.guard].prev; + let new_node = Node { + prev: prev_tail, + next: self.guard, + stat, + }; + let new_key = self.nodes.insert(new_node); + self.nodes[self.guard].prev = new_key; + self.nodes[prev_tail].next = new_key; + *idx = Some(new_key); + } + Some(key) => { + debug_assert_ne!(key, self.guard); + if let Some(stat) = f(self.nodes[key].stat.as_ref()) { + self.nodes[key].stat = Some(stat); + } + } + } + } + + /// Removes a previously registered waker from the wait list. + pub(crate) fn remove_waiter( + &mut self, + idx: usize, + f: impl FnOnce(&mut T) -> bool, + ) -> Option<&mut T> { + debug_assert_ne!(idx, self.guard); + // SAFETY: `idx` is a valid key + non-guard node always has `Some(stat)` + fn retrieve_stat(node: &mut Node) -> &mut T { + node.stat.as_mut().unwrap() + } + + if f(retrieve_stat(&mut self.nodes[idx])) { + let prev = self.nodes[idx].prev; + let next = self.nodes[idx].next; + self.nodes[prev].next = next; + self.nodes[next].prev = prev; + Some(retrieve_stat(&mut self.nodes[idx])) + } else { + None + } + } + + /// Removes the first waiter from the wait list. + pub(crate) fn remove_first_waiter(&mut self, f: impl FnOnce(&mut T) -> bool) -> Option<&mut T> { + let first = self.nodes[self.guard].next; + if first != self.guard { + self.remove_waiter(first, f) + } else { + None + } + } + + /// Returns `true` if the wait list is empty. + pub(crate) fn is_empty(&self) -> bool { + self.nodes[self.guard].next == self.guard + } + + pub(crate) fn with_mut(&mut self, idx: usize, drop: impl FnOnce(&mut T) -> bool) { + let node = &mut self.nodes[idx]; + if drop(node.stat.as_mut().unwrap()) { + self.nodes.remove(idx); + } + } +} diff --git a/mea/src/internal/waitset.rs b/mea/src/internal/waitset.rs index f7e50f1..f910014 100644 --- a/mea/src/internal/waitset.rs +++ b/mea/src/internal/waitset.rs @@ -39,9 +39,8 @@ impl WaitSet { /// Drain and wake up all waiters. pub(crate) fn wake_all(&mut self) { - let waiters = std::mem::take(&mut self.waiters); - for (_, waker) in waiters.into_iter() { - waker.wake(); + for w in self.waiters.drain() { + w.wake(); } } diff --git a/mea/src/lib.rs b/mea/src/lib.rs index 1a26582..7b6f27b 100644 --- a/mea/src/lib.rs +++ b/mea/src/lib.rs @@ -18,6 +18,8 @@ mod internal; pub mod barrier; pub mod latch; +pub mod mutex; +pub mod semaphore; pub mod waitgroup; #[cfg(test)] diff --git a/mea/src/mutex/mod.rs b/mea/src/mutex/mod.rs new file mode 100644 index 0000000..37d56b7 --- /dev/null +++ b/mea/src/mutex/mod.rs @@ -0,0 +1,149 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cell::UnsafeCell; +use std::ops::Deref; +use std::ops::DerefMut; + +use crate::internal; + +pub struct Mutex { + s: internal::Semaphore, + c: UnsafeCell, +} + +impl Mutex { + /// Creates a new lock in an unlocked state ready for use. + /// + /// # Examples + /// + /// ``` + /// use mea::mutex::Mutex; + /// + /// let lock = Mutex::new(5); + /// ``` + pub fn new(t: T) -> Self { + let s = internal::Semaphore::new(1); + let c = UnsafeCell::new(t); + Self { s, c } + } + + /// Consumes the mutex, returning the underlying data. + /// + /// # Examples + /// + /// ``` + /// use mea::mutex::Mutex; + /// + /// let mutex = Mutex::new(1); + /// let n = mutex.into_inner(); + /// assert_eq!(n, 1); + /// ``` + pub fn into_inner(self) -> T { + self.c.into_inner() + } +} + +impl Mutex { + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// If the mutex is available to be acquired immediately, then this call + /// will typically not yield to the runtime. However, this is not guaranteed + /// under all circumstances. + /// + /// # Examples + /// + /// ``` + /// use mea::mutex::Mutex; + /// use pollster::FutureExt; + /// + /// let mutex = Mutex::new(1); + /// let mut n = mutex.lock().block_on(); + /// *n = 2; + /// assert_eq!(*n, 2); + /// ``` + pub async fn lock(&self) -> MutexGuard<'_, T> { + let fut = async { + self.s.acquire(1).await; + MutexGuard { lock: self } + }; + fut.await + } + + /// Attempts to acquire the lock, and returns `None` if the lock is currently held somewhere + /// else. + /// + /// # Examples + /// + /// ``` + /// use mea::mutex::Mutex; + /// + /// let mutex = Mutex::new(1); + /// let n = mutex.try_lock().unwrap(); + /// assert_eq!(*n, 1); + /// ``` + pub fn try_lock(&self) -> Option> { + if self.s.try_acquire(1) { + let guard = MutexGuard { lock: self }; + Some(guard) + } else { + None + } + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the `Mutex` mutably, no actual locking needs to + /// take place, i.e., the mutable borrow statically guarantees no locks exist. + /// + /// # Examples + /// + /// ``` + /// use mea::mutex::Mutex; + /// + /// let mut mutex = Mutex::new(1); + /// let mut n = mutex.get_mut(); + /// *n = 2; + /// assert_eq!(*n, 2); + /// ``` + pub fn get_mut(&mut self) -> &mut T { + self.c.get_mut() + } +} + +#[must_use = "if unused the Mutex will immediately unlock"] +pub struct MutexGuard<'a, T: ?Sized> { + lock: &'a Mutex, +} + +impl Drop for MutexGuard<'_, T> { + fn drop(&mut self) { + self.lock.s.release(1); + } +} + +impl Deref for MutexGuard<'_, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl DerefMut for MutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} diff --git a/mea/src/semaphore/mod.rs b/mea/src/semaphore/mod.rs new file mode 100644 index 0000000..777025b --- /dev/null +++ b/mea/src/semaphore/mod.rs @@ -0,0 +1,120 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::internal; + +#[cfg(test)] +mod tests; + +#[derive(Debug)] +pub struct Semaphore { + s: internal::Semaphore, +} + +impl Semaphore { + /// Constructs a `Semaphore` initialized with the given permits. + pub fn new(permits: u32) -> Self { + Self { + s: internal::Semaphore::new(permits), + } + } + + /// Returns the current number of permits available in this semaphore. + /// + /// This method is typically used for debugging and testing purposes. + pub fn available_permits(&self) -> u32 { + self.s.available_permits() + } + + /// Decrease a semaphore's permits by a maximum of `n`. + /// + /// If there are insufficient permits, and it's not possible to reduce by `n`, + /// return the number of permits that were actually reduced. + pub fn forget(&self, n: u32) -> u32 { + self.s.forget(n) + } + + /// Attempts to acquire `n` permits from this semaphore. + pub fn try_acquire(&self, permits: u32) -> Option> { + self.s + .try_acquire(permits) + .then_some(SemaphorePermit { sem: self, permits }) + } + + /// Adds `n` new permits to the semaphore. + /// + /// # Panics + /// + /// This function panics if the semaphore would overflow. + pub fn release(&self, permits: u32) { + self.s.release(permits); + } + + /// Acquires `n` permits from the semaphore. + pub async fn acquire(&self, permits: u32) -> SemaphorePermit<'_> { + self.s.acquire(permits).await; + SemaphorePermit { sem: self, permits } + } +} + +/// A permit from the semaphore. +/// +/// This type is created by the [`acquire`] method. +/// +/// [`acquire`]: Semaphore::acquire() +#[must_use] +#[derive(Debug)] +pub struct SemaphorePermit<'a> { + sem: &'a Semaphore, + permits: u32, +} + +impl SemaphorePermit<'_> { + /// Forgets the permit **without** releasing it back to the semaphore. + /// This can be used to reduce the amount of permits available from a + /// semaphore. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// + /// use mea::semaphore::Semaphore; + /// + /// let sem = Arc::new(Semaphore::new(10)); + /// { + /// let permit = sem.try_acquire(5).unwrap(); + /// assert_eq!(sem.available_permits(), 5); + /// permit.forget(); + /// } + /// + /// // Since we forgot the permit, available permits won't go back to its initial value + /// // even after the permit is dropped. + /// assert_eq!(sem.available_permits(), 5); + /// ``` + pub fn forget(mut self) { + self.permits = 0; + } + + /// Returns the number of permits held by `self`. + pub fn permits(&self) -> u32 { + self.permits + } +} + +impl Drop for SemaphorePermit<'_> { + fn drop(&mut self) { + self.sem.release(self.permits); + } +} diff --git a/mea/src/semaphore/tests.rs b/mea/src/semaphore/tests.rs new file mode 100644 index 0000000..eddd077 --- /dev/null +++ b/mea/src/semaphore/tests.rs @@ -0,0 +1,125 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::vec::Vec; + +use super::*; + +#[test] +fn no_permits() { + // this should not panic + Semaphore::new(0); +} + +#[test] +fn try_acquire() { + let sem = Semaphore::new(1); + { + let p1 = sem.try_acquire(1); + assert!(p1.is_some()); + let p2 = sem.try_acquire(1); + assert!(p2.is_none()); + } + let p3 = sem.try_acquire(1); + assert!(p3.is_some()); +} + +#[tokio::test] +async fn acquire() { + let sem = Arc::new(Semaphore::new(1)); + let p1 = sem.try_acquire(1).unwrap(); + let sem_clone = sem.clone(); + let j = tokio::spawn(async move { + let _p2 = sem_clone.acquire(1).await; + }); + drop(p1); + j.await.unwrap(); +} + +#[tokio::test] +async fn add_permits() { + let sem = Arc::new(Semaphore::new(0)); + let sem_clone = sem.clone(); + let j = tokio::spawn(async move { + let _p2 = sem_clone.acquire(1).await; + }); + sem.release(1); + j.await.unwrap(); +} + +#[test] +fn forget() { + let sem = Arc::new(Semaphore::new(1)); + { + let p = sem.try_acquire(1).unwrap(); + assert_eq!(sem.available_permits(), 0); + p.forget(); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 0); + assert!(sem.try_acquire(1).is_none()); +} + +#[tokio::test] +async fn stress_test() { + let sem = Arc::new(Semaphore::new(5)); + let mut join_handles = Vec::new(); + for i in 0..100 { + let sem_clone = sem.clone(); + join_handles.push(tokio::spawn(async move { + let _p = sem_clone.acquire(1).await; + tokio::time::sleep(std::time::Duration::from_millis(100 - i)).await; + })); + } + for j in join_handles { + j.await.unwrap(); + } + // there should be exactly 5 semaphores available now + let _p1 = sem.try_acquire(1).unwrap(); + let _p2 = sem.try_acquire(1).unwrap(); + let _p3 = sem.try_acquire(1).unwrap(); + let _p4 = sem.try_acquire(1).unwrap(); + let _p5 = sem.try_acquire(1).unwrap(); + assert!(sem.try_acquire(1).is_none()); +} + +#[test] +fn add_max_amount_permits() { + let s = Semaphore::new(0); + s.release(u32::MAX); + assert_eq!(s.available_permits(), u32::MAX); +} + +#[test] +#[should_panic] +fn add_more_than_max_amount_permits1() { + let s = Semaphore::new(1); + s.release(u32::MAX); +} + +#[test] +#[should_panic] +fn add_more_than_max_amount_permits2() { + let s = Semaphore::new(u32::MAX - 1); + s.release(1); + s.release(1); +} + +#[test] +fn no_panic_at_max_permits() { + let _ = Semaphore::new(u32::MAX); + let s = Semaphore::new(u32::MAX - 1); + s.release(1); +}