Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(server): move host filtering to tower middleware #1179

Merged
merged 11 commits into from
Aug 11, 2023
2 changes: 1 addition & 1 deletion client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ where
let mut failed_calls = 0;

for _ in 0..json_rps.len() {
responses.push(Err(ErrorObject::borrowed(0, &"", None)));
responses.push(Err(ErrorObject::borrowed(0, "", None)));
}

for rp in json_rps {
Expand Down
5 changes: 1 addition & 4 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,20 @@ parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-timer = { version = "3", optional = true }
route-recognizer = { version = "0.3.1", optional = true }
http = { version = "0.2.9", optional = true }


[features]
default = []
http-helpers = ["hyper", "futures-util"]
server = [
"futures-util/alloc",
"route-recognizer",
"rustc-hash/std",
"parking_lot",
"rand",
"tokio/rt",
"tokio/sync",
"tokio/macros",
"tokio/time",
"http",
]
client = ["futures-util/sink", "tokio/sync"]
async-client = [
Expand Down
2 changes: 1 addition & 1 deletion core/src/client/async_client/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub(crate) fn process_batch_response(
};

for _ in range {
let err_obj = ErrorObject::borrowed(0, &"", None);
let err_obj = ErrorObject::borrowed(0, "", None);
responses.push(Err(err_obj));
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl MethodResponse {

let err = ResponsePayload::error_borrowed(ErrorObject::borrowed(
err_code,
&OVERSIZED_RESPONSE_MSG,
OVERSIZED_RESPONSE_MSG,
data.as_deref(),
));
let result =
Expand Down
3 changes: 0 additions & 3 deletions core/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,13 @@
mod error;
/// Helpers.
pub mod helpers;
/// Host filtering.
mod host_filtering;
/// JSON-RPC "modules" group sets of methods that belong together and handles method/subscription registration.
mod rpc_module;
/// Subscription related types.
mod subscription;

pub use error::*;
pub use helpers::{BatchResponseBuilder, BoundedWriter, MethodResponse, MethodSink};
pub use host_filtering::*;
pub use rpc_module::*;
pub use subscription::*;

Expand Down
6 changes: 1 addition & 5 deletions examples/examples/cors_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ async fn run_server() -> anyhow::Result<SocketAddr> {
// modifying requests / responses. These features are independent of one another
// and can also be used separately.
// In this example, we use both features.
let server = Server::builder()
.disable_host_filtering()
.set_middleware(middleware)
.build("127.0.0.1:0".parse::<SocketAddr>()?)
.await?;
let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0".parse::<SocketAddr>()?).await?;

let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| {
Expand Down
82 changes: 82 additions & 0 deletions examples/examples/host_filter_middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2019-2022 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any
// person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the
// Software without restriction, including without
// limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice
// shall be included in all copies or substantial portions
// of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

//! This example shows how to configure `host filtering` by tower middleware on the jsonrpsee server.
//!
//! The server whitelist only `example.com` and any call from localhost will be
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
//! rejected both by HTTP and WebSocket transports.

use std::net::SocketAddr;

use jsonrpsee::core::client::ClientT;
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::rpc_params;
use jsonrpsee::server::middleware::HostFilterLayer;
use jsonrpsee::server::{RpcModule, Server};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");

let addr = run_server().await?;
let url = format!("http://{}", addr);

// Use RPC client to get the response of `say_hello` method.
let client = HttpClientBuilder::default().build(&url)?;
// This call will be denied because only `example.com` URIs/hosts are allowed by the host filter.
let response = client.request::<String, _>("say_hello", rpc_params![]).await.unwrap_err();
println!("[main]: response: {}", response);

Ok(())
}

async fn run_server() -> anyhow::Result<SocketAddr> {
// Custom tower service to handle the RPC requests
let service_builder = tower::ServiceBuilder::new()
// For this example we only want to permit requests from `example.com`
// all other request are denied.
//
// `HostFilerLayer::new` only fails on invalid URIs..
.layer(HostFilterLayer::new(["example.com"]).unwrap());

let server = Server::builder().set_middleware(service_builder).build("127.0.0.1:0".parse::<SocketAddr>()?).await?;

let addr = server.local_addr()?;

let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| "lo").unwrap();

let handle = server.start(module);

// In this example we don't care about doing shutdown so let's it run forever.
// You may use the `ServerHandle` to shut it down or manage it yourself.
tokio::spawn(handle.stopped());

Ok(addr)
}
2 changes: 1 addition & 1 deletion examples/examples/http_proxy_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use std::time::Duration;
use jsonrpsee::core::client::ClientT;
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::rpc_params;
use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer;
use jsonrpsee::server::middleware::ProxyGetRequestLayer;
use jsonrpsee::server::{RpcModule, Server};

#[tokio::main]
Expand Down
3 changes: 3 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ tokio-util = { version = "0.7", features = ["compat"] }
tokio-stream = "0.1.7"
hyper = { version = "0.14", features = ["server", "http1", "http2"] }
tower = "0.4.13"
route-recognizer = "0.3.1"
http = "0.2.9"
thiserror = "1.0.44"

[dev-dependencies]
anyhow = "1"
Expand Down
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mod transport;

pub mod logger;
pub mod middleware;
pub mod uri;

#[cfg(test)]
mod tests;
Expand Down
181 changes: 181 additions & 0 deletions server/src/middleware/host_filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright 2019-2023 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any
// person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the
// Software without restriction, including without
// limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice
// shall be included in all copies or substantial portions
// of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

//! HTTP host validation middleware.

use crate::transport::http;
use crate::uri::{Authority, AuthorityError, Port};
use futures_util::{Future, FutureExt, TryFutureExt};
use hyper::{Body, Request, Response};
use route_recognizer::Router;
use std::error::Error as StdError;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};

/// Middleware to enable host filtering.
#[derive(Debug)]
pub struct HostFilterLayer(Arc<WhitelistedHosts>);

impl HostFilterLayer {
/// Enables host filtering and allow only the specified hosts.
pub fn new<T: IntoIterator<Item = U>, U: TryInto<Authority>>(allow_only: T) -> Result<Self, AuthorityError>
where
T: IntoIterator<Item = U>,
U: TryInto<Authority, Error = AuthorityError>,
{
let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
Ok(Self(Arc::new(WhitelistedHosts::from(allow_only?))))
}
}

impl<S> Layer<S> for HostFilterLayer {
type Service = HostFilter<S>;

fn layer(&self, inner: S) -> Self::Service {
HostFilter { inner, filter: self.0.clone() }
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Middleware to enable host filtering.
#[derive(Debug)]
pub struct HostFilter<S> {
inner: S,
filter: Arc<WhitelistedHosts>,
}

impl<S> Service<Request<Body>> for HostFilter<S>
where
S: Service<Request<Body>, Response = Response<Body>>,
S::Response: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>> + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = Box<dyn StdError + Send + Sync + 'static>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}

fn call(&mut self, request: Request<Body>) -> Self::Future {
let Some(authority) = Authority::from_http_request(&request) else {
return async { Ok(http::response::malformed()) }.boxed();
};

if self.filter.recognize(&authority) {
Box::pin(self.inner.call(request).map_err(Into::into))
} else {
tracing::debug!("Denied request: {:?}", request);
async { Ok(http::response::host_not_allowed()) }.boxed()
}
}
}

/// Represent the URL patterns that is whitelisted.
#[derive(Default, Debug, Clone)]
pub struct WhitelistedHosts(Router<Port>);

impl<T> From<T> for WhitelistedHosts
where
T: IntoIterator<Item = Authority>,
{
fn from(value: T) -> Self {
let mut router = Router::new();

for auth in value.into_iter() {
router.add(&auth.host, auth.port);
}

Self(router)
}
}

impl WhitelistedHosts {
fn recognize(&self, other: &Authority) -> bool {
if let Ok(p) = self.0.recognize(&other.host) {
let p = p.handler();

match (p, &other.port) {
(Port::Any, _) => true,
(Port::Default, Port::Default) => true,
(Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true,
_ => false,
}
} else {
false
}
}
}

#[cfg(test)]
mod tests {
use super::{Authority, WhitelistedHosts};

fn unwrap_auth(a: &str) -> Authority {
a.try_into().unwrap()
}

fn unwrap_filter(list: &[&str]) -> WhitelistedHosts {
let l: Vec<_> = list.into_iter().map(|&a| a.try_into().unwrap()).collect();
WhitelistedHosts::from(l)
}

#[test]
fn should_reject_if_header_not_on_the_list() {
let filter = unwrap_filter(&[]);
assert!(!filter.recognize(&unwrap_auth("parity.io")));
}

#[test]
fn should_accept_if_on_the_list() {
let filter = unwrap_filter(&["parity.io"]);
assert!(filter.recognize(&unwrap_auth("parity.io")));
}

#[test]
fn should_accept_if_on_the_list_with_port() {
let filter = unwrap_filter(&["parity.io:443"]);
assert!(filter.recognize(&unwrap_auth("parity.io:443")));
assert!(!filter.recognize(&unwrap_auth("parity.io")));
}

#[test]
fn should_support_wildcards() {
let filter = unwrap_filter(&["*.web3.site:*"]);
assert!(filter.recognize(&unwrap_auth("parity.web3.site:8180")));
assert!(filter.recognize(&unwrap_auth("parity.web3.site")));
}

#[test]
fn should_accept_with_and_without_default_port() {
let filter = unwrap_filter(&["https://parity.io:443"]);
assert!(filter.recognize(&unwrap_auth("https://parity.io")));
assert!(filter.recognize(&unwrap_auth("https://parity.io:443")));
}
}
Loading