Skip to content

Commit

Permalink
fix(client): close connections when Response Future or Body is dropped
Browse files Browse the repository at this point in the history
Closes #1397
  • Loading branch information
seanmonstar committed Dec 14, 2017
1 parent 8f6931b commit ef40081
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 23 deletions.
97 changes: 82 additions & 15 deletions src/proto/body.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use bytes::Bytes;
use futures::{Poll, Stream};
use futures::sync::mpsc;
use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
use futures::sync::{mpsc, oneshot};
use tokio_proto;
use std::borrow::Cow;

Expand All @@ -12,20 +12,36 @@ pub type BodySender = mpsc::Sender<Result<Chunk, ::Error>>;
/// A `Stream` for `Chunk`s used in requests and responses.
#[must_use = "streams do nothing unless polled"]
#[derive(Debug)]
pub struct Body(TokioBody);
pub struct Body(Inner);

#[derive(Debug)]
enum Inner {
Tokio(TokioBody),
Hyper {
close_tx: oneshot::Sender<()>,
rx: mpsc::Receiver<Result<Chunk, ::Error>>,
}
}

//pub(crate)
#[derive(Debug)]
pub struct ChunkSender {
close_rx: oneshot::Receiver<()>,
tx: BodySender,
}

impl Body {
/// Return an empty body stream
#[inline]
pub fn empty() -> Body {
Body(TokioBody::empty())
Body(Inner::Tokio(TokioBody::empty()))
}

/// Return a body stream with an associated sender half
#[inline]
pub fn pair() -> (mpsc::Sender<Result<Chunk, ::Error>>, Body) {
let (tx, rx) = TokioBody::pair();
let rx = Body(rx);
let rx = Body(Inner::Tokio(rx));
(tx, rx)
}
}
Expand All @@ -43,7 +59,51 @@ impl Stream for Body {

#[inline]
fn poll(&mut self) -> Poll<Option<Chunk>, ::Error> {
self.0.poll()
match self.0 {
Inner::Tokio(ref mut rx) => rx.poll(),
Inner::Hyper { ref mut rx, .. } => match rx.poll().expect("mpsc cannot error") {
Async::Ready(Some(Ok(chunk))) => Ok(Async::Ready(Some(chunk))),
Async::Ready(Some(Err(err))) => Err(err),
Async::Ready(None) => Ok(Async::Ready(None)),
Async::NotReady => Ok(Async::NotReady),
},
}
}
}

//pub(crate)
pub fn channel() -> (ChunkSender, Body) {
let (tx, rx) = mpsc::channel(0);
let (close_tx, close_rx) = oneshot::channel();

let tx = ChunkSender {
close_rx: close_rx,
tx: tx,
};
let rx = Body(Inner::Hyper {
close_tx: close_tx,
rx: rx,
});

(tx, rx)
}

impl ChunkSender {
pub fn poll_ready(&mut self) -> Poll<(), ()> {
match self.close_rx.poll() {
Ok(Async::Ready(())) | Err(_) => return Err(()),
Ok(Async::NotReady) => (),
}

self.tx.poll_ready().map_err(|_| ())
}

pub fn start_send(&mut self, msg: Result<Chunk, ::Error>) -> StartSend<(), ()> {
match self.tx.start_send(msg) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(())),
Err(_) => Err(()),
}
}
}

Expand All @@ -52,7 +112,14 @@ impl Stream for Body {
impl From<Body> for tokio_proto::streaming::Body<Chunk, ::Error> {
#[inline]
fn from(b: Body) -> tokio_proto::streaming::Body<Chunk, ::Error> {
b.0
match b.0 {
Inner::Tokio(b) => b,
Inner::Hyper { close_tx, rx } => {
warn!("converting hyper::Body into a tokio_proto Body is deprecated");
::std::mem::forget(close_tx);
rx.into()
}
}
}
}

Expand All @@ -61,42 +128,42 @@ impl From<Body> for tokio_proto::streaming::Body<Chunk, ::Error> {
impl From<tokio_proto::streaming::Body<Chunk, ::Error>> for Body {
#[inline]
fn from(tokio_body: tokio_proto::streaming::Body<Chunk, ::Error>) -> Body {
Body(tokio_body)
Body(Inner::Tokio(tokio_body))
}
}

impl From<mpsc::Receiver<Result<Chunk, ::Error>>> for Body {
#[inline]
fn from(src: mpsc::Receiver<Result<Chunk, ::Error>>) -> Body {
Body(src.into())
TokioBody::from(src).into()
}
}

impl From<Chunk> for Body {
#[inline]
fn from (chunk: Chunk) -> Body {
Body(TokioBody::from(chunk))
TokioBody::from(chunk).into()
}
}

impl From<Bytes> for Body {
#[inline]
fn from (bytes: Bytes) -> Body {
Body(TokioBody::from(Chunk::from(bytes)))
Body::from(TokioBody::from(Chunk::from(bytes)))
}
}

impl From<Vec<u8>> for Body {
#[inline]
fn from (vec: Vec<u8>) -> Body {
Body(TokioBody::from(Chunk::from(vec)))
Body::from(TokioBody::from(Chunk::from(vec)))
}
}

impl From<&'static [u8]> for Body {
#[inline]
fn from (slice: &'static [u8]) -> Body {
Body(TokioBody::from(Chunk::from(slice)))
Body::from(TokioBody::from(Chunk::from(slice)))
}
}

Expand All @@ -113,14 +180,14 @@ impl From<Cow<'static, [u8]>> for Body {
impl From<String> for Body {
#[inline]
fn from (s: String) -> Body {
Body(TokioBody::from(Chunk::from(s.into_bytes())))
Body::from(TokioBody::from(Chunk::from(s.into_bytes())))
}
}

impl From<&'static str> for Body {
#[inline]
fn from(slice: &'static str) -> Body {
Body(TokioBody::from(Chunk::from(slice.as_bytes())))
Body::from(TokioBody::from(Chunk::from(slice.as_bytes())))
}
}

Expand Down
62 changes: 54 additions & 8 deletions src/proto/dispatch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;

use futures::{Async, AsyncSink, Future, Poll, Sink, Stream};
use futures::{Async, AsyncSink, Future, Poll, Stream};
use futures::sync::{mpsc, oneshot};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_service::Service;
Expand All @@ -11,7 +11,7 @@ use ::StatusCode;
pub struct Dispatcher<D, Bs, I, B, T, K> {
conn: Conn<I, B, T, K>,
dispatch: D,
body_tx: Option<super::body::BodySender>,
body_tx: Option<super::body::ChunkSender>,
body_rx: Option<Bs>,
is_closing: bool,
}
Expand All @@ -22,6 +22,7 @@ pub trait Dispatch {
type RecvItem;
fn poll_msg(&mut self) -> Poll<Option<(Self::PollItem, Option<Self::PollBody>)>, ::Error>;
fn recv_msg(&mut self, msg: ::Result<(Self::RecvItem, Option<Body>)>) -> ::Result<()>;
fn poll_ready(&mut self) -> Poll<(), ()>;
fn should_poll(&self) -> bool;
}

Expand Down Expand Up @@ -70,10 +71,22 @@ where
if self.is_closing {
return Ok(Async::Ready(()));
} else if self.conn.can_read_head() {
// can dispatch receive, or does it still care about, an incoming message?
match self.dispatch.poll_ready() {
Ok(Async::Ready(())) => (),
Ok(Async::NotReady) => unreachable!("dispatch not ready when conn is"),
Err(()) => {
trace!("dispatch no longer receiving messages");
self.is_closing = true;
return Ok(Async::Ready(()));
}
}
// dispatch is ready for a message, try to read one
match self.conn.read_head() {
Ok(Async::Ready(Some((head, has_body)))) => {
let body = if has_body {
let (tx, rx) = super::Body::pair();
let (mut tx, rx) = super::body::channel();
let _ = tx.poll_ready(); // register this task if rx is dropped
self.body_tx = Some(tx);
Some(rx)
} else {
Expand Down Expand Up @@ -111,6 +124,8 @@ where
self.conn.close_read();
return Ok(Async::Ready(()));
}
// else the conn body is done, and user dropped,
// so everything is fine!
}
}
if can_read_body {
Expand All @@ -133,7 +148,7 @@ where
}
},
Ok(Async::Ready(None)) => {
let _ = body.close();
// just drop, the body will close automatically
},
Ok(Async::NotReady) => {
self.body_tx = Some(body);
Expand All @@ -144,7 +159,7 @@ where
}
}
} else {
let _ = body.close();
// just drop, the body will close automatically
}
} else if !T::should_read_first() {
self.conn.try_empty_read()?;
Expand Down Expand Up @@ -305,6 +320,14 @@ where
Ok(())
}

fn poll_ready(&mut self) -> Poll<(), ()> {
if self.in_flight.is_some() {
Ok(Async::NotReady)
} else {
Ok(Async::Ready(()))
}
}

fn should_poll(&self) -> bool {
self.in_flight.is_some()
}
Expand Down Expand Up @@ -333,9 +356,18 @@ where

fn poll_msg(&mut self) -> Poll<Option<(Self::PollItem, Option<Self::PollBody>)>, ::Error> {
match self.rx.poll() {
Ok(Async::Ready(Some(ClientMsg::Request(head, body, cb)))) => {
self.callback = Some(cb);
Ok(Async::Ready(Some((head, body))))
Ok(Async::Ready(Some(ClientMsg::Request(head, body, mut cb)))) => {
// check that future hasn't been canceled already
match cb.poll_cancel().expect("poll_cancel cannot error") {
Async::Ready(()) => {
trace!("request canceled");
Ok(Async::Ready(None))
},
Async::NotReady => {
self.callback = Some(cb);
Ok(Async::Ready(Some((head, body))))
}
}
},
Ok(Async::Ready(Some(ClientMsg::Close))) |
Ok(Async::Ready(None)) => {
Expand Down Expand Up @@ -370,6 +402,20 @@ where
}
}

fn poll_ready(&mut self) -> Poll<(), ()> {
match self.callback {
Some(ref mut cb) => match cb.poll_cancel() {
Ok(Async::Ready(())) => {
trace!("callback receiver has dropped");
Err(())
},
Ok(Async::NotReady) => Ok(Async::Ready(())),
Err(_) => unreachable!("oneshot poll_cancel cannot error"),
},
None => Err(()),
}
}

fn should_poll(&self) -> bool {
self.callback.is_none()
}
Expand Down
Loading

0 comments on commit ef40081

Please sign in to comment.