Skip to content

Commit

Permalink
only build TLS connector once
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 committed Dec 3, 2021
1 parent 0c64dd8 commit 3b39d60
Showing 1 changed file with 40 additions and 25 deletions.
65 changes: 40 additions & 25 deletions ws-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use std::{
convert::TryFrom,
io,
net::{SocketAddr, ToSocketAddrs},
sync::Arc,
time::Duration,
};
use thiserror::Error;
Expand Down Expand Up @@ -174,21 +173,39 @@ impl<'a> WsTransportClientBuilder<'a> {
let mut target = self.target;
let mut err = None;

// Only build TLS connector if `wss` in URL.
#[cfg(feature = "tls")]
let mut connector = match target.mode {
Mode::Tls => Some(build_tls_config(&self.certificate_store)?),
Mode::Plain => None,
};

for _ in 0..self.max_redirections {
tracing::debug!("Connecting to target: {:?}", target);

// The sockaddrs might get reused if the server replies with a relative URI.
let sockaddrs = std::mem::take(&mut target.sockaddrs);
for sockaddr in &sockaddrs {
let tcp_stream =
match connect(*sockaddr, self.timeout, &target.host, &self.certificate_store, &target.mode).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr);
err = Some(Err(e));
continue;
}
};
#[cfg(feature = "tls")]
let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, connector.as_ref()).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr);
err = Some(Err(e));
continue;
}
};

#[cfg(not(feature = "tls"))]
let tcp_stream = match connect(*sockaddr, self.timeout).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr);
err = Some(Err(e));
continue;
}
};

let mut client = WsHandshakeClient::new(
BufReader::new(BufWriter::new(tcp_stream)),
&target.host_header,
Expand Down Expand Up @@ -219,6 +236,12 @@ impl<'a> WsTransportClientBuilder<'a> {
// Absolute URI.
if uri.scheme().is_some() {
target = uri.try_into()?;

// Only build TLS connector if `wss` in redirection URL.
#[cfg(feature = "tls")]
if connector.is_none() && matches!(target.mode, Mode::Tls) {
connector = Some(build_tls_config(&self.certificate_store)?);
}
}
// Relative URI.
else {
Expand Down Expand Up @@ -266,8 +289,7 @@ async fn connect(
sockaddr: SocketAddr,
timeout_dur: Duration,
host: &str,
cert_store: &CertificateStore,
mode: &Mode,
tls_connector: Option<&tokio_rustls::TlsConnector>,
) -> Result<EitherStream, WsHandshakeError> {
let socket = TcpStream::connect(sockaddr);
let timeout = tokio::time::sleep(timeout_dur);
Expand All @@ -277,11 +299,9 @@ async fn connect(
if let Err(err) = socket.set_nodelay(true) {
tracing::warn!("set nodelay failed: {:?}", err);
}
match mode {
Mode::Plain => Ok(EitherStream::Plain(socket)),
Mode::Tls => {
// TODO(niklasad1): cache this.
let connector = build_tls_config(cert_store)?;
match tls_connector {
None => Ok(EitherStream::Plain(socket)),
Some(connector) => {
let server_name: tokio_rustls::rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?;
let tls_stream = connector.connect(server_name, socket).await?;
Ok(EitherStream::Tls(tls_stream))
Expand All @@ -293,12 +313,7 @@ async fn connect(
}

#[cfg(not(feature = "tls"))]
async fn connect(
sockaddr: SocketAddr,
timeout_dur: Duration,
host: &str,
cert_store: &CertificateStore,
) -> Result<EitherStream, WsHandshakeError> {
async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result<EitherStream, WsHandshakeError> {
let socket = TcpStream::connect(sockaddr);
let timeout = tokio::time::sleep(timeout_dur);
tokio::select! {
Expand Down Expand Up @@ -378,7 +393,7 @@ impl TryFrom<Uri> for Target {
// NOTE: this is slow and should be used sparingly.
#[cfg(feature = "tls")]
fn build_tls_config(cert_store: &CertificateStore) -> Result<tokio_rustls::TlsConnector, WsHandshakeError> {
use tokio_rustls::rustls as rustls;
use tokio_rustls::rustls;

let mut roots = rustls::RootCertStore::empty();

Expand Down Expand Up @@ -412,7 +427,7 @@ fn build_tls_config(cert_store: &CertificateStore) -> Result<tokio_rustls::TlsCo
let config =
rustls::ClientConfig::builder().with_safe_defaults().with_root_certificates(roots).with_no_client_auth();

Ok(Arc::new(config).into())
Ok(std::sync::Arc::new(config).into())
}

#[cfg(test)]
Expand Down

0 comments on commit 3b39d60

Please sign in to comment.