Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

buffer, limit: use tokio-util's PollSemaphore #556

Merged
merged 5 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tower/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ full = [
]
log = ["tracing/log"]
balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio-stream"]
buffer = ["tokio/sync", "tokio/rt", "tokio-stream", "tracing"]
buffer = ["tokio/sync", "tokio/rt", "tokio-util", "tracing"]
discover = []
filter = ["futures-util"]
hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time", "tracing"]
limit = ["tokio/time", "tokio/sync", "tracing"]
limit = ["tokio/time", "tokio/sync", "tokio-util", "tracing"]
load = ["tokio/time", "tracing"]
load-shed = []
make = ["tokio/io-std"]
Expand All @@ -74,6 +74,7 @@ rand = { version = "0.8", features = ["small_rng"], optional = true }
slab = { version = "0.4", optional = true }
tokio = { version = "1", optional = true, features = ["sync"] }
tokio-stream = { version = "0.1.0", optional = true }
tokio-util = { version = "0.6.3", default-features = false, optional = true }
tracing = { version = "0.1.2", optional = true }

[dev-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions tower/src/buffer/message.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::error::ServiceError;
use tokio::sync::oneshot;
use tokio::sync::{oneshot, OwnedSemaphorePermit};

/// Message sent over buffer
#[derive(Debug)]
pub(crate) struct Message<Request, Fut> {
pub(crate) request: Request,
pub(crate) tx: Tx<Fut>,
pub(crate) span: tracing::Span,
pub(super) _permit: crate::semaphore::Permit,
pub(super) _permit: OwnedSemaphorePermit,
}

/// Response sender
Expand Down
59 changes: 38 additions & 21 deletions tower/src/buffer/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use super::{
worker::{Handle, Worker},
};

use crate::semaphore::Semaphore;
use futures_core::ready;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;
use tower_service::Service;

/// Adds an mpsc buffer in front of an inner service.
Expand All @@ -29,7 +31,11 @@ where
// `async fn ready`, which borrows the sender. Therefore, we implement our
// own bounded MPSC on top of the unbounded channel, using a semaphore to
// limit how many items are in the channel.
semaphore: Semaphore,
semaphore: PollSemaphore,
// The current semaphore permit, if one has been acquired.
//
// This is acquired in `poll_ready` and taken in `call`.
permit: Option<OwnedSemaphorePermit>,
handle: Handle,
}

Expand Down Expand Up @@ -83,16 +89,15 @@ where
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let (semaphore, wake_waiters) = Semaphore::new_with_close(bound);
let (handle, worker) = Worker::new(service, rx, wake_waiters);
(
Buffer {
tx,
handle,
semaphore,
},
worker,
)
let semaphore = Arc::new(Semaphore::new(bound));
let (handle, worker) = Worker::new(service, rx, &semaphore);
let buffer = Buffer {
tx,
handle,
semaphore: PollSemaphore::new(semaphore),
permit: None,
};
(buffer, worker)
}

fn get_worker_error(&self) -> crate::BoxError {
Expand All @@ -116,19 +121,28 @@ where
return Poll::Ready(Err(self.get_worker_error()));
}

// Then, poll to acquire a semaphore permit. If we acquire a permit,
// then there's enough buffer capacity to send a new request. Otherwise,
// we need to wait for capacity.
self.semaphore
.poll_acquire(cx)
.map_err(|_| self.get_worker_error())
// Then, check if we've already acquired a permit.
if self.permit.is_some() {
// We've already reserved capacity to send a request. We're ready!
return Poll::Ready(Ok(()));
}

// Finally, if we haven't already acquired a permit, poll the semaphore
// to acquire one. If we acquire a permit, then there's enough buffer
// capacity to send a new request. Otherwise, we need to wait for
// capacity.
let permit =
ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
self.permit = Some(permit);

Poll::Ready(Ok(()))
}

fn call(&mut self, request: Request) -> Self::Future {
tracing::trace!("sending request to buffer worker");
let _permit = self
.semaphore
.take_permit()
.permit
.take()
.expect("buffer full; poll_ready must be called first");

// get the current Span so that we can explicitly propagate it to the worker
Expand Down Expand Up @@ -161,6 +175,9 @@ where
tx: self.tx.clone(),
handle: self.handle.clone(),
semaphore: self.semaphore.clone(),
// The new clone hasn't acquired a permit yet. It will when it's
// next polled ready.
permit: None,
}
}
}
31 changes: 19 additions & 12 deletions tower/src/buffer/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use super::{
};
use futures_core::ready;
use pin_project::pin_project;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, Weak};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Semaphore};
use tower_service::Service;

/// Task that handles processing the buffer. This type should not be used
Expand All @@ -33,7 +33,7 @@ where
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
close: Option<crate::semaphore::Close>,
close: Option<Weak<Semaphore>>,
}

/// Get the error out
Expand All @@ -50,20 +50,21 @@ where
pub(crate) fn new(
service: T,
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
close: crate::semaphore::Close,
semaphore: &Arc<Semaphore>,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
};

let semaphore = Arc::downgrade(semaphore);
let worker = Worker {
current_message: None,
finish: false,
failed: None,
rx,
service,
handle: handle.clone(),
close: Some(close),
close: Some(semaphore),
};

(handle, worker)
Expand Down Expand Up @@ -140,6 +141,17 @@ where
// requests that we receive before we've exhausted the receiver receive the error:
self.failed = Some(error);
}

/// Closes the buffer's semaphore if it is still open, waking any pending
/// tasks.
fn close_semaphore(&mut self) {
if let Some(close) = self.close.take().as_ref().and_then(Weak::upgrade) {
tracing::debug!("buffer closing; waking pending tasks");
close.close();
} else {
tracing::trace!("buffer already closed");
}
}
}

impl<T, Request> Future for Worker<T, Request>
Expand Down Expand Up @@ -199,10 +211,7 @@ where
.expect("Worker::failed did not set self.failed?")
.clone()));
// Wake any tasks waiting on channel capacity.
if let Some(close) = self.close.take() {
tracing::debug!("waking pending tasks");
close.close();
}
self.close_semaphore();
}
}
}
Expand All @@ -223,9 +232,7 @@ where
T::Error: Into<crate::BoxError>,
{
fn drop(mut self: Pin<&mut Self>) {
if let Some(close) = self.as_mut().close.take() {
close.close();
}
self.as_mut().close_semaphore();
}
}

Expand Down
2 changes: 0 additions & 2 deletions tower/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ pub use crate::make::MakeService;
pub use tower_layer::Layer;
#[doc(inline)]
pub use tower_service::Service;
#[cfg(any(feature = "buffer", feature = "limit"))]
mod semaphore;

#[allow(unreachable_pub)]
mod sealed {
Expand Down
6 changes: 3 additions & 3 deletions tower/src/limit/concurrency/future.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//! [`Future`] types
//!
//! [`Future`]: std::future::Future
use crate::semaphore::Permit;
use futures_core::ready;
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::OwnedSemaphorePermit;

/// Future for the [`ConcurrencyLimit`] service.
///
Expand All @@ -19,11 +19,11 @@ pub struct ResponseFuture<T> {
#[pin]
inner: T,
// Keep this around so that it is dropped when the future completes
_permit: Permit,
_permit: OwnedSemaphorePermit,
}

impl<T> ResponseFuture<T> {
pub(crate) fn new(inner: T, _permit: Permit) -> ResponseFuture<T> {
pub(crate) fn new(inner: T, _permit: OwnedSemaphorePermit) -> ResponseFuture<T> {
ResponseFuture { inner, _permit }
}
}
Expand Down
57 changes: 44 additions & 13 deletions tower/src/limit/concurrency/service.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
use super::future::ResponseFuture;
use crate::semaphore::Semaphore;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;
use tower_service::Service;

use futures_core::ready;
use std::task::{Context, Poll};
use std::{
sync::Arc,
task::{Context, Poll},
};

/// Enforces a limit on the concurrent number of requests the underlying
/// service can handle.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ConcurrencyLimit<T> {
inner: T,
semaphore: Semaphore,
semaphore: PollSemaphore,
/// The currently acquired semaphore permit, if there is sufficient
/// concurrency to send a new request.
///
/// The permit is acquired in `poll_ready`, and taken in `call` when sending
/// a new request.
permit: Option<OwnedSemaphorePermit>,
}

impl<T> ConcurrencyLimit<T> {
/// Create a new concurrency limiter.
pub fn new(inner: T, max: usize) -> Self {
ConcurrencyLimit {
inner,
semaphore: Semaphore::new(max),
semaphore: PollSemaphore::new(Arc::new(Semaphore::new(max))),
permit: None,
}
}

Expand Down Expand Up @@ -47,20 +58,27 @@ where
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// First, poll the semaphore...
ready!(self.semaphore.poll_acquire(cx)).expect(
"ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
should never fail",
);
// ...and if it's ready, poll the inner service.
// If we haven't already acquired a permit from the semaphore, try to
// acquire one first.
if self.permit.is_none() {
self.permit = ready!(self.semaphore.poll_acquire(cx));
debug_assert!(
self.permit.is_some(),
"ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
should never fail",
);
}

// Once we've acquired a permit (or if we already had one), poll the
// inner service.
self.inner.poll_ready(cx)
}

fn call(&mut self, request: Request) -> Self::Future {
// Take the permit
let permit = self
.semaphore
.take_permit()
.permit
.take()
.expect("max requests in-flight; poll_ready must be called first");

// Call the inner service
Expand All @@ -70,6 +88,19 @@ where
}
}

impl<T: Clone> Clone for ConcurrencyLimit<T> {
fn clone(&self) -> Self {
// Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`.
// Instead, when cloning the service, create a new service with the
// same semaphore, but with the permit in the un-acquired state.
Self {
inner: self.inner.clone(),
semaphore: self.semaphore.clone(),
permit: None,
}
}
}

#[cfg(feature = "load")]
#[cfg_attr(docsrs, doc(cfg(feature = "load")))]
impl<S> crate::load::Load for ConcurrencyLimit<S>
Expand Down
Loading