diff --git a/src/transport/quic/mod.rs b/src/transport/quic/mod.rs index ad03674d..d69e1603 100644 --- a/src/transport/quic/mod.rs +++ b/src/transport/quic/mod.rs @@ -34,7 +34,11 @@ use crate::{ PeerId, }; -use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; +use futures::{ + future::BoxFuture, + stream::{AbortHandle, FuturesUnordered}, + Stream, StreamExt, TryFutureExt, +}; use multiaddr::{Multiaddr, Protocol}; use quinn::{ClientConfig, Connecting, Connection, Endpoint, IdleTimeout}; @@ -66,6 +70,25 @@ struct NegotiatedConnection { connection: Connection, } +#[derive(Debug)] +enum RawConnectionResult { + /// The first successful connection. + Connected { + connection_id: ConnectionId, + address: Multiaddr, + stream: NegotiatedConnection, + }, + + /// All connection attempts failed. + Failed { + connection_id: ConnectionId, + errors: Vec<(Multiaddr, DialError)>, + }, + + /// Future was canceled. + Canceled { connection_id: ConnectionId }, +} + /// QUIC transport object. pub(crate) struct QuicTransport { /// Transport handle. @@ -92,21 +115,15 @@ pub(crate) struct QuicTransport { pending_open: HashMap, /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture< - 'static, - Result< - (ConnectionId, Multiaddr, NegotiatedConnection), - (ConnectionId, Vec<(Multiaddr, DialError)>), - >, - >, - >, + pending_raw_connections: FuturesUnordered>, /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap, /// Canceled raw connections. canceled: HashSet, + + cancel_futures: HashMap, } impl QuicTransport { @@ -225,6 +242,7 @@ impl TransportBuilder for QuicTransport { pending_inbound_connections: HashMap::new(), pending_raw_connections: FuturesUnordered::new(), pending_connections: FuturesUnordered::new(), + cancel_futures: HashMap::new(), }, listen_addresses, )) @@ -407,12 +425,18 @@ impl Transport for QuicTransport { }) .collect(); - self.pending_raw_connections.push(Box::pin(async move { + // Future that will resolve to the first successful connection. + let future = async move { let mut errors = Vec::with_capacity(num_addresses); while let Some(result) = futures.next().await { match result { - Ok((address, connection)) => return Ok((connection_id, address, connection)), + Ok((address, stream)) => + return RawConnectionResult::Connected { + connection_id, + address, + stream, + }, Err(error) => { tracing::debug!( target: LOG_TARGET, @@ -425,8 +449,16 @@ impl Transport for QuicTransport { } } - Err((connection_id, errors)) - })); + RawConnectionResult::Failed { + connection_id, + errors, + } + }; + + let (fut, handle) = futures::future::abortable(future); + let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id }); + self.pending_raw_connections.push(Box::pin(fut)); + self.cancel_futures.insert(connection_id, handle); Ok(()) } @@ -446,6 +478,7 @@ impl Transport for QuicTransport { /// Cancel opening connections. fn cancel(&mut self, connection_id: ConnectionId) { self.canceled.insert(connection_id); + self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); } } @@ -470,16 +503,14 @@ impl Stream for QuicTransport { } while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); + tracing::trace!(target: LOG_TARGET, ?result, "raw connection result"); + match result { + RawConnectionResult::Connected { + connection_id, + address, + stream, + } => if !self.canceled.remove(&connection_id) { self.opened_raw.insert(connection_id, (stream, address.clone())); @@ -487,15 +518,20 @@ impl Stream for QuicTransport { connection_id, address, })); - } - } - Err((connection_id, errors)) => + }, + RawConnectionResult::Failed { + connection_id, + errors, + } => if !self.canceled.remove(&connection_id) { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); }, + RawConnectionResult::Canceled { connection_id } => { + self.canceled.remove(&connection_id); + } } } diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 24e4ef54..4ef52104 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -38,7 +38,8 @@ use crate::{ use futures::{ future::BoxFuture, - stream::{FuturesUnordered, Stream, StreamExt}, + stream::{AbortHandle, FuturesUnordered, Stream, StreamExt}, + TryFutureExt, }; use multiaddr::Multiaddr; use socket2::{Domain, Socket, Type}; @@ -70,6 +71,25 @@ struct PendingInboundConnection { address: SocketAddr, } +#[derive(Debug)] +enum RawConnectionResult { + /// The first successful connection. + Connected { + connection_id: ConnectionId, + address: Multiaddr, + stream: TcpStream, + }, + + /// All connection attempts failed. + Failed { + connection_id: ConnectionId, + errors: Vec<(Multiaddr, DialError)>, + }, + + /// Future was canceled. + Canceled { connection_id: ConnectionId }, +} + /// TCP transport. pub(crate) struct TcpTransport { /// Transport context. @@ -96,15 +116,7 @@ pub(crate) struct TcpTransport { >, /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture< - 'static, - Result< - (ConnectionId, Multiaddr, TcpStream), - (ConnectionId, Vec<(Multiaddr, DialError)>), - >, - >, - >, + pending_raw_connections: FuturesUnordered>, /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap, @@ -112,6 +124,8 @@ pub(crate) struct TcpTransport { /// Canceled raw connections. canceled: HashSet, + cancel_futures: HashMap, + /// Connections which have been opened and negotiated but are being validated by the /// `TransportManager`. pending_open: HashMap, @@ -284,6 +298,7 @@ impl TransportBuilder for TcpTransport { pending_inbound_connections: HashMap::new(), pending_connections: FuturesUnordered::new(), pending_raw_connections: FuturesUnordered::new(), + cancel_futures: HashMap::new(), }, listen_addresses, )) @@ -412,11 +427,17 @@ impl Transport for TcpTransport { }) .collect(); - self.pending_raw_connections.push(Box::pin(async move { + // Future that will resolve to the first successful connection. + let future = async move { let mut errors = Vec::with_capacity(num_addresses); while let Some(result) = futures.next().await { match result { - Ok((address, stream)) => return Ok((connection_id, address, stream)), + Ok((address, stream)) => + return RawConnectionResult::Connected { + connection_id, + address, + stream, + }, Err(error) => { tracing::debug!( target: LOG_TARGET, @@ -429,8 +450,16 @@ impl Transport for TcpTransport { } } - Err((connection_id, errors)) - })); + RawConnectionResult::Failed { + connection_id, + errors, + } + }; + + let (fut, handle) = futures::future::abortable(future); + let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id }); + self.pending_raw_connections.push(Box::pin(fut)); + self.cancel_futures.insert(connection_id, handle); Ok(()) } @@ -488,6 +517,7 @@ impl Transport for TcpTransport { fn cancel(&mut self, connection_id: ConnectionId) { self.canceled.insert(connection_id); + self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); } } @@ -523,16 +553,14 @@ impl Stream for TcpTransport { } while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); + tracing::trace!(target: LOG_TARGET, ?result, "raw connection result"); + match result { + RawConnectionResult::Connected { + connection_id, + address, + stream, + } => if !self.canceled.remove(&connection_id) { self.opened_raw.insert(connection_id, (stream, address.clone())); @@ -540,15 +568,20 @@ impl Stream for TcpTransport { connection_id, address, })); - } - } - Err((connection_id, errors)) => + }, + RawConnectionResult::Failed { + connection_id, + errors, + } => if !self.canceled.remove(&connection_id) { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); }, + RawConnectionResult::Canceled { connection_id } => { + self.canceled.remove(&connection_id); + } } } diff --git a/src/transport/websocket/mod.rs b/src/transport/websocket/mod.rs index 03c58191..bcf37002 100644 --- a/src/transport/websocket/mod.rs +++ b/src/transport/websocket/mod.rs @@ -36,7 +36,11 @@ use crate::{ DialError, PeerId, }; -use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; +use futures::{ + future::BoxFuture, + stream::{AbortHandle, FuturesUnordered}, + Stream, StreamExt, TryFutureExt, +}; use multiaddr::{Multiaddr, Protocol}; use socket2::{Domain, Socket, Type}; use std::net::SocketAddr; @@ -71,6 +75,25 @@ struct PendingInboundConnection { address: SocketAddr, } +#[derive(Debug)] +enum RawConnectionResult { + /// The first successful connection. + Connected { + connection_id: ConnectionId, + address: Multiaddr, + stream: WebSocketStream>, + }, + + /// All connection attempts failed. + Failed { + connection_id: ConnectionId, + errors: Vec<(Multiaddr, DialError)>, + }, + + /// Future was canceled. + Canceled { connection_id: ConnectionId }, +} + /// WebSocket transport. pub(crate) struct WebSocketTransport { /// Transport context. @@ -97,19 +120,7 @@ pub(crate) struct WebSocketTransport { >, /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture< - 'static, - Result< - ( - ConnectionId, - Multiaddr, - WebSocketStream>, - ), - (ConnectionId, Vec<(Multiaddr, DialError)>), - >, - >, - >, + pending_raw_connections: FuturesUnordered>, /// Opened raw connection, waiting for approval/rejection from `TransportManager`. opened_raw: HashMap>, Multiaddr)>, @@ -117,6 +128,8 @@ pub(crate) struct WebSocketTransport { /// Canceled raw connections. canceled: HashSet, + cancel_futures: HashMap, + /// Negotiated connections waiting validation. pending_open: HashMap, } @@ -315,6 +328,7 @@ impl TransportBuilder for WebSocketTransport { pending_inbound_connections: HashMap::new(), pending_connections: FuturesUnordered::new(), pending_raw_connections: FuturesUnordered::new(), + cancel_futures: HashMap::new(), }, listen_addresses, )) @@ -458,12 +472,17 @@ impl Transport for WebSocketTransport { }) .collect(); - self.pending_raw_connections.push(Box::pin(async move { + // Future that will resolve to the first successful connection. + let future = async move { let mut errors = Vec::with_capacity(num_addresses); - while let Some(result) = futures.next().await { match result { - Ok((address, stream)) => return Ok((connection_id, address, stream)), + Ok((address, stream)) => + return RawConnectionResult::Connected { + connection_id, + address, + stream, + }, Err(error) => { tracing::debug!( target: LOG_TARGET, @@ -476,8 +495,16 @@ impl Transport for WebSocketTransport { } } - Err((connection_id, errors)) - })); + RawConnectionResult::Failed { + connection_id, + errors, + } + }; + + let (fut, handle) = futures::future::abortable(future); + let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id }); + self.pending_raw_connections.push(Box::pin(fut)); + self.cancel_futures.insert(connection_id, handle); Ok(()) } @@ -536,6 +563,7 @@ impl Transport for WebSocketTransport { fn cancel(&mut self, connection_id: ConnectionId) { self.canceled.insert(connection_id); + self.cancel_futures.remove(&connection_id).map(|handle| handle.abort()); } } @@ -565,16 +593,14 @@ impl Stream for WebSocketTransport { } while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); + tracing::trace!(target: LOG_TARGET, ?result, "raw connection result"); + match result { + RawConnectionResult::Connected { + connection_id, + address, + stream, + } => if !self.canceled.remove(&connection_id) { self.opened_raw.insert(connection_id, (stream, address.clone())); @@ -582,15 +608,20 @@ impl Stream for WebSocketTransport { connection_id, address, })); - } - } - Err((connection_id, errors)) => + }, + RawConnectionResult::Failed { + connection_id, + errors, + } => if !self.canceled.remove(&connection_id) { return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id, errors, })); }, + RawConnectionResult::Canceled { connection_id } => { + self.canceled.remove(&connection_id); + } } }