Skip to content

Commit

Permalink
feat(all): add socket timeouts
Browse files Browse the repository at this point in the history
Methods added to `Client` and `Server` to control read and write
timeouts of the underlying socket.

Keep-Alive is re-enabled by default on the server, with a default
timeout of 5 seconds.

BREAKING CHANGE: This adds 2 required methods to the `NetworkStream`
  trait, `set_read_timeout` and `set_write_timeout`. Any local
  implementations will need to add them.
  • Loading branch information
seanmonstar committed Nov 24, 2015
1 parent 21c4f51 commit fec6e3e
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 109 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ env_logger = "*"
default = ["ssl"]
ssl = ["openssl", "cookie/secure"]
serde-serialization = ["serde"]
timeouts = []
nightly = ["timeouts"]
nightly = []
3 changes: 0 additions & 3 deletions benches/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ extern crate test;
use std::fmt;
use std::io::{self, Read, Write, Cursor};
use std::net::SocketAddr;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use hyper::net;
Expand Down Expand Up @@ -75,12 +74,10 @@ impl net::NetworkStream for MockStream {
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
Ok("127.0.0.1:1337".parse().unwrap())
}
#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, _: Option<Duration>) -> io::Result<()> {
// can't time out
Ok(())
}
#[cfg(feature = "timeouts")]
fn set_write_timeout(&self, _: Option<Duration>) -> io::Result<()> {
// can't time out
Ok(())
Expand Down
32 changes: 3 additions & 29 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ use std::default::Default;
use std::io::{self, copy, Read};
use std::iter::Extend;

#[cfg(feature = "timeouts")]
use std::time::Duration;

use url::UrlParser;
Expand All @@ -68,7 +67,7 @@ use url::ParseError as UrlError;
use header::{Headers, Header, HeaderFormat};
use header::{ContentLength, Location};
use method::Method;
use net::{NetworkConnector, NetworkStream, Fresh};
use net::{NetworkConnector, NetworkStream};
use {Url};
use Error;

Expand All @@ -89,9 +88,7 @@ use http::h1::Http11Protocol;
pub struct Client {
protocol: Box<Protocol + Send + Sync>,
redirect_policy: RedirectPolicy,
#[cfg(feature = "timeouts")]
read_timeout: Option<Duration>,
#[cfg(feature = "timeouts")]
write_timeout: Option<Duration>,
}

Expand All @@ -113,16 +110,6 @@ impl Client {
Client::with_protocol(Http11Protocol::with_connector(connector))
}

#[cfg(not(feature = "timeouts"))]
/// Create a new client with a specific `Protocol`.
pub fn with_protocol<P: Protocol + Send + Sync + 'static>(protocol: P) -> Client {
Client {
protocol: Box::new(protocol),
redirect_policy: Default::default(),
}
}

#[cfg(feature = "timeouts")]
/// Create a new client with a specific `Protocol`.
pub fn with_protocol<P: Protocol + Send + Sync + 'static>(protocol: P) -> Client {
Client {
Expand All @@ -139,13 +126,11 @@ impl Client {
}

/// Set the read timeout value for all requests.
#[cfg(feature = "timeouts")]
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
self.read_timeout = dur;
}

/// Set the write timeout value for all requests.
#[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self.write_timeout = dur;
}
Expand Down Expand Up @@ -273,19 +258,8 @@ impl<'a> RequestBuilder<'a> {
let mut req = try!(Request::with_message(method.clone(), url.clone(), message));
headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter()));

#[cfg(not(feature = "timeouts"))]
fn set_timeouts(_req: &mut Request<Fresh>, _client: &Client) -> ::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_timeouts(req: &mut Request<Fresh>, client: &Client) -> ::Result<()> {
try!(req.set_write_timeout(client.write_timeout));
try!(req.set_read_timeout(client.read_timeout));
Ok(())
}

try!(set_timeouts(&mut req, &client));
try!(req.set_write_timeout(client.write_timeout));
try!(req.set_read_timeout(client.read_timeout));

match (can_have_body, body.as_ref()) {
(true, Some(body)) => match body.size() {
Expand Down
3 changes: 0 additions & 3 deletions src/client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::io::{self, Read, Write};
use std::net::{SocketAddr, Shutdown};
use std::sync::{Arc, Mutex};

#[cfg(feature = "timeouts")]
use std::time::Duration;

use net::{NetworkConnector, NetworkStream, DefaultConnector};
Expand Down Expand Up @@ -176,13 +175,11 @@ impl<S: NetworkStream> NetworkStream for PooledStream<S> {
self.inner.as_mut().unwrap().stream.peer_addr()
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
Expand Down
3 changes: 0 additions & 3 deletions src/client/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use std::marker::PhantomData;
use std::io::{self, Write};

#[cfg(feature = "timeouts")]
use std::time::Duration;

use url::Url;
Expand Down Expand Up @@ -44,14 +43,12 @@ impl<W> Request<W> {
pub fn method(&self) -> method::Method { self.method.clone() }

/// Set the write timeout.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.message.set_write_timeout(dur)
}

/// Set the read timeout.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.message.set_read_timeout(dur)
Expand Down
3 changes: 0 additions & 3 deletions src/http/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::cmp::min;
use std::fmt;
use std::io::{self, Write, BufWriter, BufRead, Read};
use std::net::Shutdown;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use httparse;
Expand Down Expand Up @@ -341,13 +340,11 @@ impl HttpMessage for Http11Message {
}
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_write_timeout(dur)
Expand Down
3 changes: 0 additions & 3 deletions src/http/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::io::{self, Write, Read, Cursor};
use std::net::Shutdown;
use std::ascii::AsciiExt;
use std::mem;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use http::{
Expand Down Expand Up @@ -404,13 +403,11 @@ impl<S> HttpMessage for Http2Message<S> where S: CloneableStream {
true
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, _dur: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, _dur: Option<Duration>) -> io::Result<()> {
Ok(())
Expand Down
4 changes: 0 additions & 4 deletions src/http/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use std::fmt::Debug;
use std::io::{Read, Write};
use std::mem;

#[cfg(feature = "timeouts")]
use std::io;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use typeable::Typeable;
Expand Down Expand Up @@ -65,10 +63,8 @@ pub trait HttpMessage: Write + Read + Send + Any + Typeable + Debug {
/// the response body.
fn get_incoming(&mut self) -> ::Result<ResponseHead>;
/// Set the read timeout duration for this message.
#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
/// Set the write timeout duration for this message.
#[cfg(feature = "timeouts")]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
/// Closes the underlying HTTP connection.
fn close_connection(&mut self) -> ::Result<()>;
Expand Down
21 changes: 0 additions & 21 deletions src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ use std::io::{self, Read, Write, Cursor};
use std::cell::RefCell;
use std::net::{SocketAddr, Shutdown};
use std::sync::{Arc, Mutex};
#[cfg(feature = "timeouts")]
use std::time::Duration;
#[cfg(feature = "timeouts")]
use std::cell::Cell;

use solicit::http::HttpScheme;
Expand All @@ -24,9 +22,7 @@ pub struct MockStream {
pub is_closed: bool,
pub error_on_write: bool,
pub error_on_read: bool,
#[cfg(feature = "timeouts")]
pub read_timeout: Cell<Option<Duration>>,
#[cfg(feature = "timeouts")]
pub write_timeout: Cell<Option<Duration>>,
}

Expand All @@ -45,7 +41,6 @@ impl MockStream {
MockStream::with_responses(vec![input])
}

#[cfg(feature = "timeouts")]
pub fn with_responses(mut responses: Vec<&[u8]>) -> MockStream {
MockStream {
read: Cursor::new(responses.remove(0).to_vec()),
Expand All @@ -58,18 +53,6 @@ impl MockStream {
write_timeout: Cell::new(None),
}
}

#[cfg(not(feature = "timeouts"))]
pub fn with_responses(mut responses: Vec<&[u8]>) -> MockStream {
MockStream {
read: Cursor::new(responses.remove(0).to_vec()),
next_reads: responses.into_iter().map(|arr| arr.to_vec()).collect(),
write: vec![],
is_closed: false,
error_on_write: false,
error_on_read: false,
}
}
}

impl Read for MockStream {
Expand Down Expand Up @@ -111,13 +94,11 @@ impl NetworkStream for MockStream {
Ok("127.0.0.1:1337".parse().unwrap())
}

#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.read_timeout.set(dur);
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.write_timeout.set(dur);
Ok(())
Expand Down Expand Up @@ -167,12 +148,10 @@ impl NetworkStream for CloneableMockStream {
self.inner.lock().unwrap().peer_addr()
}

#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.lock().unwrap().set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.lock().unwrap().set_write_timeout(dur)
}
Expand Down
10 changes: 0 additions & 10 deletions src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::mem;
#[cfg(feature = "openssl")]
pub use self::openssl::Openssl;

#[cfg(feature = "timeouts")]
use std::time::Duration;

use typeable::Typeable;
Expand Down Expand Up @@ -53,11 +52,9 @@ pub trait NetworkStream: Read + Write + Any + Send + Typeable {
fn peer_addr(&mut self) -> io::Result<SocketAddr>;

/// Set the maximum time to wait for a read to complete.
#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>;

/// Set the maximum time to wait for a write to complete.
#[cfg(feature = "timeouts")]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>;

/// This will be called when Stream should no longer be kept alive.
Expand Down Expand Up @@ -341,13 +338,11 @@ impl NetworkStream for HttpStream {
self.0.peer_addr()
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.0.set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.0.set_write_timeout(dur)
Expand Down Expand Up @@ -471,7 +466,6 @@ impl<S: NetworkStream> NetworkStream for HttpsStream<S> {
}
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
match *self {
Expand All @@ -480,7 +474,6 @@ impl<S: NetworkStream> NetworkStream for HttpsStream<S> {
}
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
match *self {
Expand Down Expand Up @@ -580,7 +573,6 @@ mod openssl {
use std::net::{SocketAddr, Shutdown};
use std::path::Path;
use std::sync::Arc;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use openssl::ssl::{Ssl, SslContext, SslStream, SslMethod, SSL_VERIFY_NONE};
Expand Down Expand Up @@ -660,13 +652,11 @@ mod openssl {
self.get_mut().peer_addr()
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_write_timeout(dur)
Expand Down
Loading

0 comments on commit fec6e3e

Please sign in to comment.