Skip to content

Commit

Permalink
feat(net): add socket timeouts to Server and Client
Browse files Browse the repository at this point in the history
While these methods are marked unstable in libstd, this is behind a
feature flag, `timeouts`. The Client and Server both have
`set_read_timeout` and `set_write_timeout` methods, that will affect all
connections with that entity.

BREAKING CHANGE: Any custom implementation of NetworkStream must now
  implement `set_read_timeout` and `set_write_timeout`, so those will
  break. Most users who only use the provided streams should work with
  no changes needed.

Closes #315
  • Loading branch information
seanmonstar committed Jul 27, 2015
1 parent 421422b commit 7d1f154
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 49 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ env_logger = "*"
default = ["ssl"]
ssl = ["openssl", "cookie/secure"]
serde-serialization = ["serde"]
nightly = []

timeouts = []
nightly = ["timeouts"]
43 changes: 40 additions & 3 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,16 @@ use std::default::Default;
use std::io::{self, copy, Read};
use std::iter::Extend;

#[cfg(feature = "timeouts")]
use std::time::Duration;

use url::UrlParser;
use url::ParseError as UrlError;

use header::{Headers, Header, HeaderFormat};
use header::{ContentLength, Location};
use method::Method;
use net::{NetworkConnector, NetworkStream};
use net::{NetworkConnector, NetworkStream, Fresh};
use {Url};
use Error;

Expand All @@ -87,7 +90,9 @@ pub struct Client {
protocol: Box<Protocol + Send + Sync>,
redirect_policy: RedirectPolicy,
#[cfg(feature = "timeouts")]
read_timeout: Option<Duration>
read_timeout: Option<Duration>,
#[cfg(feature = "timeouts")]
write_timeout: Option<Duration>,
}

impl Client {
Expand All @@ -108,11 +113,23 @@ impl Client {
Client::with_protocol(Http11Protocol::with_connector(connector))
}

#[cfg(not(feature = "timeouts"))]
/// Create a new client with a specific `Protocol`.
pub fn with_protocol<P: Protocol + Send + Sync + 'static>(protocol: P) -> Client {
Client {
protocol: Box::new(protocol),
redirect_policy: Default::default(),
}
}

#[cfg(feature = "timeouts")]
/// Create a new client with a specific `Protocol`.
pub fn with_protocol<P: Protocol + Send + Sync + 'static>(protocol: P) -> Client {
Client {
protocol: Box::new(protocol),
redirect_policy: Default::default()
redirect_policy: Default::default(),
read_timeout: None,
write_timeout: None,
}
}

Expand All @@ -127,6 +144,12 @@ impl Client {
self.read_timeout = dur;
}

/// Set the write timeout value for all requests.
#[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self.write_timeout = dur;
}

/// Build a Get request.
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<U> {
self.request(Method::Get, url)
Expand Down Expand Up @@ -236,6 +259,20 @@ impl<'a, U: IntoUrl> RequestBuilder<'a, U> {
let mut req = try!(Request::with_message(method.clone(), url.clone(), message));
headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter()));

#[cfg(not(feature = "timeouts"))]
fn set_timeouts(_req: &mut Request<Fresh>, _client: &Client) -> ::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
fn set_timeouts(req: &mut Request<Fresh>, client: &Client) -> ::Result<()> {
try!(req.set_write_timeout(client.write_timeout));
try!(req.set_read_timeout(client.read_timeout));
Ok(())
}

try!(set_timeouts(&mut req, &client));

match (can_have_body, body.as_ref()) {
(true, Some(body)) => match body.size() {
Some(size) => req.headers_mut().set(ContentLength(size)),
Expand Down
15 changes: 15 additions & 0 deletions src/client/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::io::{self, Read, Write};
use std::net::{SocketAddr, Shutdown};
use std::sync::{Arc, Mutex};

#[cfg(feature = "timeouts")]
use std::time::Duration;

use net::{NetworkConnector, NetworkStream, DefaultConnector};

/// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`.
Expand Down Expand Up @@ -153,6 +156,18 @@ impl<S: NetworkStream> NetworkStream for PooledStream<S> {
self.inner.as_mut().unwrap().1.peer_addr()
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.as_ref().unwrap().1.set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.as_ref().unwrap().1.set_write_timeout(dur)
}

#[inline]
fn close(&mut self, how: Shutdown) -> io::Result<()> {
self.is_closed = true;
Expand Down
17 changes: 17 additions & 0 deletions src/client/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
use std::marker::PhantomData;
use std::io::{self, Write};

#[cfg(feature = "timeouts")]
use std::time::Duration;

use url::Url;

use method::{self, Method};
Expand Down Expand Up @@ -39,6 +42,20 @@ impl<W> Request<W> {
/// Read the Request method.
#[inline]
pub fn method(&self) -> method::Method { self.method.clone() }

/// Set the write timeout.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.message.set_write_timeout(dur)
}

/// Set the read timeout.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.message.set_read_timeout(dur)
}
}

impl Request<Fresh> {
Expand Down
47 changes: 43 additions & 4 deletions src/http/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::cmp::min;
use std::fmt;
use std::io::{self, Write, BufWriter, BufRead, Read};
use std::net::Shutdown;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use httparse;

Expand Down Expand Up @@ -192,6 +194,19 @@ impl HttpMessage for Http11Message {
})
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_read_timeout(dur)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.get_ref().set_write_timeout(dur)
}

#[inline]
fn close_connection(&mut self) -> ::Result<()> {
try!(self.get_mut().close(Shutdown::Both));
Ok(())
Expand All @@ -214,13 +229,27 @@ impl Http11Message {

/// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the
/// `Http11Message`.
pub fn get_mut(&mut self) -> &mut Box<NetworkStream + Send> {
pub fn get_ref(&self) -> &(NetworkStream + Send) {
if self.stream.is_some() {
self.stream.as_mut().unwrap()
&**self.stream.as_ref().unwrap()
} else if self.writer.is_some() {
self.writer.as_mut().unwrap().get_mut().get_mut()
&**self.writer.as_ref().unwrap().get_ref().get_ref()
} else if self.reader.is_some() {
self.reader.as_mut().unwrap().get_mut().get_mut()
&**self.reader.as_ref().unwrap().get_ref().get_ref()
} else {
panic!("Http11Message lost its underlying stream somehow");
}
}

/// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the
/// `Http11Message`.
pub fn get_mut(&mut self) -> &mut (NetworkStream + Send) {
if self.stream.is_some() {
&mut **self.stream.as_mut().unwrap()
} else if self.writer.is_some() {
&mut **self.writer.as_mut().unwrap().get_mut().get_mut()
} else if self.reader.is_some() {
&mut **self.reader.as_mut().unwrap().get_mut().get_mut()
} else {
panic!("Http11Message lost its underlying stream somehow");
}
Expand Down Expand Up @@ -344,6 +373,16 @@ impl<R: Read> HttpReader<R> {
}
}

/// Gets a borrowed reference to the underlying Reader.
pub fn get_ref(&self) -> &R {
match *self {
SizedReader(ref r, _) => r,
ChunkedReader(ref r, _) => r,
EofReader(ref r) => r,
EmptyReader(ref r) => r,
}
}

/// Gets a mutable reference to the underlying Reader.
pub fn get_mut(&mut self) -> &mut R {
match *self {
Expand Down
15 changes: 15 additions & 0 deletions src/http/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::io::{self, Write, Read, Cursor};
use std::net::Shutdown;
use std::ascii::AsciiExt;
use std::mem;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use http::{
Protocol,
Expand Down Expand Up @@ -398,6 +400,19 @@ impl<S> HttpMessage for Http2Message<S> where S: CloneableStream {
Ok(head)
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_read_timeout(&self, _dur: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[cfg(feature = "timeouts")]
#[inline]
fn set_write_timeout(&self, _dur: Option<Duration>) -> io::Result<()> {
Ok(())
}

#[inline]
fn close_connection(&mut self) -> ::Result<()> {
Ok(())
}
Expand Down
13 changes: 10 additions & 3 deletions src/http/message.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
//! Defines the `HttpMessage` trait that serves to encapsulate the operations of a single
//! request-response cycle on any HTTP connection.

use std::fmt::Debug;
use std::any::{Any, TypeId};
use std::fmt::Debug;
use std::io::{Read, Write};

use std::mem;

#[cfg(feature = "timeouts")]
use std::io;
#[cfg(feature = "timeouts")]
use std::time::Duration;

use typeable::Typeable;

use header::Headers;
Expand Down Expand Up @@ -62,7 +66,10 @@ pub trait HttpMessage: Write + Read + Send + Any + Typeable + Debug {
fn get_incoming(&mut self) -> ::Result<ResponseHead>;
/// Set the read timeout duration for this message.
#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, dur: Option<Duration>) -> ::Result<()>;
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
/// Set the write timeout duration for this message.
#[cfg(feature = "timeouts")]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
/// Closes the underlying HTTP connection.
fn close_connection(&mut self) -> ::Result<()>;
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![cfg_attr(test, deny(missing_docs))]
#![cfg_attr(test, deny(warnings))]
#![cfg_attr(all(test, feature = "nightly"), feature(test))]
#![cfg_attr(feature = "timeouts", feature(duration, socket_timeout))]

//! # Hyper
//!
Expand Down
Loading

0 comments on commit 7d1f154

Please sign in to comment.