Skip to content

Commit

Permalink
feat!: return Connection from Endpoint::connect_*
Browse files Browse the repository at this point in the history
This is a small step towards removing the `ConnectionPool`. The
`ConnectionPool` is used for a number of things, one of which is
enabling the `Endpoint` APIs that operate exclusively on `SocketAddr`s,
rather than connection handles.

This commit exposes the `Connection<I>` type, and updates the
`Endpoint::new`, `Endpoint::connect_to`, and `Endpoint::connect_to_any`
methods to return `Connection<I>` rather than `SocketAddr`. For now, the
only public method on `Connection<I>` is `remote_address`, which callers
can then use with the existing (unchanged) `SocketAddr`-based `Endpoint`
APIs. Future commits will introduce additional `Connection` methods to
replace these `SocketAddr`-based APIs.

BREAKING CHANGE: `Endpoint::new`, `Endpoint::connect_to`, and
`Endpoint::connect_to_any` now use `Connection<I>` instead of
`SocketAddr` in their return type.
  • Loading branch information
Chris Connelly authored and connec committed Sep 21, 2021
1 parent a988b66 commit 9ce6947
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 51 deletions.
22 changes: 16 additions & 6 deletions src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ use tokio::sync::mpsc::{Receiver, Sender};
use tokio::time::{timeout, Duration};
use tracing::{trace, warn};

/// Connection instance to a node which can be used to send messages to it
/// A connection between two [`Endpoint`]s.
///
/// This is backed by an `Arc` and a small amount of metadata, so cloning is fairly cheap. The
/// connection is also pooled, meaning the same underlying connection will be used when connecting
/// multiple times to the same peer. If an error occurs on the connection, it will be removed from
/// the pool. See the documentation of [`Endpoint::connect_to`] for more details about connection
/// pooling.
///
/// [`Endpoint`]: crate::Endpoint
/// [`Endpoint::connect_to`]: crate::Endpoint::connect_to
#[derive(Clone)]
pub(crate) struct Connection<I: ConnId> {
pub struct Connection<I: ConnId> {
quic_conn: quinn::Connection,
remover: ConnectionRemover<I>,
}
Expand All @@ -45,6 +54,11 @@ impl<I: ConnId> Connection<I> {
Self { quic_conn, remover }
}

/// Get the address of the connected peer.
pub fn remote_address(&self) -> SocketAddr {
self.quic_conn.remote_address()
}

/// Priority default is 0. Both lower and higher can be passed in.
pub(crate) async fn open_bi(
&self,
Expand Down Expand Up @@ -81,10 +95,6 @@ impl<I: ConnId> Connection<I> {
}
}

pub(crate) fn remote_address(&self) -> SocketAddr {
self.quic_conn.remote_address()
}

async fn handle_error<T, E>(&self, result: Result<T, E>) -> Result<T, E> {
if result.is_err() {
self.remover.remove().await
Expand Down
52 changes: 29 additions & 23 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl<I: ConnId> Endpoint<I> {
IncomingConnections,
IncomingMessages,
DisconnectionEvents,
Option<SocketAddr>,
Option<Connection<I>>,
),
EndpointError,
> {
Expand All @@ -140,7 +140,7 @@ impl<I: ConnId> Endpoint<I> {
Self::build_endpoint(local_addr.into(), config, builder)?;

let contact = endpoint.connect_to_any(contacts).await;
let public_addr = endpoint.resolve_public_addr(contact).await?;
let public_addr = endpoint.resolve_public_addr(contact.as_ref()).await?;

#[cfg(feature = "igd")]
if endpoint.config.forward_port {
Expand All @@ -166,11 +166,14 @@ impl<I: ConnId> Endpoint<I> {
endpoint.clone(),
);

if let Some(peer) = contact {
if let Some(contact) = contact.as_ref() {
let valid = endpoint
.endpoint_verification(peer, public_addr)
.endpoint_verification(contact, public_addr)
.await
.map_err(|error| EndpointError::EndpointVerification { peer, error })?;
.map_err(|error| EndpointError::EndpointVerification {
peer: contact.remote_address(),
error,
})?;
if !valid {
return Err(EndpointError::Unreachable { public_addr });
}
Expand Down Expand Up @@ -296,9 +299,11 @@ impl<I: ConnId> Endpoint<I> {
/// When sending a message on `Connection` fails, the connection is also automatically removed
/// from the pool and the subsequent call to `connect_to` is guaranteed to reopen new connection
/// too.
pub async fn connect_to(&self, node_addr: &SocketAddr) -> Result<(), ConnectionError> {
let _ = self.get_or_connect_to(node_addr).await?;
Ok(())
pub async fn connect_to(
&self,
node_addr: &SocketAddr,
) -> Result<Connection<I>, ConnectionError> {
self.get_or_connect_to(node_addr).await
}

/// Connect to any of the given peers.
Expand All @@ -311,7 +316,7 @@ impl<I: ConnId> Endpoint<I> {
///
/// The successful connection, if any, will be stored in the connection pool (see
/// [`connect_to`](Self::connect_to) for more info on connection pooling).
pub async fn connect_to_any(&self, peer_addrs: &[SocketAddr]) -> Option<SocketAddr> {
pub async fn connect_to_any(&self, peer_addrs: &[SocketAddr]) -> Option<Connection<I>> {
trace!("Connecting to any of {:?}", peer_addrs);
if peer_addrs.is_empty() {
return None;
Expand All @@ -323,7 +328,7 @@ impl<I: ConnId> Endpoint<I> {
.map(|addr| Box::pin(self.get_or_connect_to(addr)));

match futures::future::select_ok(tasks).await {
Ok((connection, _)) => Some(connection.remote_address()),
Ok((connection, _)) => Some(connection),
Err(error) => {
error!("Failed to bootstrap to the network, last error: {}", error);
None
Expand Down Expand Up @@ -373,9 +378,9 @@ impl<I: ConnId> Endpoint<I> {
}

/// Get the SocketAddr of a connection using the connection ID
pub async fn get_socket_addr_by_id(&self, addr: &I) -> Option<SocketAddr> {
pub async fn get_socket_addr_by_id(&self, id: &I) -> Option<SocketAddr> {
self.connection_pool
.get_by_id(addr)
.get_by_id(id)
.await
.map(|(_, remover)| *remover.remote_addr())
}
Expand Down Expand Up @@ -572,17 +577,18 @@ impl<I: ConnId> Endpoint<I> {
// set an appropriate public address based on `config` and a reachability check.
async fn resolve_public_addr(
&mut self,
contact: Option<SocketAddr>,
contact: Option<&Connection<I>>,
) -> Result<SocketAddr, EndpointError> {
let mut public_addr = self.local_addr;

// get the IP seen for us by our contact
let visible_addr = if let Some(peer) = contact {
Some(
self.endpoint_echo(peer)
.await
.map_err(|error| EndpointError::EndpointEcho { peer, error })?,
)
let visible_addr = if let Some(contact) = contact {
Some(self.endpoint_echo(contact).await.map_err(|error| {
EndpointError::EndpointEcho {
peer: contact.remote_address(),
error,
}
})?)
} else {
None
};
Expand Down Expand Up @@ -644,8 +650,8 @@ impl<I: ConnId> Endpoint<I> {
}

/// Perform the endpoint echo RPC with the given peer.
async fn endpoint_echo(&self, peer: SocketAddr) -> Result<SocketAddr, RpcError> {
let (mut send, mut recv) = self.open_bidirectional_stream(&peer, 0).await?;
async fn endpoint_echo(&self, contact: &Connection<I>) -> Result<SocketAddr, RpcError> {
let (mut send, mut recv) = contact.open_bi(0).await?;

send.send(WireMsg::EndpointEchoReq).await?;

Expand All @@ -658,10 +664,10 @@ impl<I: ConnId> Endpoint<I> {
/// Perform the endpoint verification RPC with the given peer.
async fn endpoint_verification(
&self,
peer: SocketAddr,
contact: &Connection<I>,
public_addr: SocketAddr,
) -> Result<bool, RpcError> {
let (mut send, mut recv) = self.open_bidirectional_stream(&peer, 0).await?;
let (mut send, mut recv) = contact.open_bi(0).await?;

send.send(WireMsg::EndpointVerificationReq(public_addr))
.await?;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod wire_msg;

pub use config::{Config, ConfigError, RetryConfig};
pub use connection_pool::ConnId;
pub use connections::{DisconnectionEvents, RecvStream, SendStream};
pub use connections::{Connection, DisconnectionEvents, RecvStream, SendStream};
pub use endpoint::{Endpoint, IncomingConnections, IncomingMessages};
#[cfg(feature = "igd")]
pub use error::UpnpError;
Expand Down
37 changes: 19 additions & 18 deletions src/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn successful_connection() -> Result<()> {
let peer1_addr = peer1.public_addr();

let (peer2, _, _, _, _) = new_endpoint().await?;
peer2.connect_to(&peer1_addr).await?;
peer2.connect_to(&peer1_addr).await.map(drop)?;
let peer2_addr = peer2.public_addr();

if let Some(connecting_peer) = peer1_incoming_connections.next().await {
Expand All @@ -44,7 +44,7 @@ async fn single_message() -> Result<()> {
let peer2_addr = peer2.public_addr();

// Peer 2 connects and sends a message
peer2.connect_to(&peer1_addr).await?;
peer2.connect_to(&peer1_addr).await.map(drop)?;
let msg_from_peer2 = random_msg(1024);
peer2
.send_message(msg_from_peer2.clone(), &peer1_addr, 0)
Expand Down Expand Up @@ -77,7 +77,7 @@ async fn reuse_outgoing_connection() -> Result<()> {
let bob_addr = bob.public_addr();

// Connect for the first time and send a message.
alice.connect_to(&bob_addr).await?;
alice.connect_to(&bob_addr).await.map(drop)?;
let msg0 = random_msg(1024);
alice.send_message(msg0.clone(), &bob_addr, 0).await?;

Expand All @@ -96,7 +96,7 @@ async fn reuse_outgoing_connection() -> Result<()> {
}

// Try connecting again and send a message
alice.connect_to(&bob_addr).await?;
alice.connect_to(&bob_addr).await.map(drop)?;
let msg1 = random_msg(1024);
alice.send_message(msg1.clone(), &bob_addr, 0).await?;

Expand Down Expand Up @@ -127,7 +127,7 @@ async fn reuse_incoming_connection() -> Result<()> {
let bob_addr = bob.public_addr();

// Connect for the first time and send a message.
alice.connect_to(&bob_addr).await?;
alice.connect_to(&bob_addr).await.map(drop)?;
let msg0 = random_msg(1024);
alice.send_message(msg0.clone(), &bob_addr, 0).await?;

Expand All @@ -146,7 +146,7 @@ async fn reuse_incoming_connection() -> Result<()> {
}

// Bob tries to connect to alice and sends a message
bob.connect_to(&alice_addr).await?;
bob.connect_to(&alice_addr).await.map(drop)?;
let msg1 = random_msg(1024);
bob.send_message(msg1.clone(), &alice_addr, 0).await?;

Expand Down Expand Up @@ -177,7 +177,7 @@ async fn disconnection() -> Result<()> {
let bob_addr = bob.public_addr();

// Alice connects to Bob who should receive an incoming connection.
alice.connect_to(&bob_addr).await?;
alice.connect_to(&bob_addr).await.map(drop)?;

if let Some(connecting_peer) = bob_incoming_connections.next().await {
assert_eq!(connecting_peer, alice_addr);
Expand All @@ -201,7 +201,7 @@ async fn disconnection() -> Result<()> {
}

// This time bob connects to Alice. Since this is a *new connection*, Alice should get the connection event
bob.connect_to(&alice_addr).await?;
bob.connect_to(&alice_addr).await.map(drop)?;

if let Some(connected_peer) = alice_incoming_connections.next().await {
assert_eq!(connected_peer, bob_addr);
Expand Down Expand Up @@ -231,7 +231,9 @@ async fn simultaneous_incoming_and_outgoing_connections() -> Result<()> {
new_endpoint().await?;
let bob_addr = bob.public_addr();

future::try_join(alice.connect_to(&bob_addr), bob.connect_to(&alice_addr)).await?;
future::try_join(alice.connect_to(&bob_addr), bob.connect_to(&alice_addr))
.await
.map(drop)?;

if let Some(connecting_peer) = alice_incoming_connections.next().await {
assert_eq!(connecting_peer, bob_addr);
Expand Down Expand Up @@ -277,7 +279,7 @@ async fn simultaneous_incoming_and_outgoing_connections() -> Result<()> {

// Bob connects to Alice again. This does not open a new connection but returns the connection
// previously initiated by Alice from the pool.
bob.connect_to(&alice_addr).await?;
bob.connect_to(&alice_addr).await.map(drop)?;

if let Ok(Some(connecting_peer)) =
timeout(Duration::from_secs(2), alice_incoming_connections.next()).await
Expand Down Expand Up @@ -306,8 +308,7 @@ async fn multiple_concurrent_connects_to_the_same_peer() -> Result<()> {
let bob_addr = bob.public_addr();

// Try to establish two connections to the same peer at the same time.
let ((), ()) =
future::try_join(bob.connect_to(&alice_addr), bob.connect_to(&alice_addr)).await?;
let (_, _) = future::try_join(bob.connect_to(&alice_addr), bob.connect_to(&alice_addr)).await?;

// Alice get only one incoming connection
if let Some(connecting_peer) = alice_incoming_connections.next().await {
Expand Down Expand Up @@ -383,7 +384,7 @@ async fn multiple_connections_with_many_concurrent_messages() -> Result<()> {
let _ = hash(&msg);
}
// Send the hash result back.
sending_endpoint.connect_to(&src).await?;
sending_endpoint.connect_to(&src).await.map(drop)?;
sending_endpoint
.send_message(hash_result.to_vec().into(), &src, 0)
.await?;
Expand Down Expand Up @@ -412,7 +413,7 @@ async fn multiple_connections_with_many_concurrent_messages() -> Result<()> {

async move {
let mut hash_results = BTreeSet::new();
send_endpoint.connect_to(&server_addr).await?;
send_endpoint.connect_to(&server_addr).await.map(drop)?;
for (index, message) in messages.iter().enumerate().take(num_messages_each) {
let _ = hash_results.insert(hash(message));
info!("sender #{} sending message #{}", id, index);
Expand Down Expand Up @@ -519,7 +520,7 @@ async fn multiple_connections_with_many_larger_concurrent_messages() -> Result<(
async move {
let mut hash_results = BTreeSet::new();

send_endpoint.connect_to(&server_addr).await?;
send_endpoint.connect_to(&server_addr).await.map(drop)?;
for (index, message) in messages.iter().enumerate().take(num_messages_each) {
let _ = hash_results.insert(hash(message));

Expand Down Expand Up @@ -589,7 +590,7 @@ async fn many_messages() -> Result<()> {
async move {
info!("sending {}", id);
let msg = id.to_le_bytes().to_vec().into();
endpoint.connect_to(&recv_addr).await?;
endpoint.connect_to(&recv_addr).await.map(drop)?;
endpoint.send_message(msg, &recv_addr, 0).await?;
info!("sent {}", id);

Expand Down Expand Up @@ -650,8 +651,8 @@ async fn connection_attempts_to_bootstrap_contacts_should_succeed() -> Result<()
bootstrapped_peer.ok_or_else(|| eyre!("Failed to connecto to any contact"))?;

for peer in contacts {
if peer != bootstrapped_peer {
ep.connect_to(&peer).await?;
if peer != bootstrapped_peer.remote_address() {
ep.connect_to(&peer).await.map(drop)?;
}
}
Ok(())
Expand Down
6 changes: 3 additions & 3 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
// Software.

use crate::{
Config, ConnId, DisconnectionEvents, Endpoint, IncomingConnections, IncomingMessages,
RetryConfig,
Config, ConnId, Connection, DisconnectionEvents, Endpoint, IncomingConnections,
IncomingMessages, RetryConfig,
};
use bytes::Bytes;
use color_eyre::eyre::Result;
Expand Down Expand Up @@ -42,7 +42,7 @@ pub(crate) async fn new_endpoint() -> Result<(
IncomingConnections,
IncomingMessages,
DisconnectionEvents,
Option<SocketAddr>,
Option<Connection<[u8; 32]>>,
)> {
Ok(Endpoint::new(
local_addr(),
Expand Down

0 comments on commit 9ce6947

Please sign in to comment.