Skip to content

Commit

Permalink
Cross-origin protection (#375)
Browse files Browse the repository at this point in the history
* Initial implementation

* Comments

* Send a 403 on denied origin

* Noodling around with `set_allowed_origins`

* Error on empty list

* Soketto 0.6

* fmt

* Add `Builder::allow_all_origins`, clarify doc comments

* Rename Cors -> AllowedOrigins, nits, no panic
  • Loading branch information
maciejhirsz authored Jun 18, 2021
1 parent 6c69a8c commit 26b0613
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 16 deletions.
2 changes: 1 addition & 1 deletion test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ hyper = { version = "0.14", features = ["full"] }
log = "0.4"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
soketto = "0.5"
soketto = "0.6"
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.6", features = ["compat"] }
6 changes: 3 additions & 3 deletions test-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ async fn server_backend(listener: tokio::net::TcpListener, mut exit: Receiver<()
async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut exit: Receiver<()>) {
let mut server = Server::new(socket.compat());

let websocket_key = match server.receive_request().await {
Ok(req) => req.into_key(),
let key = match server.receive_request().await {
Ok(req) => req.key(),
Err(_) => return,
};

let accept = server.send_response(&Response::Accept { key: &websocket_key, protocol: None }).await;
let accept = server.send_response(&Response::Accept { key, protocol: None }).await;

if accept.is_err() {
return;
Expand Down
2 changes: 1 addition & 1 deletion types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ log = { version = "0.4", default-features = false }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] }
thiserror = "1.0"
soketto = "0.5"
soketto = "0.6"
hyper = "0.14"
3 changes: 3 additions & 0 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pub enum Error {
/// Configured max number of request slots exceeded.
#[error("Configured max number of request slots exceeded")]
MaxSlotsExceeded,
/// List passed into `set_allowed_origins` was empty
#[error("Must set at least one allowed origin")]
EmptyAllowedOrigins,
/// Custom error.
#[error("Custom error: {0}")]
Custom(String),
Expand Down
2 changes: 1 addition & 1 deletion ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jsonrpsee-types = { path = "../types", version = "0.2.0" }
log = "0.4"
serde = "1"
serde_json = "1"
soketto = "0.5"
soketto = "0.6"
pin-project = "1"
thiserror = "1"
url = "2"
Expand Down
2 changes: 1 addition & 1 deletion ws-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ log = "0.4"
rustc-hash = "1.1.0"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", features = ["raw_value"] }
soketto = "0.5"
soketto = "0.6"
tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] }
tokio-stream = { version = "0.1.1", features = ["net"] }
tokio-util = { version = "0.6", features = ["compat"] }
Expand Down
92 changes: 83 additions & 9 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,22 @@ impl Server {
let mut incoming = TcpListenerStream::new(self.listener);
let methods = self.methods;
let conn_counter = Arc::new(());
let cfg = self.cfg;
let mut id = 0;

while let Some(socket) = incoming.next().await {
if let Ok(socket) = socket {
socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e));
if let Err(e) = socket.set_nodelay(true) {
log::error!("Could not set NODELAY on socket: {:?}", e);
continue;
}

if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while");
continue;
}
let methods = methods.clone();
let counter = conn_counter.clone();
let cfg = self.cfg.clone();

tokio::spawn(async move {
let r = background_task(socket, id, methods, cfg).await;
Expand All @@ -111,14 +114,24 @@ async fn background_task(
// For each incoming background_task we perform a handshake.
let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));

let websocket_key = {
let key = {
let req = server.receive_request().await?;
req.into_key()

cfg.allowed_origins.verify(req.headers().origin).map(|()| req.key())
};

// Here we accept the client unconditionally.
let accept = Response::Accept { key: &websocket_key, protocol: None };
server.send_response(&accept).await?;
match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
server.send_response(&reject).await?;

return Err(error);
}
}

// And we can finally transition to a websocket background_task.
let (mut sender, mut receiver) = server.into_builder().finish();
Expand Down Expand Up @@ -185,18 +198,44 @@ async fn background_task(
}
}

#[derive(Debug, Clone)]
enum AllowedOrigins {
Any,
OneOf(Arc<[String]>),
}

impl AllowedOrigins {
fn verify(&self, origin: Option<&[u8]>) -> Result<(), Error> {
if let (AllowedOrigins::OneOf(list), Some(origin)) = (self, origin) {
if !list.iter().any(|o| o.as_bytes() == origin) {
let error = format!("Origin denied: {}", String::from_utf8_lossy(origin));
log::warn!("{}", error);
return Err(Error::Request(error));
}
}

Ok(())
}
}

/// JSON-RPC Websocket server settings.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
struct Settings {
/// Maximum size in bytes of a request.
max_request_body_size: u32,
/// Maximum number of incoming connections allowed.
max_connections: u64,
/// Cross-origin policy by which to accept or deny incoming requests.
allowed_origins: AllowedOrigins,
}

impl Default for Settings {
fn default() -> Self {
Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_connections: MAX_CONNECTIONS }
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_connections: MAX_CONNECTIONS,
allowed_origins: AllowedOrigins::Any,
}
}
}

Expand All @@ -219,6 +258,41 @@ impl Builder {
self
}

/// Set a list of allowed origins. During the handshake, the `Origin` header will be
/// checked against the list, connections without a matching origin will be denied.
/// Values should include protocol.
///
/// ```rust
/// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default();
/// builder.set_allowed_origins(vec!["https://example.com"]);
/// ```
///
/// By default allows any `Origin`.
///
/// Will return an error if `list` is empty. Use [`allow_all_origins`](Builder::allow_all_origins) to restore the default.
pub fn set_allowed_origins<Origin, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Origin>,
Origin: Into<String>,
{
let list: Arc<_> = list.into_iter().map(Into::into).collect();

if list.len() == 0 {
return Err(Error::EmptyAllowedOrigins);
}

self.settings.allowed_origins = AllowedOrigins::OneOf(list);

Ok(self)
}

/// Restores the default behavior of allowing connections with `Origin` header
/// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins).
pub fn allow_all_origins(mut self) -> Self {
self.settings.allowed_origins = AllowedOrigins::Any;
self
}

/// Finalize the configuration of the server. Consumes the [`Builder`].
pub async fn build(self, addr: impl ToSocketAddrs) -> Result<Server, Error> {
let listener = TcpListener::bind(addr).await?;
Expand Down

0 comments on commit 26b0613

Please sign in to comment.