Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniform API for custom headers between clients #814

Merged
merged 13 commits into from
Jul 13, 2022
1 change: 1 addition & 0 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ thiserror = "1.0"
tokio = { version = "1.16", features = ["time"] }
tracing = "0.1.34"
tracing-futures = "0.2.5"
http = "0.2.0"
lexnv marked this conversation as resolved.
Show resolved Hide resolved

[dev-dependencies]
jsonrpsee-test-utils = { path = "../../test-utils" }
Expand Down
21 changes: 18 additions & 3 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub struct HttpClientBuilder {
certificate_store: CertificateStore,
id_kind: IdKind,
max_log_length: u32,
headers: http::HeaderMap,
}

impl HttpClientBuilder {
Expand Down Expand Up @@ -88,11 +89,24 @@ impl HttpClientBuilder {
self
}

/// Set a custom header passed to the server with every request (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}

/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport =
HttpTransportClient::new(target, self.max_request_body_size, self.certificate_store, self.max_log_length)
.map_err(|e| Error::Transport(e.into()))?;
let transport = HttpTransportClient::new(
target,
self.max_request_body_size,
self.certificate_store,
self.max_log_length,
self.headers,
)
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
transport,
id_manager: Arc::new(RequestIdManager::new(self.max_concurrent_requests, self.id_kind)),
Expand All @@ -110,6 +124,7 @@ impl Default for HttpClientBuilder {
certificate_store: CertificateStore::Native,
id_kind: IdKind::Number,
max_log_length: 4096,
headers: http::HeaderMap::new(),
}
}
}
Expand Down
85 changes: 68 additions & 17 deletions client/http-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct HttpTransportClient {
///
/// Logs bigger than this limit will be truncated.
max_log_length: u32,
/// Custom headers to pass with every request.
headers: http::HeaderMap,
}

impl HttpTransportClient {
Expand All @@ -57,6 +59,7 @@ impl HttpTransportClient {
max_request_body_size: u32,
cert_store: CertificateStore,
max_log_length: u32,
headers: http::HeaderMap,
) -> Result<Self, Error> {
let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.port_u16().is_none() {
Expand Down Expand Up @@ -90,7 +93,18 @@ impl HttpTransportClient {
return Err(Error::Url(err.into()));
}
};
Ok(Self { target, client, max_request_body_size, max_log_length })

// Cache request headers: 2 default headers, followed by user custom headers.
// Maintain order for headers in case of duplicate keys:
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2
let mut cached_headers = http::HeaderMap::with_capacity(2 + headers.len());
cached_headers.insert(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON));
cached_headers.insert(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON));
for (key, value) in headers.iter() {
lexnv marked this conversation as resolved.
Show resolved Hide resolved
cached_headers.insert(key, value.clone());
}

Ok(Self { target, client, max_request_body_size, max_log_length, headers: cached_headers })
}

async fn inner_send(&self, body: String) -> Result<hyper::Response<hyper::Body>, Error> {
Expand All @@ -100,11 +114,9 @@ impl HttpTransportClient {
return Err(Error::RequestTooLarge);
}

let req = hyper::Request::post(&self.target)
.header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.body(From::from(body))
.expect("URI and request headers are valid; qed");
let mut req = hyper::Request::post(&self.target);
req.headers_mut().map(|headers| *headers = self.headers.clone());
let req = req.body(From::from(body)).expect("URI and request headers are valid; qed");

let response = self.client.request(req).await.map_err(|e| Error::Http(Box::new(e)))?;
if response.status().is_success() {
Expand Down Expand Up @@ -198,37 +210,67 @@ mod tests {

#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80, http::HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap();
let client = HttpTransportClient::new(
"https://localhost:9933",
80,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "localhost", "https", "/", 9933, 80);
}

#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new(
"https://localhost:9933",
80,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn faulty_port() {
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80, http::HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new(
"http://localhost:-99999",
80,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn url_with_path_works() {
let client =
HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native, 80)
.unwrap();
let client = HttpTransportClient::new(
"http://localhost:9944/my-special-path",
1337,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337);
}

Expand All @@ -239,22 +281,31 @@ mod tests {
u32::MAX,
CertificateStore::WebPki,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX);
}

#[test]
fn url_with_fragment_is_ignored() {
let client =
HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native, 80).unwrap();
let client = HttpTransportClient::new(
"http://127.0.0.1:9944/my.htm#ignore",
999,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999);
}

#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
let client = HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99).unwrap();
let client =
HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99, http::HeaderMap::new())
.unwrap();
assert_eq!(client.max_request_body_size, eighty_bytes_limit);

let body = "a".repeat(81);
Expand Down
26 changes: 15 additions & 11 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,32 @@ pub struct Receiver {

/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
#[derive(Debug)]
pub struct WsTransportClientBuilder<'a> {
pub struct WsTransportClientBuilder {
/// What certificate store to use
pub certificate_store: CertificateStore,
/// Timeout for the connection.
pub connection_timeout: Duration,
/// Custom headers to pass during the HTTP handshake. If `None`, no
/// custom header is passed.
pub headers: Vec<Header<'a>>,
/// Custom headers to pass during the HTTP handshake.
pub headers: http::HeaderMap,
/// Max payload size
pub max_request_body_size: u32,
/// Max number of redirections.
pub max_redirections: usize,
}

impl<'a> Default for WsTransportClientBuilder<'a> {
impl Default for WsTransportClientBuilder {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
max_request_body_size: TEN_MB_SIZE_BYTES,
connection_timeout: Duration::from_secs(10),
headers: Vec::new(),
headers: http::HeaderMap::new(),
max_redirections: 5,
}
}
}

impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Set whether to use system certificates (default is native).
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
Expand All @@ -107,8 +106,8 @@ impl<'a> WsTransportClientBuilder<'a> {
/// Set a custom header passed to the server during the handshake (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn add_header(mut self, name: &'a str, value: &'a str) -> Self {
self.headers.push(Header { name, value: value.as_bytes() });
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}

Expand Down Expand Up @@ -240,7 +239,7 @@ impl TransportReceiverT for Receiver {
}
}

impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Try to establish the connection.
pub async fn build(self, uri: Uri) -> Result<(Sender, Receiver), WsHandshakeError> {
let target: Target = uri.try_into()?;
Expand Down Expand Up @@ -289,7 +288,12 @@ impl<'a> WsTransportClientBuilder<'a> {
&target.path_and_query,
);

client.set_headers(&self.headers);
let headers: Vec<_> = self
.headers
.iter()
.map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() })
.collect();
client.set_headers(&headers);

// Perform the initial handshake.
match client.handshake().await {
Expand Down
1 change: 1 addition & 0 deletions client/ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client"
jsonrpsee-types = { path = "../../types", version = "0.14.0" }
jsonrpsee-client-transport = { path = "../transport", version = "0.14.0", features = ["ws"] }
jsonrpsee-core = { path = "../../core", version = "0.14.0", features = ["async-client"] }
http = "0.2.0"

[dev-dependencies]
env_logger = "0.9"
Expand Down
22 changes: 12 additions & 10 deletions client/ws-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use jsonrpsee_types as types;

use std::time::Duration;

use jsonrpsee_client_transport::ws::{Header, InvalidUri, Uri, WsTransportClientBuilder};
use jsonrpsee_client_transport::ws::{InvalidUri, Uri, WsTransportClientBuilder};
use jsonrpsee_core::client::{CertificateStore, ClientBuilder, IdKind};
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};

Expand All @@ -57,8 +57,10 @@ use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
/// #[tokio::main]
/// async fn main() {
/// // build client
/// let mut headers = http::HeaderMap::new();
/// headers.insert("Any-Header-You-Like", http::HeaderValue::from_static("42"));
/// let client = WsClientBuilder::default()
/// .add_header("Any-Header-You-Like", "42")
/// .set_headers(headers)
/// .build("wss://localhost:443")
/// .await
/// .unwrap();
Expand All @@ -68,28 +70,28 @@ use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
///
/// ```
#[derive(Clone, Debug)]
pub struct WsClientBuilder<'a> {
pub struct WsClientBuilder {
certificate_store: CertificateStore,
max_request_body_size: u32,
request_timeout: Duration,
connection_timeout: Duration,
ping_interval: Option<Duration>,
headers: Vec<Header<'a>>,
headers: http::HeaderMap,
max_concurrent_requests: usize,
max_notifs_per_subscription: usize,
max_redirections: usize,
id_kind: IdKind,
}

impl<'a> Default for WsClientBuilder<'a> {
impl Default for WsClientBuilder {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
max_request_body_size: TEN_MB_SIZE_BYTES,
request_timeout: Duration::from_secs(60),
connection_timeout: Duration::from_secs(10),
ping_interval: None,
headers: Vec::new(),
headers: http::HeaderMap::new(),
max_concurrent_requests: 256,
max_notifs_per_subscription: 1024,
max_redirections: 5,
Expand All @@ -98,7 +100,7 @@ impl<'a> Default for WsClientBuilder<'a> {
}
}

impl<'a> WsClientBuilder<'a> {
impl WsClientBuilder {
/// See documentation [`WsTransportClientBuilder::certificate_store`] (default is native).
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
Expand Down Expand Up @@ -129,9 +131,9 @@ impl<'a> WsClientBuilder<'a> {
self
}

/// See documentation [`WsTransportClientBuilder::add_header`] (default is none).
pub fn add_header(mut self, name: &'a str, value: &'a str) -> Self {
self.headers.push(Header { name, value: value.as_bytes() });
/// See documentation [`WsTransportClientBuilder::set_headers`] (default is none).
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}

Expand Down