Skip to content

Commit

Permalink
feat(http client): add tower middleware (#981)
Browse files Browse the repository at this point in the history
* feat(http client): add tower middleware

* small fixes

* fix rustdoc

* no more mutex

* fix nits

* cleanup

* fix grumbles

* fix: opt reading response body

* clippify

* fix grumbles

* Update core/src/http_helpers.rs
  • Loading branch information
niklasad1 authored Feb 1, 2023
1 parent d671bda commit 0fbca18
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 95 deletions.
1 change: 1 addition & 0 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.16", features = ["time"] }
tracing = "0.1.34"
tower = { version = "0.4.13", features = ["util"] }

[dev-dependencies]
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
Expand Down
66 changes: 57 additions & 9 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@
// DEALINGS IN THE SOFTWARE.

use std::borrow::Cow as StdCow;
use std::error::Error as StdError;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;

use crate::transport::HttpTransportClient;
use crate::transport::{self, Error as TransportError, HttpTransportClient};
use crate::types::{ErrorResponse, NotificationSer, RequestSer, Response};
use async_trait::async_trait;
use hyper::body::HttpBody;
use hyper::http::HeaderMap;
use hyper::Body;
use jsonrpsee_core::client::{
generate_batch_id_range, BatchResponse, CertificateStore, ClientT, IdKind, RequestIdManager, Subscription,
SubscriptionClientT,
Expand All @@ -43,6 +46,8 @@ use jsonrpsee_core::{Error, JsonRawValue, TEN_MB_SIZE_BYTES};
use jsonrpsee_types::error::CallError;
use jsonrpsee_types::{ErrorObject, TwoPointZero};
use serde::de::DeserializeOwned;
use tower::layer::util::Identity;
use tower::{Layer, Service};
use tracing::instrument;

/// Http Client Builder.
Expand Down Expand Up @@ -70,7 +75,7 @@ use tracing::instrument;
///
/// ```
#[derive(Debug)]
pub struct HttpClientBuilder {
pub struct HttpClientBuilder<L = Identity> {
max_request_size: u32,
max_response_size: u32,
request_timeout: Duration,
Expand All @@ -79,9 +84,10 @@ pub struct HttpClientBuilder {
id_kind: IdKind,
max_log_length: u32,
headers: HeaderMap,
service_builder: tower::ServiceBuilder<L>,
}

impl HttpClientBuilder {
impl<L> HttpClientBuilder<L> {
/// Set the maximum size of a request body in bytes. Default is 10 MiB.
pub fn max_request_size(mut self, size: u32) -> Self {
self.max_request_size = size;
Expand Down Expand Up @@ -134,8 +140,32 @@ impl HttpClientBuilder {
self
}

/// Set custom tower middleware.
pub fn set_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> HttpClientBuilder<T> {
HttpClientBuilder {
certificate_store: self.certificate_store,
id_kind: self.id_kind,
headers: self.headers,
max_log_length: self.max_log_length,
max_concurrent_requests: self.max_concurrent_requests,
max_request_size: self.max_request_size,
max_response_size: self.max_response_size,
service_builder,
request_timeout: self.request_timeout,
}
}
}

impl<B, S, L> HttpClientBuilder<L>
where
L: Layer<transport::HttpBackend, Service = S>,
S: Service<hyper::Request<Body>, Response = hyper::Response<B>, Error = TransportError> + Clone,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S>, Error> {
let Self {
max_request_size,
max_response_size,
Expand All @@ -145,6 +175,8 @@ impl HttpClientBuilder {
id_kind,
headers,
max_log_length,
service_builder,
..
} = self;

let transport = HttpTransportClient::new(
Expand All @@ -154,6 +186,7 @@ impl HttpClientBuilder {
certificate_store,
max_log_length,
headers,
service_builder,
)
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
Expand All @@ -164,7 +197,7 @@ impl HttpClientBuilder {
}
}

impl Default for HttpClientBuilder {
impl Default for HttpClientBuilder<Identity> {
fn default() -> Self {
Self {
max_request_size: TEN_MB_SIZE_BYTES,
Expand All @@ -175,23 +208,31 @@ impl Default for HttpClientBuilder {
id_kind: IdKind::Number,
max_log_length: 4096,
headers: HeaderMap::new(),
service_builder: tower::ServiceBuilder::new(),
}
}
}

/// JSON-RPC HTTP Client that provides functionality to perform method calls and notifications.
#[derive(Debug, Clone)]
pub struct HttpClient {
pub struct HttpClient<S> {
/// HTTP transport client.
transport: HttpTransportClient,
transport: HttpTransportClient<S>,
/// Request timeout. Defaults to 60sec.
request_timeout: Duration,
/// Request ID manager.
id_manager: Arc<RequestIdManager>,
}

#[async_trait]
impl ClientT for HttpClient {
impl<B, S> ClientT for HttpClient<S>
where
S: Service<hyper::Request<Body>, Response = hyper::Response<B>, Error = TransportError> + Send + Sync + Clone,
<S as Service<hyper::Request<Body>>>::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
#[instrument(name = "notification", skip(self, params), level = "trace")]
async fn notification<Params>(&self, method: &str, params: Params) -> Result<(), Error>
where
Expand Down Expand Up @@ -329,7 +370,14 @@ impl ClientT for HttpClient {
}

#[async_trait]
impl SubscriptionClientT for HttpClient {
impl<B, S> SubscriptionClientT for HttpClient<S>
where
S: Service<hyper::Request<Body>, Response = hyper::Response<B>, Error = TransportError> + Send + Sync + Clone,
<S as Service<hyper::Request<Body>>>::Future: Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
/// Send a subscription request to the server. Not implemented for HTTP; will always return [`Error::HttpNotImplemented`].
#[instrument(name = "subscription", fields(method = _subscribe_method), skip(self, _params, _subscribe_method, _unsubscribe_method), level = "trace")]
async fn subscribe<'a, N, Params>(
Expand Down
Loading

0 comments on commit 0fbca18

Please sign in to comment.