diff --git a/protocols/dcutr/src/handler/direct.rs b/protocols/dcutr/src/handler/direct.rs index c1470840b02..e336f750915 100644 --- a/protocols/dcutr/src/handler/direct.rs +++ b/protocols/dcutr/src/handler/direct.rs @@ -91,7 +91,9 @@ impl ConnectionHandler for Handler { | ConnectionEvent::FullyNegotiatedOutbound(_) | ConnectionEvent::DialUpgradeError(_) | ConnectionEvent::ListenUpgradeError(_) - | ConnectionEvent::AddressChange(_) => {} + | ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/protocols/dcutr/src/handler/relayed.rs b/protocols/dcutr/src/handler/relayed.rs index c9f49ff9497..8b59b80457b 100644 --- a/protocols/dcutr/src/handler/relayed.rs +++ b/protocols/dcutr/src/handler/relayed.rs @@ -379,7 +379,9 @@ impl ConnectionHandler for Handler { ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) } - ConnectionEvent::AddressChange(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/protocols/gossipsub/src/handler.rs b/protocols/gossipsub/src/handler.rs index 6e673516aa6..65a4a31b60c 100644 --- a/protocols/gossipsub/src/handler.rs +++ b/protocols/gossipsub/src/handler.rs @@ -553,7 +553,10 @@ impl ConnectionHandler for Handler { }) => { log::debug!("Protocol negotiation failed: {e}") } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } Handler::Disabled(_) => {} diff --git a/protocols/identify/src/behaviour.rs b/protocols/identify/src/behaviour.rs index 450d6dc5688..8150ec0d404 100644 --- a/protocols/identify/src/behaviour.rs +++ b/protocols/identify/src/behaviour.rs @@ -19,15 +19,14 @@ // DEALINGS IN THE SOFTWARE. use crate::handler::{self, Handler, InEvent}; -use crate::protocol::{Info, Protocol, UpgradeError}; +use crate::protocol::{Info, UpgradeError}; use libp2p_core::{multiaddr, ConnectedPoint, Endpoint, Multiaddr}; use libp2p_identity::PeerId; use libp2p_identity::PublicKey; use libp2p_swarm::behaviour::{ConnectionClosed, ConnectionEstablished, DialFailure, FromSwarm}; use libp2p_swarm::{ AddressScore, ConnectionDenied, DialError, ExternalAddresses, ListenAddresses, - NetworkBehaviour, NotifyHandler, PollParameters, StreamProtocol, StreamUpgradeError, - THandlerInEvent, ToSwarm, + NetworkBehaviour, NotifyHandler, PollParameters, StreamUpgradeError, THandlerInEvent, ToSwarm, }; use libp2p_swarm::{ConnectionId, THandler, THandlerOutEvent}; use lru::LruCache; @@ -50,10 +49,6 @@ pub struct Behaviour { config: Config, /// For each peer we're connected to, the observed address to send back to it. connected: HashMap>, - /// Pending requests to be fulfilled, either `Handler` requests for `Behaviour` info - /// to address identification requests, or push requests to peers - /// with current information about the local peer. - requests: Vec, /// Pending events to be emitted when polled. events: VecDeque>, /// The addresses of all peers that we have discovered. @@ -63,15 +58,6 @@ pub struct Behaviour { external_addresses: ExternalAddresses, } -/// A `Behaviour` request to be fulfilled, either `Handler` requests for `Behaviour` info -/// to address identification requests, or push requests to peers -/// with current information about the local peer. -#[derive(Debug, PartialEq, Eq)] -struct Request { - peer_id: PeerId, - protocol: Protocol, -} - /// Configuration for the [`identify::Behaviour`](Behaviour). #[non_exhaustive] #[derive(Debug, Clone)] @@ -184,7 +170,6 @@ impl Behaviour { Self { config, connected: HashMap::new(), - requests: Vec::new(), events: VecDeque::new(), discovered_peers, listen_addresses: Default::default(), @@ -203,13 +188,11 @@ impl Behaviour { continue; } - let request = Request { + self.events.push_back(ToSwarm::NotifyHandler { peer_id: p, - protocol: Protocol::Push, - }; - if !self.requests.contains(&request) { - self.requests.push(request); - } + handler: NotifyHandler::Any, + event: InEvent::Push, + }); } } @@ -239,6 +222,14 @@ impl Behaviour { } } } + + fn all_addresses(&self) -> HashSet { + self.listen_addresses + .iter() + .chain(self.external_addresses.iter()) + .cloned() + .collect() + } } impl NetworkBehaviour for Behaviour { @@ -261,6 +252,7 @@ impl NetworkBehaviour for Behaviour { self.config.protocol_version.clone(), self.config.agent_version.clone(), remote_addr.clone(), + self.all_addresses(), )) } @@ -280,13 +272,14 @@ impl NetworkBehaviour for Behaviour { self.config.protocol_version.clone(), self.config.agent_version.clone(), addr.clone(), // TODO: This is weird? That is the public address we dialed, shouldn't need to tell the other party? + self.all_addresses(), )) } fn on_connection_handler_event( &mut self, peer_id: PeerId, - connection_id: ConnectionId, + _: ConnectionId, event: THandlerOutEvent, ) { match event { @@ -315,12 +308,6 @@ impl NetworkBehaviour for Behaviour { self.events .push_back(ToSwarm::GenerateEvent(Event::Pushed { peer_id })); } - handler::Event::Identify => { - self.requests.push(Request { - peer_id, - protocol: Protocol::Identify(connection_id), - }); - } handler::Event::IdentificationError(error) => { self.events .push_back(ToSwarm::GenerateEvent(Event::Error { peer_id, error })); @@ -331,50 +318,13 @@ impl NetworkBehaviour for Behaviour { fn poll( &mut self, _cx: &mut Context<'_>, - params: &mut impl PollParameters, + _: &mut impl PollParameters, ) -> Poll>> { if let Some(event) = self.events.pop_front() { return Poll::Ready(event); } - // Check for pending requests. - match self.requests.pop() { - Some(Request { - peer_id, - protocol: Protocol::Push, - }) => Poll::Ready(ToSwarm::NotifyHandler { - peer_id, - handler: NotifyHandler::Any, - event: InEvent { - listen_addrs: self - .listen_addresses - .iter() - .chain(self.external_addresses.iter()) - .cloned() - .collect(), - supported_protocols: supported_protocols(params), - protocol: Protocol::Push, - }, - }), - Some(Request { - peer_id, - protocol: Protocol::Identify(connection_id), - }) => Poll::Ready(ToSwarm::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection_id), - event: InEvent { - listen_addrs: self - .listen_addresses - .iter() - .chain(self.external_addresses.iter()) - .cloned() - .collect(), - supported_protocols: supported_protocols(params), - protocol: Protocol::Identify(connection_id), - }, - }), - None => Poll::Pending, - } + Poll::Pending } fn handle_pending_outbound_connection( @@ -393,8 +343,35 @@ impl NetworkBehaviour for Behaviour { } fn on_swarm_event(&mut self, event: FromSwarm) { - self.listen_addresses.on_swarm_event(&event); - self.external_addresses.on_swarm_event(&event); + let listen_addr_changed = self.listen_addresses.on_swarm_event(&event); + let external_addr_changed = self.external_addresses.on_swarm_event(&event); + + if listen_addr_changed || external_addr_changed { + // notify all connected handlers about our changed addresses + let change_events = self + .connected + .iter() + .flat_map(|(peer, map)| map.keys().map(|id| (*peer, id))) + .map(|(peer_id, connection_id)| ToSwarm::NotifyHandler { + peer_id, + handler: NotifyHandler::One(*connection_id), + event: InEvent::AddressesChanged(self.all_addresses()), + }) + .collect::>(); + + self.events.extend(change_events) + } + + if listen_addr_changed && self.config.push_listen_addr_updates { + // trigger an identify push for all connected peers + let push_events = self.connected.keys().map(|peer| ToSwarm::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::Any, + event: InEvent::Push, + }); + + self.events.extend(push_events); + } match event { FromSwarm::ConnectionEstablished(connection_established) => { @@ -408,30 +385,11 @@ impl NetworkBehaviour for Behaviour { }) => { if remaining_established == 0 { self.connected.remove(&peer_id); - self.requests.retain(|request| { - request - != &Request { - peer_id, - protocol: Protocol::Push, - } - }); } else if let Some(addrs) = self.connected.get_mut(&peer_id) { addrs.remove(&connection_id); } } FromSwarm::DialFailure(DialFailure { peer_id, error, .. }) => { - if let Some(peer_id) = peer_id { - if !self.connected.contains_key(&peer_id) { - self.requests.retain(|request| { - request - != &Request { - peer_id, - protocol: Protocol::Push, - } - }); - } - } - if let Some(entry) = peer_id.and_then(|id| self.discovered_peers.get_mut(&id)) { if let DialError::Transport(errors) = error { for (addr, _error) in errors { @@ -440,20 +398,9 @@ impl NetworkBehaviour for Behaviour { } } } - FromSwarm::NewListenAddr(_) | FromSwarm::ExpiredListenAddr(_) => { - if self.config.push_listen_addr_updates { - for p in self.connected.keys() { - let request = Request { - peer_id: *p, - protocol: Protocol::Push, - }; - if !self.requests.contains(&request) { - self.requests.push(request); - } - } - } - } - FromSwarm::AddressChange(_) + FromSwarm::NewListenAddr(_) + | FromSwarm::ExpiredListenAddr(_) + | FromSwarm::AddressChange(_) | FromSwarm::ListenFailure(_) | FromSwarm::NewListener(_) | FromSwarm::ListenerError(_) @@ -496,17 +443,6 @@ pub enum Event { }, } -fn supported_protocols(params: &impl PollParameters) -> Vec { - // The protocol names can be bytes, but the identify protocol except UTF-8 strings. - // There's not much we can do to solve this conflict except strip non-UTF-8 characters. - params - .supported_protocols() - .filter_map(|p| { - StreamProtocol::try_from_owned(String::from_utf8_lossy(&p).to_string()).ok() - }) - .collect() -} - /// If there is a given peer_id in the multiaddr, make sure it is the same as /// the given peer_id. If there is no peer_id for the peer in the mutiaddr, this returns true. fn multiaddr_matches_peer_id(addr: &Multiaddr, peer_id: &PeerId) -> bool { diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index 96f15924ee7..576585b38e3 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -18,9 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::{ - self, Identify, InboundPush, Info, OutboundPush, Protocol, Push, UpgradeError, -}; +use crate::protocol::{Identify, InboundPush, Info, OutboundPush, Push, UpgradeError}; use either::Either; use futures::future::BoxFuture; use futures::prelude::*; @@ -32,15 +30,16 @@ use libp2p_identity::PeerId; use libp2p_identity::PublicKey; use libp2p_swarm::handler::{ ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, + ProtocolSupport, }; use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, KeepAlive, NegotiatedSubstream, StreamProtocol, - StreamUpgradeError, SubstreamProtocol, + ConnectionHandler, ConnectionHandlerEvent, KeepAlive, StreamProtocol, StreamUpgradeError, + SubstreamProtocol, SupportedProtocols, }; use log::warn; use smallvec::SmallVec; -use std::collections::{HashSet, VecDeque}; -use std::{io, pin::Pin, task::Context, task::Poll, time::Duration}; +use std::collections::HashSet; +use std::{io, task::Context, task::Poll, time::Duration}; /// Protocol handler for sending and receiving identification requests. /// @@ -55,9 +54,6 @@ pub struct Handler { [ConnectionHandlerEvent>, (), Event, io::Error>; 4], >, - /// Streams awaiting `BehaviourInfo` to then send identify requests. - reply_streams: VecDeque, - /// Pending identification replies, awaiting being sent. pending_replies: FuturesUnordered>>, @@ -80,19 +76,17 @@ pub struct Handler { /// Address observed by or for the remote. observed_addr: Multiaddr, + + local_supported_protocols: SupportedProtocols, + remote_supported_protocols: HashSet, + external_addresses: HashSet, } /// An event from `Behaviour` with the information requested by the `Handler`. #[derive(Debug)] -pub struct InEvent { - /// The addresses that the peer is listening on. - pub listen_addrs: HashSet, - - /// The list of protocols supported by the peer, e.g. `/ipfs/ping/1.0.0`. - pub supported_protocols: Vec, - - /// The protocol w.r.t. the information requested. - pub protocol: Protocol, +pub enum InEvent { + AddressesChanged(HashSet), + Push, } /// Event produced by the `Handler`. @@ -105,14 +99,13 @@ pub enum Event { Identification(PeerId), /// We actively pushed our identification information to the remote. IdentificationPushed, - /// We received a request for identification. - Identify, /// Failed to identify the remote, or to reply to an identification request. IdentificationError(StreamUpgradeError), } impl Handler { /// Creates a new `Handler`. + #[allow(clippy::too_many_arguments)] pub fn new( initial_delay: Duration, interval: Duration, @@ -121,12 +114,12 @@ impl Handler { protocol_version: String, agent_version: String, observed_addr: Multiaddr, + external_addresses: HashSet, ) -> Self { Self { remote_peer_id, inbound_identify_push: Default::default(), events: SmallVec::new(), - reply_streams: VecDeque::new(), pending_replies: FuturesUnordered::new(), trigger_next_identify: Delay::new(initial_delay), interval, @@ -134,6 +127,9 @@ impl Handler { protocol_version, agent_version, observed_addr, + local_supported_protocols: SupportedProtocols::default(), + remote_supported_protocols: HashSet::default(), + external_addresses, } } @@ -148,16 +144,14 @@ impl Handler { ) { match output { future::Either::Left(substream) => { - self.events - .push(ConnectionHandlerEvent::Custom(Event::Identify)); - if !self.reply_streams.is_empty() { - warn!( - "New inbound identify request from {} while a previous one \ - is still pending. Queueing the new one.", - self.remote_peer_id, - ); - } - self.reply_streams.push_back(substream); + let peer_id = self.remote_peer_id; + let info = self.build_info(); + + self.pending_replies.push(Box::pin(async move { + crate::protocol::send(substream, info).await?; + + Ok(peer_id) + })); } future::Either::Right(fut) => { if self.inbound_identify_push.replace(fut).is_some() { @@ -182,6 +176,7 @@ impl Handler { ) { match output { future::Either::Left(remote_info) => { + self.update_supported_protocols_for_remote(&remote_info); self.events .push(ConnectionHandlerEvent::Custom(Event::Identified( remote_info, @@ -207,6 +202,47 @@ impl Handler { ))); self.trigger_next_identify.reset(self.interval); } + + fn build_info(&mut self) -> Info { + Info { + public_key: self.public_key.clone(), + protocol_version: self.protocol_version.clone(), + agent_version: self.agent_version.clone(), + listen_addrs: Vec::from_iter(self.external_addresses.iter().cloned()), + protocols: Vec::from_iter(self.local_supported_protocols.iter().cloned()), + observed_addr: self.observed_addr.clone(), + } + } + + fn update_supported_protocols_for_remote(&mut self, remote_info: &Info) { + let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone()); + + let remote_added_protocols = new_remote_protocols + .difference(&self.remote_supported_protocols) + .cloned() + .collect::>(); + let remote_removed_protocols = self + .remote_supported_protocols + .difference(&new_remote_protocols) + .cloned() + .collect::>(); + + if !remote_added_protocols.is_empty() { + self.events + .push(ConnectionHandlerEvent::ReportRemoteProtocols( + ProtocolSupport::Added(remote_added_protocols), + )); + } + + if !remote_removed_protocols.is_empty() { + self.events + .push(ConnectionHandlerEvent::ReportRemoteProtocols( + ProtocolSupport::Removed(remote_removed_protocols), + )); + } + + self.remote_supported_protocols = new_remote_protocols; + } } impl ConnectionHandler for Handler { @@ -222,42 +258,18 @@ impl ConnectionHandler for Handler { SubstreamProtocol::new(SelectUpgrade::new(Identify, Push::inbound()), ()) } - fn on_behaviour_event( - &mut self, - InEvent { - listen_addrs, - supported_protocols, - protocol, - }: Self::InEvent, - ) { - let info = Info { - public_key: self.public_key.clone(), - protocol_version: self.protocol_version.clone(), - agent_version: self.agent_version.clone(), - listen_addrs: Vec::from_iter(listen_addrs), - protocols: supported_protocols, - observed_addr: self.observed_addr.clone(), - }; - - match protocol { - Protocol::Push => { + fn on_behaviour_event(&mut self, event: Self::InEvent) { + match event { + InEvent::AddressesChanged(addresses) => { + self.external_addresses = addresses; + } + InEvent::Push => { + let info = self.build_info(); self.events .push(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new(Either::Right(Push::outbound(info)), ()), }); } - Protocol::Identify(_) => { - let substream = self - .reply_streams - .pop_front() - .expect("A BehaviourInfo reply should have a matching substream."); - let peer = self.remote_peer_id; - let fut = Box::pin(async move { - protocol::send(substream, info).await?; - Ok(peer) - }); - self.pending_replies.push(fut); - } } } @@ -270,10 +282,6 @@ impl ConnectionHandler for Handler { return KeepAlive::Yes; } - if !self.reply_streams.is_empty() { - return KeepAlive::Yes; - } - KeepAlive::No } @@ -283,20 +291,17 @@ impl ConnectionHandler for Handler { ) -> Poll< ConnectionHandlerEvent, > { - if !self.events.is_empty() { - return Poll::Ready(self.events.remove(0)); + if let Some(event) = self.events.pop() { + return Poll::Ready(event); } // Poll the future that fires when we need to identify the node again. - match Future::poll(Pin::new(&mut self.trigger_next_identify), cx) { - Poll::Pending => {} - Poll::Ready(()) => { - self.trigger_next_identify.reset(self.interval); - let ev = ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(Either::Left(Identify), ()), - }; - return Poll::Ready(ev); - } + if let Poll::Ready(()) = self.trigger_next_identify.poll_unpin(cx) { + self.trigger_next_identify.reset(self.interval); + let ev = ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(Either::Left(Identify), ()), + }; + return Poll::Ready(ev); } if let Some(Poll::Ready(res)) = self @@ -307,20 +312,21 @@ impl ConnectionHandler for Handler { self.inbound_identify_push.take(); if let Ok(info) = res { + self.update_supported_protocols_for_remote(&info); return Poll::Ready(ConnectionHandlerEvent::Custom(Event::Identified(info))); } } // Check for pending replies to send. - match self.pending_replies.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(peer_id))) => Poll::Ready(ConnectionHandlerEvent::Custom( - Event::Identification(peer_id), - )), - Poll::Ready(Some(Err(err))) => Poll::Ready(ConnectionHandlerEvent::Custom( - Event::IdentificationError(StreamUpgradeError::Apply(err)), - )), - Poll::Ready(None) | Poll::Pending => Poll::Pending, + if let Poll::Ready(Some(result)) = self.pending_replies.poll_next_unpin(cx) { + let event = result + .map(Event::Identification) + .unwrap_or_else(|err| Event::IdentificationError(StreamUpgradeError::Apply(err))); + + return Poll::Ready(ConnectionHandlerEvent::Custom(event)); } + + Poll::Pending } fn on_connection_event( @@ -342,7 +348,12 @@ impl ConnectionHandler for Handler { ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} + ConnectionEvent::LocalProtocolsChange(change) => { + self.local_supported_protocols.on_protocols_change(change); + } } } } diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index bedebe61e47..a508591b106 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -28,7 +28,7 @@ use libp2p_core::{ }; use libp2p_identity as identity; use libp2p_identity::PublicKey; -use libp2p_swarm::{ConnectionId, StreamProtocol}; +use libp2p_swarm::StreamProtocol; use log::{debug, trace}; use std::convert::TryFrom; use std::{io, iter, pin::Pin}; @@ -41,13 +41,6 @@ pub const PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/ipfs/id/1.0.0"); pub const PUSH_PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/ipfs/id/push/1.0.0"); -/// The type of the Substream protocol. -#[derive(Debug, PartialEq, Eq)] -pub enum Protocol { - Identify(ConnectionId), - Push, -} - /// Substream upgrade protocol for `/ipfs/id/1.0.0`. #[derive(Debug, Clone)] pub struct Identify; diff --git a/protocols/kad/src/handler_priv.rs b/protocols/kad/src/handler_priv.rs index 3ac829600a4..3fa123410ee 100644 --- a/protocols/kad/src/handler_priv.rs +++ b/protocols/kad/src/handler_priv.rs @@ -777,7 +777,10 @@ where ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/protocols/perf/src/client/handler.rs b/protocols/perf/src/client/handler.rs index 437d63f659b..a87e82cc384 100644 --- a/protocols/perf/src/client/handler.rs +++ b/protocols/perf/src/client/handler.rs @@ -137,7 +137,9 @@ impl ConnectionHandler for Handler { ); } - ConnectionEvent::AddressChange(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} ConnectionEvent::DialUpgradeError(DialUpgradeError { info: (), error }) => { let Command { id, .. } = self .requested_streams diff --git a/protocols/perf/src/server/handler.rs b/protocols/perf/src/server/handler.rs index d279c0d1700..ab70ba9aa6e 100644 --- a/protocols/perf/src/server/handler.rs +++ b/protocols/perf/src/server/handler.rs @@ -103,7 +103,9 @@ impl ConnectionHandler for Handler { ConnectionEvent::DialUpgradeError(DialUpgradeError { info, .. }) => { void::unreachable(info) } - ConnectionEvent::AddressChange(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} ConnectionEvent::ListenUpgradeError(ListenUpgradeError { info: (), error }) => { void::unreachable(error) } diff --git a/protocols/ping/src/handler.rs b/protocols/ping/src/handler.rs index 22aee9d5d3b..d54bfe39cc2 100644 --- a/protocols/ping/src/handler.rs +++ b/protocols/ping/src/handler.rs @@ -410,7 +410,10 @@ impl ConnectionHandler for Handler { ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/protocols/relay/src/behaviour/handler.rs b/protocols/relay/src/behaviour/handler.rs index 580a69c7f02..ff2abc65aa0 100644 --- a/protocols/relay/src/behaviour/handler.rs +++ b/protocols/relay/src/behaviour/handler.rs @@ -898,7 +898,9 @@ impl ConnectionHandler for Handler { ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) } - ConnectionEvent::AddressChange(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/protocols/relay/src/priv_client/handler.rs b/protocols/relay/src/priv_client/handler.rs index a2178ce4983..c134031ad7c 100644 --- a/protocols/relay/src/priv_client/handler.rs +++ b/protocols/relay/src/priv_client/handler.rs @@ -541,7 +541,9 @@ impl ConnectionHandler for Handler { ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) } - ConnectionEvent::AddressChange(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/protocols/rendezvous/src/substream_handler.rs b/protocols/rendezvous/src/substream_handler.rs index d2a1651cd52..e4645449795 100644 --- a/protocols/rendezvous/src/substream_handler.rs +++ b/protocols/rendezvous/src/substream_handler.rs @@ -397,7 +397,9 @@ where // TODO: Handle upgrade errors properly ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) - | ConnectionEvent::DialUpgradeError(_) => {} + | ConnectionEvent::DialUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 7ee1b13f260..3a323d75edc 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -382,7 +382,9 @@ where ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { self.on_listen_upgrade_error(listen_upgrade_error) } - ConnectionEvent::AddressChange(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/swarm/CHANGELOG.md b/swarm/CHANGELOG.md index 6691876e7f6..14a584217ad 100644 --- a/swarm/CHANGELOG.md +++ b/swarm/CHANGELOG.md @@ -34,7 +34,16 @@ Users should migrate to `libp2p::connection_limits::Behaviour`. See [PR 3885]. +- Allow `ConnectionHandler`s to report and learn about the supported protocols on a connection. + The newly introduced API elements are: + - `ConnectionHandlerEvent::ReportRemoteProtocols` + - `ConnectionEvent::LocalProtocolsChange` + - `ConnectionEvent::RemoteProtocolsChange` + + See [PR 3651]. + [PR 3605]: https://github.com/libp2p/rust-libp2p/pull/3605 +[PR 3651]: https://github.com/libp2p/rust-libp2p/pull/3651 [PR 3715]: https://github.com/libp2p/rust-libp2p/pull/3715 [PR 3746]: https://github.com/libp2p/rust-libp2p/pull/3746 [PR 3865]: https://github.com/libp2p/rust-libp2p/pull/3865 diff --git a/swarm/Cargo.toml b/swarm/Cargo.toml index 722f3610b95..699e0c74b83 100644 --- a/swarm/Cargo.toml +++ b/swarm/Cargo.toml @@ -25,6 +25,7 @@ smallvec = "1.6.1" void = "1" wasm-bindgen-futures = { version = "0.4.34", optional = true } getrandom = { version = "0.2.9", features = ["js"], optional = true } # Explicit dependency to be used in `wasm-bindgen` feature +once_cell = "1.17.1" [target.'cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))'.dependencies] async-std = { version = "1.6.2", optional = true } diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index 01a10688c98..9953b08793b 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -230,6 +230,9 @@ pub trait PollParameters { /// The iterator's elements are the ASCII names as reported on the wire. /// /// Note that the list is computed once at initialization and never refreshed. + #[deprecated( + note = "Use `libp2p_swarm::SupportedProtocols` in your `ConnectionHandler` instead." + )] fn supported_protocols(&self) -> Self::SupportedProtocolsIter; /// Returns the list of the addresses we're listening on. diff --git a/swarm/src/behaviour/toggle.rs b/swarm/src/behaviour/toggle.rs index 17798e535aa..0e687fe9355 100644 --- a/swarm/src/behaviour/toggle.rs +++ b/swarm/src/behaviour/toggle.rs @@ -362,6 +362,16 @@ where ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { self.on_listen_upgrade_error(listen_upgrade_error) } + ConnectionEvent::LocalProtocolsChange(change) => { + if let Some(inner) = self.inner.as_mut() { + inner.on_connection_event(ConnectionEvent::LocalProtocolsChange(change)); + } + } + ConnectionEvent::RemoteProtocolsChange(change) => { + if let Some(inner) = self.inner.as_mut() { + inner.on_connection_event(ConnectionEvent::RemoteProtocolsChange(change)); + } + } } } } diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index a4620522dba..32a24161393 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -21,18 +21,23 @@ mod error; pub(crate) mod pool; +mod supported_protocols; pub use error::ConnectionError; pub(crate) use error::{ PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError, }; +pub use supported_protocols::SupportedProtocols; use crate::handler::{ AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound, - FullyNegotiatedOutbound, ListenUpgradeError, + FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange, + UpgradeInfoSend, }; use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper}; -use crate::{ConnectionHandlerEvent, KeepAlive, StreamUpgradeError, SubstreamProtocol}; +use crate::{ + ConnectionHandlerEvent, KeepAlive, StreamProtocol, StreamUpgradeError, SubstreamProtocol, +}; use futures::stream::FuturesUnordered; use futures::FutureExt; use futures::StreamExt; @@ -47,6 +52,7 @@ use libp2p_core::upgrade::{ }; use libp2p_core::Endpoint; use libp2p_identity::PeerId; +use std::collections::HashSet; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::Waker; @@ -147,6 +153,9 @@ where requested_substreams: FuturesUnordered< SubstreamRequested, >, + + local_supported_protocols: HashSet, + remote_supported_protocols: HashSet, } impl fmt::Debug for Connection @@ -171,10 +180,18 @@ where /// and connection handler. pub(crate) fn new( muxer: StreamMuxerBox, - handler: THandler, + mut handler: THandler, substream_upgrade_protocol_override: Option, max_negotiating_inbound_streams: usize, ) -> Self { + let initial_protocols = gather_supported_protocols(&handler); + + if !initial_protocols.is_empty() { + handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( + ProtocolsChange::Added(ProtocolsAdded::from_set(&initial_protocols)), + )); + } + Connection { muxing: muxer, handler, @@ -184,6 +201,8 @@ where substream_upgrade_protocol_override, max_negotiating_inbound_streams, requested_substreams: Default::default(), + local_supported_protocols: initial_protocols, + remote_supported_protocols: Default::default(), } } @@ -213,6 +232,8 @@ where shutdown, max_negotiating_inbound_streams, substream_upgrade_protocol_override, + local_supported_protocols: supported_protocols, + remote_supported_protocols, } = self.get_mut(); loop { @@ -246,6 +267,31 @@ where Poll::Ready(ConnectionHandlerEvent::Close(err)) => { return Poll::Ready(Err(ConnectionError::Handler(err))); } + Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( + ProtocolSupport::Added(protocols), + )) => { + if let Some(added) = + ProtocolsChange::add(remote_supported_protocols, &protocols) + { + handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added)); + remote_supported_protocols.extend(protocols); + } + + continue; + } + Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( + ProtocolSupport::Removed(protocols), + )) => { + if let Some(removed) = + ProtocolsChange::remove(remote_supported_protocols, &protocols) + { + handler + .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed)); + remote_supported_protocols.retain(|p| !protocols.contains(p)); + } + + continue; + } } // In case the [`ConnectionHandler`] can not make any more progress, poll the negotiating outbound streams. @@ -376,9 +422,33 @@ where } } + let new_protocols = gather_supported_protocols(handler); + + for change in ProtocolsChange::from_full_sets(supported_protocols, &new_protocols) { + handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change)); + } + + *supported_protocols = new_protocols; + return Poll::Pending; // Nothing can make progress, return `Pending`. } } + + #[cfg(test)] + fn poll_noop_waker( + &mut self, + ) -> Poll, ConnectionError>> { + Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref())) + } +} + +fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet { + handler + .listen_protocol() + .upgrade() + .protocol_info() + .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()) + .collect() } /// Borrowed information about an incoming connection currently being negotiated. @@ -605,9 +675,10 @@ enum Shutdown { mod tests { use super::*; use crate::keep_alive; + use futures::future; use futures::AsyncRead; use futures::AsyncWrite; - use libp2p_core::upgrade::DeniedUpgrade; + use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p_core::StreamMuxer; use quickcheck::*; use std::sync::{Arc, Weak}; @@ -629,8 +700,7 @@ mod tests { max_negotiating_inbound_streams, ); - let result = Pin::new(&mut connection) - .poll(&mut Context::from_waker(futures::task::noop_waker_ref())); + let result = connection.poll_noop_waker(); assert!(result.is_pending()); assert_eq!( @@ -654,13 +724,11 @@ mod tests { ); connection.handler.open_new_outbound(); - let _ = Pin::new(&mut connection) - .poll(&mut Context::from_waker(futures::task::noop_waker_ref())); + let _ = connection.poll_noop_waker(); std::thread::sleep(upgrade_timeout + Duration::from_secs(1)); - let _ = Pin::new(&mut connection) - .poll(&mut Context::from_waker(futures::task::noop_waker_ref())); + let _ = connection.poll_noop_waker(); assert!(matches!( connection.handler.error.unwrap(), @@ -668,6 +736,94 @@ mod tests { )) } + #[test] + fn propagates_changes_to_supported_inbound_protocols() { + let mut connection = Connection::new( + StreamMuxerBox::new(PendingStreamMuxer), + ConfigurableProtocolConnectionHandler::default(), + None, + 0, + ); + + // First, start listening on a single protocol. + connection.handler.listen_on(&["/foo"]); + let _ = connection.poll_noop_waker(); + + assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]); + assert!(connection.handler.local_removed.is_empty()); + + // Second, listen on two protocols. + connection.handler.listen_on(&["/foo", "/bar"]); + let _ = connection.poll_noop_waker(); + + assert_eq!( + connection.handler.local_added, + vec![vec!["/foo"], vec!["/bar"]], + "expect to only receive an event for the newly added protocols" + ); + assert!(connection.handler.local_removed.is_empty()); + + // Third, stop listening on the first protocol. + connection.handler.listen_on(&["/bar"]); + let _ = connection.poll_noop_waker(); + + assert_eq!( + connection.handler.local_added, + vec![vec!["/foo"], vec!["/bar"]] + ); + assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]); + } + + #[test] + fn only_propagtes_actual_changes_to_remote_protocols_to_handler() { + let mut connection = Connection::new( + StreamMuxerBox::new(PendingStreamMuxer), + ConfigurableProtocolConnectionHandler::default(), + None, + 0, + ); + + // First, remote supports a single protocol. + connection.handler.remote_adds_support_for(&["/foo"]); + let _ = connection.poll_noop_waker(); + + assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]); + assert!(connection.handler.remote_removed.is_empty()); + + // Second, it adds a protocol but also still includes the first one. + connection + .handler + .remote_adds_support_for(&["/foo", "/bar"]); + let _ = connection.poll_noop_waker(); + + assert_eq!( + connection.handler.remote_added, + vec![vec!["/foo"], vec!["/bar"]], + "expect to only receive an event for the newly added protocol" + ); + assert!(connection.handler.remote_removed.is_empty()); + + // Third, stop listening on a protocol it never advertised (we can't control what handlers do so this needs to be handled gracefully). + connection.handler.remote_removes_support_for(&["/baz"]); + let _ = connection.poll_noop_waker(); + + assert_eq!( + connection.handler.remote_added, + vec![vec!["/foo"], vec!["/bar"]] + ); + assert!(&connection.handler.remote_removed.is_empty()); + + // Fourth, stop listening on a protocol that was previously supported + connection.handler.remote_removes_support_for(&["/bar"]); + let _ = connection.poll_noop_waker(); + + assert_eq!( + connection.handler.remote_added, + vec![vec!["/foo"], vec!["/bar"]] + ); + assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]); + } + struct DummyStreamMuxer { counter: Arc<()>, } @@ -785,6 +941,40 @@ mod tests { } } + #[derive(Default)] + struct ConfigurableProtocolConnectionHandler { + events: Vec>, + active_protocols: HashSet, + local_added: Vec>, + local_removed: Vec>, + remote_added: Vec>, + remote_removed: Vec>, + } + + impl ConfigurableProtocolConnectionHandler { + fn listen_on(&mut self, protocols: &[&'static str]) { + self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect(); + } + + fn remote_adds_support_for(&mut self, protocols: &[&'static str]) { + self.events + .push(ConnectionHandlerEvent::ReportRemoteProtocols( + ProtocolSupport::Added( + protocols.iter().copied().map(StreamProtocol::new).collect(), + ), + )); + } + + fn remote_removes_support_for(&mut self, protocols: &[&'static str]) { + self.events + .push(ConnectionHandlerEvent::ReportRemoteProtocols( + ProtocolSupport::Removed( + protocols.iter().copied().map(StreamProtocol::new).collect(), + ), + )); + } + } + impl ConnectionHandler for MockConnectionHandler { type InEvent = Void; type OutEvent = Void; @@ -821,7 +1011,10 @@ mod tests { ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => { self.error = Some(error) } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } @@ -855,6 +1048,112 @@ mod tests { Poll::Pending } } + + impl ConnectionHandler for ConfigurableProtocolConnectionHandler { + type InEvent = Void; + type OutEvent = Void; + type Error = Void; + type InboundProtocol = ManyProtocolsUpgrade; + type OutboundProtocol = DeniedUpgrade; + type InboundOpenInfo = (); + type OutboundOpenInfo = (); + + fn listen_protocol( + &self, + ) -> SubstreamProtocol { + SubstreamProtocol::new( + ManyProtocolsUpgrade { + protocols: Vec::from_iter(self.active_protocols.clone()), + }, + (), + ) + } + + fn on_connection_event( + &mut self, + event: ConnectionEvent< + Self::InboundProtocol, + Self::OutboundProtocol, + Self::InboundOpenInfo, + Self::OutboundOpenInfo, + >, + ) { + match event { + ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => { + self.local_added.push(added.cloned().collect()) + } + ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => { + self.local_removed.push(removed.cloned().collect()) + } + ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => { + self.remote_added.push(added.cloned().collect()) + } + ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => { + self.remote_removed.push(removed.cloned().collect()) + } + _ => {} + } + } + + fn on_behaviour_event(&mut self, event: Self::InEvent) { + void::unreachable(event) + } + + fn connection_keep_alive(&self) -> KeepAlive { + KeepAlive::Yes + } + + fn poll( + &mut self, + _: &mut Context<'_>, + ) -> Poll< + ConnectionHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, + > { + if let Some(event) = self.events.pop() { + return Poll::Ready(event); + } + + Poll::Pending + } + } + + struct ManyProtocolsUpgrade { + protocols: Vec, + } + + impl UpgradeInfo for ManyProtocolsUpgrade { + type Info = StreamProtocol; + type InfoIter = std::vec::IntoIter; + + fn protocol_info(&self) -> Self::InfoIter { + self.protocols.clone().into_iter() + } + } + + impl InboundUpgrade for ManyProtocolsUpgrade { + type Output = C; + type Error = Void; + type Future = future::Ready>; + + fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future { + future::ready(Ok(stream)) + } + } + + impl OutboundUpgrade for ManyProtocolsUpgrade { + type Output = C; + type Error = Void; + type Future = future::Ready>; + + fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future { + future::ready(Ok(stream)) + } + } } /// The endpoint roles associated with a pending peer-to-peer connection. diff --git a/swarm/src/connection/supported_protocols.rs b/swarm/src/connection/supported_protocols.rs new file mode 100644 index 00000000000..0575046bb44 --- /dev/null +++ b/swarm/src/connection/supported_protocols.rs @@ -0,0 +1,88 @@ +use crate::handler::ProtocolsChange; +use crate::StreamProtocol; +use std::collections::HashSet; + +#[derive(Default, Clone, Debug)] +pub struct SupportedProtocols { + protocols: HashSet, +} + +impl SupportedProtocols { + pub fn on_protocols_change(&mut self, change: ProtocolsChange) -> bool { + match change { + ProtocolsChange::Added(added) => { + let mut changed = false; + + for p in added { + changed |= self.protocols.insert(p.clone()); + } + + changed + } + ProtocolsChange::Removed(removed) => { + let mut changed = false; + + for p in removed { + changed |= self.protocols.remove(p); + } + + changed + } + } + } + + pub fn iter(&self) -> impl Iterator { + self.protocols.iter() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::handler::{ProtocolsAdded, ProtocolsRemoved}; + use once_cell::sync::Lazy; + + #[test] + fn protocols_change_added_returns_correct_changed_value() { + let mut protocols = SupportedProtocols::default(); + + let changed = protocols.on_protocols_change(add_foo()); + assert!(changed); + + let changed = protocols.on_protocols_change(add_foo()); + assert!(!changed); + + let changed = protocols.on_protocols_change(add_foo_bar()); + assert!(changed); + } + + #[test] + fn protocols_change_removed_returns_correct_changed_value() { + let mut protocols = SupportedProtocols::default(); + + let changed = protocols.on_protocols_change(remove_foo()); + assert!(!changed); + + protocols.on_protocols_change(add_foo()); + + let changed = protocols.on_protocols_change(remove_foo()); + assert!(changed); + } + + fn add_foo() -> ProtocolsChange<'static> { + ProtocolsChange::Added(ProtocolsAdded::from_set(&FOO_PROTOCOLS)) + } + + fn add_foo_bar() -> ProtocolsChange<'static> { + ProtocolsChange::Added(ProtocolsAdded::from_set(&FOO_BAR_PROTOCOLS)) + } + + fn remove_foo() -> ProtocolsChange<'static> { + ProtocolsChange::Removed(ProtocolsRemoved::from_set(&FOO_PROTOCOLS)) + } + + static FOO_PROTOCOLS: Lazy> = + Lazy::new(|| HashSet::from([StreamProtocol::new("/foo")])); + static FOO_BAR_PROTOCOLS: Lazy> = + Lazy::new(|| HashSet::from([StreamProtocol::new("/foo"), StreamProtocol::new("/bar")])); +} diff --git a/swarm/src/dummy.rs b/swarm/src/dummy.rs index c605e80ea5d..83d03a7aadb 100644 --- a/swarm/src/dummy.rs +++ b/swarm/src/dummy.rs @@ -138,7 +138,10 @@ impl crate::handler::ConnectionHandler for ConnectionHandler { unreachable!("Denied upgrade does not support any protocols") } }, - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/swarm/src/handler.rs b/swarm/src/handler.rs index e595dccde7b..2ce58638bd1 100644 --- a/swarm/src/handler.rs +++ b/swarm/src/handler.rs @@ -47,17 +47,24 @@ mod pending; mod select; pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend}; - -use instant::Instant; -use libp2p_core::Multiaddr; -use std::{cmp::Ordering, error, fmt, io, task::Context, task::Poll, time::Duration}; - pub use map_in::MapInEvent; pub use map_out::MapOutEvent; pub use one_shot::{OneShotHandler, OneShotHandlerConfig}; pub use pending::PendingConnectionHandler; pub use select::ConnectionHandlerSelect; +use crate::StreamProtocol; +use ::either::Either; +use instant::Instant; +use libp2p_core::Multiaddr; +use once_cell::sync::Lazy; +use smallvec::SmallVec; +use std::collections::hash_map::RandomState; +use std::collections::hash_set::{Difference, Intersection}; +use std::collections::HashSet; +use std::iter::Peekable; +use std::{cmp::Ordering, error, fmt, io, task::Context, task::Poll, time::Duration}; + /// A handler for a set of protocols used on a connection with a remote. /// /// This trait should be implemented for a type that maintains the state for @@ -209,6 +216,10 @@ pub enum ConnectionEvent<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IO DialUpgradeError(DialUpgradeError), /// Informs the handler that upgrading an inbound substream to the given protocol has failed. ListenUpgradeError(ListenUpgradeError), + /// The local [`ConnectionHandler`] added or removed support for one or more protocols. + LocalProtocolsChange(ProtocolsChange<'a>), + /// The remote [`ConnectionHandler`] now supports a different set of protocols. + RemoteProtocolsChange(ProtocolsChange<'a>), } impl<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI> @@ -222,6 +233,8 @@ impl<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI> } ConnectionEvent::FullyNegotiatedInbound(_) | ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) | ConnectionEvent::ListenUpgradeError(_) => false, } } @@ -234,6 +247,8 @@ impl<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI> } ConnectionEvent::FullyNegotiatedOutbound(_) | ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) | ConnectionEvent::DialUpgradeError(_) => false, } } @@ -266,6 +281,122 @@ pub struct AddressChange<'a> { pub new_address: &'a Multiaddr, } +/// [`ConnectionEvent`] variant that informs the handler about a change in the protocols supported on the connection. +#[derive(Clone)] +pub enum ProtocolsChange<'a> { + Added(ProtocolsAdded<'a>), + Removed(ProtocolsRemoved<'a>), +} + +impl<'a> ProtocolsChange<'a> { + /// Compute the [`ProtocolsChange`] that results from adding `to_add` to `existing_protocols`. + /// + /// Returns `None` if the change is a no-op, i.e. `to_add` is a subset of `existing_protocols`. + pub(crate) fn add( + existing_protocols: &'a HashSet, + to_add: &'a HashSet, + ) -> Option { + let mut actually_added_protocols = to_add.difference(existing_protocols).peekable(); + + actually_added_protocols.peek()?; + + Some(ProtocolsChange::Added(ProtocolsAdded { + protocols: actually_added_protocols, + })) + } + + /// Compute the [`ProtocolsChange`] that results from removing `to_remove` from `existing_protocols`. + /// + /// Returns `None` if the change is a no-op, i.e. none of the protocols in `to_remove` are in `existing_protocols`. + pub(crate) fn remove( + existing_protocols: &'a HashSet, + to_remove: &'a HashSet, + ) -> Option { + let mut actually_removed_protocols = existing_protocols.intersection(to_remove).peekable(); + + actually_removed_protocols.peek()?; + + Some(ProtocolsChange::Removed(ProtocolsRemoved { + protocols: Either::Right(actually_removed_protocols), + })) + } + + /// Compute the [`ProtocolsChange`]s required to go from `existing_protocols` to `new_protocols`. + pub(crate) fn from_full_sets( + existing_protocols: &'a HashSet, + new_protocols: &'a HashSet, + ) -> SmallVec<[Self; 2]> { + if existing_protocols == new_protocols { + return SmallVec::new(); + } + + let mut changes = SmallVec::new(); + + let mut added_protocols = new_protocols.difference(existing_protocols).peekable(); + let mut removed_protocols = existing_protocols.difference(new_protocols).peekable(); + + if added_protocols.peek().is_some() { + changes.push(ProtocolsChange::Added(ProtocolsAdded { + protocols: added_protocols, + })); + } + + if removed_protocols.peek().is_some() { + changes.push(ProtocolsChange::Removed(ProtocolsRemoved { + protocols: Either::Left(removed_protocols), + })); + } + + changes + } +} + +/// An [`Iterator`] over all protocols that have been added. +#[derive(Clone)] +pub struct ProtocolsAdded<'a> { + protocols: Peekable>, +} + +impl<'a> ProtocolsAdded<'a> { + pub(crate) fn from_set(protocols: &'a HashSet) -> Self { + ProtocolsAdded { + protocols: protocols.difference(&EMPTY_HASHSET).peekable(), + } + } +} + +/// An [`Iterator`] over all protocols that have been removed. +#[derive(Clone)] +pub struct ProtocolsRemoved<'a> { + protocols: Either< + Peekable>, + Peekable>, + >, +} + +impl<'a> ProtocolsRemoved<'a> { + #[cfg(test)] + pub(crate) fn from_set(protocols: &'a HashSet) -> Self { + ProtocolsRemoved { + protocols: Either::Left(protocols.difference(&EMPTY_HASHSET).peekable()), + } + } +} + +impl<'a> Iterator for ProtocolsAdded<'a> { + type Item = &'a StreamProtocol; + fn next(&mut self) -> Option { + self.protocols.next() + } +} + +impl<'a> Iterator for ProtocolsRemoved<'a> { + type Item = &'a StreamProtocol; + fn next(&mut self) -> Option { + self.protocols.next() + } +} + /// [`ConnectionEvent`] variant that informs the handler /// that upgrading an outbound substream to the given protocol has failed. pub struct DialUpgradeError { @@ -357,7 +488,7 @@ impl SubstreamProtocol { } /// Event produced by a handler. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ConnectionHandlerEvent { /// Request a new outbound substream to be opened with the remote. OutboundSubstreamRequest { @@ -374,11 +505,21 @@ pub enum ConnectionHandlerEvent), + /// The remote no longer supports these protocols. + Removed(HashSet), +} + /// Event produced by a handler. impl ConnectionHandlerEvent @@ -400,6 +541,9 @@ impl } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(val), ConnectionHandlerEvent::Close(val) => ConnectionHandlerEvent::Close(val), + ConnectionHandlerEvent::ReportRemoteProtocols(support) => { + ConnectionHandlerEvent::ReportRemoteProtocols(support) + } } } @@ -420,6 +564,9 @@ impl } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(val), ConnectionHandlerEvent::Close(val) => ConnectionHandlerEvent::Close(val), + ConnectionHandlerEvent::ReportRemoteProtocols(support) => { + ConnectionHandlerEvent::ReportRemoteProtocols(support) + } } } @@ -437,6 +584,9 @@ impl } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(map(val)), ConnectionHandlerEvent::Close(val) => ConnectionHandlerEvent::Close(val), + ConnectionHandlerEvent::ReportRemoteProtocols(support) => { + ConnectionHandlerEvent::ReportRemoteProtocols(support) + } } } @@ -454,6 +604,9 @@ impl } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(val), ConnectionHandlerEvent::Close(val) => ConnectionHandlerEvent::Close(map(val)), + ConnectionHandlerEvent::ReportRemoteProtocols(support) => { + ConnectionHandlerEvent::ReportRemoteProtocols(support) + } } } } @@ -558,3 +711,7 @@ impl Ord for KeepAlive { } } } + +/// A statically declared, empty [`HashSet`] allows us to work around borrow-checker rules for +/// [`ProtocolsAdded::from_set`]. The lifetimes don't work unless we have a [`HashSet`] with a `'static' lifetime. +static EMPTY_HASHSET: Lazy> = Lazy::new(HashSet::new); diff --git a/swarm/src/handler/either.rs b/swarm/src/handler/either.rs index a6b2152a721..08c914abd47 100644 --- a/swarm/src/handler/either.rs +++ b/swarm/src/handler/either.rs @@ -208,6 +208,22 @@ where handler.on_connection_event(ConnectionEvent::AddressChange(address_change)) } }, + ConnectionEvent::LocalProtocolsChange(supported_protocols) => match self { + Either::Left(handler) => handler.on_connection_event( + ConnectionEvent::LocalProtocolsChange(supported_protocols), + ), + Either::Right(handler) => handler.on_connection_event( + ConnectionEvent::LocalProtocolsChange(supported_protocols), + ), + }, + ConnectionEvent::RemoteProtocolsChange(supported_protocols) => match self { + Either::Left(handler) => handler.on_connection_event( + ConnectionEvent::RemoteProtocolsChange(supported_protocols), + ), + Either::Right(handler) => handler.on_connection_event( + ConnectionEvent::RemoteProtocolsChange(supported_protocols), + ), + }, } } } diff --git a/swarm/src/handler/map_out.rs b/swarm/src/handler/map_out.rs index 773df2b6681..349aa553764 100644 --- a/swarm/src/handler/map_out.rs +++ b/swarm/src/handler/map_out.rs @@ -81,6 +81,9 @@ where ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => { ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } } + ConnectionHandlerEvent::ReportRemoteProtocols(support) => { + ConnectionHandlerEvent::ReportRemoteProtocols(support) + } }) } diff --git a/swarm/src/handler/multi.rs b/swarm/src/handler/multi.rs index a934ebb6359..61c357b6597 100644 --- a/swarm/src/handler/multi.rs +++ b/swarm/src/handler/multi.rs @@ -205,6 +205,20 @@ where ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { self.on_listen_upgrade_error(listen_upgrade_error) } + ConnectionEvent::LocalProtocolsChange(supported_protocols) => { + for h in self.handlers.values_mut() { + h.on_connection_event(ConnectionEvent::LocalProtocolsChange( + supported_protocols.clone(), + )); + } + } + ConnectionEvent::RemoteProtocolsChange(supported_protocols) => { + for h in self.handlers.values_mut() { + h.on_connection_event(ConnectionEvent::RemoteProtocolsChange( + supported_protocols.clone(), + )); + } + } } } diff --git a/swarm/src/handler/one_shot.rs b/swarm/src/handler/one_shot.rs index 2ab45292f64..c4da5877c96 100644 --- a/swarm/src/handler/one_shot.rs +++ b/swarm/src/handler/one_shot.rs @@ -217,7 +217,10 @@ where self.keep_alive = KeepAlive::No; } } - ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} + ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/swarm/src/handler/pending.rs b/swarm/src/handler/pending.rs index a39e498c3f2..7cf8b9209fa 100644 --- a/swarm/src/handler/pending.rs +++ b/swarm/src/handler/pending.rs @@ -99,7 +99,9 @@ impl ConnectionHandler for PendingConnectionHandler { } ConnectionEvent::AddressChange(_) | ConnectionEvent::DialUpgradeError(_) - | ConnectionEvent::ListenUpgradeError(_) => {} + | ConnectionEvent::ListenUpgradeError(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/swarm/src/handler/select.rs b/swarm/src/handler/select.rs index d8d5639648e..204d11ce502 100644 --- a/swarm/src/handler/select.rs +++ b/swarm/src/handler/select.rs @@ -240,6 +240,9 @@ where .map_info(Either::Left), }); } + Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)) => { + return Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)); + } Poll::Pending => (), }; @@ -257,6 +260,9 @@ where .map_info(Either::Right), }); } + Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)) => { + return Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)); + } Poll::Pending => (), }; @@ -317,6 +323,26 @@ where ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { self.on_listen_upgrade_error(listen_upgrade_error) } + ConnectionEvent::LocalProtocolsChange(supported_protocols) => { + self.proto1 + .on_connection_event(ConnectionEvent::LocalProtocolsChange( + supported_protocols.clone(), + )); + self.proto2 + .on_connection_event(ConnectionEvent::LocalProtocolsChange( + supported_protocols, + )); + } + ConnectionEvent::RemoteProtocolsChange(supported_protocols) => { + self.proto1 + .on_connection_event(ConnectionEvent::RemoteProtocolsChange( + supported_protocols.clone(), + )); + self.proto2 + .on_connection_event(ConnectionEvent::RemoteProtocolsChange( + supported_protocols, + )); + } } } } diff --git a/swarm/src/keep_alive.rs b/swarm/src/keep_alive.rs index c22a926afe4..aa4da2db826 100644 --- a/swarm/src/keep_alive.rs +++ b/swarm/src/keep_alive.rs @@ -136,7 +136,9 @@ impl crate::handler::ConnectionHandler for ConnectionHandler { }) => void::unreachable(protocol), ConnectionEvent::DialUpgradeError(_) | ConnectionEvent::ListenUpgradeError(_) - | ConnectionEvent::AddressChange(_) => {} + | ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 6f439f06c03..a32beda411b 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -116,7 +116,7 @@ pub use behaviour::{ PollParameters, ToSwarm, }; pub use connection::pool::ConnectionCounters; -pub use connection::{ConnectionError, ConnectionId}; +pub use connection::{ConnectionError, ConnectionId, SupportedProtocols}; pub use executor::Executor; pub use handler::{ ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerSelect, KeepAlive, OneShotHandler,