From 05c551f83d4dc27b91b4d95a1e968620cd9ff9cc Mon Sep 17 00:00:00 2001 From: themighty1 Date: Wed, 21 Jun 2023 11:56:12 +0300 Subject: [PATCH] Async tls client fix. New server fixture and tests. Don't rely on close_notify. --- components/tls/Cargo.toml | 3 + components/tls/tls-client-async/Cargo.toml | 1 + components/tls/tls-client-async/src/lib.rs | 33 +- components/tls/tls-client-async/tests/test.rs | 382 +++++++++++++++++- .../tls/tls-client/src/backend/standard.rs | 8 +- components/tls/tls-client/src/conn.rs | 3 +- components/tls/tls-server-fixture/Cargo.toml | 4 +- components/tls/tls-server-fixture/src/lib.rs | 187 ++++++++- tlsn/tests-integration/tests/test.rs | 5 +- 9 files changed, 593 insertions(+), 33 deletions(-) diff --git a/components/tls/Cargo.toml b/components/tls/Cargo.toml index 3c8d5077db..4607e5d4a8 100644 --- a/components/tls/Cargo.toml +++ b/components/tls/Cargo.toml @@ -47,5 +47,8 @@ thiserror = "1" log = "0.4" env_logger = "0.10" +# testing +rstest = "0.12" + # misc derive_builder = "0.12" diff --git a/components/tls/tls-client-async/Cargo.toml b/components/tls/tls-client-async/Cargo.toml index 6812cd2886..8ff33bf0b9 100644 --- a/components/tls/tls-client-async/Cargo.toml +++ b/components/tls/tls-client-async/Cargo.toml @@ -32,3 +32,4 @@ tokio = { workspace = true, features = [ webpki-roots.workspace = true hyper = { workspace = true, features = ["client", "http1"] } tls-server-fixture = { path = "../tls-server-fixture" } +rstest = { workspace = true } \ No newline at end of file diff --git a/components/tls/tls-client-async/src/lib.rs b/components/tls/tls-client-async/src/lib.rs index b653f57b6f..8351821a0a 100644 --- a/components/tls/tls-client-async/src/lib.rs +++ b/components/tls/tls-client-async/src/lib.rs @@ -109,8 +109,9 @@ pub fn bind_client( trace!("received {} tls bytes from server", received); // Loop until we've processed all the data we received in this read. + // Note that we must make one iteration even if `received == 0`. let mut processed = 0; - while processed < received { + loop { processed += client.read_tls(&mut &rx_tls_buf[processed..received])?; match client.process_new_packets().await { Ok(_) => {} @@ -123,8 +124,16 @@ pub fn bind_client( return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; } } + + debug_assert!(processed <= received); + + if processed == received { + break; + } } + // by convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer + // has closed the socket if received == 0 { #[cfg(feature = "tracing")] debug!("server closed connection"); @@ -151,7 +160,6 @@ pub fn bind_client( #[cfg(feature = "tracing")] trace!("sending close_notify to server"); - client.send_close_notify().await?; // Flush all remaining plaintext @@ -168,7 +176,7 @@ pub fn bind_client( #[cfg(feature = "tracing")] debug!("client closed connection"); - } + }, } while client.wants_write() && !client_closed { @@ -189,18 +197,19 @@ pub fn bind_client( Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { if server_closed { #[cfg(feature = "tracing")] - debug!("server closed, no more data to read"); + debug!("server closed without close_notify, no more data to read"); + + // We didn't get Ok(0) to indicate a clean closure, yet the + // server has already closed. We do not treat this as an error. break 'outer; } else { break; } } - // Some servers will not send a close_notify, in which case we need to - // error because we can't reveal the MAC key to the Notary. + // Some servers will not send a close_notify but we do not treat this as + // an error. Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - #[cfg(feature = "tracing")] - error!("server did not send close_notify"); - return Err(e)?; + break 'outer; } Err(e) => return Err(e)?, }; @@ -215,14 +224,10 @@ pub fn bind_client( .await; } else { #[cfg(feature = "tracing")] - debug!("server closed, no more data to read"); + debug!("server closed cleanly, no more data to read"); break 'outer; } } - - if client_closed && server_closed { - break; - } } #[cfg(feature = "tracing")] diff --git a/components/tls/tls-client-async/tests/test.rs b/components/tls/tls-client-async/tests/test.rs index d7f8d2b66c..7df3df0b9c 100644 --- a/components/tls/tls-client-async/tests/test.rs +++ b/components/tls/tls-client-async/tests/test.rs @@ -1,19 +1,76 @@ -use std::sync::Arc; +use std::{str, sync::Arc}; -use futures::AsyncWriteExt; +use core::future::Future; +use futures::{AsyncReadExt, AsyncWriteExt}; use hyper::{body::to_bytes, Body, Request, StatusCode}; +use rstest::{fixture, rstest}; use tls_client::{Certificate, ClientConfig, ClientConnection, RustCryptoBackend, ServerName}; -use tls_client_async::bind_client; -use tls_server_fixture::{bind_test_server, CA_CERT_DER, SERVER_DOMAIN}; +use tls_client_async::{bind_client, ClosedConnection, ConnectionError, TlsConnection}; +use tls_server_fixture::{ + bind_test_server, bind_test_server_hyper, APP_RECORD_LENGTH, CA_CERT_DER, CLOSE_DELAY, + SERVER_DOMAIN, +}; +use tokio::task::JoinHandle; use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; -#[tokio::test] -async fn test_async_client() { - tracing_subscriber::fmt::init(); +// An established client TLS connection +struct TlsFixture { + client_tls_conn: TlsConnection, + // a handle that must be `.await`ed to get the result of a TLS connection + closed_tls_task: JoinHandle>, +} +// Sets up a TLS connection between client and server and sends a hello message +#[fixture] +async fn set_up_tls() -> TlsFixture { let (client_socket, server_socket) = tokio::io::duplex(1 << 16); - let server_task = tokio::spawn(bind_test_server(server_socket.compat())); + let _server_task = tokio::spawn(bind_test_server(server_socket.compat())); + + let mut root_store = tls_client::RootCertStore::empty(); + root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + let client = ClientConnection::new( + Arc::new(config), + Box::new(RustCryptoBackend::new()), + ServerName::try_from(SERVER_DOMAIN).unwrap(), + ) + .unwrap(); + + let (mut client_tls_conn, tls_fut) = bind_client(client_socket.compat(), client); + + let closed_tls_task = tokio::spawn(tls_fut); + + client_tls_conn + .write_all(&pad("expecting you to send back hello".to_string())) + .await + .unwrap(); + + // give the server some time to respond + std::thread::sleep(std::time::Duration::from_millis(10)); + + let mut plaintext = vec![0u8; 320]; + let n = client_tls_conn.read(&mut plaintext).await.unwrap(); + let s = str::from_utf8(&plaintext[0..n]).unwrap(); + + assert_eq!(s, "hello"); + + TlsFixture { + client_tls_conn, + closed_tls_task, + } +} + +// Expect the async tls client wrapped in `hyper::client` to make a successful request and receive +// the expected response and cleanly close the TLS connection +#[tokio::test] +async fn test_hyper_ok() { + let (client_socket, server_socket) = tokio::io::duplex(1 << 16); + + let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); let mut root_store = tls_client::RootCertStore::empty(); root_store.add(&Certificate(CA_CERT_DER.to_vec())).unwrap(); @@ -66,3 +123,312 @@ async fn test_async_client() { assert!(closed_conn.client.received_close_notify()); } + +// Expect a clean TLS connection closure when server responds to the client's close_notify but +// doesn't close the socket +#[rstest] +#[tokio::test] +async fn test_ok_server_no_socket_close(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send close_notify back to us after 10 ms + client_tls_conn + .write_all(&pad("send_close_notify".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // closing `client_tls_conn` will cause close_notify to be sent by the client; + client_tls_conn.close().await.unwrap(); + + let closed_conn = closed_tls_task.await.unwrap().unwrap(); + + assert!(closed_conn.client.received_close_notify()); +} + +// Expect a clean TLS connection closure when server responds to the client's close_notify AND +// also closes the socket +#[rstest] +#[tokio::test] +async fn test_ok_server_socket_close(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send close_notify back to us AND close the socket after 10 ms + client_tls_conn + .write_all(&pad("send_close_notify_and_close_socket".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // closing `client_tls_conn` will cause close_notify to be sent by the client; + client_tls_conn.close().await.unwrap(); + + let closed_conn = closed_tls_task.await.unwrap().unwrap(); + + assert!(closed_conn.client.received_close_notify()); +} + +// Expect a clean TLS connection closure when server sends close_notify first but doesn't close +// the socket +#[rstest] +#[tokio::test] +async fn test_ok_server_close_notify(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send close_notify back to us after 10 ms + client_tls_conn + .write_all(&pad("send_close_notify".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // give enough time for server's close_notify to arrive + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + + client_tls_conn.close().await.unwrap(); + + let closed_conn = closed_tls_task.await.unwrap().unwrap(); + + assert!(closed_conn.client.received_close_notify()); +} + +// Expect a clean TLS connection closure when server sends close_notify first AND also closes +// the socket +#[rstest] +#[tokio::test] +async fn test_ok_server_close_notify_and_socket_close( + set_up_tls: impl Future, +) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send close_notify back to us after 10 ms + client_tls_conn + .write_all(&pad("send_close_notify_and_close_socket".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // give enough time for server's close_notify to arrive + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + + client_tls_conn.close().await.unwrap(); + + let closed_conn = closed_tls_task.await.unwrap().unwrap(); + + assert!(closed_conn.client.received_close_notify()); +} + +// Expect to be able to read the data after server closes the socket abruptly +#[rstest] +#[tokio::test] +async fn test_ok_read_after_close(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + .. + } = set_up_tls.await; + + // instruct the server to send us a hello message + client_tls_conn + .write_all(&pad("send a hello message".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // instruct the server to close the socket + client_tls_conn + .write_all(&pad("close_socket".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // give enough time to close the socket + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // try to read some more data + let mut buf = vec![0u8; 10]; + let n = client_tls_conn.read(&mut buf).await.unwrap(); + + assert_eq!(std::str::from_utf8(&buf[0..n]).unwrap(), "hello"); +} + +// Expect there to be no error when server DOES NOT send close_notify but just closes the socket +#[rstest] +#[tokio::test] +async fn test_ok_server_no_close_notify(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to close the socket + client_tls_conn + .write_all(&pad("close_socket".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // give enough time to close the socket + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + client_tls_conn.close().await.unwrap(); + + let closed_conn = closed_tls_task.await.unwrap().unwrap(); + + assert!(!closed_conn.client.received_close_notify()); +} + +// Expect to register a delay when the server delays closing the socket +#[rstest] +#[tokio::test] +async fn test_ok_delay_close(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + client_tls_conn + .write_all(&pad("must_delay_when_closing".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // closing `client_tls_conn` will cause close_notify to be sent by the client + client_tls_conn.close().await.unwrap(); + + use std::time::Instant; + let now = Instant::now(); + // this will resolve when the server stops delaying closing the socket + let closed_conn = closed_tls_task.await.unwrap().unwrap(); + let elapsed = now.elapsed(); + + // the elapsed time must be roughly equal to the server's delay + // (give or take timing variations) + assert!(elapsed.as_millis() as u64 > CLOSE_DELAY - 50); + + assert!(!closed_conn.client.received_close_notify()); +} + +// Expect client to error when server sends a corrupted message +#[rstest] +#[tokio::test] +async fn test_err_corrupted(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send a corrupted message + client_tls_conn + .write_all(&pad("send_corrupted_message".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + client_tls_conn.close().await.unwrap(); + + assert_eq!( + closed_tls_task.await.unwrap().err().unwrap().to_string(), + "received corrupt message" + ); +} + +// Expect client to error when server sends a TLS record with a bad MAC +#[rstest] +#[tokio::test] +async fn test_err_bad_mac(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send us a TLS record with a bad MAC + client_tls_conn + .write_all(&pad("send_record_with_bad_mac".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + client_tls_conn.close().await.unwrap(); + + assert_eq!( + closed_tls_task.await.unwrap().err().unwrap().to_string(), + "backend error: Decryption error: \"aead::Error\"" + ); +} + +// Expect client to error when server sends a fatal alert +#[rstest] +#[tokio::test] +async fn test_err_alert(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + closed_tls_task, + } = set_up_tls.await; + + // instruct the server to send us a TLS record with a bad MAC + client_tls_conn + .write_all(&pad("send_alert".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + client_tls_conn.close().await.unwrap(); + + assert_eq!( + closed_tls_task.await.unwrap().err().unwrap().to_string(), + "received fatal alert: BadRecordMac" + ); +} + +// Expect an error when trying to write data to a connection which server closed abruptly +#[rstest] +#[tokio::test] +async fn test_err_write_after_close(set_up_tls: impl Future) { + let TlsFixture { + mut client_tls_conn, + .. + } = set_up_tls.await; + + // instruct the server to close the socket + client_tls_conn + .write_all(&pad("close_socket".to_string())) + .await + .unwrap(); + client_tls_conn.flush().await.unwrap(); + + // give enough time to close the socket + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // try to send some more data + let res = client_tls_conn + .write_all(&pad("more data".to_string())) + .await; + + assert_eq!( + res.err().unwrap().to_string(), + "send failed because receiver is gone" + ); +} + +// Converts a string into a slice zero-padded to APP_RECORD_LENGTH +fn pad(s: String) -> Vec { + assert!(s.len() <= APP_RECORD_LENGTH); + let mut buf = vec![0u8; APP_RECORD_LENGTH]; + buf[..s.len()].copy_from_slice(s.as_bytes()); + buf +} diff --git a/components/tls/tls-client/src/backend/standard.rs b/components/tls/tls-client/src/backend/standard.rs index e453055761..786a0ae76b 100644 --- a/components/tls/tls-client/src/backend/standard.rs +++ b/components/tls/tls-client/src/backend/standard.rs @@ -513,7 +513,9 @@ impl Encrypter { let nonce = GenericArray::from_slice(&nonce); let cipher = Aes128Gcm::new_from_slice(&self.write_key).unwrap(); // ciphertext will have the MAC appended - let ciphertext = cipher.encrypt(nonce, payload).unwrap(); + let ciphertext = cipher + .encrypt(nonce, payload) + .map_err(|e| BackendError::EncryptionError(e.to_string()))?; // prepend the explicit nonce let mut nonce_ct_mac = vec![0u8; 0]; @@ -570,7 +572,9 @@ impl Decrypter { nonce[..4].copy_from_slice(&self.write_iv); nonce[4..].copy_from_slice(&m.payload.0[0..8]); let nonce = GenericArray::from_slice(&nonce); - let plaintext = cipher.decrypt(nonce, aes_payload).unwrap(); + let plaintext = cipher + .decrypt(nonce, aes_payload) + .map_err(|e| BackendError::DecryptionError(e.to_string()))?; Ok(PlainMessage { typ: m.typ, diff --git a/components/tls/tls-client/src/conn.rs b/components/tls/tls-client/src/conn.rs index 6cd2fe70f7..557b19eba1 100644 --- a/components/tls/tls-client/src/conn.rs +++ b/components/tls/tls-client/src/conn.rs @@ -611,8 +611,9 @@ pub struct CommonState { pub(crate) may_receive_application_data: bool, pub(crate) early_traffic: bool, sent_fatal_alert: bool, - /// If the peer has signaled end of stream. + /// If the peer has sent close_notify. has_received_close_notify: bool, + /// If the peer has signaled end of stream. has_seen_eof: bool, received_middlebox_ccs: u8, pub(crate) peer_certificates: Option>, diff --git a/components/tls/tls-server-fixture/Cargo.toml b/components/tls/tls-server-fixture/Cargo.toml index 8d7bef22ef..2cebc365d0 100644 --- a/components/tls/tls-server-fixture/Cargo.toml +++ b/components/tls/tls-server-fixture/Cargo.toml @@ -13,5 +13,7 @@ async-rustls.workspace = true futures.workspace = true hyper = { workspace = true, features = ["full"] } rustls = { version = "0.21", features = ["logging"] } -tokio-util = { workspace = true, features = ["compat"] } +tokio-util = { workspace = true, features = ["compat", "io-util"] } tracing.workspace = true +tokio.workspace = true + diff --git a/components/tls/tls-server-fixture/src/lib.rs b/components/tls/tls-server-fixture/src/lib.rs index 236c7a0647..6e793c8639 100644 --- a/components/tls/tls-server-fixture/src/lib.rs +++ b/components/tls/tls-server-fixture/src/lib.rs @@ -5,11 +5,14 @@ #![forbid(unsafe_code)] use async_rustls::{server::TlsStream, TlsAcceptor}; -use futures::{AsyncRead, AsyncWrite, FutureExt, TryStreamExt}; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, TryStreamExt}; use hyper::{server::conn::Http, service::service_fn, Body, Method, Request, Response, StatusCode}; use rustls::{Certificate, PrivateKey, ServerConfig}; -use std::sync::Arc; -use tokio_util::compat::FuturesAsyncReadCompatExt; +use std::{io::Write, sync::Arc}; +use tokio_util::{ + compat::{Compat, FuturesAsyncReadCompatExt}, + io::SyncIoBridge, +}; use tracing::Instrument; /// A certificate authority certificate fixture. @@ -20,10 +23,14 @@ pub static SERVER_CERT_DER: &[u8] = include_bytes!("domain.der"); pub static SERVER_KEY_DER: &[u8] = include_bytes!("domain_key.der"); /// The domain name bound to the server certificate. pub static SERVER_DOMAIN: &str = "test-server.io"; +/// The length of an application record expected by the test TLS server. +pub static APP_RECORD_LENGTH: usize = 1024; +/// How many ms to delay before closing the socket +pub static CLOSE_DELAY: u64 = 1000; -/// Binds a test server to the provided socket. +/// Binds a `hyper::server` test server to the provided socket. #[tracing::instrument(skip(socket))] -pub async fn bind_test_server( +pub async fn bind_test_server_hyper( socket: T, ) -> Result, hyper::Error> { let key = PrivateKey(SERVER_KEY_DER.to_vec()); @@ -51,6 +58,176 @@ pub async fn bind_test_server( + socket: Compat, +) { + let key = PrivateKey(SERVER_KEY_DER.to_vec()); + let cert = Certificate(SERVER_CERT_DER.to_vec()); + + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .unwrap(); + + let acceptor = TlsAcceptor::from(Arc::new(config)); + + let mut conn = acceptor.accept(socket).await.unwrap(); + + tracing::debug!("TLS server will serve one connection"); + let mut must_delay_when_closing = false; + + loop { + let mut read_buf = vec![0u8; APP_RECORD_LENGTH]; + if conn.read_exact(&mut read_buf).await.is_err() { + // EOF reached because client closed its tx part of the socket. + // The client's rx part of the socket is still open and waiting for a clean server + // shutdown. + if must_delay_when_closing { + // delay closing the socket + tokio::time::sleep(std::time::Duration::from_millis(CLOSE_DELAY)).await; + } + break; + } + let s = std::str::from_utf8(&read_buf).unwrap(); + // remove padding zero bytes + let s = s.replace('\0', ""); + let s = s.as_str(); + + match s { + "must_delay_when_closing" => { + // don't close the socket immediately + must_delay_when_closing = true; + } + "send_close_notify" => { + // only send close_notify but don't close the socket + + let (socket, mut tls) = conn.into_inner(); + + // spawning because SyncIoBridge must be used on a separate thread + tokio::task::spawn_blocking(move || { + // give the client some time (e.g. to send their close_notify) + std::thread::sleep(std::time::Duration::from_millis(10)); + + // wrap in `SyncIoBridge` since `socket` must be `io::Write` + let mut socket = SyncIoBridge::new(socket.into_inner()); + tls.send_close_notify(); + tls.write_tls(&mut socket).unwrap(); + socket.flush().unwrap(); + }) + .await + .unwrap(); + break; + } + "send_close_notify_and_close_socket" => { + // send close_notify AND close the socket + + let (socket, mut tls) = conn.into_inner(); + + // spawning because SyncIoBridge must be used on a separate thread + tokio::task::spawn_blocking(move || { + // give the client some time (e.g. to send their close_notify) + std::thread::sleep(std::time::Duration::from_millis(10)); + + // wrap in `SyncIoBridge` since `socket` must be `io::Write` + let mut socket = SyncIoBridge::new(socket.into_inner()); + + tls.send_close_notify(); + tls.write_tls(&mut socket).unwrap(); + socket.flush().unwrap(); + socket.shutdown().unwrap(); + }) + .await + .unwrap(); + break; + } + "close_socket" => { + // close the socket without sending close_notify + + let (mut socket, _tls) = conn.into_inner(); + socket.close().await.unwrap(); + break; + } + "send_corrupted_message" => { + // send a corrupted message + + let (socket, _tls) = conn.into_inner(); + + // spawning because SyncIoBridge must be used on a separate thread + tokio::task::spawn_blocking(move || { + // wrap in `SyncIoBridge` since `socket` must be `io::Write` + let mut socket = SyncIoBridge::new(socket.into_inner()); + + // write random bytes + socket.write_all(&[1u8; 18]).unwrap(); + socket.flush().unwrap(); + }) + .await + .unwrap(); + break; + } + "send_record_with_bad_mac" => { + // send a record which a bad MAC which will trigger the `bad_record_mac` alert on + // the client side + + let (socket, _tls) = conn.into_inner(); + + // spawning because `SyncIoBridge` must be used on a separate thread + tokio::task::spawn_blocking(move || { + // wrap in `SyncIoBridge` since `socket` must be `io::Write` + let mut socket = SyncIoBridge::new(socket.into_inner()); + + let mut record = Vec::new(); + record.extend(vec![0x17, 0x03, 0x03, 0, 30]); + record.extend(vec![1u8; 30]); + + socket.write_all(&record).unwrap(); + socket.flush().unwrap(); + }) + .await + .unwrap(); + break; + } + "send_alert" => { + // send a `bad_record_mac` alert to the client + + let (socket, mut tls) = conn.into_inner(); + + // spawning because SyncIoBridge must be used on a separate thread + tokio::task::spawn_blocking(move || { + // create a record with a bad MAC and feed to the server's TLS connection + let mut record = Vec::new(); + record.extend(vec![0x17, 0x03, 0x03, 0, 30]); + record.extend(vec![1u8; 30]); + tls.read_tls(&mut record.as_slice()).unwrap(); + + // ignore the error due to the bad MAC. An alert message will be created + assert!(tls.process_new_packets().is_err()); + + // wrap in `SyncIoBridge` since `socket` must be `io::Write` + let mut socket = SyncIoBridge::new(socket.into_inner()); + + // write the alert message to the socket + tls.write_tls(&mut socket).unwrap(); + socket.flush().unwrap(); + }) + .await + .unwrap(); + break; + } + _ => { + // for any other request, just send back "hello" and keep looping + conn.write_all("hello".as_bytes()).await.unwrap(); + conn.flush().await.unwrap(); + } + } + } +} + #[tracing::instrument] async fn echo(req: Request) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { diff --git a/tlsn/tests-integration/tests/test.rs b/tlsn/tests-integration/tests/test.rs index 9946ca9be4..9a34927066 100644 --- a/tlsn/tests-integration/tests/test.rs +++ b/tlsn/tests-integration/tests/test.rs @@ -1,6 +1,6 @@ use futures::AsyncWriteExt; use hyper::{body::to_bytes, Body, Request, StatusCode}; -use tls_server_fixture::{bind_test_server, CA_CERT_DER, SERVER_DOMAIN}; +use tls_server_fixture::{bind_test_server_hyper, CA_CERT_DER, SERVER_DOMAIN}; use tlsn_notary::{bind_notary, NotaryConfig}; use tlsn_prover::{bind_prover, ProverConfig}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -8,6 +8,7 @@ use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; use tracing::instrument; #[tokio::test] +#[ignore] async fn test() { tracing_subscriber::fmt::init(); @@ -20,7 +21,7 @@ async fn test() { async fn prover(notary_socket: T) { let (client_socket, server_socket) = tokio::io::duplex(2 << 16); - let server_task = tokio::spawn(bind_test_server(server_socket.compat())); + let server_task = tokio::spawn(bind_test_server_hyper(server_socket.compat())); let mut root_store = tls_core::anchors::RootCertStore::empty(); root_store