diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 1d44807865e..e4c6f763867 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -4,7 +4,12 @@ - Update to `libp2p-core` `v0.32.0`. +- Disconnect pending connections with `Swarm::disconnect`. See [PR 2517]. + +- Report aborted connections via `SwarmEvent::OutgoingConnectionError`. See [PR 2517]. + [PR 2492]: https://github.com/libp2p/rust-libp2p/pull/2492 +[PR 2517]: https://github.com/libp2p/rust-libp2p/pull/2517 # 0.33.0 [2022-01-27] diff --git a/swarm/Cargo.toml b/swarm/Cargo.toml index c8542454455..299633121f9 100644 --- a/swarm/Cargo.toml +++ b/swarm/Cargo.toml @@ -26,6 +26,7 @@ void = "1" [dev-dependencies] async-std = { version = "1.6.2", features = ["attributes"] } +env_logger = "0.9" libp2p = { path = "../", default-features = false, features = ["identify", "ping", "plaintext", "yamux"] } libp2p-mplex = { path = "../muxers/mplex" } libp2p-noise = { path = "../transports/noise" } diff --git a/swarm/src/connection/pool.rs b/swarm/src/connection/pool.rs index 3cabc7b0a15..139f4f8c5fc 100644 --- a/swarm/src/connection/pool.rs +++ b/swarm/src/connection/pool.rs @@ -136,7 +136,7 @@ struct PendingConnectionInfo { handler: THandler, endpoint: PendingPoint, /// When dropped, notifies the task which then knows to terminate. - _drop_notifier: oneshot::Sender, + abort_notifier: Option>, } impl fmt::Debug for Pool { @@ -340,10 +340,7 @@ where /// Returns `None` if the pool has no connection with the given ID. pub fn get(&mut self, id: ConnectionId) -> Option> { if let hash_map::Entry::Occupied(entry) = self.pending.entry(id) { - Some(PoolConnection::Pending(PendingConnection { - entry, - counters: &mut self.counters, - })) + Some(PoolConnection::Pending(PendingConnection { entry })) } else { self.established .iter_mut() @@ -406,11 +403,7 @@ where .entry(pending_connection) .expect_occupied("Iterating pending connections"); - PendingConnection { - entry, - counters: &mut self.counters, - } - .abort(); + PendingConnection { entry }.abort(); } } @@ -501,13 +494,13 @@ where let connection_id = self.next_connection_id(); - let (drop_notifier, drop_receiver) = oneshot::channel(); + let (abort_notifier, abort_receiver) = oneshot::channel(); self.spawn( task::new_for_pending_outgoing_connection( connection_id, dial, - drop_receiver, + abort_receiver, self.pending_connection_events_tx.clone(), ) .boxed(), @@ -521,8 +514,8 @@ where PendingConnectionInfo { peer_id: peer, handler, - endpoint: endpoint, - _drop_notifier: drop_notifier, + endpoint, + abort_notifier: Some(abort_notifier), }, ); Ok(connection_id) @@ -550,13 +543,13 @@ where let connection_id = self.next_connection_id(); - let (drop_notifier, drop_receiver) = oneshot::channel(); + let (abort_notifier, abort_receiver) = oneshot::channel(); self.spawn( task::new_for_pending_incoming_connection( connection_id, future, - drop_receiver, + abort_receiver, self.pending_connection_events_tx.clone(), ) .boxed(), @@ -569,7 +562,7 @@ where peer_id: None, handler, endpoint: endpoint.into(), - _drop_notifier: drop_notifier, + abort_notifier: Some(abort_notifier), }, ); Ok(connection_id) @@ -685,7 +678,7 @@ where peer_id: expected_peer_id, handler, endpoint, - _drop_notifier, + abort_notifier: _, } = self .pending .remove(&id) @@ -854,7 +847,7 @@ where peer_id, handler, endpoint, - _drop_notifier, + abort_notifier: _, }) = self.pending.remove(&id) { self.counters.dec_pending(&endpoint); @@ -911,14 +904,14 @@ pub enum PoolConnection<'a, THandler: IntoConnectionHandler> { /// A pending connection in a pool. pub struct PendingConnection<'a, THandler: IntoConnectionHandler> { entry: hash_map::OccupiedEntry<'a, ConnectionId, PendingConnectionInfo>, - counters: &'a mut ConnectionCounters, } impl PendingConnection<'_, THandler> { /// Aborts the connection attempt, closing the connection. - pub fn abort(self) { - self.counters.dec_pending(&self.entry.get().endpoint); - self.entry.remove(); + pub fn abort(mut self) { + if let Some(notifier) = self.entry.get_mut().abort_notifier.take() { + drop(notifier); + } } } diff --git a/swarm/src/connection/pool/task.rs b/swarm/src/connection/pool/task.rs index 1c1065d1fa5..d3ea7bee08b 100644 --- a/swarm/src/connection/pool/task.rs +++ b/swarm/src/connection/pool/task.rs @@ -42,7 +42,7 @@ use libp2p_core::muxing::StreamMuxer; use std::pin::Pin; use void::Void; -/// Commands that can be sent to a task. +/// Commands that can be sent to a task driving an established connection. #[derive(Debug)] pub enum Command { /// Notify the connection handler of an event. @@ -104,12 +104,12 @@ pub enum EstablishedConnectionEvent { pub async fn new_for_pending_outgoing_connection( connection_id: ConnectionId, dial: ConcurrentDial, - drop_receiver: oneshot::Receiver, + abort_receiver: oneshot::Receiver, mut events: mpsc::Sender>, ) where TTrans: Transport, { - match futures::future::select(drop_receiver, Box::pin(dial)).await { + match futures::future::select(abort_receiver, Box::pin(dial)).await { Either::Left((Err(oneshot::Canceled), _)) => { let _ = events .send(PendingConnectionEvent::PendingFailed { @@ -142,13 +142,13 @@ pub async fn new_for_pending_outgoing_connection( pub async fn new_for_pending_incoming_connection( connection_id: ConnectionId, future: TFut, - drop_receiver: oneshot::Receiver, + abort_receiver: oneshot::Receiver, mut events: mpsc::Sender>, ) where TTrans: Transport, TFut: Future> + Send + 'static, { - match futures::future::select(drop_receiver, Box::pin(future)).await { + match futures::future::select(abort_receiver, Box::pin(future)).await { Either::Left((Err(oneshot::Canceled), _)) => { let _ = events .send(PendingConnectionEvent::PendingFailed { diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 2d495e8d8b1..206bf7a1175 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -624,12 +624,14 @@ where /// with [`ProtocolsHandler::connection_keep_alive`] or directly with /// [`ProtocolsHandlerEvent::Close`]. pub fn disconnect_peer_id(&mut self, peer_id: PeerId) -> Result<(), ()> { - if self.pool.is_connected(peer_id) { - self.pool.disconnect(peer_id); - return Ok(()); - } + let was_connected = self.pool.is_connected(peer_id); + self.pool.disconnect(peer_id); - Err(()) + if was_connected { + Ok(()) + } else { + Err(()) + } } /// Checks whether there is an established connection to a peer. @@ -2422,4 +2424,39 @@ mod tests { })) .unwrap(); } + + #[test] + fn aborting_pending_connection_surfaces_error() { + let _ = env_logger::try_init(); + + let mut dialer = new_test_swarm::<_, ()>(DummyProtocolsHandler::default()).build(); + let mut listener = new_test_swarm::<_, ()>(DummyProtocolsHandler::default()).build(); + + let listener_peer_id = *listener.local_peer_id(); + listener.listen_on(multiaddr![Memory(0u64)]).unwrap(); + let listener_address = match block_on(listener.next()).unwrap() { + SwarmEvent::NewListenAddr { address, .. } => address, + e => panic!("Unexpected network event: {:?}", e), + }; + + dialer + .dial( + DialOpts::peer_id(listener_peer_id) + .addresses(vec![listener_address]) + .build(), + ) + .unwrap(); + + dialer + .disconnect_peer_id(listener_peer_id) + .expect_err("Expect peer to not yet be connected."); + + match block_on(dialer.next()).unwrap() { + SwarmEvent::OutgoingConnectionError { + error: DialError::Aborted, + .. + } => {} + e => panic!("Unexpected swarm event {:?}.", e), + } + } }