Skip to content

Commit

Permalink
fix: more carefully handle 'benign' connection loss
Browse files Browse the repository at this point in the history
This changes the behaviour of the `ConnectionIncoming` pseudo-stream to
signal end-of-stream when the connection closes for a 'benign' reason,
rather than returning an error. This makes it much easier to implement
message polling for callers, without them having to apply their own
filtering.

Since we're doing this detection in a couple of places, it made sense to
move the check into `ConnectionError` itself.
  • Loading branch information
Chris Connelly authored and connec committed Oct 28, 2021
1 parent f7163a0 commit 9e2be50
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 39 deletions.
137 changes: 101 additions & 36 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
use crate::{
config::{RetryConfig, SERVER_NAME},
error::{
Close, ConnectionError, RecvError, RpcError, SendError, SerializationError, StreamError,
},
error::{ConnectionError, RecvError, RpcError, SendError, SerializationError, StreamError},
wire_msg::WireMsg,
};
use bytes::Bytes;
use futures::{
future,
stream::{self, StreamExt, TryStreamExt},
stream::{self, Stream, StreamExt, TryStream, TryStreamExt},
};
use std::{fmt, net::SocketAddr, sync::Arc, time::Duration};
use std::{fmt, net::SocketAddr, pin::Pin, sync::Arc, task, time::Duration};
use tokio::{
sync::{mpsc, watch},
time::timeout,
Expand Down Expand Up @@ -294,19 +292,23 @@ fn start_message_listeners(
) {
let _ = tokio::spawn(listen_on_uni_streams(
peer_addr,
uni_streams,
FilterBenignClose(uni_streams),
alive_rx.clone(),
message_tx.clone(),
));

let _ = tokio::spawn(listen_on_bi_streams(
endpoint, peer_addr, bi_streams, alive_rx, message_tx,
endpoint,
peer_addr,
FilterBenignClose(bi_streams),
alive_rx,
message_tx,
));
}

async fn listen_on_uni_streams(
peer_addr: SocketAddr,
uni_streams: quinn::IncomingUniStreams,
uni_streams: FilterBenignClose<quinn::IncomingUniStreams>,
mut alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<Bytes, RecvError>>,
) {
Expand Down Expand Up @@ -380,7 +382,7 @@ async fn listen_on_uni_streams(
async fn listen_on_bi_streams(
endpoint: quinn::Endpoint,
peer_addr: SocketAddr,
bi_streams: quinn::IncomingBiStreams,
bi_streams: FilterBenignClose<quinn::IncomingBiStreams>,
mut alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<Bytes, RecvError>>,
) {
Expand Down Expand Up @@ -467,31 +469,12 @@ async fn listen_on_bi_streams(
);
}
future::Either::Left((Err(error), _)) => {
match error.into() {
ConnectionError::Closed(Close::Local) => {
trace!(
"Stopped listener for incoming bi-streams from {}: connection closed locally",
peer_addr
);
}
ConnectionError::Closed(Close::Application {
error_code: 0,
reason,
}) => {
trace!(
"Stopped listener for incoming bi-streams from {}: closed by peer (error code: 0, reason: {:?})",
peer_addr,
String::from_utf8_lossy(&reason)
);
}
error => {
// TODO: consider more carefully how to handle this
warn!(
"Stopped listener for incoming bi-streams from {} due to error: {:?}",
peer_addr, error
);
}
}
// A connection error occurred on bi_streams, we don't propagate anything here as we
// expect propagation to be handled in listen_on_uni_streams.
warn!(
"Stopped listener for incoming bi-streams from {} due to error: {:?}",
peer_addr, error
);
}
future::Either::Right((_, _)) => {
// the connection was closed
Expand Down Expand Up @@ -550,17 +533,47 @@ async fn handle_endpoint_verification(
Ok(())
}

struct FilterBenignClose<S>(S);

impl<S> Stream for FilterBenignClose<S>
where
S: Stream<Item = Result<S::Ok, S::Error>> + TryStream + Unpin,
S::Error: Into<ConnectionError>,
{
type Item = Result<S::Ok, ConnectionError>;

fn poll_next(
mut self: Pin<&mut Self>,
ctx: &mut task::Context,
) -> task::Poll<Option<Self::Item>> {
let next = futures::ready!(self.0.poll_next_unpin(ctx));
task::Poll::Ready(match next.transpose() {
Ok(next) => next.map(Ok),
Err(error) => {
let error = error.into();
if error.is_benign() {
None
} else {
Some(Err(error))
}
}
})
}
}

#[cfg(test)]
mod tests {
use super::Connection;
use crate::{
config::{InternalConfig, SERVER_NAME},
config::{Config, InternalConfig, SERVER_NAME},
error::{ConnectionError, SendError},
tests::local_addr,
wire_msg::WireMsg,
};
use bytes::Bytes;
use color_eyre::eyre::{bail, Result};
use futures::{StreamExt, TryStreamExt};
use std::time::Duration;

#[tokio::test]
#[tracing_test::traced_test]
Expand Down Expand Up @@ -625,6 +638,58 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn benign_connection_loss() -> Result<()> {
let config = InternalConfig::try_from_config(Config {
// set a very low idle timeout
idle_timeout: Some(Duration::from_secs(1)),
..Default::default()
})?;

let mut builder = quinn::Endpoint::builder();
let _ = builder
.listen(config.server.clone())
.default_client_config(config.client.clone());
let (peer1, _) = builder.bind(&local_addr())?;

let mut builder = quinn::Endpoint::builder();
let _ = builder
.listen(config.server.clone())
.default_client_config(config.client.clone());
let (peer2, peer2_incoming) = builder.bind(&local_addr())?;

// open a connection between the two peers
let (p1_tx, _) = Connection::new(
peer1.clone(),
None,
peer1.connect(&peer2.local_addr()?, SERVER_NAME)?.await?,
);

let (_, mut p2_rx) =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), None, connection)
} else {
bail!("did not receive incoming connection when one was expected");
};

// let 2 * idle timeout pass
tokio::time::sleep(Duration::from_secs(2)).await;

// trying to send a message should fail with an error
match p1_tx.send(b"hello"[..].into()).await {
Err(SendError::ConnectionLost(ConnectionError::TimedOut)) => {}
res => bail!("unexpected send result: {:?}", res),
}

// trying to receive should NOT return an error
match p2_rx.next().await {
Ok(None) => {}
res => bail!("unexpected recv result: {:?}", res),
}

Ok(())
}

#[tokio::test]
async fn endpoint_echo() -> Result<()> {
let config = InternalConfig::try_from_config(Default::default())?;
Expand Down Expand Up @@ -764,6 +829,6 @@ mod tests {
async fn timeout<F: std::future::Future>(
f: F,
) -> Result<F::Output, tokio::time::error::Elapsed> {
tokio::time::timeout(std::time::Duration::from_millis(500), f).await
tokio::time::timeout(Duration::from_millis(500), f).await
}
}
27 changes: 27 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,33 @@ pub enum ConnectionError {
Closed(Close),
}

impl ConnectionError {
// A 'benign' connection error is one that does not represent an error on the connection, but
// rather a natural change of state. This may still be an error for the caller (e.g. if the
// connection is closed when trying to send a message), but not always (e.g. if the connection
// closes during a message receive loop, we could quietly end the stream instead). This method
// helps to centralise that discrimination.
pub(crate) fn is_benign(&self) -> bool {
match self {
// We expect that timeouts will be used to naturally close connections once they are
// idle, so we treat it as benign.
Self::TimedOut => true,

// A graceful close from our end.
Self::Closed(Close::Local) => true,

// A graceful close from the peer, with the default code and empty reason.
Self::Closed(Close::Application {
error_code: 0,
reason,
}) if reason.is_empty() => true,

// Any other error is classified as not benign, so should always be propagated.
_ => false,
}
}
}

impl From<quinn::ConnectError> for ConnectionError {
fn from(error: quinn::ConnectError) -> Self {
match error {
Expand Down
18 changes: 15 additions & 3 deletions src/wire_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
utils,
};
use bytes::Bytes;
use futures::TryFutureExt;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::{fmt, net::SocketAddr};
Expand Down Expand Up @@ -39,9 +40,20 @@ impl WireMsg {
) -> Result<Option<Self>, RecvError> {
let mut header_bytes = [0; MSG_HEADER_LEN];

match recv.read(&mut header_bytes).await? {
None => return Ok(None),
Some(len) => {
match recv.read(&mut header_bytes).err_into().await {
Err(RecvError::ConnectionLost(error)) if error.is_benign() => {
// We ignore 'benign' connection loss for the initial read, this follows from the
// understanding that quinn would always yield any successfully read bytes, so we
// would move further into the function, which will always propagated encountered
// errors.
return Ok(None);
}
Err(error) => {
// Any other error would indicate a real issue, so return it
return Err(error);
}
Ok(None) => return Ok(None),
Ok(Some(len)) => {
if len < MSG_HEADER_LEN {
recv.read_exact(&mut header_bytes[len..]).await?;
}
Expand Down

0 comments on commit 9e2be50

Please sign in to comment.