From 3b39d6002300924e21131f0aa2fbc56edb4dce46 Mon Sep 17 00:00:00 2001 From: Niklas Date: Fri, 3 Dec 2021 19:20:42 +0100 Subject: [PATCH] only build TLS connector once --- ws-client/src/transport.rs | 65 +++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/ws-client/src/transport.rs b/ws-client/src/transport.rs index 41a5debd02..60b7e21305 100644 --- a/ws-client/src/transport.rs +++ b/ws-client/src/transport.rs @@ -35,7 +35,6 @@ use std::{ convert::TryFrom, io, net::{SocketAddr, ToSocketAddrs}, - sync::Arc, time::Duration, }; use thiserror::Error; @@ -174,21 +173,39 @@ impl<'a> WsTransportClientBuilder<'a> { let mut target = self.target; let mut err = None; + // Only build TLS connector if `wss` in URL. + #[cfg(feature = "tls")] + let mut connector = match target.mode { + Mode::Tls => Some(build_tls_config(&self.certificate_store)?), + Mode::Plain => None, + }; + for _ in 0..self.max_redirections { tracing::debug!("Connecting to target: {:?}", target); // The sockaddrs might get reused if the server replies with a relative URI. let sockaddrs = std::mem::take(&mut target.sockaddrs); for sockaddr in &sockaddrs { - let tcp_stream = - match connect(*sockaddr, self.timeout, &target.host, &self.certificate_store, &target.mode).await { - Ok(stream) => stream, - Err(e) => { - tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); - err = Some(Err(e)); - continue; - } - }; + #[cfg(feature = "tls")] + let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, connector.as_ref()).await { + Ok(stream) => stream, + Err(e) => { + tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; + + #[cfg(not(feature = "tls"))] + let tcp_stream = match connect(*sockaddr, self.timeout).await { + Ok(stream) => stream, + Err(e) => { + tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr); + err = Some(Err(e)); + continue; + } + }; + let mut client = WsHandshakeClient::new( BufReader::new(BufWriter::new(tcp_stream)), &target.host_header, @@ -219,6 +236,12 @@ impl<'a> WsTransportClientBuilder<'a> { // Absolute URI. if uri.scheme().is_some() { target = uri.try_into()?; + + // Only build TLS connector if `wss` in redirection URL. + #[cfg(feature = "tls")] + if connector.is_none() && matches!(target.mode, Mode::Tls) { + connector = Some(build_tls_config(&self.certificate_store)?); + } } // Relative URI. else { @@ -266,8 +289,7 @@ async fn connect( sockaddr: SocketAddr, timeout_dur: Duration, host: &str, - cert_store: &CertificateStore, - mode: &Mode, + tls_connector: Option<&tokio_rustls::TlsConnector>, ) -> Result { let socket = TcpStream::connect(sockaddr); let timeout = tokio::time::sleep(timeout_dur); @@ -277,11 +299,9 @@ async fn connect( if let Err(err) = socket.set_nodelay(true) { tracing::warn!("set nodelay failed: {:?}", err); } - match mode { - Mode::Plain => Ok(EitherStream::Plain(socket)), - Mode::Tls => { - // TODO(niklasad1): cache this. - let connector = build_tls_config(cert_store)?; + match tls_connector { + None => Ok(EitherStream::Plain(socket)), + Some(connector) => { let server_name: tokio_rustls::rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?; let tls_stream = connector.connect(server_name, socket).await?; Ok(EitherStream::Tls(tls_stream)) @@ -293,12 +313,7 @@ async fn connect( } #[cfg(not(feature = "tls"))] -async fn connect( - sockaddr: SocketAddr, - timeout_dur: Duration, - host: &str, - cert_store: &CertificateStore, -) -> Result { +async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result { let socket = TcpStream::connect(sockaddr); let timeout = tokio::time::sleep(timeout_dur); tokio::select! { @@ -378,7 +393,7 @@ impl TryFrom for Target { // NOTE: this is slow and should be used sparingly. #[cfg(feature = "tls")] fn build_tls_config(cert_store: &CertificateStore) -> Result { - use tokio_rustls::rustls as rustls; + use tokio_rustls::rustls; let mut roots = rustls::RootCertStore::empty(); @@ -412,7 +427,7 @@ fn build_tls_config(cert_store: &CertificateStore) -> Result