From 7976023b594ec6784e40a147d3baec99a947b118 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 9 Jan 2018 17:46:29 -0800 Subject: [PATCH] fix(client): don't error on read before writing request --- src/proto/conn.rs | 18 ++++++++-- src/proto/dispatch.rs | 77 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 77 insertions(+), 18 deletions(-) diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 3a45659fc4..652c6714f1 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -238,7 +238,21 @@ where I: AsyncRead + AsyncWrite, ret } - pub fn maybe_park_read(&mut self) { + pub fn read_keep_alive(&mut self) -> Result<(), ::Error> { + debug_assert!(!self.can_read_head() && !self.can_read_body()); + + trace!("Conn::read_keep_alive"); + + if T::should_read_first() || !self.state.is_idle() { + self.maybe_park_read(); + } else { + self.try_empty_read()?; + } + + Ok(()) + } + + fn maybe_park_read(&mut self) { if !self.io.is_read_blocked() { // the Io object is ready to read, which means it will never alert // us that it is ready until we drain it. However, we're currently @@ -258,7 +272,7 @@ where I: AsyncRead + AsyncWrite, // // This should only be called for Clients wanting to enter the idle // state. - pub fn try_empty_read(&mut self) -> io::Result<()> { + fn try_empty_read(&mut self) -> io::Result<()> { assert!(!self.can_read_head() && !self.can_read_body()); if !self.io.read_buf().is_empty() { diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs index cc345eda44..ef165ef8c5 100644 --- a/src/proto/dispatch.rs +++ b/src/proto/dispatch.rs @@ -66,6 +66,20 @@ where self.conn.disable_keep_alive() } + fn poll2(&mut self) -> Poll<(), ::Error> { + self.poll_read()?; + self.poll_write()?; + self.poll_flush()?; + + if self.is_done() { + try_ready!(self.conn.shutdown()); + trace!("Dispatch::poll done"); + Ok(Async::Ready(())) + } else { + Ok(Async::NotReady) + } + } + fn poll_read(&mut self) -> Poll<(), ::Error> { loop { if self.is_closing { @@ -163,12 +177,8 @@ where } else { // just drop, the body will close automatically } - } else if !T::should_read_first() { - self.conn.try_empty_read()?; - return Ok(Async::NotReady); } else { - self.conn.maybe_park_read(); - return Ok(Async::Ready(())); + return self.conn.read_keep_alive().map(Async::Ready); } } } @@ -266,17 +276,13 @@ where #[inline] fn poll(&mut self) -> Poll { trace!("Dispatcher::poll"); - self.poll_read()?; - self.poll_write()?; - self.poll_flush()?; - - if self.is_done() { - try_ready!(self.conn.shutdown()); - trace!("Dispatch::poll done"); - Ok(Async::Ready(())) - } else { - Ok(Async::NotReady) - } + self.poll2().or_else(|e| { + // An error means we're shutting down either way. + // We just try to give the error to the user, + // and close the connection with an Ok. If we + // cannot give it to the user, then return the Err. + self.dispatch.recv_msg(Err(e)).map(Async::Ready) + }) } } @@ -399,6 +405,9 @@ where if let Some(cb) = self.callback.take() { let _ = cb.send(Err(err)); Ok(()) + } else if let Ok(Async::Ready(Some(ClientMsg::Request(_, _, cb)))) = self.rx.poll() { + let _ = cb.send(Err(err)); + Ok(()) } else { Err(err) } @@ -424,3 +433,39 @@ where self.callback.is_none() } } + +#[cfg(test)] +mod tests { + use futures::Sink; + + use super::*; + use mock::AsyncIo; + use proto::ClientTransaction; + + #[test] + fn client_read_response_before_writing_request() { + extern crate pretty_env_logger; + let _ = pretty_env_logger::try_init(); + ::futures::lazy(|| { + let io = AsyncIo::new_buf(b"HTTP/1.1 200 OK\r\n\r\n".to_vec(), 100); + let (mut tx, rx) = mpsc::channel(0); + let conn = Conn::<_, ::Chunk, ClientTransaction>::new(io, Default::default()); + let mut dispatcher = Dispatcher::new(Client::new(rx), conn); + + let req = RequestHead { + version: ::HttpVersion::Http11, + subject: ::proto::RequestLine::default(), + headers: Default::default(), + }; + let (res_tx, res_rx) = oneshot::channel(); + tx.start_send(ClientMsg::Request(req, None::<::Body>, res_tx)).unwrap(); + + dispatcher.poll().expect("dispatcher poll 1"); + dispatcher.poll().expect("dispatcher poll 2"); + let _res = res_rx.wait() + .expect("callback poll") + .expect("callback response"); + Ok::<(), ()>(()) + }).wait().unwrap(); + } +}