diff --git a/src/api.rs b/src/api.rs index 76f92313..d05eea61 100644 --- a/src/api.rs +++ b/src/api.rs @@ -268,7 +268,9 @@ impl QuicP2p { })? .0; let bootstrapped_peer = successful_connection.connection.remote_address(); - endpoint.add_new_connection_to_pool(successful_connection); + endpoint + .add_new_connection_to_pool(successful_connection) + .await; Ok(bootstrapped_peer) } diff --git a/src/connection_pool.rs b/src/connection_pool.rs index 5b0e1444..333018be 100644 --- a/src/connection_pool.rs +++ b/src/connection_pool.rs @@ -7,28 +7,26 @@ // specific language governing permissions and limitations relating to use of the SAFE Network // Software. -use std::{ - collections::BTreeMap, - net::SocketAddr, - sync::{Arc, Mutex, PoisonError}, -}; +use std::{collections::BTreeMap, net::SocketAddr, sync::Arc}; + +use tokio::sync::RwLock; // Pool for keeping open connections. Pooled connections are associated with a `ConnectionRemover` // which can be used to remove them from the pool. #[derive(Clone)] pub(crate) struct ConnectionPool { - store: Arc>, + store: Arc>, } impl ConnectionPool { pub fn new() -> Self { Self { - store: Arc::new(Mutex::new(Store::default())), + store: Arc::new(RwLock::new(Store::default())), } } - pub fn insert(&self, addr: SocketAddr, conn: quinn::Connection) -> ConnectionRemover { - let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + pub async fn insert(&self, addr: SocketAddr, conn: quinn::Connection) -> ConnectionRemover { + let mut store = self.store.write().await; let key = Key { addr, @@ -42,19 +40,19 @@ impl ConnectionPool { } } - pub fn has(&self, addr: &SocketAddr) -> bool { - let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + pub async fn has(&self, addr: &SocketAddr) -> bool { + let store = self.store.read().await; // Efficiently fetch the first entry whose key is equal to `key` and check if it exists store .map - .range_mut(Key::min(*addr)..=Key::max(*addr)) + .range(Key::min(*addr)..=Key::max(*addr)) .next() .is_some() } - pub fn remove(&self, addr: &SocketAddr) -> Vec { - let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + pub async fn remove(&self, addr: &SocketAddr) -> Vec { + let mut store = self.store.write().await; let keys_to_remove = store .map @@ -69,14 +67,11 @@ impl ConnectionPool { .collect::>() } - pub fn get(&self, addr: &SocketAddr) -> Option<(quinn::Connection, ConnectionRemover)> { - let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + pub async fn get(&self, addr: &SocketAddr) -> Option<(quinn::Connection, ConnectionRemover)> { + let store = self.store.read().await; // Efficiently fetch the first entry whose key is equal to `key`. - let (key, conn) = store - .map - .range_mut(Key::min(*addr)..=Key::max(*addr)) - .next()?; + let (key, conn) = store.map.range(Key::min(*addr)..=Key::max(*addr)).next()?; let conn = conn.clone(); let remover = ConnectionRemover { @@ -91,14 +86,14 @@ impl ConnectionPool { // Handle for removing a connection from the pool. #[derive(Clone)] pub(crate) struct ConnectionRemover { - store: Arc>, + store: Arc>, key: Key, } impl ConnectionRemover { // Remove the connection from the pool. - pub fn remove(&self) { - let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner); + pub async fn remove(&self) { + let mut store = self.store.write().await; let _ = store.map.remove(&self.key); } diff --git a/src/connections.rs b/src/connections.rs index d3138ab4..fd2b635b 100644 --- a/src/connections.rs +++ b/src/connections.rs @@ -34,14 +34,15 @@ impl Connection { } pub async fn open_bi(&self) -> Result<(SendStream, RecvStream)> { - let (send_stream, recv_stream) = self.handle_error(self.quic_conn.open_bi().await)?; + let (send_stream, recv_stream) = self.handle_error(self.quic_conn.open_bi().await).await?; Ok((SendStream::new(send_stream), RecvStream::new(recv_stream))) } /// Send message to peer using a uni-directional stream. pub async fn send_uni(&self, msg: Bytes) -> Result<()> { - let mut send_stream = self.handle_error(self.quic_conn.open_uni().await)?; - self.handle_error(send_msg(&mut send_stream, msg).await)?; + let mut send_stream = self.handle_error(self.quic_conn.open_uni().await).await?; + self.handle_error(send_msg(&mut send_stream, msg).await) + .await?; // We try to make sure the stream is gracefully closed and the bytes get sent, // but if it was already closed (perhaps by the peer) then we @@ -49,15 +50,15 @@ impl Connection { match send_stream.finish().await { Ok(()) | Err(quinn::WriteError::Stopped(_)) => Ok(()), Err(err) => { - self.handle_error(Err(err))?; + self.handle_error(Err(err)).await?; Ok(()) } } } - fn handle_error(&self, result: Result) -> Result { + async fn handle_error(&self, result: Result) -> Result { if result.is_err() { - self.remover.remove() + self.remover.remove().await } result @@ -154,7 +155,7 @@ pub(super) fn listen_for_incoming_connections( .. }) => { let peer_address = connection.remote_address(); - let pool_handle = connection_pool.insert(peer_address, connection); + let pool_handle = connection_pool.insert(peer_address, connection).await; let _ = connection_tx.send(peer_address); listen_for_incoming_messages( uni_streams, @@ -197,7 +198,7 @@ pub(super) fn listen_for_incoming_messages( log::trace!("The connection to {:?} has been terminated.", src); let _ = disconnection_tx.send(src); - remover.remove(); + remover.remove().await; }); } diff --git a/src/endpoint.rs b/src/endpoint.rs index 84405bf2..979a7cbf 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -141,6 +141,7 @@ impl Endpoint { endpoint.connect_to(contact).await?; let connection = endpoint .get_connection(&contact) + .await .ok_or(Error::MissingConnection)?; let (mut send, mut recv) = connection.open_bi().await?; send.send(WireMsg::EndpointVerificationReq(addr)).await?; @@ -283,9 +284,10 @@ impl Endpoint { } /// Removes all existing connections to a given peer - pub fn disconnect_from(&self, peer_addr: &SocketAddr) -> Result<()> { + pub async fn disconnect_from(&self, peer_addr: &SocketAddr) -> Result<()> { self.connection_pool .remove(peer_addr) + .await .iter() .for_each(|conn| { conn.close(0u8.into(), b""); @@ -331,7 +333,7 @@ impl Endpoint { /// from the pool and the subsequent call to `connect_to` is guaranteed to reopen new connection /// too. pub async fn connect_to(&self, node_addr: &SocketAddr) -> Result<()> { - if self.connection_pool.has(node_addr) { + if self.connection_pool.has(node_addr).await { trace!("We are already connected to this peer: {}", node_addr); } @@ -369,7 +371,7 @@ impl Endpoint { trace!("Successfully connected to peer: {}", node_addr); - self.add_new_connection_to_pool(new_conn); + self.add_new_connection_to_pool(new_conn).await; self.connection_deduplicator .complete(node_addr, Ok(())) @@ -426,10 +428,11 @@ impl Endpoint { Ok(new_connection) } - pub(crate) fn add_new_connection_to_pool(&self, conn: quinn::NewConnection) { + pub(crate) async fn add_new_connection_to_pool(&self, conn: quinn::NewConnection) { let guard = self .connection_pool - .insert(conn.connection.remote_address(), conn.connection); + .insert(conn.connection.remote_address(), conn.connection) + .await; listen_for_incoming_messages( conn.uni_streams, @@ -442,8 +445,8 @@ impl Endpoint { } /// Get an existing connection for the peer address. - pub(crate) fn get_connection(&self, peer_addr: &SocketAddr) -> Option { - if let Some((conn, guard)) = self.connection_pool.get(peer_addr) { + pub(crate) async fn get_connection(&self, peer_addr: &SocketAddr) -> Option { + if let Some((conn, guard)) = self.connection_pool.get(peer_addr).await { trace!("Connection exists in the connection pool: {}", peer_addr); Some(Connection::new(conn, guard)) } else { @@ -459,6 +462,7 @@ impl Endpoint { self.connect_to(peer_addr).await?; let connection = self .get_connection(peer_addr) + .await .ok_or(Error::MissingConnection)?; connection.open_bi().await } @@ -466,7 +470,10 @@ impl Endpoint { /// Sends a message to a peer. This will attempt to use an existing connection /// to the destination peer. If a connection does not exist, this will fail with `Error::MissingConnection` pub async fn try_send_message(&self, msg: Bytes, dest: &SocketAddr) -> Result<()> { - let connection = self.get_connection(dest).ok_or(Error::MissingConnection)?; + let connection = self + .get_connection(dest) + .await + .ok_or(Error::MissingConnection)?; connection.send_uni(msg).await?; Ok(()) } @@ -502,6 +509,7 @@ impl Endpoint { endpoint.connect_to(&node).await?; let connection = endpoint .get_connection(&node) + .await .ok_or(Error::MissingConnection)?; let (mut send_stream, mut recv_stream) = connection.open_bi().await?; send_stream.send(WireMsg::EndpointEchoReq).await?;