Skip to content

Commit

Permalink
fix(server): use a timeout for Server keep-alive
Browse files Browse the repository at this point in the history
Server keep-alive is now **off** by default. In order to turn it on, the
`keep_alive` method must be called on the `Server` object.

Closes #368
  • Loading branch information
seanmonstar committed Oct 8, 2015
1 parent 388ddf6 commit 9482614
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 60 deletions.
150 changes: 90 additions & 60 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ use std::fmt;
use std::io::{self, ErrorKind, BufWriter, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::thread::{self, JoinHandle};

#[cfg(feature = "timeouts")]
use std::time::Duration;

use num_cpus;
Expand Down Expand Up @@ -146,20 +144,16 @@ mod listener;
#[derive(Debug)]
pub struct Server<L = HttpListener> {
listener: L,
_timeouts: Timeouts,
timeouts: Timeouts,
}

#[cfg(feature = "timeouts")]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts {
read: Option<Duration>,
write: Option<Duration>,
keep_alive: Option<Duration>,
}

#[cfg(not(feature = "timeouts"))]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts;

macro_rules! try_option(
($e:expr) => {{
match $e {
Expand All @@ -175,18 +169,30 @@ impl<L: NetworkListener> Server<L> {
pub fn new(listener: L) -> Server<L> {
Server {
listener: listener,
_timeouts: Timeouts::default(),
timeouts: Timeouts::default(),
}
}

/// Enables keep-alive for this server.
///
/// The timeout duration passed will be used to determine how long
/// to keep the connection alive before dropping it.
///
/// **NOTE**: The timeout will only be used when the `timeouts` feature
/// is enabled for hyper, and rustc is 1.4 or greater.
#[inline]
pub fn keep_alive(&mut self, timeout: Duration) {
self.timeouts.keep_alive = Some(timeout);
}

#[cfg(feature = "timeouts")]
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.read = dur;
self.timeouts.read = dur;
}

#[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.write = dur;
self.timeouts.write = dur;
}


Expand Down Expand Up @@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static {

debug!("threads = {:?}", threads);
let pool = ListenerPool::new(server.listener);
let worker = Worker::new(handler, server._timeouts);
let worker = Worker::new(handler, server.timeouts);
let work = move |mut stream| worker.handle_connection(&mut stream);

let guard = thread::spawn(move || pool.accept(work, threads));
Expand All @@ -241,15 +247,15 @@ L: NetworkListener + Send + 'static {

struct Worker<H: Handler + 'static> {
handler: H,
_timeouts: Timeouts,
timeouts: Timeouts,
}

impl<H: Handler + 'static> Worker<H> {

fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
Worker {
handler: handler,
_timeouts: timeouts,
timeouts: timeouts,
}
}

Expand All @@ -258,7 +264,7 @@ impl<H: Handler + 'static> Worker<H> {

self.handler.on_connection_start();

if let Err(e) = self.set_timeouts(stream) {
if let Err(e) = self.set_timeouts(&(stream as &mut NetworkStream)) {
error!("set_timeouts error: {:?}", e);
return;
}
Expand All @@ -273,73 +279,97 @@ impl<H: Handler + 'static> Worker<H> {

// FIXME: Use Type ascription
let stream_clone: &mut NetworkStream = &mut stream.clone();
let rdr = BufReader::new(stream_clone);
let wrt = BufWriter::new(stream);
let mut rdr = BufReader::new(stream_clone);
let mut wrt = BufWriter::new(stream);

self.keep_alive_loop(rdr, wrt, addr);
while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
if let Err(e) = self.set_read_timeout(rdr.get_mut(), self.timeouts.keep_alive) {
error!("set_read_timeout keep_alive {:?}", e);
break;
}
}

self.handler.on_connection_end();

debug!("keep_alive loop ending for {}", addr);
}

fn set_timeouts(&self, s: & &mut NetworkStream) -> io::Result<()> {
try!(self.set_read_timeout(s, self.timeouts.read));
self.set_write_timeout(s, self.timeouts.write)
}


#[cfg(not(feature = "timeouts"))]
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream {
fn set_write_timeout(&self, _s: & &mut NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream {
try!(s.set_read_timeout(self._timeouts.read));
s.set_write_timeout(self._timeouts.write)
fn set_write_timeout(&self, s: & &mut NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_write_timeout(timeout)
}

fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>,
mut wrt: W, addr: SocketAddr) {
let mut keep_alive = true;
while keep_alive {
let req = match Request::new(&mut rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
break;
}
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
break;
}
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
break;
}
};
#[cfg(not(feature = "timeouts"))]
fn set_read_timeout(&self, _s: & &mut NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, s: & &mut NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_read_timeout(timeout)
}

if !self.handle_expect(&req, &mut wrt) {
break;
fn keep_alive_loop<W: Write>(&self, mut rdr: &mut BufReader<&mut NetworkStream>,
wrt: &mut W, addr: SocketAddr) -> bool {
let req = match Request::new(rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
return false;
}

keep_alive = http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
return false;
}
{
let mut res = Response::new(&mut wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
return false;
}
};

// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}

debug!("keep_alive = {:?} for {}", keep_alive, addr);
if !self.handle_expect(&req, wrt) {
return false;
}

if let Err(e) = req.set_read_timeout(self.timeouts.read) {
error!("set_read_timeout {:?}", e);
return false;
}

let mut keep_alive = self.timeouts.keep_alive.is_some() &&
http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
}
{
let mut res = Response::new(wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
}

// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}

debug!("keep_alive = {:?} for {}", keep_alive, addr);
keep_alive
}

fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
Expand Down
14 changes: 14 additions & 0 deletions src/server/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! target URI, headers, and message body.
use std::io::{self, Read};
use std::net::SocketAddr;
use std::time::Duration;

use buffer::BufReader;
use net::NetworkStream;
Expand Down Expand Up @@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> {
})
}

/// Set the read timeout of the underlying NetworkStream.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.body.get_ref().get_ref().set_read_timeout(timeout)
}

/// Set the read timeout of the underlying NetworkStream.
#[cfg(not(feature = "timeouts"))]
#[inline]
pub fn set_read_timeout(&self, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}
/// Get a reference to the underlying `NetworkStream`.
#[inline]
pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> {
Expand Down

0 comments on commit 9482614

Please sign in to comment.