diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bfcd80b..508bbd2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -87,10 +87,10 @@ jobs: # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. run: rustup update stable --no-self-update && rustup default stable - name: Test - run: cargo test --lib --no-default-features --features future + run: cargo test --lib --features future - tokio: - name: tokio + no_std: + name: no_std strategy: matrix: os: @@ -104,8 +104,23 @@ jobs: # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. run: rustup update stable --no-self-update && rustup default stable - name: Test - run: cargo test --lib --no-default-features --features tokio - + run: cargo build --no-default-features --features alloc + no_std_and_future: + name: no_std & future + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Install Rust + # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. + run: rustup update stable --no-self-update && rustup default stable + - name: Test + run: cargo build --no-default-features --features alloc,future sync: name: sync strategy: @@ -127,7 +142,7 @@ jobs: name: cargo tarpaulin runs-on: ubuntu-latest needs: - - tokio + - no_std - future - sync - build diff --git a/Cargo.toml b/Cargo.toml index 8accc12..077239c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,46 +6,38 @@ homepage = "https://github.com/al8n/wg" repository = "https://github.com/al8n/wg.git" documentation = "https://docs.rs/wg/" readme = "README.md" -version = "0.8.3" +version = "0.9.0" license = "MIT OR Apache-2.0" keywords = ["waitgroup", "async", "sync", "notify", "wake"] -categories = ["asynchronous", "concurrency", "data-structures"] +categories = ["asynchronous", "concurrency", "data-structures", "no-std"] edition = "2021" [features] default = ["std", "parking_lot", "triomphe"] -std = ["triomphe?/default", "event-listener?/default", "futures-core?/default", "tokio?/rt"] +alloc = ["event-listener"] +std = ["triomphe?/default", "event-listener?/default", "futures-core?/default"] triomphe = ["dep:triomphe"] parking_lot = ["dep:parking_lot"] - -future = ["event-listener", "pin-project-lite", "agnostic-lite"] -tokio = ["dep:tokio", "futures-core", "pin-project-lite", "agnostic-lite/tokio"] -smol = ["agnostic-lite/smol", "future"] -async-std = ["agnostic-lite/async-std", "future"] +future = ["event-listener", "pin-project-lite"] [dependencies] parking_lot = { version = "0.12", optional = true } triomphe = { version = "0.1", optional = true, default-features = false } -event-listener = { version = "5", optional = true, default-features = false } -pin-project-lite = { version = "0.2", optional = true } +event-listener = { version = "5", optional = true, default-features = false, features = ["portable-atomic"] } -tokio = { version = "1", optional = true, default-features = false, features = ["sync"] } +pin-project-lite = { version = "0.2", optional = true } futures-core = { version = "0.3", default-features = false, optional = true } -agnostic-lite = { version = "0.3", optional = true } [dev-dependencies] +agnostic-lite = { version = "0.3", features = ["smol", "async-std", "tokio", "time"] } tokio = { version = "1", features = ["full"] } async-std = { version = "1", features = ["attributes"] } +smol = "2" [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] -[[test]] -name = "tokio" -path = "tests/tokio.rs" -required-features = ["tokio"] - [[test]] name = "future" path = "tests/future.rs" diff --git a/README.md b/README.md index 9090342..2a6a28b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@
-Golang like WaitGroup implementation for sync/async Rust. +Golang like WaitGroup implementation for sync/async Rust, support `no_std` environment. [github][Github-url] [Build][CI-url] @@ -20,37 +20,42 @@ Golang like WaitGroup implementation for sync/async Rust. By default, blocking version `WaitGroup` is enabled. -If you are using `tokio`, you need to enable `tokio` feature in your `Cargo.toml` and use `wg::tokio::AsyncWaitGroup`. - If you are using other async runtime, you need to -enbale `future` feature in your `Cargo.toml` and use `wg::future::AsyncWaitGroup`. +enbale `future` feature in your `Cargo.toml` and use `wg::AsyncWaitGroup`. -### Sync -```toml -[dependencies] -wg = "0.8" -``` +## Installation -### `tokio` +- std -An async implementation for `tokio` runtime. + ```toml + [dependencies] + wg = "0.9" + ``` -```toml -[dependencies] -wg = { version = "0.8", features = ["tokio"] } -``` -### `future` +- `future` -A more generic async implementation. + ```toml + [dependencies] + wg = { version = "0.9", features = ["future"] } + ``` -```toml -[dependencies] -wg = { version = "0.8", features = ["future"] } -``` +- no_std -## Instruction + ```toml + [dependencies] + wg = { version = "0.9", default_features = false, features = ["alloc"] } + ``` + +- no_std & future + + ```toml + [dependencies] + wg = { version = "0.9", default_features = false, features = ["alloc", "future"] } + ``` + +## Examples ### Sync @@ -83,10 +88,10 @@ fn main() { } ``` -### `tokio` +### Async ```rust -use wg::tokio::AsyncWaitGroup; +use wg::AsyncWaitGroup; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::{spawn, time::{sleep, Duration}}; @@ -114,39 +119,6 @@ async fn main() { } ``` -### `async-io` - -```rust -use wg::future::AsyncWaitGroup; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::time::Duration; -use async_std::task::{spawn, block_on, sleep}; - -fn main() { - block_on(async { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - - for _ in 0..5 { - let ctrx = ctr.clone(); - let t_wg = wg.add(1); - spawn(async move { - // mock some time consuming task - sleep(Duration::from_millis(50)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - - // mock task is finished - t_wg.done(); - }); - } - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); - }); -} -``` - ## Acknowledgements - Inspired by Golang sync.WaitGroup and [`crossbeam_utils::WaitGroup`]. @@ -158,6 +130,7 @@ Licensed under either of Ap 2.0 or MIT license at your option. + Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this project by you, as defined in the Apache-2.0 license, diff --git a/examples/future.rs b/examples/future.rs index 60ed1d5..8a33544 100644 --- a/examples/future.rs +++ b/examples/future.rs @@ -4,9 +4,10 @@ use tokio::{ spawn, time::{sleep, Duration}, }; -use wg::future::AsyncWaitGroup; +use wg::AsyncWaitGroup; -fn main() { +#[tokio::main] +async fn main() { async_std::task::block_on(async { let wg = AsyncWaitGroup::new(); let ctr = Arc::new(AtomicUsize::new(0)); diff --git a/examples/tokio.rs b/examples/tokio.rs deleted file mode 100644 index f10bc6b..0000000 --- a/examples/tokio.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use tokio::{ - spawn, - time::{sleep, Duration}, -}; -use wg::tokio::AsyncWaitGroup; - -#[tokio::main] -async fn main() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - - for _ in 0..5 { - let ctrx = ctr.clone(); - let t_wg = wg.add(1); - spawn(async move { - // mock some time consuming task - sleep(Duration::from_millis(50)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - - // mock task is finished - t_wg.done(); - }); - } - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); -} diff --git a/src/future.rs b/src/future.rs index 29ddeb2..af15524 100644 --- a/src/future.rs +++ b/src/future.rs @@ -6,21 +6,13 @@ use core::{ task::{Context, Poll}, }; -pub use agnostic_lite::{AsyncSpawner, Detach}; +#[cfg(feature = "triomphe")] +use triomphe::Arc; -#[cfg(feature = "smol")] -pub use agnostic_lite::smol::SmolSpawner; - -#[cfg(feature = "tokio")] -pub use agnostic_lite::tokio::TokioSpawner; - -#[cfg(feature = "async-std")] -pub use agnostic_lite::async_std::AsyncStdSpawner; - -#[cfg(feature = "std")] +#[cfg(all(feature = "std", not(feature = "triomphe")))] use std::sync::Arc; -#[cfg(not(feature = "std"))] +#[cfg(all(not(feature = "std"), not(feature = "triomphe")))] use alloc::sync::Arc; #[derive(Debug)] @@ -40,7 +32,7 @@ struct AsyncInner { /// # Example /// /// ```rust -/// use wg::future::AsyncWaitGroup; +/// use wg::AsyncWaitGroup; /// use std::sync::Arc; /// use std::sync::atomic::{AtomicUsize, Ordering}; /// use std::time::Duration; @@ -132,7 +124,7 @@ impl AsyncWaitGroup { /// # Example /// /// ```rust - /// use wg::future::AsyncWaitGroup; + /// use wg::AsyncWaitGroup; /// use async_std::task::spawn; /// /// # async_std::task::block_on(async { @@ -160,12 +152,12 @@ impl AsyncWaitGroup { } } - /// done decrements the WaitGroup counter by one. + /// Decrements the AsyncWaitGroup counter by one, returning the remaining count. /// /// # Example /// /// ```rust - /// use wg::future::AsyncWaitGroup; + /// use wg::AsyncWaitGroup; /// use async_std::task::spawn; /// /// # async_std::task::block_on(async { @@ -178,9 +170,25 @@ impl AsyncWaitGroup { /// }); /// # }) /// ``` - pub fn done(&self) { - if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { - self.inner.event.notify(usize::MAX); + pub fn done(&self) -> usize { + match self + .inner + .counter + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| { + if v != 0 { + Some(v - 1) + } else { + None + } + }) { + Ok(x) => { + self.inner.event.notify(usize::MAX); + x + } + Err(x) => { + assert_eq!(x, 0); + x + } } } @@ -194,7 +202,7 @@ impl AsyncWaitGroup { /// # Example /// /// ```rust - /// use wg::future::AsyncWaitGroup; + /// use wg::AsyncWaitGroup; /// use async_std::task::spawn; /// /// # async_std::task::block_on(async { @@ -228,7 +236,7 @@ impl AsyncWaitGroup { /// # Example /// /// ```rust - /// use wg::future::{AsyncWaitGroup, AsyncStdSpawner}; + /// use wg::AsyncWaitGroup; /// use async_std::task::spawn; /// /// # async_std::task::block_on(async { @@ -242,24 +250,22 @@ impl AsyncWaitGroup { /// }); /// /// // wait other thread completes - /// wg.block_wait::(); + /// wg.wait_blocking(); /// # }) /// ``` #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn block_wait(&self) - where - S: agnostic_lite::AsyncSpawner, - { - let this = self.clone(); - let (tx, rx) = std::sync::mpsc::channel(); - - S::spawn_detach(async move { - this.wait().await; - let _ = tx.send(()); - }); + pub fn wait_blocking(&self) { + use event_listener::Listener; - let _ = rx.recv(); + while self.inner.counter.load(Ordering::SeqCst) != 0 { + let ln = self.inner.event.listen(); + // Check the flag again after creating the listener. + if self.inner.counter.load(Ordering::SeqCst) == 0 { + return; + } + ln.wait(); + } } } @@ -285,16 +291,21 @@ impl<'a> core::future::Future for WaitGroupFuture<'a> { return Poll::Ready(()); } - let this = self.project(); - match this.notified.poll(cx) { + let mut this = self.project(); + match this.notified.as_mut().poll(cx) { Poll::Pending => { - cx.waker().wake_by_ref(); - Poll::Pending + if this.inner.inner.counter.load(Ordering::SeqCst) == 0 { + Poll::Ready(()) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } } Poll::Ready(_) => { if this.inner.inner.counter.load(Ordering::SeqCst) == 0 { Poll::Ready(()) } else { + *this.notified = this.inner.inner.event.listen(); cx.waker().wake_by_ref(); Poll::Pending } diff --git a/src/lib.rs b/src/lib.rs index eab6fce..1ab6a59 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,274 +18,33 @@ * limitations under the License. */ #![doc = include_str!("../README.md")] -#![cfg_attr(not(feature = "std"), no_std)] -#![deny(missing_docs)] +#![cfg_attr(not(all(feature = "std", test)), no_std)] +#![deny(missing_docs, warnings)] #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] +#[cfg(not(any(feature = "alloc", feature = "std")))] +compile_error!("This crate can only be used when feature `alloc` or `std` is enabled"); + #[cfg(not(feature = "std"))] extern crate alloc; -/// [`AsyncWaitGroup`](crate::future::AsyncWaitGroup) for `futures`. +#[cfg(feature = "std")] +extern crate std; + #[cfg(feature = "future")] -#[cfg_attr(docsrs, doc(cfg(feature = "future")))] -pub mod future; +mod future; -/// [`AsyncWaitGroup`](crate::tokio::AsyncWaitGroup) for `tokio` runtime. -#[cfg(feature = "tokio")] -#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] -pub mod tokio; +#[cfg(feature = "future")] +#[cfg_attr(docsrs, doc(cfg(feature = "future")))] +pub use future::*; #[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -pub use sync::*; - +mod sync; #[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -mod sync { - trait Mu { - type Guard<'a> - where - Self: 'a; - fn lock_me(&self) -> Self::Guard<'_>; - } - - #[cfg(feature = "parking_lot")] - impl Mu for parking_lot::Mutex { - type Guard<'a> = parking_lot::MutexGuard<'a, T> where Self: 'a; - - fn lock_me(&self) -> Self::Guard<'_> { - self.lock() - } - } - - #[cfg(not(feature = "parking_lot"))] - impl Mu for std::sync::Mutex { - type Guard<'a> = std::sync::MutexGuard<'a, T> where Self: 'a; - - fn lock_me(&self) -> Self::Guard<'_> { - self.lock().unwrap() - } - } - - #[cfg(feature = "parking_lot")] - use parking_lot::{Condvar, Mutex}; - #[cfg(not(feature = "triomphe"))] - use std::sync::Arc; - #[cfg(not(feature = "parking_lot"))] - use std::sync::{Condvar, Mutex}; - #[cfg(feature = "triomphe")] - use triomphe::Arc; - - struct Inner { - cvar: Condvar, - count: Mutex, - } - - /// A WaitGroup waits for a collection of threads to finish. - /// The main thread calls [`add`] to set the number of - /// thread to wait for. Then each of the goroutines - /// runs and calls Done when finished. At the same time, - /// Wait can be used to block until all goroutines have finished. - /// - /// A WaitGroup must not be copied after first use. - /// - /// # Example - /// - /// ```rust - /// use wg::WaitGroup; - /// use std::sync::Arc; - /// use std::sync::atomic::{AtomicUsize, Ordering}; - /// use std::time::Duration; - /// use std::thread::{spawn, sleep}; - /// - /// let wg = WaitGroup::new(); - /// let ctr = Arc::new(AtomicUsize::new(0)); - /// - /// for _ in 0..5 { - /// let ctrx = ctr.clone(); - /// let t_wg = wg.add(1); - /// spawn(move || { - /// // mock some time consuming task - /// sleep(Duration::from_millis(50)); - /// ctrx.fetch_add(1, Ordering::Relaxed); - /// - /// // mock task is finished - /// t_wg.done(); - /// }); - /// } - /// - /// wg.wait(); - /// assert_eq!(ctr.load(Ordering::Relaxed), 5); - /// ``` - /// - /// [`wait`]: struct.WaitGroup.html#method.wait - /// [`add`]: struct.WaitGroup.html#method.add - pub struct WaitGroup { - inner: Arc, - } - - impl Default for WaitGroup { - fn default() -> Self { - Self { - inner: Arc::new(Inner { - cvar: Condvar::new(), - count: Mutex::new(0), - }), - } - } - } - - impl From for WaitGroup { - fn from(count: usize) -> Self { - Self { - inner: Arc::new(Inner { - cvar: Condvar::new(), - count: Mutex::new(count), - }), - } - } - } - - impl Clone for WaitGroup { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } - } - - impl std::fmt::Debug for WaitGroup { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let count = self.inner.count.lock_me(); - f.debug_struct("WaitGroup").field("count", &*count).finish() - } - } - - impl WaitGroup { - /// Creates a new wait group and returns the single reference to it. - /// - /// # Examples - /// - /// ``` - /// use wg::WaitGroup; - /// - /// let wg = WaitGroup::new(); - /// ``` - pub fn new() -> Self { - Self::default() - } - - /// Adds delta to the WaitGroup counter. - /// If the counter becomes zero, all threads blocked on [`wait`] are released. - /// - /// Note that calls with a delta that occur when the counter is zero - /// must happen before a Wait. - /// Typically this means the calls to add should execute before the statement - /// creating the thread or other event to be waited for. - /// If a `WaitGroup` is reused to [`wait`] for several independent sets of events, - /// new `add` calls must happen after all previous [`wait`] calls have returned. - /// - /// # Example - /// ```rust - /// use wg::WaitGroup; - /// - /// let wg = WaitGroup::new(); - /// - /// wg.add(3); - /// (0..3).for_each(|_| { - /// let t_wg = wg.clone(); - /// std::thread::spawn(move || { - /// // do some time consuming work - /// t_wg.done(); - /// }); - /// }); - /// - /// wg.wait(); - /// ``` - /// - /// [`wait`]: struct.AsyncWaitGroup.html#method.wait - pub fn add(&self, num: usize) -> Self { - let mut ctr = self.inner.count.lock_me(); - - *ctr += num; - Self { - inner: self.inner.clone(), - } - } - - /// done decrements the WaitGroup counter by one. - /// - /// # Example - /// - /// ```rust - /// use wg::WaitGroup; - /// use std::thread; - /// - /// let wg = WaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); - /// thread::spawn(move || { - /// // do some time consuming task - /// t_wg.done() - /// }); - /// - /// ``` - pub fn done(&self) { - let mut val = self.inner.count.lock_me(); - - *val = if val.eq(&1) { - self.inner.cvar.notify_all(); - 0 - } else if val.eq(&0) { - 0 - } else { - *val - 1 - }; - } - - /// waitings return how many jobs are waiting. - pub fn waitings(&self) -> usize { - *self.inner.count.lock_me() - } - - /// wait blocks until the WaitGroup counter is zero. - /// - /// # Example - /// - /// ```rust - /// use wg::WaitGroup; - /// use std::thread; - /// - /// let wg = WaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); - /// thread::spawn(move || { - /// // do some time consuming task - /// t_wg.done() - /// }); - /// - /// // wait other thread completes - /// wg.wait(); - /// ``` - pub fn wait(&self) { - let mut ctr = self.inner.count.lock_me(); - - if ctr.eq(&0) { - return; - } - - while *ctr > 0 { - #[cfg(feature = "parking_lot")] - { - self.inner.cvar.wait(&mut ctr); - } +pub use sync::*; - #[cfg(not(feature = "parking_lot"))] - { - ctr = self.inner.cvar.wait(ctr).unwrap(); - } - } - } - } -} +#[cfg(not(feature = "std"))] +mod no_std; +#[cfg(not(feature = "std"))] +pub use no_std::*; diff --git a/src/no_std.rs b/src/no_std.rs new file mode 100644 index 0000000..780cada --- /dev/null +++ b/src/no_std.rs @@ -0,0 +1,166 @@ +use core::sync::atomic::{AtomicUsize, Ordering}; + +use alloc::sync::Arc; + +#[derive(Debug)] +struct Inner { + counter: AtomicUsize, +} + +/// An WaitGroup waits for a collection of threads to finish. +/// The main thread calls [`add`] to set the number of +/// thread to wait for. Then each of the tasks +/// runs and calls [`WaitGroup::done`](WaitGroup::done) when finished. At the same time, +/// Wait can be used to block until all tasks have finished. +/// +/// [`wait`]: struct.WaitGroup.html#method.wait +/// [`add`]: struct.WaitGroup.html#method.add +#[cfg_attr(docsrs, doc(cfg(feature = "future")))] +pub struct WaitGroup { + inner: Arc, +} + +impl Default for WaitGroup { + fn default() -> Self { + Self { + inner: Arc::new(Inner { + counter: AtomicUsize::new(0), + }), + } + } +} + +impl From for WaitGroup { + fn from(count: usize) -> Self { + Self { + inner: Arc::new(Inner { + counter: AtomicUsize::new(count), + }), + } + } +} + +impl Clone for WaitGroup { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl core::fmt::Debug for WaitGroup { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WaitGroup") + .field("counter", &self.inner.counter) + .finish() + } +} + +impl WaitGroup { + /// Creates a new `WaitGroup` + pub fn new() -> Self { + Self::default() + } + + /// Adds delta to the WaitGroup counter. + /// If the counter becomes zero, all threads blocked on [`wait`] are released. + /// + /// Note that calls with a delta that occur when the counter is zero + /// must happen before a Wait. + /// Typically this means the calls to add should execute before the statement + /// creating the thread or other event to be waited for. + /// If a `WaitGroup` is reused to [`wait`] for several independent sets of events, + /// new `add` calls must happen after all previous [`wait`] calls have returned. + /// + /// # Example + /// + /// ```rust,ignore + /// use wg::WaitGroup; + /// + /// let wg = WaitGroup::new(); + /// + /// wg.add(3); + /// (0..3).for_each(|_| { + /// let t_wg = wg.clone(); + /// spawn(move || { + /// // do some time consuming work + /// t_wg.done(); + /// }); + /// }); + /// + /// wg.wait(); + /// ``` + /// + /// [`wait`]: struct.WaitGroup.html#method.wait + pub fn add(&self, num: usize) -> Self { + self.inner.counter.fetch_add(num, Ordering::AcqRel); + + Self { + inner: self.inner.clone(), + } + } + + /// Decrements the WaitGroup counter by one, returning the remaining count. + /// + /// # Example + /// + /// ```rust,ignore + /// use wg::WaitGroup; + /// + /// let wg = WaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// spawn(move || { + /// // do some time consuming task + /// t_wg.done(); + /// }); + /// ``` + pub fn done(&self) -> usize { + match self + .inner + .counter + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| { + if v != 0 { + Some(v - 1) + } else { + None + } + }) { + Ok(x) => x, + Err(x) => { + assert_eq!(x, 0); + x + } + } + } + + /// waitings return how many jobs are waiting. + pub fn waitings(&self) -> usize { + self.inner.counter.load(Ordering::Acquire) + } + + /// wait blocks until the [`WaitGroup`] counter is zero. + /// + /// # Example + /// + /// ```rust + /// use wg::WaitGroup; + /// + /// let wg = WaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// + /// spawn(async move { + /// // do some time consuming task + /// t_wg.done() + /// }); + /// + /// // wait other thread completes + /// wg.wait(); + /// ``` + pub fn wait(&self) { + while self.inner.counter.load(Ordering::SeqCst) != 0 { + core::hint::spin_loop(); + } + } +} diff --git a/src/sync.rs b/src/sync.rs new file mode 100644 index 0000000..2a5e6c9 --- /dev/null +++ b/src/sync.rs @@ -0,0 +1,246 @@ +trait Mu { + type Guard<'a> + where + Self: 'a; + fn lock_me(&self) -> Self::Guard<'_>; +} + +#[cfg(feature = "parking_lot")] +impl Mu for parking_lot::Mutex { + type Guard<'a> = parking_lot::MutexGuard<'a, T> where Self: 'a; + + fn lock_me(&self) -> Self::Guard<'_> { + self.lock() + } +} + +#[cfg(not(feature = "parking_lot"))] +impl Mu for std::sync::Mutex { + type Guard<'a> = std::sync::MutexGuard<'a, T> where Self: 'a; + + fn lock_me(&self) -> Self::Guard<'_> { + self.lock().unwrap() + } +} + +#[cfg(feature = "parking_lot")] +use parking_lot::{Condvar, Mutex}; +#[cfg(not(feature = "triomphe"))] +use std::sync::Arc; +#[cfg(not(feature = "parking_lot"))] +use std::sync::{Condvar, Mutex}; +#[cfg(feature = "triomphe")] +use triomphe::Arc; + +struct Inner { + cvar: Condvar, + count: Mutex, +} + +/// A WaitGroup waits for a collection of threads to finish. +/// The main thread calls [`add`] to set the number of +/// thread to wait for. Then each of the goroutines +/// runs and calls Done when finished. At the same time, +/// Wait can be used to block until all goroutines have finished. +/// +/// A WaitGroup must not be copied after first use. +/// +/// # Example +/// +/// ```rust +/// use wg::WaitGroup; +/// use std::sync::Arc; +/// use std::sync::atomic::{AtomicUsize, Ordering}; +/// use std::time::Duration; +/// use std::thread::{spawn, sleep}; +/// +/// let wg = WaitGroup::new(); +/// let ctr = Arc::new(AtomicUsize::new(0)); +/// +/// for _ in 0..5 { +/// let ctrx = ctr.clone(); +/// let t_wg = wg.add(1); +/// spawn(move || { +/// // mock some time consuming task +/// sleep(Duration::from_millis(50)); +/// ctrx.fetch_add(1, Ordering::Relaxed); +/// +/// // mock task is finished +/// t_wg.done(); +/// }); +/// } +/// +/// wg.wait(); +/// assert_eq!(ctr.load(Ordering::Relaxed), 5); +/// ``` +/// +/// [`wait`]: struct.WaitGroup.html#method.wait +/// [`add`]: struct.WaitGroup.html#method.add +pub struct WaitGroup { + inner: Arc, +} + +impl Default for WaitGroup { + fn default() -> Self { + Self { + inner: Arc::new(Inner { + cvar: Condvar::new(), + count: Mutex::new(0), + }), + } + } +} + +impl From for WaitGroup { + fn from(count: usize) -> Self { + Self { + inner: Arc::new(Inner { + cvar: Condvar::new(), + count: Mutex::new(count), + }), + } + } +} + +impl Clone for WaitGroup { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl std::fmt::Debug for WaitGroup { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let count = self.inner.count.lock_me(); + f.debug_struct("WaitGroup").field("count", &*count).finish() + } +} + +impl WaitGroup { + /// Creates a new wait group and returns the single reference to it. + /// + /// # Examples + /// + /// ``` + /// use wg::WaitGroup; + /// + /// let wg = WaitGroup::new(); + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// Adds delta to the WaitGroup counter. + /// If the counter becomes zero, all threads blocked on [`wait`] are released. + /// + /// Note that calls with a delta that occur when the counter is zero + /// must happen before a Wait. + /// Typically this means the calls to add should execute before the statement + /// creating the thread or other event to be waited for. + /// If a `WaitGroup` is reused to [`wait`] for several independent sets of events, + /// new `add` calls must happen after all previous [`wait`] calls have returned. + /// + /// # Example + /// ```rust + /// use wg::WaitGroup; + /// + /// let wg = WaitGroup::new(); + /// + /// wg.add(3); + /// (0..3).for_each(|_| { + /// let t_wg = wg.clone(); + /// std::thread::spawn(move || { + /// // do some time consuming work + /// t_wg.done(); + /// }); + /// }); + /// + /// wg.wait(); + /// ``` + /// + /// [`wait`]: struct.AsyncWaitGroup.html#method.wait + pub fn add(&self, num: usize) -> Self { + let mut ctr = self.inner.count.lock_me(); + + *ctr += num; + Self { + inner: self.inner.clone(), + } + } + + /// Decrements the WaitGroup counter by one, returning the remaining count. + /// + /// # Example + /// + /// ```rust + /// use wg::WaitGroup; + /// use std::thread; + /// + /// let wg = WaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// thread::spawn(move || { + /// // do some time consuming task + /// t_wg.done() + /// }); + /// + /// ``` + pub fn done(&self) -> usize { + let mut val = self.inner.count.lock_me(); + + *val = if val.eq(&1) { + self.inner.cvar.notify_all(); + 0 + } else if val.eq(&0) { + 0 + } else { + *val - 1 + }; + *val + } + + /// waitings return how many jobs are waiting. + pub fn waitings(&self) -> usize { + *self.inner.count.lock_me() + } + + /// wait blocks until the WaitGroup counter is zero. + /// + /// # Example + /// + /// ```rust + /// use wg::WaitGroup; + /// use std::thread; + /// + /// let wg = WaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// thread::spawn(move || { + /// // do some time consuming task + /// t_wg.done() + /// }); + /// + /// // wait other thread completes + /// wg.wait(); + /// ``` + pub fn wait(&self) { + let mut ctr = self.inner.count.lock_me(); + + if ctr.eq(&0) { + return; + } + + while *ctr > 0 { + #[cfg(feature = "parking_lot")] + { + self.inner.cvar.wait(&mut ctr); + } + + #[cfg(not(feature = "parking_lot"))] + { + ctr = self.inner.cvar.wait(ctr).unwrap(); + } + } + } +} diff --git a/src/tokio.rs b/src/tokio.rs deleted file mode 100644 index a4d7ecb..0000000 --- a/src/tokio.rs +++ /dev/null @@ -1,287 +0,0 @@ -use ::tokio::sync::{futures::Notified, Notify}; - -use core::{ - future::Future, - pin::Pin, - sync::atomic::{AtomicUsize, Ordering}, - task::{Context, Poll}, -}; - -#[cfg(feature = "std")] -use std::sync::Arc; - -#[cfg(not(feature = "std"))] -use alloc::sync::Arc; - -#[derive(Debug)] -struct AsyncInner { - counter: AtomicUsize, - notify: Notify, -} - -/// An AsyncWaitGroup waits for a collection of threads to finish. -/// The main thread calls [`add`] to set the number of -/// thread to wait for. Then each of the tasks -/// runs and calls Done when finished. At the same time, -/// Wait can be used to block until all tasks have finished. -/// -/// A WaitGroup must not be copied after first use. -/// -/// # Example -/// -/// ```rust -/// use wg::tokio::AsyncWaitGroup; -/// use std::sync::Arc; -/// use std::sync::atomic::{AtomicUsize, Ordering}; -/// use tokio::{spawn, time::{sleep, Duration}}; -/// -/// #[tokio::main] -/// async fn main() { -/// let wg = AsyncWaitGroup::new(); -/// let ctr = Arc::new(AtomicUsize::new(0)); -/// -/// for _ in 0..5 { -/// let ctrx = ctr.clone(); -/// let t_wg = wg.add(1); -/// spawn(async move { -/// // mock some time consuming task -/// sleep(Duration::from_millis(50)).await; -/// ctrx.fetch_add(1, Ordering::Relaxed); -/// -/// // mock task is finished -/// t_wg.done(); -/// }); -/// } -/// -/// wg.wait().await; -/// assert_eq!(ctr.load(Ordering::Relaxed), 5); -/// } -/// ``` -/// -/// [`wait`]: struct.AsyncWaitGroup.html#method.wait -/// [`add`]: struct.AsyncWaitGroup.html#method.add -pub struct AsyncWaitGroup { - inner: Arc, -} - -impl Default for AsyncWaitGroup { - fn default() -> Self { - Self { - inner: Arc::new(AsyncInner { - counter: AtomicUsize::new(0), - notify: Notify::new(), - }), - } - } -} - -impl From for AsyncWaitGroup { - fn from(count: usize) -> Self { - Self { - inner: Arc::new(AsyncInner { - counter: AtomicUsize::new(count), - notify: Notify::new(), - }), - } - } -} - -impl Clone for AsyncWaitGroup { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } -} - -impl core::fmt::Debug for AsyncWaitGroup { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("AsyncWaitGroup") - .field("counter", &self.inner.counter) - .finish() - } -} - -impl AsyncWaitGroup { - /// Creates a new `AsyncWaitGroup` - pub fn new() -> Self { - Self::default() - } - - /// Adds delta to the WaitGroup counter. - /// If the counter becomes zero, all threads blocked on [`wait`] are released. - /// - /// Note that calls with a delta that occur when the counter is zero - /// must happen before a Wait. - /// Typically this means the calls to add should execute before the statement - /// creating the thread or other event to be waited for. - /// If a `AsyncWaitGroup` is reused to [`wait`] for several independent sets of events, - /// new `add` calls must happen after all previous [`wait`] calls have returned. - /// - /// # Example - /// ```rust - /// use wg::tokio::AsyncWaitGroup; - /// - /// #[tokio::main] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); - /// - /// wg.add(3); - /// (0..3).for_each(|_| { - /// let t_wg = wg.clone(); - /// tokio::spawn(async move { - /// // do some time consuming work - /// t_wg.done(); - /// }); - /// }); - /// - /// wg.wait().await; - /// } - /// ``` - /// - /// [`wait`]: struct.AsyncWaitGroup.html#method.wait - pub fn add(&self, num: usize) -> Self { - self.inner.counter.fetch_add(num, Ordering::AcqRel); - - Self { - inner: self.inner.clone(), - } - } - - /// done decrements the WaitGroup counter by one. - /// - /// # Example - /// - /// ```rust - /// use wg::tokio::AsyncWaitGroup; - /// - /// #[tokio::main] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); - /// tokio::spawn(async move { - /// // do some time consuming task - /// t_wg.done(); - /// }); - /// } - /// ``` - pub fn done(&self) { - if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { - self.inner.notify.notify_waiters(); - } - } - - /// waitings return how many jobs are waiting. - pub fn waitings(&self) -> usize { - self.inner.counter.load(Ordering::Acquire) - } - - /// wait blocks until the [`AsyncWaitGroup`] counter is zero. - /// - /// # Example - /// - /// ```rust - /// use wg::tokio::AsyncWaitGroup; - /// - /// #[tokio::main] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); - /// - /// tokio::spawn( async move { - /// // do some time consuming task - /// t_wg.done() - /// }); - /// - /// // wait other thread completes - /// wg.wait().await; - /// } - /// ``` - pub fn wait(&self) -> WaitGroupFuture<'_> { - WaitGroupFuture { - inner: self, - notified: self.inner.notify.notified(), - _pin: core::marker::PhantomPinned, - } - } - - /// Wait blocks until the [`AsyncWaitGroup`] counter is zero. This method is - /// intended to be used in a non-async context, - /// e.g. when implementing the [`Drop`] trait. - /// - /// The implementation is like a spin lock, which is not efficient, so use it with caution. - /// - /// # Example - /// - /// ```rust - /// use wg::tokio::AsyncWaitGroup; - /// - /// #[tokio::main(flavor = "multi_thread")] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); - /// - /// tokio::spawn( async move { - /// // do some time consuming task - /// t_wg.done() - /// }); - /// - /// // wait other thread completes - /// wg.block_wait(); - /// } - /// ``` - #[cfg(feature = "std")] - #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn block_wait(&self) { - let this = self.clone(); - let (tx, rx) = std::sync::mpsc::channel(); - ::tokio::task::spawn(async move { - this.wait().await; - let _ = tx.send(()); - }); - let _ = rx.recv(); - } -} - -pin_project_lite::pin_project! { - /// A future returned by [`AsyncWaitGroup::wait()`]. - #[derive(Debug)] - #[must_use = "futures do nothing unless you `.await` or poll them"] - #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] - pub struct WaitGroupFuture<'a> { - inner: &'a AsyncWaitGroup, - #[pin] - notified: Notified<'a>, - #[pin] - _pin: core::marker::PhantomPinned, - } -} - -impl<'a> Future for WaitGroupFuture<'a> { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.inner.inner.counter.load(Ordering::SeqCst) == 0 { - return Poll::Ready(()); - } - - let this = self.project(); - match this.notified.poll(cx) { - Poll::Pending => { - cx.waker().wake_by_ref(); - Poll::Pending - } - Poll::Ready(_) => { - if this.inner.inner.counter.load(Ordering::SeqCst) == 0 { - Poll::Ready(()) - } else { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } - } -} diff --git a/tests/future.rs b/tests/future.rs index a6b91e7..c906fcb 100644 --- a/tests/future.rs +++ b/tests/future.rs @@ -1,4 +1,5 @@ -use wg::future::AsyncWaitGroup; +use agnostic_lite::{AsyncSpawner, RuntimeLite}; +use wg::AsyncWaitGroup; use std::{ sync::{ @@ -8,8 +9,7 @@ use std::{ time::Duration, }; -#[async_std::test] -async fn test_async_wait_group() { +async fn basic_in() { let wg = AsyncWaitGroup::new(); let ctr = Arc::new(AtomicUsize::new(0)); @@ -17,25 +17,40 @@ async fn test_async_wait_group() { let ctrx = ctr.clone(); let wg = wg.add(1); - async_std::task::spawn(async move { - async_std::task::sleep(Duration::from_millis(50)).await; + S::spawn_detach(async move { + S::sleep(Duration::from_millis(50)).await; ctrx.fetch_add(1, Ordering::Relaxed); - wg.done(); + let remaining = wg.done(); + println!("remaining: {}", remaining); }); } wg.wait().await; assert_eq!(ctr.load(Ordering::Relaxed), 5); } +#[tokio::test] +async fn tokio_basic() { + basic_in::().await; +} + #[async_std::test] -async fn test_async_wait_group_reuse() { +async fn async_std_basic() { + basic_in::().await; +} + +#[test] +fn smol_basic() { + smol::block_on(basic_in::()) +} + +async fn reuse_in() { let wg = AsyncWaitGroup::new(); let ctr = Arc::new(AtomicUsize::new(0)); for _ in 0..6 { let wg = wg.add(1); let ctrx = ctr.clone(); - async_std::task::spawn(async move { - async_std::task::sleep(Duration::from_millis(5)).await; + S::spawn_detach(async move { + S::sleep(Duration::from_millis(5)).await; ctrx.fetch_add(1, Ordering::Relaxed); wg.done(); }); @@ -47,8 +62,8 @@ async fn test_async_wait_group_reuse() { let worker = wg.add(1); let ctrx = ctr.clone(); - async_std::task::spawn(async move { - async_std::task::sleep(Duration::from_millis(5)).await; + S::spawn_detach(async move { + S::sleep(Duration::from_millis(5)).await; ctrx.fetch_add(1, Ordering::Relaxed); worker.done(); }); @@ -57,17 +72,31 @@ async fn test_async_wait_group_reuse() { assert_eq!(ctr.load(Ordering::Relaxed), 7); } +#[tokio::test] +async fn tokio_reuse() { + reuse_in::().await; +} + #[async_std::test] -async fn test_async_wait_group_nested() { +async fn async_std_reuse() { + reuse_in::().await; +} + +#[test] +fn smol_reuse() { + smol::block_on(reuse_in::()) +} + +async fn nested_in() { let wg = AsyncWaitGroup::new(); let ctr = Arc::new(AtomicUsize::new(0)); for _ in 0..5 { let worker = wg.add(1); let ctrx = ctr.clone(); - async_std::task::spawn(async move { + S::spawn_detach(async move { let nested_worker = worker.add(1); let ctrxx = ctrx.clone(); - async_std::task::spawn(async move { + S::spawn_detach(async move { ctrxx.fetch_add(1, Ordering::Relaxed); nested_worker.done(); }); @@ -80,12 +109,26 @@ async fn test_async_wait_group_nested() { assert_eq!(ctr.load(Ordering::Relaxed), 10); } +#[tokio::test] +async fn tokio_nested() { + nested_in::().await; +} + #[async_std::test] -async fn test_async_wait_group_from() { +async fn async_std_nested() { + nested_in::().await; +} + +#[test] +fn smol_nested() { + smol::block_on(nested_in::()) +} + +async fn from_in() { let wg = AsyncWaitGroup::from(5); for _ in 0..5 { let t = wg.clone(); - async_std::task::spawn(async move { + S::spawn_detach(async move { t.done(); }); } @@ -93,66 +136,91 @@ async fn test_async_wait_group_from() { } #[async_std::test] -async fn test_sync_wait_group() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); +async fn from_async_std() { + from_in::().await; +} - for _ in 0..5 { - let ctrx = ctr.clone(); - let wg = wg.add(1); - std::thread::spawn(move || { - std::thread::sleep(Duration::from_millis(50)); - ctrx.fetch_add(1, Ordering::Relaxed); +#[tokio::test] +async fn from_tokio() { + from_in::().await; +} - wg.done(); - }); - } - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); +#[test] +fn from_smol() { + smol::block_on(from_in::()) } -#[async_std::test] -async fn test_async_waitings() { +#[test] +fn test_async_waitings() { let wg = AsyncWaitGroup::new(); wg.add(1); wg.add(1); assert_eq!(wg.waitings(), 2); } -#[test] -fn test_async_block_wait() { +async fn block_wait_in() { let wg = AsyncWaitGroup::new(); let t_wg = wg.add(1); - std::thread::spawn(move || { + S::spawn_detach(async move { // do some time consuming task t_wg.done(); + S::yield_now().await; }); // wait other thread completes - wg.block_wait::(); + wg.wait_blocking(); assert_eq!(wg.waitings(), 0); } #[async_std::test] -async fn test_wake_after_updating() { +async fn block_wait_async_std() { + block_wait_in::().await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn block_wait_tokio() { + block_wait_in::().await; +} + +#[test] +fn block_wait_smol() { + smol::block_on(block_wait_in::()) +} + +async fn wake_after_updating_in() { let wg = AsyncWaitGroup::new(); for _ in 0..100000 { let worker = wg.add(1); - async_std::task::spawn(async move { - async_std::task::sleep(std::time::Duration::from_millis(10)).await; + S::spawn_detach(async move { + S::sleep(std::time::Duration::from_millis(10)).await; let mut a = 0; for _ in 0..1000 { a += 1; } - println!("{a}"); - async_std::task::sleep(std::time::Duration::from_millis(10)).await; + println!("{}", a); + S::sleep(std::time::Duration::from_millis(10)).await; worker.done(); }); } wg.wait().await; } +#[async_std::test] +async fn wake_after_updating_async_std() { + wake_after_updating_in::().await; +} + +#[tokio::test] +async fn wake_after_updating_tokio() { + wake_after_updating_in::().await; +} + +#[test] +fn wake_after_updating_smol() { + smol::block_on(wake_after_updating_in::()) +} + #[test] fn test_clone_and_fmt() { let awg = AsyncWaitGroup::new(); @@ -160,3 +228,11 @@ fn test_clone_and_fmt() { awg1.add(3); assert_eq!(format!("{:?}", awg), format!("{:?}", awg1)); } + +#[test] +fn test_over_done() { + let wg = AsyncWaitGroup::new(); + assert_eq!(wg.done(), 0); + assert_eq!(wg.done(), 0); + assert_eq!(wg.waitings(), 0); +} diff --git a/tests/tokio.rs b/tests/tokio.rs deleted file mode 100644 index 79fa576..0000000 --- a/tests/tokio.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::{ - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - time::Duration, -}; -use wg::tokio::*; - -#[::tokio::test] -async fn test_async_wait_group() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - - for _ in 0..5 { - let ctrx = ctr.clone(); - let wg = wg.add(1); - - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(50)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - wg.done(); - }); - } - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); -} - -#[::tokio::test] -async fn test_async_wait_group_reuse() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - for _ in 0..6 { - let wg = wg.add(1); - let ctrx = ctr.clone(); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(5)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - wg.done(); - }); - } - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 6); - - let worker = wg.add(1); - - let ctrx = ctr.clone(); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(5)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - worker.done(); - }); - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 7); -} - -#[::tokio::test] -async fn test_async_wait_group_nested() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - for _ in 0..5 { - let worker = wg.add(1); - let ctrx = ctr.clone(); - tokio::spawn(async move { - let nested_worker = worker.add(1); - let ctrxx = ctrx.clone(); - tokio::spawn(async move { - ctrxx.fetch_add(1, Ordering::Relaxed); - nested_worker.done(); - }); - ctrx.fetch_add(1, Ordering::Relaxed); - worker.done(); - }); - } - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 10); -} - -#[::tokio::test] -async fn test_async_wait_group_from() { - let wg = AsyncWaitGroup::from(5); - for _ in 0..5 { - let t = wg.clone(); - tokio::spawn(async move { - t.done(); - }); - } - wg.wait().await; -} - -#[::tokio::test] -async fn test_sync_wait_group() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - - for _ in 0..5 { - let ctrx = ctr.clone(); - let wg = wg.add(1); - std::thread::spawn(move || { - std::thread::sleep(Duration::from_millis(50)); - ctrx.fetch_add(1, Ordering::Relaxed); - - wg.done(); - }); - } - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); -} - -#[::tokio::test] -async fn test_async_waitings() { - let wg = AsyncWaitGroup::new(); - wg.add(1); - wg.add(1); - assert_eq!(wg.waitings(), 2); -} - -#[::tokio::test(flavor = "multi_thread")] -async fn test_async_block_wait() { - let wg = AsyncWaitGroup::new(); - let t_wg = wg.add(1); - ::tokio::spawn(async move { - // do some time consuming task - t_wg.done(); - ::tokio::task::yield_now().await; - }); - - // wait other thread completes - wg.block_wait(); - - assert_eq!(wg.waitings(), 0); -} - -#[::tokio::test] -async fn test_wake_after_updating() { - let wg = AsyncWaitGroup::new(); - for _ in 0..100000 { - let worker = wg.add(1); - tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let mut a = 0; - for _ in 0..1000 { - a += 1; - } - println!("{a}"); - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - worker.done(); - }); - } - wg.wait().await; -} - -#[test] -fn test_clone_and_fmt() { - let awg = AsyncWaitGroup::new(); - let awg1 = awg.clone(); - awg1.add(3); - assert_eq!(format!("{:?}", awg), format!("{:?}", awg1)); -}