Skip to content

Commit

Permalink
Factor out ensure_writeable methods.
Browse files Browse the repository at this point in the history
There is some important logic that is easy to overlook in the client and
server channels: streams of data to write to the transport should not be
polled until the transport is known to be ready to buffer a message. In
the case that a transport's buffer is full, it needs to be flushed to
make room for more messages.

Without this logic, start_send() could return an error when the buffer
is full, which would cause the entire Channel to error out.

Due to the importance of this logic, it's now factored out into its own
method that's easier to understand: fn ensure_writeable. There is one in
the client module and and one in the server module.
  • Loading branch information
tikue committed Mar 8, 2021
1 parent 66419db commit 66cdc99
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 85 deletions.
150 changes: 78 additions & 72 deletions tarpc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
use futures::{prelude::*, ready, stream::Fuse, task::*};
use in_flight_requests::InFlightRequests;
use log::{info, trace};
use pin_project::{pin_project, pinned_drop};
use pin_project::pin_project;
use std::{
convert::TryFrom,
fmt, io,
Expand Down Expand Up @@ -165,7 +165,6 @@ impl<Req, Resp> Channel<Req, Resp> {

/// A server response that is completed by request dispatch when the corresponding response
/// arrives off the wire.
#[pin_project(PinnedDrop)]
#[derive(Debug)]
struct DispatchResponse<Resp> {
response: oneshot::Receiver<Response<Resp>>,
Expand Down Expand Up @@ -193,11 +192,9 @@ impl<Resp> Future for DispatchResponse<Resp> {
}

// Cancels the request when dropped, if not already complete.
#[pinned_drop]
impl<Resp> PinnedDrop for DispatchResponse<Resp> {
fn drop(mut self: Pin<&mut Self>) {
let self_ = self.project();
if let Some(cancellation) = self_.cancellation {
impl<Resp> Drop for DispatchResponse<Resp> {
fn drop(&mut self) {
if let Some(cancellation) = &mut self.cancellation {
// The receiver needs to be closed to handle the edge case that the request has not
// yet been received by the dispatch task. It is possible for the cancel message to
// arrive before the request itself, in which case the request could get stuck in the
Expand All @@ -208,8 +205,8 @@ impl<Resp> PinnedDrop for DispatchResponse<Resp> {
// closing the receiver before sending the cancel message, it is guaranteed that if the
// dispatch task misses an early-arriving cancellation message, then it will see the
// receiver as closed.
self_.response.close();
cancellation.cancel(*self_.request_id);
self.response.close();
cancellation.cancel(self.request_id);
}
}
}
Expand Down Expand Up @@ -252,10 +249,8 @@ pub struct RequestDispatch<Req, Resp, C> {
#[pin]
transport: Fuse<C>,
/// Requests waiting to be written to the wire.
#[pin]
pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
/// Requests that were dropped.
#[pin]
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Resp>,
Expand All @@ -271,16 +266,28 @@ where
self.as_mut().project().in_flight_requests
}

fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<C>> {
self.as_mut().project().transport
}

fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
self.as_mut().project().canceled_requests
}

fn pending_requests_mut<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut mpsc::Receiver<DispatchRequest<Req, Resp>> {
self.as_mut().project().pending_requests
}

fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Poll::Ready(
match ready!(self.as_mut().project().transport.poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
},
)
Poll::Ready(match ready!(self.transport_pin_mut().poll_next(cx)?) {
Some(response) => {
self.complete(response);
Some(Ok(()))
}
None => None,
})
}

fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
Expand All @@ -289,20 +296,14 @@ where
Closed,
}

let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
Poll::Ready(Some(dispatch_request)) => {
self.as_mut().write_request(dispatch_request)?;
return Poll::Ready(Some(Ok(())));
}
let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};

let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
Poll::Ready(Some((context, request_id))) => {
self.as_mut().write_cancel(context, request_id)?;
return Poll::Ready(Some(Ok(())));
}
let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
Poll::Ready(None) => ReceiverStatus::Closed,
Poll::Pending => ReceiverStatus::NotReady,
};
Expand All @@ -319,12 +320,12 @@ where

match (pending_requests_status, canceled_requests_status) {
(ReceiverStatus::Closed, ReceiverStatus::Closed) => {
ready!(self.as_mut().project().transport.poll_flush(cx)?);
ready!(self.transport_pin_mut().poll_flush(cx)?);
Poll::Ready(None)
}
(ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
// No more messages to process, so flush any messages buffered in the transport.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
ready!(self.transport_pin_mut().poll_flush(cx)?);

// Even if we fully-flush, we return Pending, because we have no more requests
// or cancellations right now.
Expand All @@ -334,6 +335,9 @@ where
}

/// Yields the next pending request, if one is ready to be sent.
///
/// Note that a request will only be yielded if the transport is *ready* to be written to (i.e.
/// start_send would succeed).
fn poll_next_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand All @@ -350,19 +354,10 @@ where
return Poll::Pending;
}

while self
.as_mut()
.project()
.transport
.poll_ready(cx)?
.is_pending()
{
// We can't yield a request-to-be-sent before the transport is capable of buffering it.
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
ready!(self.ensure_writeable(cx)?);

loop {
match ready!(self.as_mut().project().pending_requests.poll_recv(cx)) {
match ready!(self.pending_requests_mut().poll_recv(cx)) {
Some(request) => {
if request.response_completion.is_closed() {
trace!(
Expand All @@ -380,27 +375,17 @@ where
}

/// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
///
/// Note that a request to cancel will only be yielded if the transport is *ready* to be
/// written to (i.e. start_send would succeed).
fn poll_next_cancellation(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, u64)> {
while self
.as_mut()
.project()
.transport
.poll_ready(cx)?
.is_pending()
{
ready!(self.as_mut().project().transport.poll_flush(cx)?);
}
ready!(self.ensure_writeable(cx)?);

loop {
let cancellation = self
.as_mut()
.project()
.canceled_requests
.poll_next_unpin(cx);
match ready!(cancellation) {
match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
Some(request_id) => {
if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) {
return Poll::Ready(Some(Ok((ctx, request_id))));
Expand All @@ -411,10 +396,24 @@ where
}
}

fn write_request(
mut self: Pin<&mut Self>,
dispatch_request: DispatchRequest<Req, Resp>,
) -> io::Result<()> {
/// Returns Ready if writing a message to the transport (i.e. via write_request or
/// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
/// written to, flushes it until it is ready.
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
while self.transport_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.transport_pin_mut().poll_flush(cx)?);
}
Poll::Ready(Some(Ok(())))
}

fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) {
Some(dispatch_request) => dispatch_request,
None => return Poll::Ready(None),
};
// poll_next_request only returns Ready if there is room to buffer another request.
// Therefore, we can call write_request without fear of erroring due to a full
// buffer.
let request_id = dispatch_request.request_id;
let request = ClientMessage::Request(Request {
id: request_id,
Expand All @@ -424,30 +423,31 @@ where
trace_context: dispatch_request.ctx.trace_context,
},
});
self.as_mut().project().transport.start_send(request)?;
self.transport_pin_mut().start_send(request)?;
self.in_flight_requests()
.insert_request(
request_id,
dispatch_request.ctx,
dispatch_request.response_completion,
)
.expect("Request IDs should be unique");
Ok(())
Poll::Ready(Some(Ok(())))
}

fn write_cancel(
mut self: Pin<&mut Self>,
context: context::Context,
request_id: u64,
) -> io::Result<()> {
fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
Some((context, request_id)) => (context, request_id),
None => return Poll::Ready(None),
};

let trace_id = *context.trace_id();
let cancel = ClientMessage::Cancel {
trace_context: context.trace_context,
request_id,
};
self.as_mut().project().transport.start_send(cancel)?;
self.transport_pin_mut().start_send(cancel)?;
trace!("[{}] Cancel message sent.", trace_id);
Ok(())
Poll::Ready(Some(Ok(())))
}

/// Sends a server response to the client task that initiated the associated request.
Expand Down Expand Up @@ -532,11 +532,17 @@ impl RequestCancellation {
}
}

impl CanceledRequests {
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
}
}

impl Stream for CanceledRequests {
type Item = u64;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
self.poll_recv(cx)
}
}

Expand Down
37 changes: 24 additions & 13 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,8 @@ where
#[pin]
channel: C,
/// Responses waiting to be written to the wire.
#[pin]
pending_responses: mpsc::Receiver<(context::Context, Response<C::Resp>)>,
/// Handed out to request handlers to fan in responses.
#[pin]
responses_tx: mpsc::Sender<(context::Context, Response<C::Resp>)>,
}

Expand All @@ -397,6 +395,13 @@ where
self.as_mut().project().channel
}

/// Returns the inner channel over which messages are sent and received.
pub fn pending_responses_mut<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut mpsc::Receiver<(context::Context, Response<C::Resp>)> {
self.as_mut().project().pending_responses
}

fn pump_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down Expand Up @@ -451,12 +456,8 @@ where
context.trace_id(),
self.channel.in_flight_requests(),
);
// TODO: it's possible for poll_flush to be starved and start_send to end up full.
// Currently that would cause the channel to shut down. serde_transport internally
// uses tokio-util Framed, which will allocate as much as needed. But other
// transports may work differently.
//
// There should be a way to know if a flush is needed soon.
// A Ready result from poll_next_response means the Channel is ready to be written
// to. Therefore, we can call start_send without worry of a full buffer.
self.channel_pin_mut().start_send(response)?;
Poll::Ready(Some(Ok(())))
}
Expand All @@ -481,23 +482,33 @@ where
}
}

/// Yields a response ready to be written to the Channel sink.
///
/// Note that a response will only be yielded if the Channel is *ready* to be written to (i.e.
/// start_send would succeed).
fn poll_next_response(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> PollIo<(context::Context, Response<C::Resp>)> {
// Ensure there's room to write a response.
while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.as_mut().project().channel.poll_flush(cx)?);
}
ready!(self.ensure_writeable(cx)?);

match ready!(self.as_mut().project().pending_responses.poll_recv(cx)) {
match ready!(self.pending_responses_mut().poll_recv(cx)) {
Some(response) => Poll::Ready(Some(Ok(response))),
None => {
// This branch likely won't happen, since the Requests stream is holding a Sender.
Poll::Ready(None)
}
}
}

/// Returns Ready if writing a message to the Channel would not fail due to a full buffer. If
/// the Channel is not ready to be written to, flushes it until it is ready.
fn ensure_writeable<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
ready!(self.channel_pin_mut().poll_flush(cx)?);
}
Poll::Ready(Some(Ok(())))
}
}

impl<C> fmt::Debug for Requests<C>
Expand Down

0 comments on commit 66cdc99

Please sign in to comment.