Skip to content

Commit

Permalink
sync: add PollSemaphore::poll_acquire_many (#5137)
Browse files Browse the repository at this point in the history
  • Loading branch information
duarten authored Oct 30, 2022
1 parent 3886a3e commit 29c6de0
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
49 changes: 42 additions & 7 deletions tokio-util/src/sync/poll_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use super::ReusableBoxFuture;
/// [`Semaphore`]: tokio::sync::Semaphore
pub struct PollSemaphore {
semaphore: Arc<Semaphore>,
permit_fut: Option<ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>>,
permit_fut: Option<(
u32, // The number of permits requested.
ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
)>,
}

impl PollSemaphore {
Expand Down Expand Up @@ -53,25 +56,57 @@ impl PollSemaphore {
/// the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
self.poll_acquire_many(cx, 1)
}

/// Poll to acquire many permits from the semaphore.
///
/// This can return the following values:
///
/// - `Poll::Pending` if a permit is not currently available.
/// - `Poll::Ready(Some(permit))` if a permit was acquired.
/// - `Poll::Ready(None)` if the semaphore has been closed.
///
/// When this method returns `Poll::Pending`, the current task is scheduled
/// to receive a wakeup when the permits become available, or when the
/// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
/// the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
pub fn poll_acquire_many(
&mut self,
cx: &mut Context<'_>,
permits: u32,
) -> Poll<Option<OwnedSemaphorePermit>> {
let permit_future = match self.permit_fut.as_mut() {
Some(fut) => fut,
Some((prev_permits, fut)) if *prev_permits == permits => fut,
Some((old_permits, fut_box)) => {
// We're requesting a different number of permits, so replace the future
// and record the new amount.
let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
fut_box.set(fut);
*old_permits = permits;
fut_box
}
None => {
// avoid allocations completely if we can grab a permit immediately
match Arc::clone(&self.semaphore).try_acquire_owned() {
match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
Ok(permit) => return Poll::Ready(Some(permit)),
Err(TryAcquireError::Closed) => return Poll::Ready(None),
Err(TryAcquireError::NoPermits) => {}
}

let next_fut = Arc::clone(&self.semaphore).acquire_owned();
self.permit_fut
.get_or_insert(ReusableBoxFuture::new(next_fut))
let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
&mut self
.permit_fut
.get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
.1
}
};

let result = ready!(permit_future.poll(cx));

let next_fut = Arc::clone(&self.semaphore).acquire_owned();
// Assume we'll request the same amount of permits in a subsequent call.
let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
permit_future.set(next_fut);

match result {
Expand Down
48 changes: 48 additions & 0 deletions tokio-util/tests/poll_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ fn semaphore_poll(
tokio_test::task::spawn(fut)
}

fn semaphore_poll_many(
sem: &mut PollSemaphore,
permits: u32,
) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + '_> {
let fut = futures::future::poll_fn(move |cx| sem.poll_acquire_many(cx, permits));
tokio_test::task::spawn(fut)
}

#[tokio::test]
async fn it_works() {
let sem = Arc::new(Semaphore::new(1));
Expand All @@ -34,3 +42,43 @@ async fn it_works() {
assert!(semaphore_poll(&mut poll_sem).await.is_none());
assert!(semaphore_poll(&mut poll_sem).await.is_none());
}

#[tokio::test]
async fn can_acquire_many_permits() {
let sem = Arc::new(Semaphore::new(4));
let mut poll_sem = PollSemaphore::new(sem.clone());

let permit1 = semaphore_poll(&mut poll_sem).poll();
assert!(matches!(permit1, Poll::Ready(Some(_))));

let permit2 = semaphore_poll_many(&mut poll_sem, 2).poll();
assert!(matches!(permit2, Poll::Ready(Some(_))));

assert_eq!(sem.available_permits(), 1);

drop(permit2);

let mut permit4 = semaphore_poll_many(&mut poll_sem, 4);
assert!(permit4.poll().is_pending());

drop(permit1);

let permit4 = permit4.poll();
assert!(matches!(permit4, Poll::Ready(Some(_))));
assert_eq!(sem.available_permits(), 0);
}

#[tokio::test]
async fn can_poll_different_amounts_of_permits() {
let sem = Arc::new(Semaphore::new(4));
let mut poll_sem = PollSemaphore::new(sem.clone());
assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready());

let permit = sem.acquire_many(4).await.unwrap();
assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_pending());
drop(permit);
assert!(semaphore_poll_many(&mut poll_sem, 5).poll().is_pending());
assert!(semaphore_poll_many(&mut poll_sem, 4).poll().is_ready());
}

0 comments on commit 29c6de0

Please sign in to comment.