From 6edb29049a7ef13cd39ba0289a6983c165d7bafa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Cig=C3=A1nek?= Date: Thu, 12 Nov 2020 17:03:40 +0100 Subject: [PATCH] feat!: implement connection pooling BREAKING CHANGE: - `Endpoint::connect_to` now returns pair or `(Connection, Option)` (previously it returned only `Connection`). - `Connection::open_bi_stream` renamed to `open_bi` (same as in quinn) - `Connection::send` renamed to `send_bi` for consistency - `Endpoint::listen` no longer returns `Result` --- examples/echo_service.rs | 6 +- src/api.rs | 2 +- src/connection_pool.rs | 112 +++++++++++++++++++++++++------------- src/connections.rs | 72 ++++++++++++++++++------ src/endpoint.rs | 57 ++++++++++++------- src/error.rs | 2 +- src/lib.rs | 1 + src/tests/common.rs | 36 +++++++----- src/tests/echo_service.rs | 6 +- 9 files changed, 195 insertions(+), 99 deletions(-) diff --git a/examples/echo_service.rs b/examples/echo_service.rs index 4b03b2c8..b6940907 100644 --- a/examples/echo_service.rs +++ b/examples/echo_service.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), Error> { println!("Process running at: {}", &socket_addr); if genesis { println!("Waiting for connections"); - let mut incoming = endpoint.listen()?; + let mut incoming = endpoint.listen(); let mut messages = incoming .next() .await @@ -63,8 +63,8 @@ async fn main() -> Result<(), Error> { } else { println!("Echo service complete"); let node_addr = bootstrap_nodes[0]; - let connection = endpoint.connect_to(&node_addr).await?; - let (mut send, mut recv) = connection.open_bi_stream().await?; + let (connection, _) = endpoint.connect_to(&node_addr).await?; + let (mut send, mut recv) = connection.open_bi().await?; loop { println!("Enter message:"); let mut input = String::new(); diff --git a/src/api.rs b/src/api.rs index 18bb30ea..71b3c56d 100644 --- a/src/api.rs +++ b/src/api.rs @@ -391,7 +391,7 @@ async fn new_connection_to( bootstrap_nodes, qp2p_config, )?; - let connection = endpoint.connect_to(node_addr).await?; + let (connection, _) = endpoint.connect_to(node_addr).await?; Ok((endpoint, connection)) } diff --git a/src/connection_pool.rs b/src/connection_pool.rs index 0e5b0834..31e79201 100644 --- a/src/connection_pool.rs +++ b/src/connection_pool.rs @@ -7,72 +7,108 @@ // specific language governing permissions and limitations relating to use of the SAFE Network // Software. -use slotmap::{DefaultKey, DenseSlotMap}; use std::{ - collections::{hash_map::Entry, HashMap}, + collections::BTreeMap, net::SocketAddr, - sync::{Arc, RwLock}, + sync::{Arc, Mutex, PoisonError}, }; +// Pool for keeping open connections. Pooled connections are associated with a `ConnectionRemover` +// which can be used to remove them from the pool. #[derive(Clone)] pub(crate) struct ConnectionPool { - store: Arc>, + store: Arc>, } impl ConnectionPool { pub fn new() -> Self { Self { - store: Arc::new(RwLock::new(Store { - connections: DenseSlotMap::new(), - keys: HashMap::new(), - })), + store: Arc::new(Mutex::new(Store::default())), } } - pub fn insert(&self, conn: quinn::Connection) -> Handle { - let addr = conn.remote_address(); + pub fn insert(&self, addr: SocketAddr, conn: quinn::Connection) -> ConnectionRemover { + let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); - let mut store = self.store.write().expect("RwLock poisoned"); - let key = store.connections.insert(conn); - let _ = store.keys.insert(addr, key); + let key = Key { + addr, + id: store.id_gen.next(), + }; + let _ = store.map.insert(key, conn); - Handle { + ConnectionRemover { store: self.store.clone(), key, } } - pub fn get(&self, addr: &SocketAddr) -> Option { - let store = self.store.read().ok()?; - let key = store.keys.get(addr)?; - store.connections.get(*key).cloned() + pub fn get(&self, addr: &SocketAddr) -> Option<(quinn::Connection, ConnectionRemover)> { + let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + + // Efficiently fetch the first entry whose key is equal to `key`. + let (key, conn) = store + .map + .range_mut(Key::min(*addr)..=Key::max(*addr)) + .next()?; + + let conn = conn.clone(); + let remover = ConnectionRemover { + store: self.store.clone(), + key: *key, + }; + + Some((conn, remover)) } } -pub(crate) struct Handle { - store: Arc>, - key: DefaultKey, +// Handle for removing a connection from the pool. +#[derive(Clone)] +pub(crate) struct ConnectionRemover { + store: Arc>, + key: Key, } -impl Drop for Handle { - fn drop(&mut self) { - let mut store = if let Ok(store) = self.store.write() { - store - } else { - return; - }; - - if let Some(conn) = store.connections.remove(self.key) { - if let Entry::Occupied(entry) = store.keys.entry(conn.remote_address()) { - if entry.get() == &self.key { - let _ = entry.remove(); - } - } - } +impl ConnectionRemover { + // Remove the connection from the pool. + pub fn remove(&self) { + let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + let _ = store.map.remove(&self.key); } } +#[derive(Default)] struct Store { - connections: DenseSlotMap, - keys: HashMap, + map: BTreeMap, + id_gen: IdGen, +} + +// Unique key identifying a connection. Two connections will always have distict keys even if they +// have the same socket address. +#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +struct Key { + addr: SocketAddr, + id: u64, +} + +impl Key { + // Returns the minimal `Key` for the given address according to its `Ord` relation. + fn min(addr: SocketAddr) -> Self { + Self { addr, id: u64::MIN } + } + + // Returns the maximal `Key` for the given address according to its `Ord` relation. + fn max(addr: SocketAddr) -> Self { + Self { addr, id: u64::MAX } + } +} + +#[derive(Default)] +struct IdGen(u64); + +impl IdGen { + fn next(&mut self) -> u64 { + let id = self.0; + self.0 = self.0.wrapping_add(1); + id + } } diff --git a/src/connections.rs b/src/connections.rs index 1ff5bc01..8c2e5a84 100644 --- a/src/connections.rs +++ b/src/connections.rs @@ -9,6 +9,7 @@ use super::{ api::Message, + connection_pool::{ConnectionPool, ConnectionRemover}, error::{Error, Result}, wire_msg::WireMsg, }; @@ -21,6 +22,7 @@ use tokio::select; /// Connection instance to a node which can be used to send messages to it pub struct Connection { quic_conn: quinn::Connection, + remover: ConnectionRemover, } impl Drop for Connection { @@ -30,8 +32,8 @@ impl Drop for Connection { } impl Connection { - pub(crate) async fn new(quic_conn: quinn::Connection) -> Result { - Ok(Self { quic_conn }) + pub(crate) fn new(quic_conn: quinn::Connection, remover: ConnectionRemover) -> Self { + Self { quic_conn, remover } } /// Returns the address of the connected peer. @@ -78,44 +80,62 @@ impl Connection { /// let peer1_addr = peer_1.socket_addr().await?; /// /// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?; - /// let (send_stream, recv_stream) = connection.open_bi_stream().await?; + /// let (send_stream, recv_stream) = connection.open_bi().await?; /// Ok(()) /// } /// ``` - pub async fn open_bi_stream(&self) -> Result<(SendStream, RecvStream)> { - let (send_stream, recv_stream) = self.quic_conn.open_bi().await?; + pub async fn open_bi(&self) -> Result<(SendStream, RecvStream)> { + let (send_stream, recv_stream) = self.handle_error(self.quic_conn.open_bi().await)?; Ok((SendStream::new(send_stream), RecvStream::new(recv_stream))) } /// Send message to the connected peer via a bi-directional stream. /// This returns the streams to send additional messages / read responses sent using the same stream. - pub async fn send(&self, msg: Bytes) -> Result<(SendStream, RecvStream)> { - let (mut send_stream, recv_stream) = self.open_bi_stream().await?; - send_stream.send_user_msg(msg).await?; + pub async fn send_bi(&self, msg: Bytes) -> Result<(SendStream, RecvStream)> { + let (mut send_stream, recv_stream) = self.open_bi().await?; + self.handle_error(send_stream.send_user_msg(msg).await)?; Ok((send_stream, recv_stream)) } /// Send message to peer using a uni-directional stream. pub async fn send_uni(&self, msg: Bytes) -> Result<()> { - let mut send_stream = self.quic_conn.open_uni().await?; - send_msg(&mut send_stream, msg).await?; - send_stream.finish().await.map_err(Error::from) + let mut send_stream = self.handle_error(self.quic_conn.open_uni().await)?; + self.handle_error(send_msg(&mut send_stream, msg).await)?; + self.handle_error(send_stream.finish().await) + .map_err(Error::from) } /// Gracefully close connection immediatelly pub fn close(&self) { self.quic_conn.close(0u32.into(), b""); + // TODO: uncomment + // self.remover.remove(); + } + + fn handle_error(&self, result: Result) -> Result { + if result.is_err() { + self.remover.remove() + } + + result } } /// Stream of incoming QUIC connections pub struct IncomingConnections { quinn_incoming: Arc>, + connection_pool: ConnectionPool, } impl IncomingConnections { - pub(crate) fn new(quinn_incoming: Arc>) -> Result { - Ok(Self { quinn_incoming }) + pub(crate) fn new( + quinn_incoming: Arc>, + connection_pool: ConnectionPool, + ) -> Self { + Self { + quinn_incoming, + connection_pool, + } } /// Returns next QUIC connection established by a peer @@ -127,11 +147,18 @@ impl IncomingConnections { uni_streams, bi_streams, .. - }) => Some(IncomingMessages::new( - connection.remote_address(), - uni_streams, - bi_streams, - )), + }) => { + let pool_handle = self + .connection_pool + .insert(connection.remote_address(), connection.clone()); + + Some(IncomingMessages::new( + connection.remote_address(), + uni_streams, + bi_streams, + pool_handle, + )) + } Err(_err) => None, }, None => None, @@ -144,6 +171,7 @@ pub struct IncomingMessages { peer_addr: SocketAddr, uni_streams: quinn::IncomingUniStreams, bi_streams: quinn::IncomingBiStreams, + remover: ConnectionRemover, } impl IncomingMessages { @@ -151,11 +179,13 @@ impl IncomingMessages { peer_addr: SocketAddr, uni_streams: quinn::IncomingUniStreams, bi_streams: quinn::IncomingBiStreams, + remover: ConnectionRemover, ) -> Self { Self { peer_addr, uni_streams, bi_streams, + remover, } } @@ -250,6 +280,12 @@ impl IncomingMessages { } } +impl Drop for IncomingMessages { + fn drop(&mut self) { + self.remover.remove() + } +} + /// Stream to receive multiple messages pub struct RecvStream { pub(crate) quinn_recv_stream: quinn::RecvStream, diff --git a/src/endpoint.rs b/src/endpoint.rs index 15a594da..645fd68e 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -11,7 +11,8 @@ use super::error::Error; use super::igd::forward_port; use super::wire_msg::WireMsg; use super::{ - connections::{Connection, IncomingConnections}, + connection_pool::ConnectionPool, + connections::{Connection, IncomingConnections, IncomingMessages}, error::Result, Config, }; @@ -34,6 +35,7 @@ pub struct Endpoint { client_cfg: quinn::ClientConfig, bootstrap_nodes: Vec, qp2p_config: Config, + connection_pool: ConnectionPool, } impl std::fmt::Debug for Endpoint { @@ -63,18 +65,19 @@ impl Endpoint { client_cfg, bootstrap_nodes, qp2p_config, + connection_pool: ConnectionPool::new(), }) } /// Endpoint local address - async fn local_addr(&self) -> Result { - Ok(self.local_addr) + fn local_addr(&self) -> SocketAddr { + self.local_addr } /// Returns the socket address of the endpoint pub async fn socket_addr(&self) -> Result { if cfg!(test) || !self.qp2p_config.forward_port { - self.local_addr().await + Ok(self.local_addr()) } else { self.public_addr().await } @@ -151,31 +154,43 @@ impl Endpoint { } } - /// Connect to another peer - pub async fn connect_to(&self, node_addr: &SocketAddr) -> Result { - let quinn_connecting = self.quic_endpoint.connect_with( - self.client_cfg.clone(), - &node_addr, - CERT_SERVER_NAME, - )?; + /// Connects to another peer + /// + pub async fn connect_to( + &self, + node_addr: &SocketAddr, + ) -> Result<(Connection, Option)> { + if let Some((conn, guard)) = self.connection_pool.get(node_addr) { + trace!("Using cached connection to peer: {}", node_addr); + return Ok((Connection::new(conn, guard), None)); + } - let quinn::NewConnection { - connection: quinn_conn, - .. - } = quinn_connecting.await?; + let new_conn = self + .quic_endpoint + .connect_with(self.client_cfg.clone(), node_addr, CERT_SERVER_NAME)? + .await?; trace!("Successfully connected to peer: {}", node_addr); - Connection::new(quinn_conn).await + let guard = self.connection_pool.insert( + new_conn.connection.remote_address(), + new_conn.connection.clone(), + ); + + let conn = Connection::new(new_conn.connection, guard.clone()); + let incoming_msgs = + IncomingMessages::new(*node_addr, new_conn.uni_streams, new_conn.bi_streams, guard); + + Ok((conn, Some(incoming_msgs))) } /// Obtain stream of incoming QUIC connections - pub fn listen(&self) -> Result { + pub fn listen(&self) -> IncomingConnections { trace!( "Incoming connections will be received at {}", - self.quic_endpoint.local_addr()? + self.local_addr() ); - IncomingConnections::new(Arc::clone(&self.quic_incoming)) + IncomingConnections::new(self.quic_incoming.clone(), self.connection_pool.clone()) } // Private helper @@ -188,9 +203,9 @@ impl Endpoint { let mut tasks = Vec::default(); for node in self.bootstrap_nodes.iter().cloned() { debug!("Connecting to {:?}", &node); - let connection = self.connect_to(&node).await?; // TODO: move into loop + let (connection, _) = self.connect_to(&node).await?; // TODO: move into loop let task_handle = tokio::spawn(async move { - let (mut send_stream, mut recv_stream) = connection.open_bi_stream().await?; + let (mut send_stream, mut recv_stream) = connection.open_bi().await?; send_stream.send(WireMsg::EndpointEchoReq).await?; match WireMsg::read_from_stream(&mut recv_stream.quinn_recv_stream).await { Ok(WireMsg::EndpointEchoResp(socket_addr)) => Ok(socket_addr), diff --git a/src/error.rs b/src/error.rs index beb8778b..7147cc70 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,7 +12,7 @@ use std::net::SocketAddr; use std::{io, sync::mpsc}; /// Result used by `QuicP2p`. -pub type Result = std::result::Result; +pub type Result = std::result::Result; #[derive(Debug, Error)] #[allow(missing_docs)] diff --git a/src/lib.rs b/src/lib.rs index 3e137ff1..04e7993b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,7 @@ mod api; mod bootstrap_cache; mod config; +mod connection_pool; mod connections; mod endpoint; mod error; diff --git a/src/tests/common.rs b/src/tests/common.rs index fc2b60ff..29156b14 100644 --- a/src/tests/common.rs +++ b/src/tests/common.rs @@ -39,7 +39,7 @@ async fn successful_connection() -> Result<()> { let peer2 = qp2p.new_endpoint()?; let _connection = peer2.connect_to(&peer1_addr).await?; - let mut incoming_conn = peer1.listen()?; + let mut incoming_conn = peer1.listen(); let incoming_messages = incoming_conn .next() .await @@ -57,14 +57,14 @@ async fn bi_directional_streams() -> Result<()> { let peer1_addr = peer1.socket_addr().await?; let peer2 = qp2p.new_endpoint()?; - let connection = peer2.connect_to(&peer1_addr).await?; + let (connection, _) = peer2.connect_to(&peer1_addr).await?; let msg = random_msg(); // Peer 2 sends a message and gets the bi-directional streams - let (mut send_stream2, mut recv_stream2) = connection.send(msg.clone()).await?; + let (mut send_stream2, mut recv_stream2) = connection.send_bi(msg.clone()).await?; // Peer 1 gets an incoming connection - let mut incoming_conn = peer1.listen()?; + let mut incoming_conn = peer1.listen(); let mut incoming_messages = incoming_conn .next() .await @@ -111,15 +111,14 @@ async fn uni_directional_streams() -> Result<()> { let qp2p = new_qp2p(); let peer1 = qp2p.new_endpoint()?; let peer1_addr = peer1.socket_addr().await?; - let mut incoming_conn_peer1 = peer1.listen()?; + let mut incoming_conn_peer1 = peer1.listen(); let peer2 = qp2p.new_endpoint()?; let peer2_addr = peer2.socket_addr().await?; - let mut incoming_conn_peer2 = peer2.listen()?; + let mut incoming_conn_peer2 = peer2.listen(); // Peer 2 sends a message - let conn_to_peer1 = peer2.connect_to(&peer1_addr).await?; - + let (conn_to_peer1, _) = peer2.connect_to(&peer1_addr).await?; let msg_from_peer2 = random_msg(); conn_to_peer1.send_uni(msg_from_peer2.clone()).await?; drop(conn_to_peer1); @@ -140,12 +139,18 @@ async fn uni_directional_streams() -> Result<()> { assert_eq!(src, peer2_addr); src } else { - return Err(Error::UnexpectedMessageType); + panic!("Expected a unidirectional stream") }; + // Peer 2 dropped the connection to peer 1 after sending the message, so the incoming message + // stream gets closed. Drop the stream which also removes the connection from the connection + // pool. + assert!(incoming_messages.next().await.is_none()); + drop(incoming_messages); + // Peer 1 sends back a message to Peer 2 on a new uni-directional stream + let (conn_to_peer2, _) = peer1.connect_to(&src).await?; let msg_from_peer1 = random_msg(); - let conn_to_peer2 = peer1.connect_to(&src).await?; conn_to_peer2.send_uni(msg_from_peer1.clone()).await?; drop(conn_to_peer2); @@ -163,10 +168,13 @@ async fn uni_directional_streams() -> Result<()> { if let Message::UniStream { bytes, src, .. } = message { assert_eq!(msg_from_peer1, bytes); assert_eq!(src, peer1_addr); - Ok(()) } else { - Err(Error::Unexpected( - "Expected a Unidirectional stream".to_string(), - )) + panic!("Expected a unidirectional stream") } + + // Peer 1 dropped the connection to peer 2 after sending the message, so the incoming message + // stream gets closed. + assert!(incoming_messages.next().await.is_none()); + + Ok(()) } diff --git a/src/tests/echo_service.rs b/src/tests/echo_service.rs index 7f6f39f3..bce642ae 100644 --- a/src/tests/echo_service.rs +++ b/src/tests/echo_service.rs @@ -19,7 +19,7 @@ async fn echo_service() -> Result<()> { // Listen for messages / connections at peer 1 let handle1 = tokio::spawn(async move { - let mut incoming = peer1.listen()?; + let mut incoming = peer1.listen(); let mut inbound_messages = incoming .next() .await @@ -32,8 +32,8 @@ async fn echo_service() -> Result<()> { let handle2 = tokio::spawn(async move { let peer2 = qp2p.new_endpoint()?; let socket_addr = peer2.socket_addr().await?; - let connection = peer2.connect_to(&peer1_addr).await?; - let (mut send_stream, mut recv_stream) = connection.open_bi_stream().await?; + let (connection, _) = peer2.connect_to(&peer1_addr).await?; + let (mut send_stream, mut recv_stream) = connection.open_bi().await?; let echo_service_req = WireMsg::EndpointEchoReq; echo_service_req .write_to_stream(&mut send_stream.quinn_send_stream)