Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async tls client fix. New server fixture and tests. Don't rely on close_notify. #294

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions components/tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,8 @@ thiserror = "1"
log = "0.4"
env_logger = "0.10"

# testing
rstest = "0.12"

# misc
derive_builder = "0.12"
1 change: 1 addition & 0 deletions components/tls/tls-client-async/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ tokio = { workspace = true, features = [
webpki-roots.workspace = true
hyper = { workspace = true, features = ["client", "http1"] }
tls-server-fixture = { path = "../tls-server-fixture" }
rstest = { workspace = true }
33 changes: 19 additions & 14 deletions components/tls/tls-client-async/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
trace!("received {} tls bytes from server", received);

// Loop until we've processed all the data we received in this read.
// Note that we must make one iteration even if `received == 0`.
let mut processed = 0;
while processed < received {
loop {
processed += client.read_tls(&mut &rx_tls_buf[processed..received])?;
match client.process_new_packets().await {
Ok(_) => {}
Expand All @@ -123,8 +124,16 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
}
}

debug_assert!(processed <= received);

if processed == received {
break;
}
}

// by convention if `AsyncRead::read` returns 0, it means EOF, i.e. the peer
// has closed the socket
if received == 0 {
#[cfg(feature = "tracing")]
debug!("server closed connection");
Expand All @@ -151,7 +160,6 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(

#[cfg(feature = "tracing")]
trace!("sending close_notify to server");

client.send_close_notify().await?;

// Flush all remaining plaintext
Expand All @@ -168,7 +176,7 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(

#[cfg(feature = "tracing")]
debug!("client closed connection");
}
},
}

while client.wants_write() && !client_closed {
Expand All @@ -189,18 +197,19 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if server_closed {
#[cfg(feature = "tracing")]
debug!("server closed, no more data to read");
debug!("server closed without close_notify, no more data to read");

// We didn't get Ok(0) to indicate a clean closure, yet the
// server has already closed. We do not treat this as an error.
break 'outer;
} else {
break;
}
}
// Some servers will not send a close_notify, in which case we need to
// error because we can't reveal the MAC key to the Notary.
// Some servers will not send a close_notify but we do not treat this as
// an error.
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
#[cfg(feature = "tracing")]
error!("server did not send close_notify");
return Err(e)?;
break 'outer;
}
Err(e) => return Err(e)?,
};
Expand All @@ -215,14 +224,10 @@ pub fn bind_client<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
.await;
} else {
#[cfg(feature = "tracing")]
debug!("server closed, no more data to read");
debug!("server closed cleanly, no more data to read");
break 'outer;
}
}

if client_closed && server_closed {
break;
}
}

#[cfg(feature = "tracing")]
Expand Down
Loading