Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(server): move host filtering to tower middleware #1179

Merged
merged 11 commits into from
Aug 11, 2023
2 changes: 1 addition & 1 deletion client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ where
let mut failed_calls = 0;

for _ in 0..json_rps.len() {
responses.push(Err(ErrorObject::borrowed(0, &"", None)));
responses.push(Err(ErrorObject::borrowed(0, "", None)));
}

for rp in json_rps {
Expand Down
5 changes: 1 addition & 4 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,20 @@ parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-timer = { version = "3", optional = true }
route-recognizer = { version = "0.3.1", optional = true }
http = { version = "0.2.9", optional = true }


[features]
default = []
http-helpers = ["hyper", "futures-util"]
server = [
"futures-util/alloc",
"route-recognizer",
"rustc-hash/std",
"parking_lot",
"rand",
"tokio/rt",
"tokio/sync",
"tokio/macros",
"tokio/time",
"http",
]
client = ["futures-util/sink", "tokio/sync"]
async-client = [
Expand Down
2 changes: 1 addition & 1 deletion core/src/client/async_client/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub(crate) fn process_batch_response(
};

for _ in range {
let err_obj = ErrorObject::borrowed(0, &"", None);
let err_obj = ErrorObject::borrowed(0, "", None);
responses.push(Err(err_obj));
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl MethodResponse {

let err = ResponsePayload::error_borrowed(ErrorObject::borrowed(
err_code,
&OVERSIZED_RESPONSE_MSG,
OVERSIZED_RESPONSE_MSG,
data.as_deref(),
));
let result =
Expand Down
3 changes: 0 additions & 3 deletions core/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,13 @@
mod error;
/// Helpers.
pub mod helpers;
/// Host filtering.
mod host_filtering;
/// JSON-RPC "modules" group sets of methods that belong together and handles method/subscription registration.
mod rpc_module;
/// Subscription related types.
mod subscription;

pub use error::*;
pub use helpers::{BatchResponseBuilder, BoundedWriter, MethodResponse, MethodSink};
pub use host_filtering::*;
pub use rpc_module::*;
pub use subscription::*;

Expand Down
6 changes: 1 addition & 5 deletions examples/examples/cors_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
// modifying requests / responses. These features are independent of one another
// and can also be used separately.
// In this example, we use both features.
let server = Server::builder()
.disable_host_filtering()
.set_middleware(middleware)
.build("127.0.0.1:0".parse::<SocketAddr>()?)
.await?;
let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0".parse::<SocketAddr>()?).await?;

let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| {
Expand Down
2 changes: 1 addition & 1 deletion examples/examples/http_proxy_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use std::time::Duration;
use jsonrpsee::core::client::ClientT;
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::rpc_params;
use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer;
use jsonrpsee::server::middleware::ProxyGetRequestLayer;
use jsonrpsee::server::{RpcModule, Server};

#[tokio::main]
Expand Down
3 changes: 3 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ tokio-util = { version = "0.7", features = ["compat"] }
tokio-stream = "0.1.7"
hyper = { version = "0.14", features = ["server", "http1", "http2"] }
tower = "0.4.13"
route-recognizer = "0.3.1"
http = "0.2.9"
thiserror = "1.0.44"

[dev-dependencies]
anyhow = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,80 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

//! HTTP Host Header validation.
//! HTTP host validation middleware.

use std::net::SocketAddr;

use crate::Error;
use crate::transport::http as http_helpers;
use futures_util::{Future, FutureExt, TryFutureExt};
use http::uri::{InvalidUri, Uri};
use hyper::{Body, Request, Response};
use jsonrpsee_core::Error;
use route_recognizer::Router;
use std::error::Error as StdError;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};

/// Middleware to enable host filtering.
#[derive(Debug)]
pub struct HostFilterLayer(Arc<WhitelistedHosts>);

impl HostFilterLayer {
/// Enables host filtering and allow only the specified hosts.
pub fn new<T: IntoIterator<Item = U>, U: TryInto<Authority>>(allow_only: T) -> Result<Self, AuthorityError>
where
T: IntoIterator<Item = U>,
U: TryInto<Authority, Error = AuthorityError>,
{
let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
Ok(Self(Arc::new(WhitelistedHosts::from(allow_only?))))
}
}

impl<S> Layer<S> for HostFilterLayer {
type Service = HostFilter<S>;

fn layer(&self, inner: S) -> Self::Service {
HostFilter { inner, filter: self.0.clone() }
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Middleware to enable host filtering.
#[derive(Debug)]
pub struct HostFilter<S> {
inner: S,
filter: Arc<WhitelistedHosts>,
}

impl<S> Service<Request<Body>> for HostFilter<S>
where
S: Service<Request<Body>, Response = Response<Body>>,
S::Response: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>> + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = Box<dyn StdError + Send + Sync + 'static>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, request: Request<Body>) -> Self::Future {
let Some(authority) = http_helpers::authority(&request) else {
return async { Ok(http_helpers::response::malformed()) }.boxed();
};

if self.filter.recognize(&authority) {
Box::pin(self.inner.call(request).map_err(Into::into))
} else {
tracing::debug!("Denied request: {:?}", request);
async { Ok(http_helpers::response::host_not_allowed()) }.boxed()
}
}
}

/// Port pattern
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -229,6 +296,7 @@ mod tests {
assert!(Authority::try_from("user:password").is_err());
assert!(Authority::try_from("parity.io/somepath").is_err());
assert!(Authority::try_from("127.0.0.1:8545/somepath").is_err());
assert!(Authority::try_from("127.0.0.1:-1337").is_err());
}

#[test]
Expand Down
7 changes: 6 additions & 1 deletion server/src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
//! Various middleware implementations for RPC specific purposes.

/// HTTP Host filtering middleware.
mod host_filter;
/// Proxy `GET /path` to internal RPC methods.
pub mod proxy_get_request;
mod proxy_get_request;

pub use host_filter::*;
pub use proxy_get_request::*;
Comment on lines +37 to +38
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Given that host_filter has types like Authority etc in, maybe it's best to expose each one in its own module so it's obvious which types are for what (especially if we ever add any more)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an uri mod and simplified the host_filter mod, lemme know if you are happy with it? :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Def looks better now that the name Host is in all of the exposed things! Personally I liked that the Authority stuff lived here (because it's only used in this filter so keeping the code next to it felt good) but I'm easy either way :)

46 changes: 1 addition & 45 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use std::time::Duration;

use crate::future::{ConnectionGuard, ServerHandle, StopHandle};
use crate::logger::{Logger, TransportProtocol};
use crate::transport::http::fetch_authority;
use crate::transport::{http, ws};

use futures_util::future::{self, Either, FutureExt};
Expand All @@ -43,7 +42,7 @@ use futures_util::io::{BufReader, BufWriter};
use hyper::body::HttpBody;
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;

use jsonrpsee_core::server::{AllowHosts, Authority, AuthorityError, Methods, WhitelistedHosts};
use jsonrpsee_core::server::Methods;
use jsonrpsee_core::traits::IdProvider;
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};

Expand Down Expand Up @@ -128,7 +127,6 @@ where
let max_response_body_size = self.cfg.max_response_body_size;
let max_log_length = self.cfg.max_log_length;
let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection;
let allow_hosts = self.cfg.allow_hosts;
let logger = self.logger;
let batch_requests_config = self.cfg.batch_requests_config;
let id_provider = self.id_provider;
Expand All @@ -148,7 +146,6 @@ where
let data = ProcessConnection {
remote_addr,
methods: methods.clone(),
allow_hosts: allow_hosts.clone(),
max_request_body_size,
max_response_body_size,
max_log_length,
Expand Down Expand Up @@ -209,8 +206,6 @@ struct Settings {
max_log_length: u32,
/// Maximum number of subscriptions per connection.
max_subscriptions_per_connection: u32,
/// Host filtering.
allow_hosts: AllowHosts,
/// Whether batch requests are supported by this server or not.
batch_requests_config: BatchRequestConfig,
/// Custom tokio runtime to run the server on.
Expand Down Expand Up @@ -245,7 +240,6 @@ impl Default for Settings {
max_connections: MAX_CONNECTIONS,
max_subscriptions_per_connection: 1024,
batch_requests_config: BatchRequestConfig::Unlimited,
allow_hosts: AllowHosts::Any,
tokio_runtime: None,
ping_interval: Duration::from_secs(60),
enable_http: true,
Expand Down Expand Up @@ -420,30 +414,6 @@ impl<B, L> Builder<B, L> {
self
}

/// Enables host filtering and allow only the specified hosts.
///
/// Default: no host filtering is enabled.
pub fn host_filter<T: IntoIterator<Item = U>, U: TryInto<Authority>>(
mut self,
allow_only: T,
) -> Result<Self, AuthorityError>
where
T: IntoIterator<Item = U>,
U: TryInto<Authority, Error = AuthorityError>,
{
let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
self.settings.allow_hosts = AllowHosts::Only(WhitelistedHosts::from(allow_only?));
Ok(self)
}

/// Disable host filtering and allow all.
///
/// Default: no host filtering is enabled.
pub fn disable_host_filtering(mut self) -> Self {
self.settings.allow_hosts = AllowHosts::Any;
self
}

/// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied to the RPC service.
///
/// Default: No tower layers are applied to the RPC service.
Expand Down Expand Up @@ -592,8 +562,6 @@ pub(crate) struct ServiceData<L: Logger> {
pub(crate) remote_addr: SocketAddr,
/// Registered server methods.
pub(crate) methods: Methods,
/// Access control.
pub(crate) allow_hosts: AllowHosts,
/// Max request body size.
pub(crate) max_request_body_size: u32,
/// Max response body size.
Expand Down Expand Up @@ -652,15 +620,6 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
fn call(&mut self, request: hyper::Request<hyper::Body>) -> Self::Future {
tracing::trace!("{:?}", request);

let Some(authority) = fetch_authority(&request) else {
return async { Ok(http::response::malformed()) }.boxed();
};

if let Err(e) = self.inner.allow_hosts.verify(authority) {
tracing::debug!("Denied request: {}", e);
return async { Ok(http::response::host_not_allowed()) }.boxed();
}

let is_upgrade_request = is_upgrade_request(&request);

if self.inner.enable_ws && is_upgrade_request {
Expand Down Expand Up @@ -727,8 +686,6 @@ struct ProcessConnection<L> {
remote_addr: SocketAddr,
/// Registered server methods.
methods: Methods,
/// Access control.
allow_hosts: AllowHosts,
/// Max request body size.
max_request_body_size: u32,
/// Max response body size.
Expand Down Expand Up @@ -806,7 +763,6 @@ fn process_connection<'a, L: Logger, B, U>(
inner: ServiceData {
remote_addr: cfg.remote_addr,
methods: cfg.methods,
allow_hosts: cfg.allow_hosts,
max_request_body_size: cfg.max_request_body_size,
max_response_body_size: cfg.max_response_body_size,
max_log_length: cfg.max_log_length,
Expand Down
Loading