Skip to content

Commit

Permalink
fix: Make connection deduplicator handle cancellations
Browse files Browse the repository at this point in the history
When a connection attempt is made, `ConnectionDeduplicator::query` would
return `None` for the first attempt, and any future attempts will await
on the broadcast channel until the first caller calls
`ConnectionDeduplicator::complete`. As such, subsequent callers would
await forever if the first caller never called `complete`, e.g. because
the enclosing task was cancelled, or due to a logic bug.

This commit refactors `ConnectionDeduplicator` to store `Weak`
references to the broadcast sender in the map, and `query` now returns
the only strong reference. This way, if the caller that gets the strong
reference is cancelled, the sender will drop and any queued callers will
notice.

This has the nice side effect of having to properly correlate the
`query` and 'complete's, since the result of `query` is needed to
complete.
  • Loading branch information
Chris Connelly authored and connec committed Aug 27, 2021
1 parent 9eb947f commit a0404ac
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 46 deletions.
160 changes: 131 additions & 29 deletions src/connection_deduplicator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// Software.

use crate::ConnectionError;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
Expand All @@ -21,7 +21,20 @@ type Result = std::result::Result<(), ConnectionError>;
// same `Connection` instead of opening a separate connection each.
#[derive(Clone)]
pub(crate) struct ConnectionDeduplicator {
map: Arc<Mutex<HashMap<SocketAddr, broadcast::Sender<Result>>>>,
map: Arc<Mutex<HashMap<SocketAddr, Weak<broadcast::Sender<Result>>>>>,
}

pub(crate) enum DedupHandle {
New(Completion),
Dup(Result),
}

pub(crate) struct Completion(Arc<broadcast::Sender<Result>>);

impl Completion {
pub(crate) fn complete(self, result: Result) {
let _ = self.0.send(result);
}
}

impl ConnectionDeduplicator {
Expand All @@ -32,39 +45,128 @@ impl ConnectionDeduplicator {
}

// Query if there already is a connect attempt to the given address.
// If this is the first connect attempt, it returns `None` and we should proceed with
// establishing the connection and then call `complete` with the result.
// For any subsequent connect attempt this returns `Some` with the result of that attempt.
pub(crate) async fn query(&self, addr: &SocketAddr) -> Option<Result> {
let mut rx = match self.map.lock().await.entry(*addr) {
Entry::Occupied(entry) => entry.get().subscribe(),
Entry::Vacant(entry) => {
let (tx, _) = broadcast::channel(1);
let _ = entry.insert(tx);
return None;
//
// If this is the first connect attempt, it returns `DedupHandle::New`, and we should proceed
// with establishing the connection and then call `send` on the wrapped sender. For any
// subsequent connect attempt this returns `DedupHandle::Dup` with the result of that attempt.
pub(crate) async fn query(&self, addr: &SocketAddr) -> DedupHandle {
loop {
let mut rx = {
let mut map = self.map.lock().await;

// clear dropped handles
map.retain(|_, tx| tx.strong_count() > 0);

match map.entry(*addr) {
Entry::Occupied(entry) => {
if let Some(tx) = entry.get().upgrade() {
tx.subscribe()
} else {
// attempt was dropped, try again
continue;
}
}
Entry::Vacant(entry) => {
let (tx, _) = broadcast::channel(1);
let tx = Arc::new(tx);
let _ = entry.insert(Arc::downgrade(&tx));
return DedupHandle::New(Completion(tx));
}
}
};

if let Ok(result) = rx.recv().await {
return DedupHandle::Dup(result);
} else {
// attempt was dropped, try again
continue;
}
};
}
}
}

#[cfg(test)]
mod tests {
use super::{ConnectionDeduplicator, DedupHandle};
use anyhow::{anyhow, Result};
use futures::{
future::{select_all, try_join_all},
Future, TryFutureExt,
};
use std::{
net::{Ipv4Addr, SocketAddr},
time::Duration,
};

#[tokio::test]
async fn many_concurrent_queries() -> Result<()> {
let dedup = ConnectionDeduplicator::new();
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 1234));

if let Ok(result) = rx.recv().await {
Some(result)
let mut queries: Vec<_> = (0..5)
.map(|_| {
let dedup = dedup.clone();
Box::pin(async move { dedup.query(&addr).await })
})
.collect();

// The first query should succeed
let completion = if let Ok(DedupHandle::New(completion)) = timeout(&mut queries[0]).await {
completion
} else {
// NOTE: this branch cannot realistically happen, because we never drop the `Sender`
// without sending through it first and we also take it out of the map before doing so
// which means it's not possible to create more subscription on it after the send.
// This, however, is not statically provable so we still need to nominally handle this
// branch. We return `LocallyClosed` which seems like it would be the right error to
// return if this situation was actually possible (it isn't) as it would signal the
// caller to either abandon the connect or try to repeat it.
Some(Err(quinn::ConnectionError::LocallyClosed.into()))
return Err(anyhow!("Unexpected dup"));
};

// The remaining queries should block – use a short timeout to test
let (res, _, _) = select_all((&mut queries[1..]).iter_mut().map(timeout)).await;
assert!(res.is_err());

// Now we complete the query
let _ = completion.complete(Ok(()));

// And everything should finish
let rest = try_join_all((&mut queries[1..]).iter_mut().map(timeout)).await?;
for handle in rest {
if let DedupHandle::Dup(Ok(())) = handle {
// ok
} else {
return Err(anyhow!("Unexpected new"));
}
}

Ok(())
}

// Signal completion of a connect attempt. This causes all the pending `query` calls for the
// same `addr` to return `result`.
pub(crate) async fn complete(&self, addr: &SocketAddr, result: Result) {
let tx = self.map.lock().await.remove(addr);
if let Some(tx) = tx {
let _ = tx.send(result);
#[tokio::test]
async fn cancellation() -> Result<()> {
let dedup = ConnectionDeduplicator::new();
let addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 1234));

// Two attempts – start a query, do some 'work', complete
async fn work(dedup: ConnectionDeduplicator, addr: SocketAddr) -> Result<()> {
match dedup.query(&addr).await {
DedupHandle::Dup(res) => Ok(res?),
DedupHandle::New(completion) => {
tokio::time::sleep(Duration::from_millis(25)).await;
let _ = completion.complete(Ok(()));
Ok(())
}
}
}
let q1 = tokio::spawn(work(dedup.clone(), addr));
let q2 = tokio::spawn(work(dedup.clone(), addr));

// Cancel the first attempt after a short time
tokio::time::sleep(Duration::from_millis(10)).await;
q1.abort();

// The 2nd attempt should still finish
timeout(q2).await???;

Ok(())
}

fn timeout<Fut: Future + Unpin>(fut: Fut) -> impl Future<Output = Result<Fut::Output>> + Unpin {
Box::pin(tokio::time::timeout(Duration::from_millis(100), fut).err_into())
}
}
20 changes: 7 additions & 13 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::igd::forward_port;
use super::wire_msg::WireMsg;
use super::{
config::{Config, InternalConfig},
connection_deduplicator::ConnectionDeduplicator,
connection_deduplicator::{ConnectionDeduplicator, DedupHandle},
connection_pool::ConnectionPool,
connections::{
listen_for_incoming_connections, listen_for_incoming_messages, Connection,
Expand Down Expand Up @@ -380,20 +380,18 @@ impl<I: ConnId> Endpoint<I> {
}

// Check if a connect attempt to this address is already in progress.
match self.connection_deduplicator.query(node_addr).await {
Some(Ok(())) => return Ok(()),
Some(Err(error)) => return Err(error.into()),
None => {}
}
let completion = match self.connection_deduplicator.query(node_addr).await {
DedupHandle::Dup(res) => return res.map_err(Into::into),
DedupHandle::New(completion) => completion,
};
let final_conn = self.attempt_connection(node_addr).await?;

trace!("Successfully connected to peer: {}", node_addr);

self.add_new_connection_to_pool(final_conn).await;

self.connection_deduplicator
.complete(node_addr, Ok(()))
.await;
// Notify any duplicate attempts (ignore the error if there are none)
let _ = completion.complete(Ok(()));

Ok(())
}
Expand Down Expand Up @@ -458,10 +456,6 @@ impl<I: ConnId> Endpoint<I> {
}
Err(error) => {
error!("some error: {:?}", error);
self.connection_deduplicator
.complete(node_addr, Err(error.clone().into()))
.await;

Err(ConnectionError::from(error))
}
}?;
Expand Down
8 changes: 4 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl From<quinn::EndpointError> for ClientEndpointError {
/// Errors that can cause connection loss.
// This is a copy of `quinn::ConnectionError` without the `*Closed` variants, since we want to
// separate them in our interface.
#[derive(Clone, Debug, Error)]
#[derive(Clone, Debug, Error, PartialEq)]
pub enum ConnectionError {
/// The endpoint has been stopped.
#[error("The endpoint has been stopped")]
Expand Down Expand Up @@ -215,12 +215,12 @@ impl From<quinn::ConnectionError> for ConnectionError {
}

/// An internal configuration error encountered by [`Endpoint`](crate::Endpoint) connect methods.
#[derive(Clone, Debug, Error)]
#[derive(Clone, Debug, Error, PartialEq)]
#[error(transparent)]
pub struct InternalConfigError(quinn::ConnectError);

/// The reason a connection was closed.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub enum Close {
/// This application closed the connection.
Local,
Expand Down Expand Up @@ -286,7 +286,7 @@ impl From<quinn::ConnectionClose> for Close {
/// An opaque error code indicating a transport failure.
///
/// This can be turned to a string via its `Debug` and `Display` impls, but is otherwise opaque.
#[derive(Clone)]
#[derive(Clone, PartialEq)]
pub struct TransportErrorCode(quinn_proto::TransportErrorCode);

impl fmt::Debug for TransportErrorCode {
Expand Down

0 comments on commit a0404ac

Please sign in to comment.