diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 9915300453..aca58d3cc3 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -15,6 +15,6 @@ hyper = { version = "0.14", features = ["full"] } log = "0.4" serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" -soketto = "0.5" +soketto = "0.6" tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.6", features = ["compat"] } diff --git a/test-utils/src/types.rs b/test-utils/src/types.rs index 419a937a41..5c5aa8c92c 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/types.rs @@ -199,12 +199,12 @@ async fn server_backend(listener: tokio::net::TcpListener, mut exit: Receiver<() async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut exit: Receiver<()>) { let mut server = Server::new(socket.compat()); - let websocket_key = match server.receive_request().await { - Ok(req) => req.into_key(), + let key = match server.receive_request().await { + Ok(req) => req.key(), Err(_) => return, }; - let accept = server.send_response(&Response::Accept { key: &websocket_key, protocol: None }).await; + let accept = server.send_response(&Response::Accept { key, protocol: None }).await; if accept.is_err() { return; diff --git a/types/Cargo.toml b/types/Cargo.toml index 31c2c32779..6964760be6 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -18,5 +18,5 @@ log = { version = "0.4", default-features = false } serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"] } thiserror = "1.0" -soketto = "0.5" +soketto = "0.6" hyper = "0.14" diff --git a/types/src/error.rs b/types/src/error.rs index f502c27b11..84a699971e 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -81,6 +81,9 @@ pub enum Error { /// Configured max number of request slots exceeded. #[error("Configured max number of request slots exceeded")] MaxSlotsExceeded, + /// List passed into `set_allowed_origins` was empty + #[error("Must set at least one allowed origin")] + EmptyAllowedOrigins, /// Custom error. #[error("Custom error: {0}")] Custom(String), diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 3a52d54bb7..04949e29bc 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -27,7 +27,7 @@ jsonrpsee-types = { path = "../types", version = "0.2.0" } log = "0.4" serde = "1" serde_json = "1" -soketto = "0.5" +soketto = "0.6" pin-project = "1" thiserror = "1" url = "2" diff --git a/ws-server/Cargo.toml b/ws-server/Cargo.toml index 20ad2a1619..6601735702 100644 --- a/ws-server/Cargo.toml +++ b/ws-server/Cargo.toml @@ -19,7 +19,7 @@ log = "0.4" rustc-hash = "1.1.0" serde = { version = "1", default-features = false, features = ["derive"] } serde_json = { version = "1", features = ["raw_value"] } -soketto = "0.5" +soketto = "0.6" tokio = { version = "1", features = ["net", "rt-multi-thread", "macros"] } tokio-stream = { version = "0.1.1", features = ["net"] } tokio-util = { version = "0.6", features = ["compat"] } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 62f82daac4..21bcb8a8cd 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -76,12 +76,14 @@ impl Server { let mut incoming = TcpListenerStream::new(self.listener); let methods = self.methods; let conn_counter = Arc::new(()); - let cfg = self.cfg; let mut id = 0; while let Some(socket) = incoming.next().await { if let Ok(socket) = socket { - socket.set_nodelay(true).unwrap_or_else(|e| panic!("Could not set NODELAY on socket: {:?}", e)); + if let Err(e) = socket.set_nodelay(true) { + log::error!("Could not set NODELAY on socket: {:?}", e); + continue; + } if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { log::warn!("Too many connections. Try again in a while"); @@ -89,6 +91,7 @@ impl Server { } let methods = methods.clone(); let counter = conn_counter.clone(); + let cfg = self.cfg.clone(); tokio::spawn(async move { let r = background_task(socket, id, methods, cfg).await; @@ -111,14 +114,24 @@ async fn background_task( // For each incoming background_task we perform a handshake. let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat()))); - let websocket_key = { + let key = { let req = server.receive_request().await?; - req.into_key() + + cfg.allowed_origins.verify(req.headers().origin).map(|()| req.key()) }; - // Here we accept the client unconditionally. - let accept = Response::Accept { key: &websocket_key, protocol: None }; - server.send_response(&accept).await?; + match key { + Ok(key) => { + let accept = Response::Accept { key, protocol: None }; + server.send_response(&accept).await?; + } + Err(error) => { + let reject = Response::Reject { status_code: 403 }; + server.send_response(&reject).await?; + + return Err(error); + } + } // And we can finally transition to a websocket background_task. let (mut sender, mut receiver) = server.into_builder().finish(); @@ -185,18 +198,44 @@ async fn background_task( } } +#[derive(Debug, Clone)] +enum AllowedOrigins { + Any, + OneOf(Arc<[String]>), +} + +impl AllowedOrigins { + fn verify(&self, origin: Option<&[u8]>) -> Result<(), Error> { + if let (AllowedOrigins::OneOf(list), Some(origin)) = (self, origin) { + if !list.iter().any(|o| o.as_bytes() == origin) { + let error = format!("Origin denied: {}", String::from_utf8_lossy(origin)); + log::warn!("{}", error); + return Err(Error::Request(error)); + } + } + + Ok(()) + } +} + /// JSON-RPC Websocket server settings. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] struct Settings { /// Maximum size in bytes of a request. max_request_body_size: u32, /// Maximum number of incoming connections allowed. max_connections: u64, + /// Cross-origin policy by which to accept or deny incoming requests. + allowed_origins: AllowedOrigins, } impl Default for Settings { fn default() -> Self { - Self { max_request_body_size: TEN_MB_SIZE_BYTES, max_connections: MAX_CONNECTIONS } + Self { + max_request_body_size: TEN_MB_SIZE_BYTES, + max_connections: MAX_CONNECTIONS, + allowed_origins: AllowedOrigins::Any, + } } } @@ -219,6 +258,41 @@ impl Builder { self } + /// Set a list of allowed origins. During the handshake, the `Origin` header will be + /// checked against the list, connections without a matching origin will be denied. + /// Values should include protocol. + /// + /// ```rust + /// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default(); + /// builder.set_allowed_origins(vec!["https://example.com"]); + /// ``` + /// + /// By default allows any `Origin`. + /// + /// Will return an error if `list` is empty. Use [`allow_all_origins`](Builder::allow_all_origins) to restore the default. + pub fn set_allowed_origins(mut self, list: List) -> Result + where + List: IntoIterator, + Origin: Into, + { + let list: Arc<_> = list.into_iter().map(Into::into).collect(); + + if list.len() == 0 { + return Err(Error::EmptyAllowedOrigins); + } + + self.settings.allowed_origins = AllowedOrigins::OneOf(list); + + Ok(self) + } + + /// Restores the default behavior of allowing connections with `Origin` header + /// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins). + pub fn allow_all_origins(mut self) -> Self { + self.settings.allowed_origins = AllowedOrigins::Any; + self + } + /// Finalize the configuration of the server. Consumes the [`Builder`]. pub async fn build(self, addr: impl ToSocketAddrs) -> Result { let listener = TcpListener::bind(addr).await?;