diff --git a/src/proto/conn.rs b/src/proto/conn.rs index 8f1a32fd51..2bcd911a9c 100644 --- a/src/proto/conn.rs +++ b/src/proto/conn.rs @@ -453,6 +453,14 @@ where I: AsyncRead + AsyncWrite, pub fn close_write(&mut self) { self.state.close_write(); } + + pub fn disable_keep_alive(&mut self) { + if self.state.is_idle() { + self.state.close_read(); + } else { + self.state.disable_keep_alive(); + } + } } // ==== tokio_proto impl ==== @@ -700,6 +708,10 @@ impl State { } } + fn disable_keep_alive(&mut self) { + self.keep_alive.disable() + } + fn busy(&mut self) { if let KA::Disabled = self.keep_alive.status() { return; @@ -869,7 +881,7 @@ mod tests { other => panic!("unexpected frame: {:?}", other) } - // client + // client let io = AsyncIo::new_buf(vec![], 1); let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); conn.state.busy(); diff --git a/src/proto/dispatch.rs b/src/proto/dispatch.rs index 903855418d..f309ca6e45 100644 --- a/src/proto/dispatch.rs +++ b/src/proto/dispatch.rs @@ -54,6 +54,10 @@ where } } + pub fn disable_keep_alive(&mut self) { + self.conn.disable_keep_alive() + } + fn poll_read(&mut self) -> Poll<(), ::Error> { loop { if self.conn.can_read_head() { diff --git a/src/server/mod.rs b/src/server/mod.rs index 4e2cd34e36..b06f90a4f8 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -536,6 +536,18 @@ where } } +impl Connection +where S: Service, Error = ::Error> + 'static, + I: AsyncRead + AsyncWrite + 'static, + B: Stream + 'static, + B::Item: AsRef<[u8]>, +{ + /// Disables keep-alive for this connection. + pub fn disable_keep_alive(&mut self) { + self.conn.disable_keep_alive() + } +} + mod unnameable { // This type is specifically not exported outside the crate, // so no one can actually name the type. With no methods, we make no diff --git a/tests/server.rs b/tests/server.rs index 21d0f2d3cf..b3a994e8cc 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -6,7 +6,7 @@ extern crate pretty_env_logger; extern crate tokio_core; use futures::{Future, Stream}; -use futures::future::{self, FutureResult}; +use futures::future::{self, FutureResult, Either}; use futures::sync::oneshot; use tokio_core::net::TcpListener; @@ -551,6 +551,106 @@ fn pipeline_enabled() { assert_eq!(n, 0); } +#[test] +fn disable_keep_alive_mid_request() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + let (tx1, rx1) = oneshot::channel(); + let (tx2, rx2) = oneshot::channel(); + + let child = thread::spawn(move || { + let mut req = connect(&addr); + req.write_all(b"GET / HTTP/1.1\r\n").unwrap(); + tx1.send(()).unwrap(); + rx2.wait().unwrap(); + req.write_all(b"Host: localhost\r\n\r\n").unwrap(); + let mut buf = vec![]; + req.read_to_end(&mut buf).unwrap(); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::::new().serve_connection(socket, HelloWorld) + .select2(rx1) + .then(|r| { + match r { + Ok(Either::A(_)) => panic!("expected rx first"), + Ok(Either::B(((), mut conn))) => { + conn.disable_keep_alive(); + tx2.send(()).unwrap(); + conn + } + Err(Either::A((e, _))) => panic!("unexpected error {}", e), + Err(Either::B((e, _))) => panic!("unexpected error {}", e), + } + }) + }); + + core.run(fut).unwrap(); + child.join().unwrap(); +} + +#[test] +fn disable_keep_alive_post_request() { + let mut core = Core::new().unwrap(); + let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); + let addr = listener.local_addr().unwrap(); + + let (tx1, rx1) = oneshot::channel(); + + let child = thread::spawn(move || { + let mut req = connect(&addr); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: localhost\r\n\ + \r\n\ + ").unwrap(); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf).expect("reading 1"); + if n < buf.len() { + if &buf[n - HELLO.len()..n] == HELLO.as_bytes() { + break; + } + } + } + + tx1.send(()).unwrap(); + + let nread = req.read(&mut buf).unwrap(); + assert_eq!(nread, 0); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| unreachable!()) + .and_then(|(item, _incoming)| { + let (socket, _) = item.unwrap(); + Http::::new().serve_connection(socket, HelloWorld) + .select2(rx1) + .then(|r| { + match r { + Ok(Either::A(_)) => panic!("expected rx first"), + Ok(Either::B(((), mut conn))) => { + conn.disable_keep_alive(); + conn + } + Err(Either::A((e, _))) => panic!("unexpected error {}", e), + Err(Either::B((e, _))) => panic!("unexpected error {}", e), + } + }) + }); + + core.run(fut).unwrap(); + child.join().unwrap(); +} + #[test] fn no_proto_empty_parse_eof_does_not_return_error() { let mut core = Core::new().unwrap(); @@ -719,6 +819,8 @@ impl Service for TestService { } +const HELLO: &'static str = "hello"; + struct HelloWorld; impl Service for HelloWorld { @@ -728,7 +830,10 @@ impl Service for HelloWorld { type Future = FutureResult; fn call(&self, _req: Request) -> Self::Future { - future::ok(Response::new()) + let mut response = Response::new(); + response.headers_mut().set(hyper::header::ContentLength(HELLO.len() as u64)); + response.set_body(HELLO); + future::ok(response) } }