From 29c6de0e3e823206c8edb4a810f018c1fab7bf3c Mon Sep 17 00:00:00 2001 From: Duarte Nunes Date: Sun, 30 Oct 2022 06:06:10 -0300 Subject: [PATCH] sync: add `PollSemaphore::poll_acquire_many` (#5137) --- tokio-util/src/sync/poll_semaphore.rs | 49 +++++++++++++++++++++++---- tokio-util/tests/poll_semaphore.rs | 48 ++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/tokio-util/src/sync/poll_semaphore.rs b/tokio-util/src/sync/poll_semaphore.rs index d0b1dedc273..bf47dd9261c 100644 --- a/tokio-util/src/sync/poll_semaphore.rs +++ b/tokio-util/src/sync/poll_semaphore.rs @@ -12,7 +12,10 @@ use super::ReusableBoxFuture; /// [`Semaphore`]: tokio::sync::Semaphore pub struct PollSemaphore { semaphore: Arc, - permit_fut: Option>>, + permit_fut: Option<( + u32, // The number of permits requested. + ReusableBoxFuture<'static, Result>, + )>, } impl PollSemaphore { @@ -53,25 +56,57 @@ impl PollSemaphore { /// the `Waker` from the `Context` passed to the most recent call is /// scheduled to receive a wakeup. pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_acquire_many(cx, 1) + } + + /// Poll to acquire many permits from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when the permits become available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire_many( + &mut self, + cx: &mut Context<'_>, + permits: u32, + ) -> Poll> { let permit_future = match self.permit_fut.as_mut() { - Some(fut) => fut, + Some((prev_permits, fut)) if *prev_permits == permits => fut, + Some((old_permits, fut_box)) => { + // We're requesting a different number of permits, so replace the future + // and record the new amount. + let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); + fut_box.set(fut); + *old_permits = permits; + fut_box + } None => { // avoid allocations completely if we can grab a permit immediately - match Arc::clone(&self.semaphore).try_acquire_owned() { + match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) { Ok(permit) => return Poll::Ready(Some(permit)), Err(TryAcquireError::Closed) => return Poll::Ready(None), Err(TryAcquireError::NoPermits) => {} } - let next_fut = Arc::clone(&self.semaphore).acquire_owned(); - self.permit_fut - .get_or_insert(ReusableBoxFuture::new(next_fut)) + let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); + &mut self + .permit_fut + .get_or_insert((permits, ReusableBoxFuture::new(next_fut))) + .1 } }; let result = ready!(permit_future.poll(cx)); - let next_fut = Arc::clone(&self.semaphore).acquire_owned(); + // Assume we'll request the same amount of permits in a subsequent call. + let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits); permit_future.set(next_fut); match result { diff --git a/tokio-util/tests/poll_semaphore.rs b/tokio-util/tests/poll_semaphore.rs index 50f36dd803b..28beca19fa3 100644 --- a/tokio-util/tests/poll_semaphore.rs +++ b/tokio-util/tests/poll_semaphore.rs @@ -13,6 +13,14 @@ fn semaphore_poll( tokio_test::task::spawn(fut) } +fn semaphore_poll_many( + sem: &mut PollSemaphore, + permits: u32, +) -> tokio_test::task::Spawn + '_> { + let fut = futures::future::poll_fn(move |cx| sem.poll_acquire_many(cx, permits)); + tokio_test::task::spawn(fut) +} + #[tokio::test] async fn it_works() { let sem = Arc::new(Semaphore::new(1)); @@ -34,3 +42,43 @@ async fn it_works() { assert!(semaphore_poll(&mut poll_sem).await.is_none()); assert!(semaphore_poll(&mut poll_sem).await.is_none()); } + +#[tokio::test] +async fn can_acquire_many_permits() { + let sem = Arc::new(Semaphore::new(4)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + + let permit1 = semaphore_poll(&mut poll_sem).poll(); + assert!(matches!(permit1, Poll::Ready(Some(_)))); + + let permit2 = semaphore_poll_many(&mut poll_sem, 2).poll(); + assert!(matches!(permit2, Poll::Ready(Some(_)))); + + assert_eq!(sem.available_permits(), 1); + + drop(permit2); + + let mut permit4 = semaphore_poll_many(&mut poll_sem, 4); + assert!(permit4.poll().is_pending()); + + drop(permit1); + + let permit4 = permit4.poll(); + assert!(matches!(permit4, Poll::Ready(Some(_)))); + assert_eq!(sem.available_permits(), 0); +} + +#[tokio::test] +async fn can_poll_different_amounts_of_permits() { + let sem = Arc::new(Semaphore::new(4)); + let mut poll_sem = PollSemaphore::new(sem.clone()); + assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); + assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready()); + + let permit = sem.acquire_many(4).await.unwrap(); + assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); + assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_pending()); + drop(permit); + assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending()); + assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready()); +}