From 94c881b83c8fd5a3facba17a90c5a386467c44eb Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 5 Oct 2021 17:33:07 +0200 Subject: [PATCH] ws client redirections (#397) * feat(ws client): support redirections * reuse socket * reuse socket * add hacks * fix build * remove hacks * fix bad merge * address grumbles * fix grumbles * fix grumbles * fix nit * add redirection test * Update test-utils/src/types.rs * Resolved todo * Check that redirected client actually works * Rename test-utils "types" to "mocks" * Fix windows test (?) * fmt * What is wrong with you windows? * Ignore redirect test on windows * fix bad transport errors * debug windows tests * update soketto * maybe fix windows test * add config flag for max redirections * revert faulty change. Relative reference must start with either `/` or `//` * revert windows path * use manual join paths * remove url dep * Update ws-client/src/tests.rs * default max redirects 5 * remove needless clone vec * fix bad merge * cmon CI run Co-authored-by: David Palm --- http-client/src/tests.rs | 2 +- http-server/src/tests.rs | 2 +- test-utils/Cargo.toml | 2 +- test-utils/src/helpers.rs | 2 +- test-utils/src/lib.rs | 2 +- test-utils/src/{types.rs => mocks.rs} | 63 +++++- ws-client/Cargo.toml | 7 +- ws-client/src/client.rs | 33 +-- ws-client/src/tests.rs | 33 ++- ws-client/src/transport.rs | 285 ++++++++++++++++---------- ws-server/src/tests.rs | 2 +- 11 files changed, 300 insertions(+), 133 deletions(-) rename test-utils/src/{types.rs => mocks.rs} (84%) diff --git a/http-client/src/tests.rs b/http-client/src/tests.rs index 8ce53edb98..a7ba27cc8a 100644 --- a/http-client/src/tests.rs +++ b/http-client/src/tests.rs @@ -31,7 +31,7 @@ use crate::types::{ }; use crate::HttpClientBuilder; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::Id; +use jsonrpsee_test_utils::mocks::Id; use jsonrpsee_test_utils::TimeoutFutureExt; #[tokio::test] diff --git a/http-server/src/tests.rs b/http-server/src/tests.rs index f122128e76..55dcc9834c 100644 --- a/http-server/src/tests.rs +++ b/http-server/src/tests.rs @@ -32,7 +32,7 @@ use crate::types::error::{CallError, Error}; use crate::{server::StopHandle, HttpServerBuilder, RpcModule}; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext}; +use jsonrpsee_test_utils::mocks::{Id, StatusCode, TestContext}; use jsonrpsee_test_utils::TimeoutFutureExt; use serde_json::Value as JsonValue; use tokio::task::JoinHandle; diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 05bdf3371e..e4bfac7f27 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -15,6 +15,6 @@ hyper = { version = "0.14.10", features = ["full"] } log = "0.4" serde = { version = "1", default-features = false, features = ["derive"] } serde_json = "1" -soketto = "0.7" +soketto = { version = "0.7", features = ["http"] } tokio = { version = "1", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.6", features = ["compat"] } diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs index 06cbe068b3..7cb3e7521f 100644 --- a/test-utils/src/helpers.rs +++ b/test-utils/src/helpers.rs @@ -24,7 +24,7 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::types::{Body, HttpResponse, Id, Uri}; +use crate::mocks::{Body, HttpResponse, Id, Uri}; use hyper::service::{make_service_fn, service_fn}; use hyper::{Request, Response, Server}; use serde_json::Value; diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 6f6133eeaa..c47211bc62 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -32,7 +32,7 @@ use std::{future::Future, time::Duration}; use tokio::time::{timeout, Timeout}; pub mod helpers; -pub mod types; +pub mod mocks; /// Helper extension trait which allows to limit execution time for the futures. /// It is helpful in tests to ensure that no future will ever get stuck forever. diff --git a/test-utils/src/types.rs b/test-utils/src/mocks.rs similarity index 84% rename from test-utils/src/types.rs rename to test-utils/src/mocks.rs index 5d3da21d6f..c3bb183ba0 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/mocks.rs @@ -34,10 +34,8 @@ use futures_util::{ stream::{self, StreamExt}, }; use serde::{Deserialize, Serialize}; -use soketto::handshake::{self, server::Response, Error as SokettoError, Server}; -use std::io; -use std::net::SocketAddr; -use std::time::Duration; +use soketto::handshake::{self, http::is_upgrade_request, server::Response, Error as SokettoError, Server}; +use std::{io, net::SocketAddr, time::Duration}; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; @@ -314,3 +312,60 @@ async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut ex } } } + +// Run a WebSocket server running on localhost that redirects requests for testing. +// Requests to any url except for `/myblock/two` will redirect one or two times (HTTP 301) and eventually end up in `/myblock/two`. +pub fn ws_server_with_redirect(other_server: String) -> String { + let addr = ([127, 0, 0, 1], 0).into(); + + let service = hyper::service::make_service_fn(move |_| { + let other_server = other_server.clone(); + async move { + Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| { + let other_server = other_server.clone(); + async move { handler(req, other_server).await } + })) + } + }); + let server = hyper::Server::bind(&addr).serve(service); + let addr = server.local_addr(); + + tokio::spawn(async move { server.await }); + format!("ws://{}", addr) +} + +/// Handle incoming HTTP Requests. +async fn handler( + req: hyper::Request, + other_server: String, +) -> Result, soketto::BoxedError> { + if is_upgrade_request(&req) { + log::debug!("{:?}", req); + + match req.uri().path() { + "/myblock/two" => { + let response = hyper::Response::builder() + .status(301) + .header("Location", other_server) + .body(Body::empty()) + .unwrap(); + Ok(response) + } + "/myblock/one" => { + let response = + hyper::Response::builder().status(301).header("Location", "two").body(Body::empty()).unwrap(); + Ok(response) + } + _ => { + let response = hyper::Response::builder() + .status(301) + .header("Location", "/myblock/one") + .body(Body::empty()) + .unwrap(); + Ok(response) + } + } + } else { + panic!("expect upgrade to WS"); + } +} diff --git a/ws-client/Cargo.toml b/ws-client/Cargo.toml index 11074d919c..b57b895e0d 100644 --- a/ws-client/Cargo.toml +++ b/ws-client/Cargo.toml @@ -18,15 +18,16 @@ arrayvec = "0.7.1" async-trait = "0.1" fnv = "1" futures = { version = "0.3.14", default-features = false, features = ["std"] } +http = "0.2" jsonrpsee-types = { path = "../types", version = "0.3.0" } log = "0.4" +pin-project = "1" +rustls-native-certs = "0.5.0" serde = "1" serde_json = "1" soketto = "0.7" -pin-project = "1" thiserror = "1" -url = "2" -rustls-native-certs = "0.5.0" [dev-dependencies] jsonrpsee-test-utils = { path = "../test-utils" } +env_logger = "0.9" diff --git a/ws-client/src/client.rs b/ws-client/src/client.rs index 048c99a83d..1a1dfc8cd4 100644 --- a/ws-client/src/client.rs +++ b/ws-client/src/client.rs @@ -24,7 +24,7 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Receiver as WsReceiver, Sender as WsSender, Target, WsTransportClientBuilder}; +use crate::transport::{Receiver as WsReceiver, Sender as WsSender, WsHandshakeError, WsTransportClientBuilder}; use crate::types::{ traits::{Client, SubscriptionClient}, v2::{Id, Notification, NotificationSer, ParamsSer, RequestSer, Response, RpcError, SubscriptionResponse}, @@ -46,10 +46,13 @@ use futures::{ prelude::*, sink::SinkExt, }; +use http::uri::{InvalidUri, Uri}; use tokio::sync::Mutex; use serde::de::DeserializeOwned; -use std::{borrow::Cow, time::Duration}; +use std::{borrow::Cow, convert::TryInto, time::Duration}; + +pub use soketto::handshake::client::Header; /// Wrapper over a [`oneshot::Receiver`](futures::channel::oneshot::Receiver) that reads /// the underlying channel once and then stores the result in String. @@ -109,6 +112,7 @@ pub struct WsClientBuilder<'a> { origin_header: Option>, max_concurrent_requests: usize, max_notifs_per_subscription: usize, + max_redirections: usize, } impl<'a> Default for WsClientBuilder<'a> { @@ -121,6 +125,7 @@ impl<'a> Default for WsClientBuilder<'a> { origin_header: None, max_concurrent_requests: 256, max_notifs_per_subscription: 1024, + max_redirections: 5, } } } @@ -151,8 +156,8 @@ impl<'a> WsClientBuilder<'a> { } /// Set origin header to pass during the handshake. - pub fn origin_header(mut self, origin: &'a str) -> Self { - self.origin_header = Some(Cow::Borrowed(origin)); + pub fn origin_header(mut self, origin: Cow<'a, str>) -> Self { + self.origin_header = Some(origin); self } @@ -176,18 +181,19 @@ impl<'a> WsClientBuilder<'a> { self } + /// Set the max number of redirections to perform until a connection is regarded as failed. + pub fn max_redirections(mut self, redirect: usize) -> Self { + self.max_redirections = redirect; + self + } + /// Build the client with specified URL to connect to. - /// If the port number is missing from the URL, the default port number is used. - /// - /// - /// `ws://host` - port 80 is used - /// - /// `wss://host` - port 443 is used + /// You must provide the port number in the URL. /// /// ## Panics /// /// Panics if being called outside of `tokio` runtime context. - pub async fn build(self, url: &'a str) -> Result { + pub async fn build(self, uri: &'a str) -> Result { let certificate_store = self.certificate_store; let max_capacity_per_subscription = self.max_notifs_per_subscription; let max_concurrent_requests = self.max_concurrent_requests; @@ -195,12 +201,15 @@ impl<'a> WsClientBuilder<'a> { let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests); let (err_tx, err_rx) = oneshot::channel(); + let uri: Uri = uri.parse().map_err(|e: InvalidUri| Error::Transport(e.into()))?; + let builder = WsTransportClientBuilder { certificate_store, - target: Target::parse(url).map_err(|e| Error::Transport(e.into()))?, + target: uri.try_into().map_err(|e: WsHandshakeError| Error::Transport(e.into()))?, timeout: self.connection_timeout, origin_header: self.origin_header, max_request_body_size: self.max_request_body_size, + max_redirections: self.max_redirections, }; let (sender, receiver) = builder.build().await.map_err(|e| Error::Transport(e.into()))?; diff --git a/ws-client/src/tests.rs b/ws-client/src/tests.rs index 2b9f74cbee..b2b0261826 100644 --- a/ws-client/src/tests.rs +++ b/ws-client/src/tests.rs @@ -32,7 +32,7 @@ use crate::types::{ }; use crate::WsClientBuilder; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::{Id, WebSocketTestServer}; +use jsonrpsee_test_utils::mocks::{Id, WebSocketTestServer}; use jsonrpsee_test_utils::TimeoutFutureExt; use serde_json::Value as JsonValue; @@ -263,3 +263,34 @@ fn assert_error_response(err: Error, exp: ErrorObject) { e => panic!("Expected error: \"{}\", got: {:?}", err, e), }; } + +#[tokio::test] +async fn redirections() { + let _ = env_logger::try_init(); + let expected = "abc 123"; + let server = WebSocketTestServer::with_hardcoded_response( + "127.0.0.1:0".parse().unwrap(), + ok_response(expected.into(), Id::Num(0)), + ) + .with_default_timeout() + .await + .unwrap(); + + let server_url = format!("ws://{}", server.local_addr()); + let redirect_url = jsonrpsee_test_utils::mocks::ws_server_with_redirect(server_url); + + // The client will first connect to a server that only performs re-directions and finally + // redirect to another server to complete the handshake. + let client = WsClientBuilder::default().build(&redirect_url).with_default_timeout().await; + // It's an ok client + let client = match client { + Ok(Ok(client)) => client, + Ok(Err(e)) => panic!("WsClient builder failed with: {:?}", e), + Err(e) => panic!("WsClient builder timed out with: {:?}", e), + }; + // It's connected + assert!(client.is_connected()); + // It works + let response = client.request::("anything", ParamsSer::NoParams).with_default_timeout().await.unwrap(); + assert_eq!(response.unwrap(), String::from(expected)); +} diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index 7a00ca6c1a..e18f34229f 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -24,12 +24,21 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::stream::EitherStream; use arrayvec::ArrayVec; use futures::io::{BufReader, BufWriter}; -use futures::prelude::*; +use http::Uri; use soketto::connection; -use soketto::handshake::client::{Client as WsRawClient, Header, ServerResponse}; -use std::{borrow::Cow, io, net::SocketAddr, sync::Arc, time::Duration}; +use soketto::handshake::client::{Client as WsHandshakeClient, Header, ServerResponse}; +use std::convert::TryInto; +use std::{ + borrow::Cow, + convert::TryFrom, + io, + net::{SocketAddr, ToSocketAddrs}, + sync::Arc, + time::Duration, +}; use thiserror::Error; use tokio::net::TcpStream; use tokio_rustls::{ @@ -39,7 +48,7 @@ use tokio_rustls::{ TlsConnector, }; -type TlsOrPlain = crate::stream::EitherStream>; +type TlsOrPlain = EitherStream>; /// Sending end of WebSocket transport. #[derive(Debug)] @@ -67,6 +76,8 @@ pub struct WsTransportClientBuilder<'a> { pub origin_header: Option>, /// Max payload size pub max_request_body_size: u32, + /// Max number of redirections. + pub max_redirections: usize, } /// Stream mode, either plain TCP or TLS. @@ -95,42 +106,42 @@ pub enum CertificateStore { #[derive(Debug, Error)] pub enum WsHandshakeError { /// Failed to load system certs - #[error("Failed to load system certs: {}", 0)] + #[error("Failed to load system certs: {0}")] CertificateStore(io::Error), /// Invalid URL. - #[error("Invalid url: {}", 0)] + #[error("Invalid URL: {0}")] Url(Cow<'static, str>), /// Error when opening the TCP socket. - #[error("Error when opening the TCP socket: {}", 0)] + #[error("Error when opening the TCP socket: {0}")] Io(io::Error), /// Error in the transport layer. - #[error("Error in the WebSocket handshake: {}", 0)] + #[error("Error in the WebSocket handshake: {0}")] Transport(#[source] soketto::handshake::Error), /// Invalid DNS name error for TLS - #[error("Invalid DNS name: {}", 0)] + #[error("Invalid DNS name: {0}")] InvalidDnsName(#[source] InvalidDNSNameError), - /// RawServer rejected our handshake. - #[error("Connection rejected with status code: {}", status_code)] + /// Server rejected the handshake. + #[error("Connection rejected with status code: {status_code}")] Rejected { /// HTTP status code that the server returned. status_code: u16, }, /// Timeout while trying to connect. - #[error("Connection timeout exceeded: {}", 0)] + #[error("Connection timeout exceeded: {0:?}")] Timeout(Duration), /// Failed to resolve IP addresses for this hostname. - #[error("Failed to resolve IP addresses for this hostname: {}", 0)] + #[error("Failed to resolve IP addresses for this hostname: {0}")] ResolutionFailed(io::Error), /// Couldn't find any IP address for this hostname. - #[error("No IP address found for this hostname: {}", 0)] + #[error("No IP address found for this hostname: {0}")] NoAddressFound(String), } @@ -186,79 +197,143 @@ impl<'a> WsTransportClientBuilder<'a> { Mode::Plain => None, }; - let mut err = None; - for sockaddr in &self.target.sockaddrs { - match self.try_connect(*sockaddr, &connector).await { - Ok(res) => return Ok(res), - Err(e) => { - log::debug!("Failed to connect to sockaddr: {:?} with err: {:?}", sockaddr, e); - err = Some(Err(e)); - } - } - } - // NOTE(niklasad1): this is most likely unreachable because [`Url::socket_addrs`] doesn't - // return an empty `Vec` if no socket address was found for the host name. - err.unwrap_or(Err(WsHandshakeError::NoAddressFound(self.target.host))) + self.try_connect(connector).await } async fn try_connect( - &self, - sockaddr: SocketAddr, - tls_connector: &Option, + self, + mut tls_connector: Option, ) -> Result<(Sender, Receiver), WsHandshakeError> { - // Try establish the TCP connection. - let tcp_stream = { - let socket = TcpStream::connect(sockaddr); - let timeout = tokio::time::sleep(self.timeout); - futures::pin_mut!(socket, timeout); - match future::select(socket, timeout).await { - future::Either::Left((socket, _)) => { - let socket = socket?; - if let Err(err) = socket.set_nodelay(true) { - log::warn!("set nodelay failed: {:?}", err); - } - match tls_connector { - None => TlsOrPlain::Plain(socket), - Some(connector) => { - let dns_name = DNSNameRef::try_from_ascii_str(&self.target.host)?; - let tls_stream = connector.connect(dns_name, socket).await?; - TlsOrPlain::Tls(tls_stream) - } - } - } - future::Either::Right((_, _)) => return Err(WsHandshakeError::Timeout(self.timeout)), - } - }; - - log::debug!("Connecting to target: {:?}", self.target); - let mut client = WsRawClient::new( - BufReader::new(BufWriter::new(tcp_stream)), - &self.target.host_header, - &self.target.path_and_query, - ); - + let mut target = self.target; let mut headers: ArrayVec = ArrayVec::new(); + let mut err = None; if let Some(origin) = self.origin_header.as_ref() { headers.push(Header { name: "Origin", value: origin.as_bytes() }); } - client.set_headers(&headers); + for _ in 0..self.max_redirections { + log::debug!("Connecting to target: {:?}", target); + + // The sockaddrs might get reused if the server replies with a relative URI. + let sockaddrs = std::mem::take(&mut target.sockaddrs); + for sockaddr in &sockaddrs { + let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, &tls_connector).await { + Ok(stream) => stream, + Err(e) => { + log::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; + let mut client = WsHandshakeClient::new( + BufReader::new(BufWriter::new(tcp_stream)), + &target.host_header, + &target.path_and_query, + ); + + client.set_headers(&headers); + + // Perform the initial handshake. + match client.handshake().await { + Ok(ServerResponse::Accepted { .. }) => { + log::info!("Connection established to target: {:?}", target); + let mut builder = client.into_builder(); + builder.set_max_message_size(self.max_request_body_size as usize); + let (sender, receiver) = builder.finish(); + return Ok((Sender { inner: sender }, Receiver { inner: receiver })); + } - // Perform the initial handshake. - match client.handshake().await? { - ServerResponse::Accepted { .. } => {} - ServerResponse::Rejected { status_code } | ServerResponse::Redirect { status_code, .. } => { - // TODO: HTTP redirects also lead here #339. - return Err(WsHandshakeError::Rejected { status_code }); + Ok(ServerResponse::Rejected { status_code }) => { + log::debug!("Connection rejected: {:?}", status_code); + err = Some(Err(WsHandshakeError::Rejected { status_code })); + } + Ok(ServerResponse::Redirect { status_code, location }) => { + log::debug!("Redirection: status_code: {}, location: {}", status_code, location); + match location.parse::() { + // redirection with absolute path => need to lookup. + Ok(uri) => { + // Absolute URI. + if uri.scheme().is_some() { + target = uri.try_into()?; + tls_connector = match target.mode { + Mode::Tls => { + let mut client_config = ClientConfig::default(); + if let CertificateStore::Native = self.certificate_store { + client_config.root_store = rustls_native_certs::load_native_certs() + .map_err(|(_, e)| WsHandshakeError::CertificateStore(e))?; + } + Some(Arc::new(client_config).into()) + } + Mode::Plain => None, + }; + break; + } + // Relative URI. + else { + // Replace the entire path_and_query if `location` starts with `/` or `//`. + if location.starts_with('/') { + target.path_and_query = location; + } else { + match target.path_and_query.rfind('/') { + Some(offset) => { + target.path_and_query.replace_range(offset + 1.., &location) + } + None => { + err = Some(Err(WsHandshakeError::Url( + format!( + "path_and_query: {}; this is a bug it must contain `/` please open issue", + location + ) + .into(), + ))); + continue; + } + }; + } + target.sockaddrs = sockaddrs; + break; + } + } + Err(e) => { + err = Some(Err(WsHandshakeError::Url(e.to_string().into()))); + } + }; + } + Err(e) => { + err = Some(Err(e.into())); + } + }; } } + err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host))) + } +} - // If the handshake succeeded, return. - let mut builder = client.into_builder(); - builder.set_max_message_size(self.max_request_body_size as usize); - let (sender, receiver) = builder.finish(); - Ok((Sender { inner: sender }, Receiver { inner: receiver })) +async fn connect( + sockaddr: SocketAddr, + timeout_dur: Duration, + host: &str, + tls_connector: &Option, +) -> Result>, WsHandshakeError> { + let socket = TcpStream::connect(sockaddr); + let timeout = tokio::time::sleep(timeout_dur); + tokio::select! { + socket = socket => { + let socket = socket?; + if let Err(err) = socket.set_nodelay(true) { + log::warn!("set nodelay failed: {:?}", err); + } + match tls_connector { + None => Ok(TlsOrPlain::Plain(socket)), + Some(connector) => { + let dns_name = DNSNameRef::try_from_ascii_str(host)?; + let tls_stream = connector.connect(dns_name, socket).await?; + Ok(TlsOrPlain::Tls(tls_stream)) + } + } + } + _ = timeout => Err(WsHandshakeError::Timeout(timeout_dur)) } } @@ -301,34 +376,32 @@ pub struct Target { path_and_query: String, } -impl Target { - /// Parse an URL String to a WebSocket address. - pub fn parse(url: impl AsRef) -> Result { - let url = - url::Url::parse(url.as_ref()).map_err(|e| WsHandshakeError::Url(format!("Invalid URL: {}", e).into()))?; - let mode = match url.scheme() { - "ws" => Mode::Plain, - "wss" => Mode::Tls, +impl TryFrom for Target { + type Error = WsHandshakeError; + + fn try_from(uri: Uri) -> Result { + let mode = match uri.scheme_str() { + Some("ws") => Mode::Plain, + Some("wss") => Mode::Tls, _ => return Err(WsHandshakeError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())), }; - let host = - url.host_str().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?; - let port = url.port_or_known_default().ok_or_else(|| WsHandshakeError::Url("No port number in URL".into()))?; + let host = uri.host().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?; + let port = uri + .port_u16() + .ok_or_else(|| WsHandshakeError::Url("No port number in URL (default port is not supported)".into()))?; let host_header = format!("{}:{}", host, port); - let mut path_and_query = url.path().to_owned(); - if let Some(query) = url.query() { - path_and_query.push('?'); - path_and_query.push_str(query); - } - // NOTE: `Url::socket_addrs` is using the default port if it's missing (ws:// - 80, wss:// - 443) - let sockaddrs = url.socket_addrs(|| None).map_err(WsHandshakeError::ResolutionFailed)?; - Ok(Self { sockaddrs, host, host_header, mode, path_and_query }) + let parts = uri.into_parts(); + let path_and_query = parts.path_and_query.ok_or_else(|| WsHandshakeError::Url("No path in URL".into()))?; + let sockaddrs = host_header.to_socket_addrs().map_err(WsHandshakeError::ResolutionFailed)?; + Ok(Self { sockaddrs: sockaddrs.collect(), host, host_header, mode, path_and_query: path_and_query.to_string() }) } } #[cfg(test)] mod tests { - use super::{Mode, Target, WsHandshakeError}; + use super::{Mode, Target, Uri, WsHandshakeError}; + use http::uri::InvalidUri; + use std::convert::TryInto; fn assert_ws_target(target: Target, host: &str, host_header: &str, mode: Mode, path_and_query: &str) { assert_eq!(&target.host, host); @@ -337,53 +410,51 @@ mod tests { assert_eq!(&target.path_and_query, path_and_query); } + fn parse_target(uri: &str) -> Result { + uri.parse::().map_err(|e: InvalidUri| WsHandshakeError::Url(e.to_string().into()))?.try_into() + } + #[test] fn ws_works() { - let target = Target::parse("ws://127.0.0.1:9933").unwrap(); + let target = parse_target("ws://127.0.0.1:9933").unwrap(); assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/"); } #[test] fn wss_works() { - let target = Target::parse("wss://kusama-rpc.polkadot.io:443").unwrap(); + let target = parse_target("wss://kusama-rpc.polkadot.io:443").unwrap(); assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:443", Mode::Tls, "/"); } #[test] fn faulty_url_scheme() { - let err = Target::parse("http://kusama-rpc.polkadot.io:443").unwrap_err(); + let err = parse_target("http://kusama-rpc.polkadot.io:443").unwrap_err(); assert!(matches!(err, WsHandshakeError::Url(_))); } #[test] fn faulty_port() { - let err = Target::parse("ws://127.0.0.1:-43").unwrap_err(); + let err = parse_target("ws://127.0.0.1:-43").unwrap_err(); assert!(matches!(err, WsHandshakeError::Url(_))); - let err = Target::parse("ws://127.0.0.1:99999").unwrap_err(); + let err = parse_target("ws://127.0.0.1:99999").unwrap_err(); assert!(matches!(err, WsHandshakeError::Url(_))); } - #[test] - fn default_port_works() { - let target = Target::parse("ws://127.0.0.1").unwrap(); - assert_ws_target(target, "127.0.0.1", "127.0.0.1:80", Mode::Plain, "/"); - } - #[test] fn url_with_path_works() { - let target = Target::parse("wss://127.0.0.1/my-special-path").unwrap(); + let target = parse_target("wss://127.0.0.1:443/my-special-path").unwrap(); assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my-special-path"); } #[test] fn url_with_query_works() { - let target = Target::parse("wss://127.0.0.1/my?name1=value1&name2=value2").unwrap(); + let target = parse_target("wss://127.0.0.1:443/my?name1=value1&name2=value2").unwrap(); assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my?name1=value1&name2=value2"); } #[test] fn url_with_fragment_is_ignored() { - let target = Target::parse("wss://127.0.0.1/my.htm#ignore").unwrap(); + let target = parse_target("wss://127.0.0.1:443/my.htm#ignore").unwrap(); assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my.htm"); } } diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index d1b9cc7dbd..5739cd37ba 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -30,7 +30,7 @@ use crate::types::error::{CallError, Error}; use crate::{future::StopHandle, RpcModule, WsServerBuilder}; use anyhow::anyhow; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient, WebSocketTestError}; +use jsonrpsee_test_utils::mocks::{Id, TestContext, WebSocketTestClient, WebSocketTestError}; use jsonrpsee_test_utils::TimeoutFutureExt; use serde_json::Value as JsonValue; use std::fmt;