Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tcp/websocket/quic: Fix cancel memory leak #272

Merged
merged 11 commits into from
Oct 30, 2024
60 changes: 46 additions & 14 deletions src/transport/quic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use multiaddr::{Multiaddr, Protocol};
use quinn::{ClientConfig, Connecting, Connection, Endpoint, IdleTimeout};

use std::{
collections::{HashMap, HashSet},
collections::HashMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::Arc,
Expand Down Expand Up @@ -120,9 +120,9 @@ pub(crate) struct QuicTransport {
/// Opened raw connection, waiting for approval/rejection from `TransportManager`.
opened_raw: HashMap<ConnectionId, (NegotiatedConnection, Multiaddr)>,

/// Canceled raw connections.
canceled: HashSet<ConnectionId>,

/// Cancel raw connections futures.
///
/// This is cancelling `Self::pending_raw_connections`.
cancel_futures: HashMap<ConnectionId, AbortHandle>,
}

Expand Down Expand Up @@ -235,7 +235,6 @@ impl TransportBuilder for QuicTransport {
context,
config,
listener,
canceled: HashSet::new(),
opened_raw: HashMap::new(),
pending_open: HashMap::new(),
pending_dials: HashMap::new(),
Expand Down Expand Up @@ -477,8 +476,11 @@ impl Transport for QuicTransport {

/// Cancel opening connections.
fn cancel(&mut self, connection_id: ConnectionId) {
self.canceled.insert(connection_id);
self.cancel_futures.remove(&connection_id).map(|handle| handle.abort());
// Cancel the future if it exists.
// State clean-up happens inside the `poll_next`.
if let Some(handle) = self.cancel_futures.get(&connection_id) {
handle.abort();
}
}
}

Expand Down Expand Up @@ -510,27 +512,57 @@ impl Stream for QuicTransport {
connection_id,
address,
stream,
} =>
if !self.canceled.remove(&connection_id) {
} => {
let Some(handle) = self.cancel_futures.remove(&connection_id) else {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
?address,
"raw connection without a cancel handle",
);
continue;
};

if !handle.is_aborted() {
self.opened_raw.insert(connection_id, (stream, address.clone()));

return Poll::Ready(Some(TransportEvent::ConnectionOpened {
connection_id,
address,
}));
},
}
}

RawConnectionResult::Failed {
connection_id,
errors,
} =>
if !self.canceled.remove(&connection_id) {
} => {
let Some(handle) = self.cancel_futures.remove(&connection_id) else {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
?errors,
"raw connection without a cancel handle",
);
continue;
};

if !handle.is_aborted() {
return Poll::Ready(Some(TransportEvent::OpenFailure {
connection_id,
errors,
}));
},
}
}

RawConnectionResult::Canceled { connection_id } => {
self.canceled.remove(&connection_id);
if self.cancel_futures.remove(&connection_id).is_none() {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
"raw cancelled connection without a cancel handle",
);
}
}
}
}
Expand Down
59 changes: 45 additions & 14 deletions src/transport/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use socket2::{Domain, Socket, Type};
use tokio::net::TcpStream;

use std::{
collections::{HashMap, HashSet},
collections::HashMap,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -121,9 +121,9 @@ pub(crate) struct TcpTransport {
/// Opened raw connection, waiting for approval/rejection from `TransportManager`.
opened_raw: HashMap<ConnectionId, (TcpStream, Multiaddr)>,

/// Canceled raw connections.
canceled: HashSet<ConnectionId>,

/// Cancel raw connections futures.
///
/// This is cancelling `Self::pending_raw_connections`.
cancel_futures: HashMap<ConnectionId, AbortHandle>,

/// Connections which have been opened and negotiated but are being validated by the
Expand Down Expand Up @@ -291,7 +291,6 @@ impl TransportBuilder for TcpTransport {
config,
context,
dial_addresses,
canceled: HashSet::new(),
opened_raw: HashMap::new(),
pending_open: HashMap::new(),
pending_dials: HashMap::new(),
Expand Down Expand Up @@ -516,8 +515,11 @@ impl Transport for TcpTransport {
}

fn cancel(&mut self, connection_id: ConnectionId) {
self.canceled.insert(connection_id);
self.cancel_futures.remove(&connection_id).map(|handle| handle.abort());
// Cancel the future if it exists.
// State clean-up happens inside the `poll_next`.
if let Some(handle) = self.cancel_futures.get(&connection_id) {
handle.abort();
}
}
}

Expand Down Expand Up @@ -560,27 +562,56 @@ impl Stream for TcpTransport {
connection_id,
address,
stream,
} =>
if !self.canceled.remove(&connection_id) {
} => {
let Some(handle) = self.cancel_futures.remove(&connection_id) else {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
?address,
"raw connection without a cancel handle",
);
continue;
};

if !handle.is_aborted() {
self.opened_raw.insert(connection_id, (stream, address.clone()));

return Poll::Ready(Some(TransportEvent::ConnectionOpened {
connection_id,
address,
}));
},
}
}

RawConnectionResult::Failed {
connection_id,
errors,
} =>
if !self.canceled.remove(&connection_id) {
} => {
let Some(handle) = self.cancel_futures.remove(&connection_id) else {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
?errors,
"raw connection without a cancel handle",
);
continue;
};

if !handle.is_aborted() {
return Poll::Ready(Some(TransportEvent::OpenFailure {
connection_id,
errors,
}));
},
}
}
RawConnectionResult::Canceled { connection_id } => {
self.canceled.remove(&connection_id);
if self.cancel_futures.remove(&connection_id).is_none() {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
"raw cancelled connection without a cancel handle",
);
}
}
}
}
Expand Down
59 changes: 45 additions & 14 deletions src/transport/websocket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use url::Url;

use std::{
collections::{HashMap, HashSet},
collections::HashMap,
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -125,9 +125,9 @@ pub(crate) struct WebSocketTransport {
/// Opened raw connection, waiting for approval/rejection from `TransportManager`.
opened_raw: HashMap<ConnectionId, (WebSocketStream<MaybeTlsStream<TcpStream>>, Multiaddr)>,

/// Canceled raw connections.
canceled: HashSet<ConnectionId>,

/// Cancel raw connections futures.
///
/// This is cancelling `Self::pending_raw_connections`.
cancel_futures: HashMap<ConnectionId, AbortHandle>,

/// Negotiated connections waiting validation.
Expand Down Expand Up @@ -321,7 +321,6 @@ impl TransportBuilder for WebSocketTransport {
config,
context,
dial_addresses,
canceled: HashSet::new(),
opened_raw: HashMap::new(),
pending_open: HashMap::new(),
pending_dials: HashMap::new(),
Expand Down Expand Up @@ -562,8 +561,11 @@ impl Transport for WebSocketTransport {
}

fn cancel(&mut self, connection_id: ConnectionId) {
self.canceled.insert(connection_id);
self.cancel_futures.remove(&connection_id).map(|handle| handle.abort());
// Cancel the future if it exists.
// State clean-up happens inside the `poll_next`.
if let Some(handle) = self.cancel_futures.get(&connection_id) {
handle.abort();
}
}
}

Expand Down Expand Up @@ -600,27 +602,56 @@ impl Stream for WebSocketTransport {
connection_id,
address,
stream,
} =>
if !self.canceled.remove(&connection_id) {
} => {
let Some(handle) = self.cancel_futures.remove(&connection_id) else {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
?address,
"raw connection without a cancel handle",
);
continue;
};

if !handle.is_aborted() {
self.opened_raw.insert(connection_id, (stream, address.clone()));

return Poll::Ready(Some(TransportEvent::ConnectionOpened {
connection_id,
address,
}));
},
}
}

RawConnectionResult::Failed {
connection_id,
errors,
} =>
if !self.canceled.remove(&connection_id) {
} => {
let Some(handle) = self.cancel_futures.remove(&connection_id) else {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
?errors,
"raw connection without a cancel handle",
);
continue;
};

if !handle.is_aborted() {
return Poll::Ready(Some(TransportEvent::OpenFailure {
connection_id,
errors,
}));
},
}
}
RawConnectionResult::Canceled { connection_id } => {
self.canceled.remove(&connection_id);
if self.cancel_futures.remove(&connection_id).is_none() {
tracing::warn!(
target: LOG_TARGET,
?connection_id,
"raw cancelled connection without a cancel handle",
);
}
}
}
}
Expand Down