Skip to content

Commit

Permalink
cargo: rustls->0.22.1, tokio-rustls->0.25.0
Browse files Browse the repository at this point in the history
This patch is from this axum-server draft PR, credit to @eric-seppanen:
- programatik29#106

It looks like axum-server will skip directly to 0.23, so this patch can
be removed then. programatik29#112
  • Loading branch information
eric-seppanen authored and MaxFangX committed Mar 26, 2024
1 parent 61fdf52 commit 10dca0a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 65 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ tracing = "0.1"
# optional dependencies
## rustls
arc-swap = { version = "1", optional = true }
rustls = { version = "0.21", features = ["dangerous_configuration"], optional = true }
rustls = { version = "0.22.1", optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
tokio-rustls = { version = "0.24", optional = true }
tokio-rustls = { version = "0.25.0", optional = true }

## openssl
openssl = { version = "0.10", optional = true }
tokio-openssl = { version = "0.6", optional = true }

[dev-dependencies]
serial_test = "2.0"
axum = "0.7"
axum = "0.7.1"
hyper = { version = "1.0.1", features = ["full"] }
tokio = { version = "1", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
Expand Down
125 changes: 63 additions & 62 deletions src/tls_rustls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ use crate::{
server::{io_other, Server},
};
use arc_swap::ArcSwap;
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::Item;
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
ServerConfig,
};
use std::time::Duration;
use std::{fmt, io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
Expand Down Expand Up @@ -172,10 +174,8 @@ impl RustlsConfig {
/// The certificate must be DER-encoded X.509.
///
/// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
pub async fn from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<Self> {
let server_config = spawn_blocking(|| config_from_der(cert, key))
.await
.unwrap()?;
pub async fn from_der(cert: Vec<Vec<u8>>, key: PrivateKeyDer<'static>) -> io::Result<Self> {
let server_config = config_from_der(cert, key)?;
let inner = Arc::new(ArcSwap::from_pointee(server_config));

Ok(Self { inner })
Expand Down Expand Up @@ -218,10 +218,12 @@ impl RustlsConfig {
/// The certificate must be DER-encoded X.509.
///
/// The private key must be DER-encoded ASN.1 in either PKCS#8 or PKCS#1 format.
pub async fn reload_from_der(&self, cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<()> {
let server_config = spawn_blocking(|| config_from_der(cert, key))
.await
.unwrap()?;
pub async fn reload_from_der(
&self,
cert: Vec<Vec<u8>>,
key: PrivateKeyDer<'static>,
) -> io::Result<()> {
let server_config = config_from_der(cert, key)?;
let inner = Arc::new(server_config);

self.inner.store(inner);
Expand Down Expand Up @@ -278,12 +280,10 @@ impl fmt::Debug for RustlsConfig {
}
}

fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig> {
let cert = cert.into_iter().map(Certificate).collect();
let key = PrivateKey(key);
fn config_from_der(cert: Vec<Vec<u8>>, key: PrivateKeyDer<'static>) -> io::Result<ServerConfig> {
let cert = cert.into_iter().map(CertificateDer::from).collect();

let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key)
.map_err(io_other)?;
Expand All @@ -295,24 +295,13 @@ fn config_from_der(cert: Vec<Vec<u8>>, key: Vec<u8>) -> io::Result<ServerConfig>

fn config_from_pem(cert: Vec<u8>, key: Vec<u8>) -> io::Result<ServerConfig> {
let cert = rustls_pemfile::certs(&mut cert.as_ref())
.map(|it| it.map(|it| it.to_vec()))
.map(|cert| cert.map(|cert| cert.as_ref().to_vec()))
.collect::<Result<Vec<_>, _>>()?;
// Check the entire PEM file for the key in case it is not first section
let mut key_vec: Vec<Vec<u8>> = rustls_pemfile::read_all(&mut key.as_ref())
.filter_map(|i| match i.ok()? {
Item::Sec1Key(key) => Some(key.secret_sec1_der().to_vec()),
Item::Pkcs1Key(key) => Some(key.secret_pkcs1_der().to_vec().into()),
Item::Pkcs8Key(key) => Some(key.secret_pkcs8_der().to_vec().into()),
_ => None,
})
.collect();

// Make sure file contains only one key
if key_vec.len() != 1 {
return Err(io_other("private key format not supported"));
}
// Use the first private key found.
let key = rustls_pemfile::private_key(&mut key.as_ref())?
.ok_or(io_other("private key format not found"))?;

config_from_der(cert, key_vec.pop().unwrap())
config_from_der(cert, key)
}

async fn config_from_pem_file(
Expand All @@ -330,21 +319,12 @@ async fn config_from_pem_chain_file(
chain: impl AsRef<Path>,
) -> io::Result<ServerConfig> {
let cert = tokio::fs::read(cert.as_ref()).await?;
let cert = rustls_pemfile::certs(&mut cert.as_ref())
.map(|it| it.map(|it| rustls::Certificate(it.to_vec())))
.collect::<Result<Vec<_>, _>>()?;
let cert = rustls_pemfile::certs(&mut cert.as_ref()).collect::<Result<Vec<_>, _>>()?;
let key = tokio::fs::read(chain.as_ref()).await?;
let key_cert: rustls::PrivateKey = match rustls_pemfile::read_one(&mut key.as_ref())?
.ok_or_else(|| io_other("could not parse pem file"))?
{
Item::Pkcs8Key(key) => Ok(rustls::PrivateKey(key.secret_pkcs8_der().to_vec().into())),
x => Err(io_other(format!(
"invalid certificate format, received: {x:?}"
))),
}?;
let key_cert = rustls_pemfile::private_key(&mut key.as_ref())?
.ok_or_else(|| io_other("could not parse pem file"))?;

ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key_cert)
.map_err(|_| io_other("invalid certificate"))
Expand All @@ -362,17 +342,10 @@ mod tests {
use http_body_util::BodyExt;
use hyper::client::conn::http1::{handshake, SendRequest};
use hyper_util::rt::TokioIo;
use rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate, ClientConfig, ServerName,
};
use std::{
convert::TryFrom,
io,
net::SocketAddr,
sync::Arc,
time::{Duration, SystemTime},
};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName};
use rustls::{ClientConfig, SignatureScheme};
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use tokio::time::sleep;
use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
use tokio_rustls::TlsConnector;
Expand Down Expand Up @@ -552,13 +525,15 @@ mod tests {
(handle, server_task, addr)
}

async fn get_first_cert(addr: SocketAddr) -> Certificate {
async fn get_first_cert(addr: SocketAddr) -> CertificateDer<'static> {
let stream = TcpStream::connect(addr).await.unwrap();
let tls_stream = tls_connector().connect(dns_name(), stream).await.unwrap();

let (_io, client_connection) = tls_stream.into_inner();

client_connection.peer_certificates().unwrap()[0].clone()
client_connection.peer_certificates().unwrap()[0]
.clone()
.into_owned()
}

async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
Expand Down Expand Up @@ -586,24 +561,50 @@ mod tests {
}

fn tls_connector() -> TlsConnector {
#[derive(Debug)]
struct NoVerify;

impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: SystemTime,
_now: rustls::pki_types::UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
]
}
}

let mut client_config = ClientConfig::builder()
.with_safe_defaults()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoVerify))
.with_no_client_auth();

Expand All @@ -612,7 +613,7 @@ mod tests {
TlsConnector::from(Arc::new(client_config))
}

fn dns_name() -> ServerName {
fn dns_name() -> ServerName<'static> {
ServerName::try_from("localhost").unwrap()
}
}

0 comments on commit 10dca0a

Please sign in to comment.