From 25010fc1fc3e61ed9948d10b36b83e062fbd5612 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 25 Apr 2016 15:37:03 -0700 Subject: [PATCH] feat(client): add Proxy support This works by configuring proxy options on a `Client`, such as `client.set_proxy("http", "127.0.0.1", "8018")`. Closes #531 --- src/client/mod.rs | 64 +++++++++++++++++++++++++++++++++++++------ src/client/pool.rs | 7 +++++ src/client/request.rs | 38 +++++++++++++++++++++++-- src/http/h1.rs | 15 ++++++++-- 4 files changed, 112 insertions(+), 12 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 171da0a71a..4ad5c2d5fd 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -55,9 +55,9 @@ //! clone2.post("http://example.domain/post").body("foo=bar").send().unwrap(); //! }); //! ``` +use std::borrow::Cow; use std::default::Default; use std::io::{self, copy, Read}; -use std::iter::Extend; use std::fmt; use std::time::Duration; @@ -66,7 +66,7 @@ use url::Url; use url::ParseError as UrlError; use header::{Headers, Header, HeaderFormat}; -use header::{ContentLength, Location}; +use header::{ContentLength, Host, Location}; use method::Method; use net::{NetworkConnector, NetworkStream}; use Error; @@ -90,6 +90,7 @@ pub struct Client { redirect_policy: RedirectPolicy, read_timeout: Option, write_timeout: Option, + proxy: Option<(Cow<'static, str>, Cow<'static, str>, u16)> } impl fmt::Debug for Client { @@ -98,6 +99,7 @@ impl fmt::Debug for Client { .field("redirect_policy", &self.redirect_policy) .field("read_timeout", &self.read_timeout) .field("write_timeout", &self.write_timeout) + .field("proxy", &self.proxy) .finish() } } @@ -127,6 +129,7 @@ impl Client { redirect_policy: Default::default(), read_timeout: None, write_timeout: None, + proxy: None, } } @@ -145,6 +148,12 @@ impl Client { self.write_timeout = dur; } + /// Set a proxy for requests of this Client. + pub fn set_proxy(&mut self, scheme: S, host: H, port: u16) + where S: Into>, H: Into> { + self.proxy = Some((scheme.into(), host.into(), port)); + } + /// Build a Get request. pub fn get(&self, url: U) -> RequestBuilder { self.request(Method::Get, url) @@ -247,7 +256,7 @@ impl<'a> RequestBuilder<'a> { pub fn send(self) -> ::Result { let RequestBuilder { client, method, url, headers, body } = self; let mut url = try!(url); - trace!("send {:?} {:?}", method, url); + trace!("send method={:?}, url={:?}, client={:?}", method, url, client); let can_have_body = match method { Method::Get | Method::Head => false, @@ -261,12 +270,25 @@ impl<'a> RequestBuilder<'a> { }; loop { - let message = { - let (host, port) = try!(get_host_and_port(&url)); - try!(client.protocol.new_message(&host, port, url.scheme())) + let mut req = { + let (scheme, host, port) = match client.proxy { + Some(ref proxy) => (proxy.0.as_ref(), proxy.1.as_ref(), proxy.2), + None => { + let hp = try!(get_host_and_port(&url)); + (url.scheme(), hp.0, hp.1) + } + }; + let mut headers = match headers { + Some(ref headers) => headers.clone(), + None => Headers::new(), + }; + headers.set(Host { + hostname: host.to_owned(), + port: Some(port), + }); + let message = try!(client.protocol.new_message(&host, port, scheme)); + Request::with_headers_and_message(method.clone(), url.clone(), headers, message) }; - let mut req = try!(Request::with_message(method.clone(), url.clone(), message)); - headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter())); try!(req.set_write_timeout(client.write_timeout)); try!(req.set_read_timeout(client.read_timeout)); @@ -456,6 +478,8 @@ fn get_host_and_port(url: &Url) -> ::Result<(&str, u16)> { mod tests { use std::io::Read; use header::Server; + use http::h1::Http11Message; + use mock::{MockStream}; use super::{Client, RedirectPolicy}; use super::pool::Pool; use url::Url; @@ -477,6 +501,30 @@ mod tests { " }); + + #[test] + fn test_proxy() { + use super::pool::PooledStream; + mock_connector!(ProxyConnector { + b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" + }); + let mut client = Client::with_connector(Pool::with_connector(Default::default(), ProxyConnector)); + client.set_proxy("http", "example.proxy", 8008); + let mut dump = vec![]; + client.get("http://127.0.0.1/foo/bar").send().unwrap().read_to_end(&mut dump).unwrap(); + + { + let box_message = client.protocol.new_message("example.proxy", 8008, "http").unwrap(); + let message = box_message.downcast::().unwrap(); + let stream = message.into_inner().downcast::>().unwrap().into_inner(); + let s = ::std::str::from_utf8(&stream.write).unwrap(); + let request_line = "GET http://127.0.0.1/foo/bar HTTP/1.1\r\n"; + assert_eq!(&s[..request_line.len()], request_line); + assert!(s.contains("Host: example.proxy:8008\r\n")); + } + + } + #[test] fn test_redirect_followall() { let mut client = Client::with_connector(MockRedirectPolicy); diff --git a/src/client/pool.rs b/src/client/pool.rs index 7a530302f3..2c44fb37c1 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -133,6 +133,13 @@ pub struct PooledStream { pool: Arc>>, } +impl PooledStream { + /// Take the wrapped stream out of the pool completely. + pub fn into_inner(mut self) -> S { + self.inner.take().expect("PooledStream lost its inner stream").stream + } +} + #[derive(Debug)] struct PooledStreamInner { key: Key, diff --git a/src/client/request.rs b/src/client/request.rs index 612f944d65..db0fce0a94 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -70,14 +70,20 @@ impl Request { }); } - Ok(Request { + Ok(Request::with_headers_and_message(method, url, headers, message)) + } + + #[doc(hidden)] + pub fn with_headers_and_message(method: Method, url: Url, headers: Headers, message: Box) + -> Request { + Request { method: method, headers: headers, url: url, version: version::HttpVersion::Http11, message: message, _marker: PhantomData, - }) + } } /// Create a new client request. @@ -129,6 +135,8 @@ impl Request { pub fn headers_mut(&mut self) -> &mut Headers { &mut self.headers } } + + impl Request { /// Completes writing the request, and returns a response to read from. /// @@ -246,6 +254,32 @@ mod tests { assert!(!s.contains("Content-Length:")); } + #[test] + fn test_host_header() { + let url = Url::parse("http://example.dom").unwrap(); + let req = Request::with_connector( + Get, url, &mut MockConnector + ).unwrap(); + let bytes = run_request(req); + let s = from_utf8(&bytes[..]).unwrap(); + assert!(s.contains("Host: example.dom")); + } + + #[test] + fn test_proxy() { + let url = Url::parse("http://example.dom").unwrap(); + let proxy_url = Url::parse("http://pro.xy").unwrap(); + let mut req = Request::with_connector( + Get, proxy_url, &mut MockConnector + ).unwrap(); + req.url = url; + let bytes = run_request(req); + let s = from_utf8(&bytes[..]).unwrap(); + let request_line = "GET http://example.dom/ HTTP/1.1"; + assert_eq!(&s[..request_line.len()], request_line); + assert!(s.contains("Host: pro.xy")); + } + #[test] fn test_post_chunked_with_encoding() { let url = Url::parse("http://example.dom").unwrap(); diff --git a/src/http/h1.rs b/src/http/h1.rs index 63246e2819..7b9e4abd21 100644 --- a/src/http/h1.rs +++ b/src/http/h1.rs @@ -11,7 +11,7 @@ use url::Position as UrlPosition; use buffer::BufReader; use Error; -use header::{Headers, ContentLength, TransferEncoding}; +use header::{Headers, Host, ContentLength, TransferEncoding}; use header::Encoding::Chunked; use method::{Method}; use net::{NetworkConnector, NetworkStream}; @@ -144,7 +144,18 @@ impl HttpMessage for Http11Message { let mut stream = BufWriter::new(stream); { - let uri = &head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery]; + let uri = match head.headers.get::() { + Some(host) + if Some(&*host.hostname) == head.url.host_str() + && host.port == head.url.port_or_known_default() => { + &head.url[UrlPosition::BeforePath..UrlPosition::AfterQuery] + }, + _ => { + trace!("url and host header dont match, using absolute uri form"); + head.url.as_ref() + } + + }; let version = version::HttpVersion::Http11; debug!("request line: {:?} {:?} {:?}", head.method, uri, version);