Skip to content

Commit

Permalink
clients: feature gate tls (#545)
Browse files Browse the repository at this point in the history
* clients: introduce tls feature flag

* Update tests/tests/integration_tests.rs

* fix: don't rebuild tls connector of every connect

* fix tests + remove url dep

* fix tests again
  • Loading branch information
niklasad1 authored Dec 6, 2021
1 parent 3cb5eda commit 3f1c7fc
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 79 deletions.
7 changes: 5 additions & 2 deletions http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ documentation = "https://docs.rs/jsonrpsee-http-client"
async-trait = "0.1"
fnv = "1"
hyper = { version = "0.14.10", features = ["client", "http1", "http2", "tcp"] }
hyper-rustls = { version = "0.23", features = ["webpki-tokio"] }
hyper-rustls = { version = "0.23", optional = true }
jsonrpsee-types = { path = "../types", version = "0.6.0" }
jsonrpsee-utils = { path = "../utils", version = "0.6.0", features = ["client", "http-helpers"] }
serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.8", features = ["time"] }
tracing = "0.1"
url = "2.2"

[dev-dependencies]
jsonrpsee-test-utils = { path = "../test-utils" }
tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros"] }

[features]
default = ["tls"]
tls = ["hyper-rustls/webpki-tokio"]
139 changes: 120 additions & 19 deletions http-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,39 @@

use crate::types::error::GenericTransportError;
use hyper::client::{Client, HttpConnector};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper::Uri;
use jsonrpsee_types::CertificateStore;
use jsonrpsee_utils::http_helpers;
use thiserror::Error;

const CONTENT_TYPE_JSON: &str = "application/json";

#[derive(Debug, Clone)]
enum HyperClient {
/// Hyper client with https connector.
#[cfg(feature = "tls")]
Https(Client<hyper_rustls::HttpsConnector<HttpConnector>>),
/// Hyper client with http connector.
Http(Client<HttpConnector>),
}

impl HyperClient {
fn request(&self, req: hyper::Request<hyper::Body>) -> hyper::client::ResponseFuture {
match self {
Self::Http(client) => client.request(req),
#[cfg(feature = "tls")]
Self::Https(client) => client.request(req),
}
}
}

/// HTTP Transport Client.
#[derive(Debug, Clone)]
pub(crate) struct HttpTransportClient {
/// Target to connect to.
target: url::Url,
target: Uri,
/// HTTP client
client: Client<HttpsConnector<HttpConnector>>,
client: HyperClient,
/// Configurable max request body size
max_request_body_size: u32,
}
Expand All @@ -33,22 +52,40 @@ impl HttpTransportClient {
max_request_body_size: u32,
cert_store: CertificateStore,
) -> Result<Self, Error> {
let target = url::Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.scheme() == "http" || target.scheme() == "https" {
let connector = match cert_store {
CertificateStore::Native => {
HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1()
}
CertificateStore::WebPki => {
HttpsConnectorBuilder::new().with_webpki_roots().https_or_http().enable_http1()
}
_ => return Err(Error::InvalidCertficateStore),
};
let client = Client::builder().build::<_, hyper::Body>(connector.build());
Ok(HttpTransportClient { target, client, max_request_body_size })
} else {
Err(Error::Url("URL scheme not supported, expects 'http' or 'https'".into()))
let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.port_u16().is_none() {
return Err(Error::Url("Port number is missing in the URL".into()));
}

let client = match target.scheme_str() {
Some("http") => {
let connector = HttpConnector::new();
let client = Client::builder().build::<_, hyper::Body>(connector);
HyperClient::Http(client)
}
#[cfg(feature = "tls")]
Some("https") => {
let connector = match cert_store {
CertificateStore::Native => {
hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1()
}
CertificateStore::WebPki => {
hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots().https_or_http().enable_http1()
}
_ => return Err(Error::InvalidCertficateStore),
};
let client = Client::builder().build::<_, hyper::Body>(connector.build());
HyperClient::Https(client)
}
_ => {
#[cfg(feature = "tls")]
let err = "URL scheme not supported, expects 'http' or 'https'";
#[cfg(not(feature = "tls"))]
let err = "URL scheme not supported, expects 'http'";
return Err(Error::Url(err.into()));
}
};
Ok(Self { target, client, max_request_body_size })
}

async fn inner_send(&self, body: String) -> Result<hyper::Response<hyper::Body>, Error> {
Expand All @@ -58,7 +95,9 @@ impl HttpTransportClient {
return Err(Error::RequestTooLarge);
}

let req = hyper::Request::post(self.target.as_str())
// NOTE(niklasad1): this annoying we could just take `&str` here but more user-friendly to check
// that the URI is well-formed in the constructor.
let req = hyper::Request::post(self.target.clone())
.header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.body(From::from(body))
Expand Down Expand Up @@ -135,12 +174,74 @@ where
mod tests {
use super::{CertificateStore, Error, HttpTransportClient};

fn assert_target(
client: &HttpTransportClient,
host: &str,
scheme: &str,
path_and_query: &str,
port: u16,
max_request_size: u32,
) {
assert_eq!(client.target.scheme_str(), Some(scheme));
assert_eq!(client.target.path_and_query().map(|pq| pq.as_str()), Some(path_and_query));
assert_eq!(client.target.host(), Some(host));
assert_eq!(client.target.port_u16(), Some(port));
assert_eq!(client.max_request_body_size, max_request_size);
}

#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap();
assert_target(&client, "localhost", "https", "/", 9933, 80);
}

#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn faulty_port() {
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn url_with_path_works() {
let client =
HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native).unwrap();
assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337);
}

#[test]
fn url_with_query_works() {
let client = HttpTransportClient::new(
"http://127.0.0.1:9999/my?name1=value1&name2=value2",
u32::MAX,
CertificateStore::WebPki,
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX);
}

#[test]
fn url_with_fragment_is_ignored() {
let client =
HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native).unwrap();
assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999);
}

#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
Expand Down
3 changes: 1 addition & 2 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ async fn ws_with_non_ascii_url_doesnt_hang_or_panic() {

#[tokio::test]
async fn http_with_non_ascii_url_doesnt_hang_or_panic() {
let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap();
let err: Result<(), Error> = client.request("system_chain", None).await;
let err = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂");
assert!(matches!(err, Err(Error::Transport(_))));
}

Expand Down
6 changes: 5 additions & 1 deletion ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ serde_json = "1"
soketto = "0.7.1"
thiserror = "1"
tokio = { version = "1.8", features = ["net", "time", "rt-multi-thread", "macros"] }
tokio-rustls = "0.23"
tokio-rustls = { version = "0.23", optional = true }
tokio-util = { version = "0.6", features = ["compat"] }
tracing = "0.1"
webpki-roots = "0.22.0"
Expand All @@ -32,3 +32,7 @@ env_logger = "0.9"
jsonrpsee-test-utils = { path = "../test-utils" }
jsonrpsee-utils = { path = "../utils", features = ["client"] }
tokio = { version = "1.8", features = ["macros"] }

[features]
default = ["tls"]
tls = ["tokio-rustls"]
28 changes: 14 additions & 14 deletions ws-client/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,29 @@ use futures::{
};
use pin_project::pin_project;
use std::{io::Error as IoError, pin::Pin, task::Context, task::Poll};
use tokio::net::TcpStream;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};

/// Stream to represent either a unencrypted or encrypted socket stream.
#[pin_project(project = EitherStreamProj)]
#[derive(Debug, Copy, Clone)]
pub enum EitherStream<S, T> {
#[derive(Debug)]
pub enum EitherStream {
/// Unencrypted socket stream.
Plain(#[pin] S),
Plain(#[pin] TcpStream),
/// Encrypted socket stream.
Tls(#[pin] T),
#[cfg(feature = "tls")]
Tls(#[pin] tokio_rustls::client::TlsStream<TcpStream>),
}

impl<S, T> AsyncRead for EitherStream<S, T>
where
S: TokioAsyncReadCompatExt,
T: TokioAsyncReadCompatExt,
{
impl AsyncRead for EitherStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat();
futures::pin_mut!(compat);
AsyncRead::poll_read(compat, cx, buf)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat();
futures::pin_mut!(compat);
Expand All @@ -75,6 +74,7 @@ where
futures::pin_mut!(compat);
AsyncRead::poll_read_vectored(compat, cx, bufs)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat();
futures::pin_mut!(compat);
Expand All @@ -84,18 +84,15 @@ where
}
}

impl<S, T> AsyncWrite for EitherStream<S, T>
where
S: TokioAsyncWriteCompatExt,
T: TokioAsyncWriteCompatExt,
{
impl AsyncWrite for EitherStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
let compat = s.compat_write();
futures::pin_mut!(compat);
AsyncWrite::poll_write(compat, cx, buf)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
Expand All @@ -111,6 +108,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_write_vectored(compat, cx, bufs)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
Expand All @@ -126,6 +124,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_flush(compat, cx)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
Expand All @@ -141,6 +140,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_close(compat, cx)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
Expand Down
Loading

0 comments on commit 3f1c7fc

Please sign in to comment.