diff --git a/tower/Cargo.toml b/tower/Cargo.toml index 23871e821..f9f8b33ff 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -45,11 +45,11 @@ full = [ ] log = ["tracing/log"] balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio-stream"] -buffer = ["tokio/sync", "tokio/rt", "tokio-stream", "tracing"] +buffer = ["tokio/sync", "tokio/rt", "tokio-util", "tracing"] discover = [] filter = ["futures-util"] hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time", "tracing"] -limit = ["tokio/time", "tokio/sync", "tracing"] +limit = ["tokio/time", "tokio/sync", "tokio-util", "tracing"] load = ["tokio/time", "tracing"] load-shed = [] make = ["tokio/io-std"] @@ -74,6 +74,7 @@ rand = { version = "0.8", features = ["small_rng"], optional = true } slab = { version = "0.4", optional = true } tokio = { version = "1", optional = true, features = ["sync"] } tokio-stream = { version = "0.1.0", optional = true } +tokio-util = { version = "0.6.3", default-features = false, optional = true } tracing = { version = "0.1.2", optional = true } [dev-dependencies] diff --git a/tower/src/buffer/message.rs b/tower/src/buffer/message.rs index 069828edf..2b2fcdedc 100644 --- a/tower/src/buffer/message.rs +++ b/tower/src/buffer/message.rs @@ -1,5 +1,5 @@ use super::error::ServiceError; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, OwnedSemaphorePermit}; /// Message sent over buffer #[derive(Debug)] @@ -7,7 +7,7 @@ pub(crate) struct Message { pub(crate) request: Request, pub(crate) tx: Tx, pub(crate) span: tracing::Span, - pub(super) _permit: crate::semaphore::Permit, + pub(super) _permit: OwnedSemaphorePermit, } /// Response sender diff --git a/tower/src/buffer/service.rs b/tower/src/buffer/service.rs index 36b8f2f26..9b690f0f6 100644 --- a/tower/src/buffer/service.rs +++ b/tower/src/buffer/service.rs @@ -4,9 +4,11 @@ use super::{ worker::{Handle, Worker}, }; -use crate::semaphore::Semaphore; +use futures_core::ready; +use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; use tower_service::Service; /// Adds an mpsc buffer in front of an inner service. @@ -29,7 +31,11 @@ where // `async fn ready`, which borrows the sender. Therefore, we implement our // own bounded MPSC on top of the unbounded channel, using a semaphore to // limit how many items are in the channel. - semaphore: Semaphore, + semaphore: PollSemaphore, + // The current semaphore permit, if one has been acquired. + // + // This is acquired in `poll_ready` and taken in `call`. + permit: Option, handle: Handle, } @@ -83,16 +89,15 @@ where Request: Send + 'static, { let (tx, rx) = mpsc::unbounded_channel(); - let (semaphore, wake_waiters) = Semaphore::new_with_close(bound); - let (handle, worker) = Worker::new(service, rx, wake_waiters); - ( - Buffer { - tx, - handle, - semaphore, - }, - worker, - ) + let semaphore = Arc::new(Semaphore::new(bound)); + let (handle, worker) = Worker::new(service, rx, &semaphore); + let buffer = Buffer { + tx, + handle, + semaphore: PollSemaphore::new(semaphore), + permit: None, + }; + (buffer, worker) } fn get_worker_error(&self) -> crate::BoxError { @@ -116,19 +121,28 @@ where return Poll::Ready(Err(self.get_worker_error())); } - // Then, poll to acquire a semaphore permit. If we acquire a permit, - // then there's enough buffer capacity to send a new request. Otherwise, - // we need to wait for capacity. - self.semaphore - .poll_acquire(cx) - .map_err(|_| self.get_worker_error()) + // Then, check if we've already acquired a permit. + if self.permit.is_some() { + // We've already reserved capacity to send a request. We're ready! + return Poll::Ready(Ok(())); + } + + // Finally, if we haven't already acquired a permit, poll the semaphore + // to acquire one. If we acquire a permit, then there's enough buffer + // capacity to send a new request. Otherwise, we need to wait for + // capacity. + let permit = + ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?; + self.permit = Some(permit); + + Poll::Ready(Ok(())) } fn call(&mut self, request: Request) -> Self::Future { tracing::trace!("sending request to buffer worker"); let _permit = self - .semaphore - .take_permit() + .permit + .take() .expect("buffer full; poll_ready must be called first"); // get the current Span so that we can explicitly propagate it to the worker @@ -161,6 +175,9 @@ where tx: self.tx.clone(), handle: self.handle.clone(), semaphore: self.semaphore.clone(), + // The new clone hasn't acquired a permit yet. It will when it's + // next polled ready. + permit: None, } } } diff --git a/tower/src/buffer/worker.rs b/tower/src/buffer/worker.rs index c8d8d6a9c..0a8d91705 100644 --- a/tower/src/buffer/worker.rs +++ b/tower/src/buffer/worker.rs @@ -4,13 +4,13 @@ use super::{ }; use futures_core::ready; use pin_project::pin_project; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, Weak}; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, Semaphore}; use tower_service::Service; /// Task that handles processing the buffer. This type should not be used @@ -33,7 +33,7 @@ where finish: bool, failed: Option, handle: Handle, - close: Option, + close: Option>, } /// Get the error out @@ -50,12 +50,13 @@ where pub(crate) fn new( service: T, rx: mpsc::UnboundedReceiver>, - close: crate::semaphore::Close, + semaphore: &Arc, ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), }; + let semaphore = Arc::downgrade(semaphore); let worker = Worker { current_message: None, finish: false, @@ -63,7 +64,7 @@ where rx, service, handle: handle.clone(), - close: Some(close), + close: Some(semaphore), }; (handle, worker) @@ -140,6 +141,17 @@ where // requests that we receive before we've exhausted the receiver receive the error: self.failed = Some(error); } + + /// Closes the buffer's semaphore if it is still open, waking any pending + /// tasks. + fn close_semaphore(&mut self) { + if let Some(close) = self.close.take().as_ref().and_then(Weak::upgrade) { + tracing::debug!("buffer closing; waking pending tasks"); + close.close(); + } else { + tracing::trace!("buffer already closed"); + } + } } impl Future for Worker @@ -199,10 +211,7 @@ where .expect("Worker::failed did not set self.failed?") .clone())); // Wake any tasks waiting on channel capacity. - if let Some(close) = self.close.take() { - tracing::debug!("waking pending tasks"); - close.close(); - } + self.close_semaphore(); } } } @@ -223,9 +232,7 @@ where T::Error: Into, { fn drop(mut self: Pin<&mut Self>) { - if let Some(close) = self.as_mut().close.take() { - close.close(); - } + self.as_mut().close_semaphore(); } } diff --git a/tower/src/lib.rs b/tower/src/lib.rs index ca31cef45..f03d62cb3 100644 --- a/tower/src/lib.rs +++ b/tower/src/lib.rs @@ -220,8 +220,6 @@ pub use crate::make::MakeService; pub use tower_layer::Layer; #[doc(inline)] pub use tower_service::Service; -#[cfg(any(feature = "buffer", feature = "limit"))] -mod semaphore; #[allow(unreachable_pub)] mod sealed { diff --git a/tower/src/limit/concurrency/future.rs b/tower/src/limit/concurrency/future.rs index 103aa11a5..1fea13438 100644 --- a/tower/src/limit/concurrency/future.rs +++ b/tower/src/limit/concurrency/future.rs @@ -1,7 +1,6 @@ //! [`Future`] types //! //! [`Future`]: std::future::Future -use crate::semaphore::Permit; use futures_core::ready; use pin_project::pin_project; use std::{ @@ -9,6 +8,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use tokio::sync::OwnedSemaphorePermit; /// Future for the [`ConcurrencyLimit`] service. /// @@ -19,11 +19,11 @@ pub struct ResponseFuture { #[pin] inner: T, // Keep this around so that it is dropped when the future completes - _permit: Permit, + _permit: OwnedSemaphorePermit, } impl ResponseFuture { - pub(crate) fn new(inner: T, _permit: Permit) -> ResponseFuture { + pub(crate) fn new(inner: T, _permit: OwnedSemaphorePermit) -> ResponseFuture { ResponseFuture { inner, _permit } } } diff --git a/tower/src/limit/concurrency/service.rs b/tower/src/limit/concurrency/service.rs index 44ac2c630..5f8601f88 100644 --- a/tower/src/limit/concurrency/service.rs +++ b/tower/src/limit/concurrency/service.rs @@ -1,16 +1,26 @@ use super::future::ResponseFuture; -use crate::semaphore::Semaphore; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; use tower_service::Service; use futures_core::ready; -use std::task::{Context, Poll}; +use std::{ + sync::Arc, + task::{Context, Poll}, +}; /// Enforces a limit on the concurrent number of requests the underlying /// service can handle. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ConcurrencyLimit { inner: T, - semaphore: Semaphore, + semaphore: PollSemaphore, + /// The currently acquired semaphore permit, if there is sufficient + /// concurrency to send a new request. + /// + /// The permit is acquired in `poll_ready`, and taken in `call` when sending + /// a new request. + permit: Option, } impl ConcurrencyLimit { @@ -18,7 +28,8 @@ impl ConcurrencyLimit { pub fn new(inner: T, max: usize) -> Self { ConcurrencyLimit { inner, - semaphore: Semaphore::new(max), + semaphore: PollSemaphore::new(Arc::new(Semaphore::new(max))), + permit: None, } } @@ -47,20 +58,27 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - // First, poll the semaphore... - ready!(self.semaphore.poll_acquire(cx)).expect( - "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \ - should never fail", - ); - // ...and if it's ready, poll the inner service. + // If we haven't already acquired a permit from the semaphore, try to + // acquire one first. + if self.permit.is_none() { + self.permit = ready!(self.semaphore.poll_acquire(cx)); + debug_assert!( + self.permit.is_some(), + "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \ + should never fail", + ); + } + + // Once we've acquired a permit (or if we already had one), poll the + // inner service. self.inner.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { // Take the permit let permit = self - .semaphore - .take_permit() + .permit + .take() .expect("max requests in-flight; poll_ready must be called first"); // Call the inner service @@ -70,6 +88,19 @@ where } } +impl Clone for ConcurrencyLimit { + fn clone(&self) -> Self { + // Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`. + // Instead, when cloning the service, create a new service with the + // same semaphore, but with the permit in the un-acquired state. + Self { + inner: self.inner.clone(), + semaphore: self.semaphore.clone(), + permit: None, + } + } +} + #[cfg(feature = "load")] #[cfg_attr(docsrs, doc(cfg(feature = "load")))] impl crate::load::Load for ConcurrencyLimit diff --git a/tower/src/semaphore.rs b/tower/src/semaphore.rs deleted file mode 100644 index 6b6511f8a..000000000 --- a/tower/src/semaphore.rs +++ /dev/null @@ -1,100 +0,0 @@ -pub(crate) use self::sync::{AcquireError, OwnedSemaphorePermit as Permit}; -use futures_core::ready; -use std::{ - fmt, - future::Future, - mem, - pin::Pin, - sync::{Arc, Weak}, - task::{Context, Poll}, -}; -use tokio::sync; - -#[derive(Debug)] -pub(crate) struct Semaphore { - semaphore: Arc, - state: State, -} - -#[derive(Debug)] -pub(crate) struct Close { - semaphore: Weak, -} - -enum State { - Waiting(Pin> + Send + Sync + 'static>>), - Ready(Permit), - Empty, -} - -impl Semaphore { - pub(crate) fn new_with_close(permits: usize) -> (Self, Close) { - let semaphore = Arc::new(sync::Semaphore::new(permits)); - let close = Close { - semaphore: Arc::downgrade(&semaphore), - }; - let semaphore = Self { - semaphore, - state: State::Empty, - }; - (semaphore, close) - } - - pub(crate) fn new(permits: usize) -> Self { - Self { - semaphore: Arc::new(sync::Semaphore::new(permits)), - state: State::Empty, - } - } - - pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - self.state = match self.state { - State::Ready(_) => return Poll::Ready(Ok(())), - State::Waiting(ref mut fut) => { - let permit = ready!(Pin::new(fut).poll(cx))?; - State::Ready(permit) - } - State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), - }; - } - } - - pub(crate) fn take_permit(&mut self) -> Option { - if let State::Ready(permit) = mem::replace(&mut self.state, State::Empty) { - return Some(permit); - } - None - } -} - -impl Clone for Semaphore { - fn clone(&self) -> Self { - Self { - semaphore: self.semaphore.clone(), - state: State::Empty, - } - } -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::Waiting(_) => f - .debug_tuple("State::Waiting") - .field(&format_args!("...")) - .finish(), - State::Ready(ref r) => f.debug_tuple("State::Ready").field(&r).finish(), - State::Empty => f.debug_tuple("State::Empty").finish(), - } - } -} - -impl Close { - /// Close the semaphore, waking any remaining tasks currently awaiting a permit. - pub(crate) fn close(self) { - if let Some(semaphore) = self.semaphore.upgrade() { - semaphore.close() - } - } -} diff --git a/tower/tests/buffer/main.rs b/tower/tests/buffer/main.rs index 41b0336f4..3bb12dafe 100644 --- a/tower/tests/buffer/main.rs +++ b/tower/tests/buffer/main.rs @@ -346,6 +346,49 @@ async fn wakes_pending_waiters_on_failure() { ); } +#[tokio::test(flavor = "current_thread")] +async fn doesnt_leak_permits() { + let _t = support::trace_init(); + + let (service, mut handle) = mock::pair::<_, ()>(); + + let (mut service1, worker) = Buffer::pair(service, 2); + let mut worker = task::spawn(worker); + let mut service2 = service1.clone(); + let mut service3 = service1.clone(); + + // Attempt to poll the first clone of the buffer to readiness multiple + // times. These should all succeed, because the readiness is never + // *consumed* --- no request is sent. + assert_ready_ok!(task::spawn(service1.ready_and()).poll()); + assert_ready_ok!(task::spawn(service1.ready_and()).poll()); + assert_ready_ok!(task::spawn(service1.ready_and()).poll()); + + // It should also be possible to drive the second clone of the service to + // readiness --- it should only acquire one permit, as well. + assert_ready_ok!(task::spawn(service2.ready_and()).poll()); + assert_ready_ok!(task::spawn(service2.ready_and()).poll()); + assert_ready_ok!(task::spawn(service2.ready_and()).poll()); + + // The third clone *doesn't* poll ready, because the first two clones have + // each acquired one permit. + let mut ready3 = task::spawn(service3.ready_and()); + assert_pending!(ready3.poll()); + + // Consume the first service's readiness. + let mut response = task::spawn(service1.call(())); + handle.allow(1); + assert_pending!(worker.poll()); + + handle.next_request().await.unwrap().1.send_response(()); + assert_pending!(worker.poll()); + assert_ready_ok!(response.poll()); + + // Now, the third service should acquire a permit... + assert!(ready3.is_woken()); + assert_ready_ok!(ready3.poll()); +} + type Mock = mock::Mock<&'static str, &'static str>; type Handle = mock::Handle<&'static str, &'static str>;