Skip to content

Commit

Permalink
fix(client): more reliably detect closed pooled connections (#1434)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmonstar authored Feb 5, 2018
1 parent 8fb84d2 commit 265ad67
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 60 deletions.
148 changes: 148 additions & 0 deletions src/client/cancel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

use futures::{Async, Future, Poll};
use futures::task::{self, Task};

use common::Never;

use self::lock::Lock;

#[derive(Clone)]
pub struct Cancel {
inner: Arc<Inner>,
}

pub struct Canceled {
inner: Arc<Inner>,
}

struct Inner {
is_canceled: AtomicBool,
task: Lock<Option<Task>>,
}

impl Cancel {
pub fn new() -> (Cancel, Canceled) {
let inner = Arc::new(Inner {
is_canceled: AtomicBool::new(false),
task: Lock::new(None),
});
let inner2 = inner.clone();
(
Cancel {
inner: inner,
},
Canceled {
inner: inner2,
},
)
}

pub fn cancel(&self) {
if !self.inner.is_canceled.swap(true, Ordering::SeqCst) {
if let Some(mut locked) = self.inner.task.try_lock() {
if let Some(task) = locked.take() {
task.notify();
}
}
// if we couldn't take the lock, Canceled was trying to park.
// After parking, it will check is_canceled one last time,
// so we can just stop here.
}
}

pub fn is_canceled(&self) -> bool {
self.inner.is_canceled.load(Ordering::SeqCst)
}
}

impl Future for Canceled {
type Item = ();
type Error = Never;

fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if self.inner.is_canceled.load(Ordering::SeqCst) {
Ok(Async::Ready(()))
} else {
if let Some(mut locked) = self.inner.task.try_lock() {
if locked.is_none() {
// it's possible a Cancel just tried to cancel on another thread,
// and we just missed it. Once we have the lock, we should check
// one more time before parking this task and going away.
if self.inner.is_canceled.load(Ordering::SeqCst) {
return Ok(Async::Ready(()));
}
*locked = Some(task::current());
}
Ok(Async::NotReady)
} else {
// if we couldn't take the lock, then a Cancel taken has it.
// The *ONLY* reason is because it is in the process of canceling.
Ok(Async::Ready(()))
}
}
}
}

impl Drop for Canceled {
fn drop(&mut self) {
self.inner.is_canceled.store(true, Ordering::SeqCst);
}
}


// a sub module just to protect unsafety
mod lock {
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};

pub struct Lock<T> {
is_locked: AtomicBool,
value: UnsafeCell<T>,
}

impl<T> Lock<T> {
pub fn new(val: T) -> Lock<T> {
Lock {
is_locked: AtomicBool::new(false),
value: UnsafeCell::new(val),
}
}

pub fn try_lock(&self) -> Option<Locked<T>> {
if !self.is_locked.swap(true, Ordering::SeqCst) {
Some(Locked { lock: self })
} else {
None
}
}
}

unsafe impl<T: Send> Send for Lock<T> {}
unsafe impl<T: Send> Sync for Lock<T> {}

pub struct Locked<'a, T: 'a> {
lock: &'a Lock<T>,
}

impl<'a, T> Deref for Locked<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.value.get() }
}
}

impl<'a, T> DerefMut for Locked<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.value.get() }
}
}

impl<'a, T> Drop for Locked<'a, T> {
fn drop(&mut self) {
self.lock.is_locked.store(false, Ordering::SeqCst);
}
}
}
73 changes: 73 additions & 0 deletions src/client/dispatch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use futures::{Async, Future, Poll, Stream};
use futures::sync::{mpsc, oneshot};

use common::Never;
use super::cancel::{Cancel, Canceled};

pub type Callback<U> = oneshot::Sender<::Result<U>>;
pub type Promise<U> = oneshot::Receiver<::Result<U>>;

pub fn channel<T, U>() -> (Sender<T, U>, Receiver<T, U>) {
let (tx, rx) = mpsc::unbounded();
let (cancel, canceled) = Cancel::new();
let tx = Sender {
cancel: cancel,
inner: tx,
};
let rx = Receiver {
canceled: canceled,
inner: rx,
};
(tx, rx)
}

pub struct Sender<T, U> {
cancel: Cancel,
inner: mpsc::UnboundedSender<(T, Callback<U>)>,
}

impl<T, U> Sender<T, U> {
pub fn is_closed(&self) -> bool {
self.cancel.is_canceled()
}

pub fn cancel(&self) {
self.cancel.cancel();
}

pub fn send(&self, val: T) -> Result<Promise<U>, T> {
let (tx, rx) = oneshot::channel();
self.inner.unbounded_send((val, tx))
.map(move |_| rx)
.map_err(|e| e.into_inner().0)
}
}

impl<T, U> Clone for Sender<T, U> {
fn clone(&self) -> Sender<T, U> {
Sender {
cancel: self.cancel.clone(),
inner: self.inner.clone(),
}
}
}

pub struct Receiver<T, U> {
canceled: Canceled,
inner: mpsc::UnboundedReceiver<(T, Callback<U>)>,
}

impl<T, U> Stream for Receiver<T, U> {
type Item = (T, Callback<U>);
type Error = Never;

fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if let Async::Ready(()) = self.canceled.poll()? {
return Ok(Async::Ready(None));
}
self.inner.poll()
.map_err(|()| unreachable!("mpsc never errors"))
}
}

//TODO: Drop for Receiver should consume inner
78 changes: 33 additions & 45 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
//! HTTP Client

use std::cell::{Cell, RefCell};
use std::cell::Cell;
use std::fmt;
use std::io;
use std::marker::PhantomData;
use std::rc::Rc;
use std::time::Duration;

use futures::{Future, Poll, Stream};
use futures::future::{self, Executor};
use futures::{Async, Future, Poll, Stream};
use futures::future::{self, Either, Executor};
#[cfg(feature = "compat")]
use http;
use tokio::reactor::Handle;
Expand All @@ -28,7 +28,10 @@ pub use self::connect::{HttpConnector, Connect};

use self::background::{bg, Background};

mod cancel;
mod connect;
//TODO(easy): move cancel and dispatch into common instead
pub(crate) mod dispatch;
mod dns;
mod pool;
#[cfg(feature = "compat")]
Expand Down Expand Up @@ -189,20 +192,16 @@ where C: Connect,
head.headers.set_pos(0, host);
}

use futures::Sink;
use futures::sync::{mpsc, oneshot};

let checkout = self.pool.checkout(domain.as_ref());
let connect = {
let executor = self.executor.clone();
let pool = self.pool.clone();
let pool_key = Rc::new(domain.to_string());
self.connector.connect(url)
.and_then(move |io| {
// 1 extra slot for possible Close message
let (tx, rx) = mpsc::channel(1);
let (tx, rx) = dispatch::channel();
let tx = HyperClient {
tx: RefCell::new(tx),
tx: tx,
should_close: Cell::new(true),
};
let pooled = pool.pooled(pool_key, tx);
Expand All @@ -225,33 +224,26 @@ where C: Connect,
});

let resp = race.and_then(move |client| {
use proto::dispatch::ClientMsg;

let (callback, rx) = oneshot::channel();
client.should_close.set(false);

match client.tx.borrow_mut().start_send(ClientMsg::Request(head, body, callback)) {
Ok(_) => (),
Err(e) => match e.into_inner() {
ClientMsg::Request(_, _, callback) => {
error!("pooled connection was not ready, this is a hyper bug");
let err = io::Error::new(
io::ErrorKind::BrokenPipe,
"pool selected dead connection",
);
let _ = callback.send(Err(::Error::Io(err)));
},
_ => unreachable!("ClientMsg::Request was just sent"),
match client.tx.send((head, body)) {
Ok(rx) => {
client.should_close.set(false);
Either::A(rx.then(|res| {
match res {
Ok(Ok(res)) => Ok(res),
Ok(Err(err)) => Err(err),
Err(_) => panic!("dispatch dropped without returning error"),
}
}))
},
Err(_) => {
error!("pooled connection was not ready, this is a hyper bug");
let err = io::Error::new(
io::ErrorKind::BrokenPipe,
"pool selected dead connection",
);
Either::B(future::err(::Error::Io(err)))
}
}

rx.then(|res| {
match res {
Ok(Ok(res)) => Ok(res),
Ok(Err(err)) => Err(err),
Err(_) => panic!("dispatch dropped without returning error"),
}
})
});

FutureResponse(Box::new(resp))
Expand All @@ -276,13 +268,8 @@ impl<C, B> fmt::Debug for Client<C, B> {
}

struct HyperClient<B> {
// A sentinel that is usually always true. If this is dropped
// while true, this will try to shutdown the dispatcher task.
//
// This should be set to false whenever it is checked out of the
// pool and successfully used to send a request.
should_close: Cell<bool>,
tx: RefCell<::futures::sync::mpsc::Sender<proto::dispatch::ClientMsg<B>>>,
tx: dispatch::Sender<proto::dispatch::ClientMsg<B>, ::Response>,
}

impl<B> Clone for HyperClient<B> {
Expand All @@ -296,18 +283,19 @@ impl<B> Clone for HyperClient<B> {

impl<B> self::pool::Ready for HyperClient<B> {
fn poll_ready(&mut self) -> Poll<(), ()> {
self.tx
.borrow_mut()
.poll_ready()
.map_err(|_| ())
if self.tx.is_closed() {
Err(())
} else {
Ok(Async::Ready(()))
}
}
}

impl<B> Drop for HyperClient<B> {
fn drop(&mut self) {
if self.should_close.get() {
self.should_close.set(false);
let _ = self.tx.borrow_mut().try_send(proto::dispatch::ClientMsg::Close);
self.tx.cancel();
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub use self::str::ByteStr;

mod str;

pub enum Never {}
Loading

0 comments on commit 265ad67

Please sign in to comment.