Skip to content

Commit

Permalink
refactor: move more ConnectionHandle methods into Connection
Browse files Browse the repository at this point in the history
This will reduce the diff when the time comes to kill
`ConnectionHandle`.
  • Loading branch information
Chris Connelly authored and joshuef committed Oct 15, 2021
1 parent 6ebab02 commit 1b7c000
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 43 deletions.
72 changes: 66 additions & 6 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! A message-oriented API wrapping the underlying QUIC library (`quinn`).
use crate::{
config::SERVER_NAME,
error::{Close, ConnectionError, RecvError, RpcError, SendError, SerializationError},
config::{RetryConfig, SERVER_NAME},
error::{
Close, ConnectionError, RecvError, RpcError, SendError, SerializationError, StreamError,
},
wire_msg::WireMsg,
};
use bytes::Bytes;
Expand All @@ -27,6 +29,7 @@ const ENDPOINT_VERIFICATION_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub(crate) struct Connection {
inner: quinn::Connection,
default_retry_config: Option<Arc<RetryConfig>>,

// A reference to the 'alive' marker for the connection. This isn't read by `Connection`, but
// must be held to keep background listeners alive until both halves of the connection are
Expand All @@ -37,6 +40,7 @@ pub(crate) struct Connection {
impl Connection {
pub(crate) fn new(
endpoint: quinn::Endpoint,
default_retry_config: Option<Arc<RetryConfig>>,
connection: quinn::NewConnection,
) -> (Connection, ConnectionIncoming) {
// this channel serves to keep the background message listener alive so long as one side of
Expand All @@ -48,6 +52,7 @@ impl Connection {
(
Self {
inner: connection.connection,
default_retry_config,
_alive_tx: Arc::clone(&alive_tx),
},
ConnectionIncoming::new(
Expand All @@ -66,6 +71,41 @@ impl Connection {
self.inner.remote_address()
}

/// Send a message to the peer with default retry configuration.
///
/// The message will be sent on a unidirectional QUIC stream, meaning the application is
/// responsible for correlating any anticipated responses from incoming streams.
///
/// The priority will be `0` and retry behaviour will be determined by the
/// [`Config`](crate::Config) that was used to construct the [`Endpoint`] this connection
/// belongs to. See [`send_with`](Self::send_with) if you want to send a message with specific
/// configuration.
pub(crate) async fn send(&self, msg: Bytes) -> Result<(), SendError> {
self.send_with(msg, 0, None).await
}

/// Send a message to the peer using the given configuration.
///
/// See [`send`](Self::send) if you want to send with the default configuration.
pub(crate) async fn send_with(
&self,
msg: Bytes,
priority: i32,
retry_config: Option<&RetryConfig>,
) -> Result<(), SendError> {
match retry_config.or_else(|| self.default_retry_config.as_deref()) {
Some(retry_config) => {
retry_config
.retry(|| async { Ok(self.send_uni(msg.clone(), priority).await?) })
.await?;
}
None => {
self.send_uni(msg, priority).await?;
}
}
Ok(())
}

/// Open a unidirection stream to the peer.
///
/// Messages sent over the stream will arrive at the peer in the order they were sent.
Expand Down Expand Up @@ -93,6 +133,23 @@ impl Connection {
pub(crate) fn close(&self) {
self.inner.close(0u8.into(), b"");
}

async fn send_uni(&self, msg: Bytes, priority: i32) -> Result<(), SendError> {
let mut send_stream = self.open_uni().await.map_err(SendError::ConnectionLost)?;
send_stream.set_priority(priority);

send_stream.send_user_msg(msg.clone()).await?;

// We try to make sure the stream is gracefully closed and the bytes get sent, but if it
// was already closed (perhaps by the peer) then we ignore the error.
// TODO: we probably shouldn't ignore the error...
send_stream.finish().await.or_else(|err| match err {
SendError::StreamLost(StreamError::Stopped(_)) => Ok(()),
_ => Err(err),
})?;

Ok(())
}
}

/// The sending API for a QUIC stream.
Expand Down Expand Up @@ -496,12 +553,13 @@ mod tests {
{
let (p1_tx, mut p1_rx) = Connection::new(
peer1.clone(),
None,
peer1.connect(&peer2.local_addr()?, SERVER_NAME)?.await?,
);

let (p2_tx, mut p2_rx) =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
Connection::new(peer2.clone(), None, connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
Expand Down Expand Up @@ -557,13 +615,14 @@ mod tests {
{
let (p1_tx, _) = Connection::new(
peer1.clone(),
None,
peer1.connect(&peer2.local_addr()?, SERVER_NAME)?.await?,
);

// we need to accept the connection on p2, or the message won't be processed
let _p2_handle =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
Connection::new(peer2.clone(), None, connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
Expand Down Expand Up @@ -612,13 +671,14 @@ mod tests {
{
let (p1_tx, _) = Connection::new(
peer1.clone(),
None,
peer1.connect(&peer2.local_addr()?, SERVER_NAME)?.await?,
);

// we need to accept the connection on p2, or the message won't be processed
let _p2_handle =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
Connection::new(peer2.clone(), None, connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
Expand All @@ -631,7 +691,7 @@ mod tests {
// we need to accept the connection on p1, or the message won't be processed
let _p1_handle =
if let Some(connection) = timeout(peer1_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer1.clone(), connection)
Connection::new(peer1.clone(), None, connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
Expand Down
46 changes: 11 additions & 35 deletions src/connection_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

use super::{
connection_pool::{ConnId, ConnectionPool, ConnectionRemover},
error::{ConnectionError, SendError, StreamError},
error::{ConnectionError, SendError},
};
use crate::{
connection::{Connection, ConnectionIncoming, RecvStream, SendStream},
Expand Down Expand Up @@ -83,7 +83,7 @@ impl<I: ConnId> ConnectionHandle<I> {
/// belongs to. See [`send_message_with`](Self::send_message_with) if you want to send a message
/// with specific configuration.
pub async fn send(&self, msg: Bytes) -> Result<(), SendError> {
self.send_with(msg, 0, None).await
self.handle_error(self.inner.send(msg).await).await
}

/// Send a message to the peer using the given configuration.
Expand All @@ -95,12 +95,8 @@ impl<I: ConnId> ConnectionHandle<I> {
priority: i32,
retry_config: Option<&RetryConfig>,
) -> Result<(), SendError> {
retry_config
.unwrap_or_else(|| self.default_retry_config.as_ref())
.retry(|| async { Ok(self.send_uni(msg.clone(), priority).await?) })
.await?;

Ok(())
self.handle_error(self.inner.send_with(msg, priority, retry_config).await)
.await
}

/// Priority default is 0. Both lower and higher can be passed in.
Expand All @@ -113,31 +109,6 @@ impl<I: ConnId> ConnectionHandle<I> {
Ok((send_stream, recv_stream))
}

/// Send message to peer using a uni-directional stream.
/// Priority default is 0. Both lower and higher can be passed in.
pub(crate) async fn send_uni(&self, msg: Bytes, priority: i32) -> Result<(), SendError> {
let mut send_stream = self.handle_error(self.inner.open_uni().await).await?;

// quinn returns `UnknownStream` error if the stream does not exist. We ignore it, on the
// basis that operations on the stream will fail instead (and the effect of setting priority
// or not is only observable if the stream exists).
let _ = send_stream.set_priority(priority);

self.handle_error(send_stream.send_user_msg(msg).await)
.await?;

// We try to make sure the stream is gracefully closed and the bytes get sent,
// but if it was already closed (perhaps by the peer) then we
// don't remove the connection from the pool.
match send_stream.finish().await {
Ok(()) | Err(SendError::StreamLost(StreamError::Stopped(_))) => Ok(()),
Err(err) => {
self.handle_error(Err(err)).await?;
Ok(())
}
}
}

async fn handle_error<T, E>(&self, result: Result<T, E>) -> Result<T, E> {
if result.is_err() {
self.remover.remove().await
Expand All @@ -147,6 +118,7 @@ impl<I: ConnId> ConnectionHandle<I> {
}
}

#[allow(clippy::too_many_arguments)] // this will be removed soon, so let is pass for now
pub(super) fn listen_for_incoming_connections<I: ConnId>(
mut quinn_incoming: quinn::Incoming,
connection_pool: ConnectionPool<I>,
Expand All @@ -155,14 +127,18 @@ pub(super) fn listen_for_incoming_connections<I: ConnId>(
disconnection_tx: Sender<SocketAddr>,
endpoint: Endpoint<I>,
quic_endpoint: quinn::Endpoint,
retry_config: Arc<RetryConfig>,
) {
let _ = tokio::spawn(async move {
loop {
match quinn_incoming.next().await {
Some(quinn_conn) => match quinn_conn.await {
Ok(connection) => {
let (connection, connection_incoming) =
Connection::new(quic_endpoint.clone(), connection);
let (connection, connection_incoming) = Connection::new(
quic_endpoint.clone(),
Some(retry_config.clone()),
connection,
);

let peer_address = connection.remote_address();
let id = ConnId::generate(&peer_address);
Expand Down
9 changes: 7 additions & 2 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ impl<I: ConnId> Endpoint<I> {
channels.disconnection.0.clone(),
endpoint.clone(),
endpoint.quic_endpoint.clone(),
endpoint.retry_config.clone(),
);

if let Some(contact) = contact.as_ref() {
Expand Down Expand Up @@ -368,6 +369,7 @@ impl<I: ConnId> Endpoint<I> {
// avoid the connection pool
let (connection, _) = Connection::new(
self.quic_endpoint.clone(),
Some(self.retry_config.clone()),
self.new_connection(peer_addr).await?,
);
let (mut send_stream, mut recv_stream) = connection.open_bi().await?;
Expand Down Expand Up @@ -464,8 +466,11 @@ impl<I: ConnId> Endpoint<I> {
Ok(new_connection) => {
trace!("Successfully connected to peer: {}", addr);

let (connection, connection_incoming) =
Connection::new(self.quic_endpoint.clone(), new_connection);
let (connection, connection_incoming) = Connection::new(
self.quic_endpoint.clone(),
Some(self.retry_config.clone()),
new_connection,
);
let id = ConnId::generate(&connection.remote_address());
let remover = self
.connection_pool
Expand Down

0 comments on commit 1b7c000

Please sign in to comment.