From 0fbca18fe11c271ca8f570dc398abe9d6ffa9985 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 1 Feb 2023 17:56:32 +0100 Subject: [PATCH] feat(http client): add tower middleware (#981) * feat(http client): add tower middleware * small fixes * fix rustdoc * no more mutex * fix nits * cleanup * fix grumbles * fix: opt reading response body * clippify * fix grumbles * Update core/src/http_helpers.rs --- client/http-client/Cargo.toml | 1 + client/http-client/src/client.rs | 66 +++++++-- client/http-client/src/transport.rs | 211 +++++++++++++++++++++------- core/src/error.rs | 4 +- core/src/http_helpers.rs | 93 ++++++++---- examples/examples/http.rs | 19 ++- 6 files changed, 299 insertions(+), 95 deletions(-) diff --git a/client/http-client/Cargo.toml b/client/http-client/Cargo.toml index 524e737b65..552b4d7eb3 100644 --- a/client/http-client/Cargo.toml +++ b/client/http-client/Cargo.toml @@ -21,6 +21,7 @@ serde_json = "1.0" thiserror = "1.0" tokio = { version = "1.16", features = ["time"] } tracing = "0.1.34" +tower = { version = "0.4.13", features = ["util"] } [dev-dependencies] tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } diff --git a/client/http-client/src/client.rs b/client/http-client/src/client.rs index a890269dd0..7b8fc570d6 100644 --- a/client/http-client/src/client.rs +++ b/client/http-client/src/client.rs @@ -25,14 +25,17 @@ // DEALINGS IN THE SOFTWARE. use std::borrow::Cow as StdCow; +use std::error::Error as StdError; use std::fmt; use std::sync::Arc; use std::time::Duration; -use crate::transport::HttpTransportClient; +use crate::transport::{self, Error as TransportError, HttpTransportClient}; use crate::types::{ErrorResponse, NotificationSer, RequestSer, Response}; use async_trait::async_trait; +use hyper::body::HttpBody; use hyper::http::HeaderMap; +use hyper::Body; use jsonrpsee_core::client::{ generate_batch_id_range, BatchResponse, CertificateStore, ClientT, IdKind, RequestIdManager, Subscription, SubscriptionClientT, @@ -43,6 +46,8 @@ use jsonrpsee_core::{Error, JsonRawValue, TEN_MB_SIZE_BYTES}; use jsonrpsee_types::error::CallError; use jsonrpsee_types::{ErrorObject, TwoPointZero}; use serde::de::DeserializeOwned; +use tower::layer::util::Identity; +use tower::{Layer, Service}; use tracing::instrument; /// Http Client Builder. @@ -70,7 +75,7 @@ use tracing::instrument; /// /// ``` #[derive(Debug)] -pub struct HttpClientBuilder { +pub struct HttpClientBuilder { max_request_size: u32, max_response_size: u32, request_timeout: Duration, @@ -79,9 +84,10 @@ pub struct HttpClientBuilder { id_kind: IdKind, max_log_length: u32, headers: HeaderMap, + service_builder: tower::ServiceBuilder, } -impl HttpClientBuilder { +impl HttpClientBuilder { /// Set the maximum size of a request body in bytes. Default is 10 MiB. pub fn max_request_size(mut self, size: u32) -> Self { self.max_request_size = size; @@ -134,8 +140,32 @@ impl HttpClientBuilder { self } + /// Set custom tower middleware. + pub fn set_middleware(self, service_builder: tower::ServiceBuilder) -> HttpClientBuilder { + HttpClientBuilder { + certificate_store: self.certificate_store, + id_kind: self.id_kind, + headers: self.headers, + max_log_length: self.max_log_length, + max_concurrent_requests: self.max_concurrent_requests, + max_request_size: self.max_request_size, + max_response_size: self.max_response_size, + service_builder, + request_timeout: self.request_timeout, + } + } +} + +impl HttpClientBuilder +where + L: Layer, + S: Service, Response = hyper::Response, Error = TransportError> + Clone, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ /// Build the HTTP client with target to connect to. - pub fn build(self, target: impl AsRef) -> Result { + pub fn build(self, target: impl AsRef) -> Result, Error> { let Self { max_request_size, max_response_size, @@ -145,6 +175,8 @@ impl HttpClientBuilder { id_kind, headers, max_log_length, + service_builder, + .. } = self; let transport = HttpTransportClient::new( @@ -154,6 +186,7 @@ impl HttpClientBuilder { certificate_store, max_log_length, headers, + service_builder, ) .map_err(|e| Error::Transport(e.into()))?; Ok(HttpClient { @@ -164,7 +197,7 @@ impl HttpClientBuilder { } } -impl Default for HttpClientBuilder { +impl Default for HttpClientBuilder { fn default() -> Self { Self { max_request_size: TEN_MB_SIZE_BYTES, @@ -175,15 +208,16 @@ impl Default for HttpClientBuilder { id_kind: IdKind::Number, max_log_length: 4096, headers: HeaderMap::new(), + service_builder: tower::ServiceBuilder::new(), } } } /// JSON-RPC HTTP Client that provides functionality to perform method calls and notifications. #[derive(Debug, Clone)] -pub struct HttpClient { +pub struct HttpClient { /// HTTP transport client. - transport: HttpTransportClient, + transport: HttpTransportClient, /// Request timeout. Defaults to 60sec. request_timeout: Duration, /// Request ID manager. @@ -191,7 +225,14 @@ pub struct HttpClient { } #[async_trait] -impl ClientT for HttpClient { +impl ClientT for HttpClient +where + S: Service, Response = hyper::Response, Error = TransportError> + Send + Sync + Clone, + >>::Future: Send, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ #[instrument(name = "notification", skip(self, params), level = "trace")] async fn notification(&self, method: &str, params: Params) -> Result<(), Error> where @@ -329,7 +370,14 @@ impl ClientT for HttpClient { } #[async_trait] -impl SubscriptionClientT for HttpClient { +impl SubscriptionClientT for HttpClient +where + S: Service, Response = hyper::Response, Error = TransportError> + Send + Sync + Clone, + >>::Future: Send, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ /// Send a subscription request to the server. Not implemented for HTTP; will always return [`Error::HttpNotImplemented`]. #[instrument(name = "subscription", fields(method = _subscribe_method), skip(self, _params, _subscribe_method, _unsubscribe_method), level = "trace")] async fn subscribe<'a, N, Params>( diff --git a/client/http-client/src/transport.rs b/client/http-client/src/transport.rs index ad728f2800..77be138f5b 100644 --- a/client/http-client/src/transport.rs +++ b/client/http-client/src/transport.rs @@ -6,6 +6,7 @@ // that we need to be guaranteed that hyper doesn't re-use an existing connection if we ever reset // the JSON-RPC request id to a value that might have already been used. +use hyper::body::{Body, HttpBody}; use hyper::client::{Client, HttpConnector}; use hyper::http::{HeaderMap, HeaderValue}; use hyper::Uri; @@ -13,36 +14,72 @@ use jsonrpsee_core::client::CertificateStore; use jsonrpsee_core::error::GenericTransportError; use jsonrpsee_core::http_helpers; use jsonrpsee_core::tracing::{rx_log_from_bytes, tx_log_from_str}; +use std::error::Error as StdError; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use thiserror::Error; +use tower::{Layer, Service, ServiceExt}; const CONTENT_TYPE_JSON: &str = "application/json"; -#[derive(Debug, Clone)] -enum HyperClient { +/// Wrapper over HTTP transport and connector. +#[derive(Debug)] +pub enum HttpBackend { /// Hyper client with https connector. #[cfg(feature = "tls")] - Https(Client>), + Https(Client, B>), /// Hyper client with http connector. - Http(Client), + Http(Client), +} + +impl Clone for HttpBackend { + fn clone(&self) -> Self { + match self { + Self::Http(inner) => Self::Http(inner.clone()), + #[cfg(feature = "tls")] + Self::Https(inner) => Self::Https(inner.clone()), + } + } } -impl HyperClient { - fn request(&self, req: hyper::Request) -> hyper::client::ResponseFuture { +impl tower::Service> for HttpBackend +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ + type Response = hyper::Response; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll> { match self { - Self::Http(client) => client.request(req), + Self::Http(inner) => inner.poll_ready(ctx), #[cfg(feature = "tls")] - Self::Https(client) => client.request(req), + Self::Https(inner) => inner.poll_ready(ctx), } + .map_err(Into::into) + } + + fn call(&mut self, req: hyper::Request) -> Self::Future { + let resp = match self { + Self::Http(inner) => inner.call(req), + #[cfg(feature = "tls")] + Self::Https(inner) => inner.call(req), + }; + + Box::pin(async move { resp.await.map_err(Into::into) }) } } /// HTTP Transport Client. #[derive(Debug, Clone)] -pub struct HttpTransportClient { +pub struct HttpTransportClient { /// Target to connect to. - target: Uri, + target: ParsedUri, /// HTTP client - client: HyperClient, + client: S, /// Configurable max request body size max_request_size: u32, /// Configurable max response body size @@ -55,23 +92,27 @@ pub struct HttpTransportClient { headers: HeaderMap, } -impl HttpTransportClient { +impl HttpTransportClient +where + S: Service, Response = hyper::Response, Error = Error> + Clone, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ /// Initializes a new HTTP client. - pub(crate) fn new( + pub(crate) fn new, Service = S>>( max_request_size: u32, target: impl AsRef, max_response_size: u32, cert_store: CertificateStore, max_log_length: u32, headers: HeaderMap, + service_builder: tower::ServiceBuilder, ) -> Result { - let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {e}")))?; - if target.port_u16().is_none() { - return Err(Error::Url("Port number is missing in the URL".into())); - } + let uri = ParsedUri::try_from(target.as_ref())?; - let client = match target.scheme_str() { - Some("http") => HyperClient::Http(Client::new()), + let client = match uri.0.scheme_str() { + Some("http") => HttpBackend::Http(Client::new()), #[cfg(feature = "tls")] Some("https") => { let connector = match cert_store { @@ -87,7 +128,7 @@ impl HttpTransportClient { .build(), _ => return Err(Error::InvalidCertficateStore), }; - HyperClient::Https(Client::builder().build::<_, hyper::Body>(connector)) + HttpBackend::Https(Client::builder().build::<_, hyper::Body>(connector)) } _ => { #[cfg(feature = "tls")] @@ -110,23 +151,30 @@ impl HttpTransportClient { } } - Ok(Self { target, client, max_request_size, max_response_size, max_log_length, headers: cached_headers }) + Ok(Self { + target: uri, + client: service_builder.service(client), + max_request_size, + max_response_size, + max_log_length, + headers: cached_headers, + }) } - async fn inner_send(&self, body: String) -> Result, Error> { + async fn inner_send(&self, body: String) -> Result, Error> { tx_log_from_str(&body, self.max_log_length); if body.len() > self.max_request_size as usize { return Err(Error::RequestTooLarge); } - let mut req = hyper::Request::post(&self.target); + let mut req = hyper::Request::post(&self.target.0); if let Some(headers) = req.headers_mut() { *headers = self.headers.clone(); } let req = req.body(From::from(body)).expect("URI and request headers are valid; qed"); + let response = self.client.clone().ready().await?.call(req).await?; - let response = self.client.request(req).await.map_err(|e| Error::Http(Box::new(e)))?; if response.status().is_success() { Ok(response) } else { @@ -153,6 +201,22 @@ impl HttpTransportClient { } } +#[derive(Debug, Clone)] +struct ParsedUri(Uri); + +impl TryFrom<&str> for ParsedUri { + type Error = Error; + + fn try_from(target: &str) -> Result { + let uri: Uri = target.parse().map_err(|e| Error::Url(format!("Invalid URL: {e}")))?; + if uri.port_u16().is_none() { + Err(Error::Url("Port number is missing in the URL".into())) + } else { + Ok(ParsedUri(uri)) + } + } +} + /// Error that can happen during a request. #[derive(Debug, Error)] pub enum Error { @@ -184,73 +248,112 @@ pub enum Error { InvalidCertficateStore, } -impl From> for Error -where - T: std::error::Error + Send + Sync + 'static, -{ - fn from(err: GenericTransportError) -> Self { +impl From for Error { + fn from(err: GenericTransportError) -> Self { match err { - GenericTransportError::::TooLarge => Self::RequestTooLarge, - GenericTransportError::::Malformed => Self::Malformed, - GenericTransportError::::Inner(e) => Self::Http(Box::new(e)), + GenericTransportError::TooLarge => Self::RequestTooLarge, + GenericTransportError::Malformed => Self::Malformed, + GenericTransportError::Inner(e) => Self::Http(e.into()), } } } +impl From for Error { + fn from(err: hyper::Error) -> Self { + Self::Http(Box::new(err)) + } +} + #[cfg(test)] mod tests { use super::*; + use jsonrpsee_core::client::CertificateStore; fn assert_target( - client: &HttpTransportClient, + client: &HttpTransportClient, host: &str, scheme: &str, path_and_query: &str, port: u16, max_request_size: u32, ) { - assert_eq!(client.target.scheme_str(), Some(scheme)); - assert_eq!(client.target.path_and_query().map(|pq| pq.as_str()), Some(path_and_query)); - assert_eq!(client.target.host(), Some(host)); - assert_eq!(client.target.port_u16(), Some(port)); + assert_eq!(client.target.0.scheme_str(), Some(scheme)); + assert_eq!(client.target.0.path_and_query().map(|pq| pq.as_str()), Some(path_and_query)); + assert_eq!(client.target.0.host(), Some(host)); + assert_eq!(client.target.0.port_u16(), Some(port)); assert_eq!(client.max_request_size, max_request_size); } #[test] fn invalid_http_url_rejected() { - let err = - HttpTransportClient::new(80, "ws://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new()) - .unwrap_err(); + let err = HttpTransportClient::new( + 80, + "ws://localhost:9933", + 80, + CertificateStore::Native, + 80, + HeaderMap::new(), + tower::ServiceBuilder::new(), + ) + .unwrap_err(); assert!(matches!(err, Error::Url(_))); } #[cfg(feature = "tls")] #[test] fn https_works() { - let client = - HttpTransportClient::new(80, "https://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new()) - .unwrap(); + let client = HttpTransportClient::new( + 80, + "https://localhost:9933", + 80, + CertificateStore::Native, + 80, + HeaderMap::new(), + tower::ServiceBuilder::new(), + ) + .unwrap(); assert_target(&client, "localhost", "https", "/", 9933, 80); } #[cfg(not(feature = "tls"))] #[test] fn https_fails_without_tls_feature() { - let err = - HttpTransportClient::new(80, "https://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new()) - .unwrap_err(); + let err = HttpTransportClient::new( + 80, + "https://localhost:9933", + 80, + CertificateStore::Native, + 80, + HeaderMap::new(), + tower::ServiceBuilder::new(), + ) + .unwrap_err(); assert!(matches!(err, Error::Url(_))); } #[test] fn faulty_port() { - let err = - HttpTransportClient::new(80, "http://localhost:-43", 80, CertificateStore::Native, 80, HeaderMap::new()) - .unwrap_err(); + let err = HttpTransportClient::new( + 80, + "http://localhost:-43", + 80, + CertificateStore::Native, + 80, + HeaderMap::new(), + tower::ServiceBuilder::new(), + ) + .unwrap_err(); assert!(matches!(err, Error::Url(_))); - let err = - HttpTransportClient::new(80, "http://localhost:-99999", 80, CertificateStore::Native, 80, HeaderMap::new()) - .unwrap_err(); + let err = HttpTransportClient::new( + 80, + "http://localhost:-99999", + 80, + CertificateStore::Native, + 80, + HeaderMap::new(), + tower::ServiceBuilder::new(), + ) + .unwrap_err(); assert!(matches!(err, Error::Url(_))); } @@ -263,6 +366,7 @@ mod tests { CertificateStore::Native, 80, HeaderMap::new(), + tower::ServiceBuilder::new(), ) .unwrap(); assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337); @@ -277,6 +381,7 @@ mod tests { CertificateStore::WebPki, 80, HeaderMap::new(), + tower::ServiceBuilder::new(), ) .unwrap(); assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX); @@ -291,6 +396,7 @@ mod tests { CertificateStore::Native, 80, HeaderMap::new(), + tower::ServiceBuilder::new(), ) .unwrap(); assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999); @@ -308,6 +414,7 @@ mod tests { CertificateStore::WebPki, 99, HeaderMap::new(), + tower::ServiceBuilder::new(), ) .unwrap(); assert_eq!(client.max_request_size, eighty_bytes_limit); diff --git a/core/src/error.rs b/core/src/error.rs index 33894e0ff8..0dd1ab55ce 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -191,7 +191,7 @@ impl From for ErrorObjectOwned { /// Generic transport error. #[derive(Debug, thiserror::Error)] -pub enum GenericTransportError { +pub enum GenericTransportError { /// Request was too large. #[error("The request was too big")] TooLarge, @@ -200,7 +200,7 @@ pub enum GenericTransportError { Malformed, /// Concrete transport error. #[error("Transport error: {0}")] - Inner(T), + Inner(anyhow::Error), } impl From for Error { diff --git a/core/src/http_helpers.rs b/core/src/http_helpers.rs index e63917c7b9..e0bd62e0a1 100644 --- a/core/src/http_helpers.rs +++ b/core/src/http_helpers.rs @@ -27,53 +27,84 @@ //! Utility methods relying on hyper use crate::error::GenericTransportError; -use futures_util::stream::StreamExt; +use anyhow::anyhow; +use hyper::body::{Buf, HttpBody}; +use std::error::Error as StdError; -/// Read a data from a [`hyper::Body`] and return the data if it is valid and within the allowed size range. +/// Read a data from [`hyper::body::HttpBody`] and return the data if it is valid JSON and within the allowed size range. /// /// Returns `Ok((bytes, single))` if the body was in valid size range; and a bool indicating whether the JSON-RPC /// request is a single or a batch. /// Returns `Err` if the body was too large or the body couldn't be read. -pub async fn read_body( +pub async fn read_body( headers: &hyper::HeaderMap, - mut body: hyper::Body, - max_request_body_size: u32, -) -> Result<(Vec, bool), GenericTransportError> { + body: B, + max_body_size: u32, +) -> Result<(Vec, bool), GenericTransportError> +where + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into>, +{ // NOTE(niklasad1): Values bigger than `u32::MAX` will be turned into zero here. This is unlikely to occur in - // practice and for that case we fallback to allocating in the while-loop below instead of pre-allocating. + // practice and in that case we fallback to allocating in the while-loop below instead of pre-allocating. let body_size = read_header_content_length(headers).unwrap_or(0); - if body_size > max_request_body_size { + if body_size > max_body_size { return Err(GenericTransportError::TooLarge); } - let first_chunk = - body.next().await.ok_or(GenericTransportError::Malformed)?.map_err(GenericTransportError::Inner)?; - - if first_chunk.len() > max_request_body_size as usize { - return Err(GenericTransportError::TooLarge); + futures_util::pin_mut!(body); + + // only allocate up to 16KB initially + let mut received_data = Vec::with_capacity(std::cmp::min(body_size as usize, 16 * 1024)); + let mut is_single = None; + + while let Some(d) = body.data().await { + let data = d.map_err(|e| GenericTransportError::Inner(anyhow!(e.into())))?; + + // if it's the first chunk, trim the whitespaces to determine whether it's valid JSON-RPC call. + if received_data.is_empty() { + let first_non_whitespace = + data.chunk().iter().enumerate().take(128).find(|(_, byte)| !byte.is_ascii_whitespace()); + + let skip = match first_non_whitespace { + Some((idx, b'{')) => { + is_single = Some(true); + idx + } + Some((idx, b'[')) => { + is_single = Some(false); + idx + } + _ => return Err(GenericTransportError::Malformed), + }; + + if data.chunk().len() - skip > max_body_size as usize { + return Err(GenericTransportError::TooLarge); + } + + // ignore whitespace as these doesn't matter just makes the JSON decoding slower. + received_data.extend_from_slice(&data.chunk()[skip..]); + } else { + if data.chunk().len() + received_data.len() > max_body_size as usize { + return Err(GenericTransportError::TooLarge); + } + + received_data.extend_from_slice(data.chunk()); + } } - let first_non_whitespace = first_chunk.iter().find(|byte| !byte.is_ascii_whitespace()); - - let single = match first_non_whitespace { - Some(b'{') => true, - Some(b'[') => false, - _ => return Err(GenericTransportError::Malformed), - }; - - let mut received_data = Vec::with_capacity(body_size as usize); - received_data.extend_from_slice(&first_chunk); - - while let Some(chunk) = body.next().await { - let chunk = chunk.map_err(GenericTransportError::Inner)?; - let body_length = chunk.len() + received_data.len(); - if body_length > max_request_body_size as usize { - return Err(GenericTransportError::TooLarge); + match is_single { + Some(single) if !received_data.is_empty() => { + tracing::trace!( + "HTTP response body: {}", + std::str::from_utf8(&received_data).unwrap_or("Invalid UTF-8 data") + ); + Ok((received_data, single)) } - received_data.extend_from_slice(&chunk); + _ => Err(GenericTransportError::Malformed), } - Ok((received_data, single)) } /// Read the `Content-Length` HTTP Header. Must fit into a `u32`; returns `None` otherwise. diff --git a/examples/examples/http.rs b/examples/examples/http.rs index 0c9da1ec99..7f33369d57 100644 --- a/examples/examples/http.rs +++ b/examples/examples/http.rs @@ -25,11 +25,15 @@ // DEALINGS IN THE SOFTWARE. use std::net::SocketAddr; +use std::time::Duration; +use hyper::body::Bytes; use jsonrpsee::core::client::ClientT; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::rpc_params; use jsonrpsee::server::{RpcModule, ServerBuilder}; +use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}; +use tower_http::LatencyUnit; use tracing_subscriber::util::SubscriberInitExt; #[tokio::main] @@ -41,7 +45,20 @@ async fn main() -> anyhow::Result<()> { let server_addr = run_server().await?; let url = format!("http://{}", server_addr); - let client = HttpClientBuilder::default().build(url)?; + let middleware = tower::ServiceBuilder::new() + .layer( + TraceLayer::new_for_http() + .on_request( + |request: &hyper::Request, _span: &tracing::Span| tracing::info!(request = ?request, "on_request"), + ) + .on_body_chunk(|chunk: &Bytes, latency: Duration, _: &tracing::Span| { + tracing::info!(size_bytes = chunk.len(), latency = ?latency, "sending body chunk") + }) + .make_span_with(DefaultMakeSpan::new().include_headers(true)) + .on_response(DefaultOnResponse::new().include_headers(true).latency_unit(LatencyUnit::Micros)), + ); + + let client = HttpClientBuilder::default().set_middleware(middleware).build(url)?; let params = rpc_params![1_u64, 2, 3]; let response: Result = client.request("say_hello", params).await; tracing::info!("r: {:?}", response);