Skip to content

Commit

Permalink
Fix packet dispatch for outgoing connections with zero-length CIDs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralith committed May 25, 2024
1 parent 772c269 commit 45132ae
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
49 changes: 38 additions & 11 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use crate::{
},
token::TokenDecodeError,
transport_parameters::{PreferredAddress, TransportParameters},
ResetToken, RetryToken, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE,
MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU,
MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
};

/// The main entry point to the library
Expand Down Expand Up @@ -814,6 +814,10 @@ impl Endpoint {
) -> Connection {
let mut rng_seed = [0; 32];
self.rng.fill_bytes(&mut rng_seed);
let side = match server_config.is_some() {
true => Side::Server,
false => Side::Client,
};
let conn = Connection::new(
self.config.clone(),
server_config,
Expand Down Expand Up @@ -854,7 +858,7 @@ impl Endpoint {
});
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");

self.index.insert_conn(addresses, loc_cid, ch);
self.index.insert_conn(addresses, loc_cid, ch, side);

conn
}
Expand Down Expand Up @@ -913,7 +917,8 @@ impl Endpoint {
// Not all connections have known reset tokens
debug_assert!(x >= self.index.connection_reset_tokens.0.len());
// Not all connections have unique remotes, and 0-length CIDs might not be in use.
debug_assert!(x >= self.index.connection_remotes.len());
debug_assert!(x >= self.index.incoming_connection_remotes.len());
debug_assert!(x >= self.index.outgoing_connection_remotes.len());
x
}

Expand Down Expand Up @@ -978,10 +983,19 @@ struct ConnectionIndex {
///
/// Uses a cheaper hash function since keys are locally created
connection_ids: FxHashMap<ConnectionId, ConnectionHandle>,
/// Identifies connections with zero-length CIDs
/// Identifies incoming connections with zero-length CIDs
///
/// Uses a standard `HashMap` to protect against hash collision attacks.
connection_remotes: HashMap<FourTuple, ConnectionHandle>,
incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
/// Identifies outgoing connections with zero-length CIDs
///
/// We don't yet support explicit source addresses for client connections, and zero-length CIDs
/// require a unique four-tuple, so at most one client connection with zero-length local CIDs
/// may be established per remote. We must omit the source address from the key because we don't
/// necessarily know what address we're sending from.
///
/// Uses a standard `HashMap` to protect against hash collision attacks.
outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
/// Reset tokens provided by the peer for the CID each connection is currently sending to
///
/// Incoming stateless resets do not have correct CIDs, so we need this to identify the correct
Expand Down Expand Up @@ -1014,11 +1028,19 @@ impl ConnectionIndex {
addresses: FourTuple,
dst_cid: ConnectionId,
connection: ConnectionHandle,
side: Side,
) {
match dst_cid.len() {
0 => {
self.connection_remotes.insert(addresses, connection);
}
0 => match side {
Side::Server => {
self.incoming_connection_remotes
.insert(addresses, connection);
}
Side::Client => {
self.outgoing_connection_remotes
.insert(addresses.remote, connection);
}
},
_ => {
self.connection_ids.insert(dst_cid, connection);
}
Expand All @@ -1038,7 +1060,9 @@ impl ConnectionIndex {
for cid in conn.loc_cids.values() {
self.connection_ids.remove(cid);
}
self.connection_remotes.remove(&conn.addresses);
self.incoming_connection_remotes.remove(&conn.addresses);
self.outgoing_connection_remotes
.remove(&conn.addresses.remote);
if let Some((remote, token)) = conn.reset_token {
self.connection_reset_tokens.remove(remote, token);
}
Expand All @@ -1057,7 +1081,10 @@ impl ConnectionIndex {
}
}
if datagram.dst_cid().len() == 0 {
if let Some(&ch) = self.connection_remotes.get(addresses) {
if let Some(&ch) = self.incoming_connection_remotes.get(addresses) {
return Some(RouteDatagramTo::Connection(ch));
}
if let Some(&ch) = self.outgoing_connection_remotes.get(&addresses.remote) {
return Some(RouteDatagramTo::Connection(ch));
}
}
Expand Down
24 changes: 15 additions & 9 deletions quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use tracing::{error_span, info};
use tracing_futures::Instrument as _;
use tracing_subscriber::EnvFilter;

use super::{ClientConfig, Endpoint, RecvStream, SendStream, TransportConfig};
use super::{ClientConfig, Endpoint, EndpointConfig, RecvStream, SendStream, TransportConfig};

#[test]
fn handshake_timeout() {
Expand Down Expand Up @@ -264,29 +264,35 @@ fn endpoint_with_config(transport_config: TransportConfig) -> Endpoint {
}

/// Constructs endpoints suitable for connecting to themselves and each other
struct EndpointFactory(rcgen::CertifiedKey);
struct EndpointFactory {
cert: rcgen::CertifiedKey,
}

impl EndpointFactory {
fn new() -> Self {
Self(rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap())
Self {
cert: rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(),
}
}

fn endpoint(&self) -> Endpoint {
self.endpoint_with_config(TransportConfig::default())
}

fn endpoint_with_config(&self, transport_config: TransportConfig) -> Endpoint {
let key = PrivateKeyDer::Pkcs8(self.0.key_pair.serialize_der().into());
let key = PrivateKeyDer::Pkcs8(self.cert.key_pair.serialize_der().into());
let transport_config = Arc::new(transport_config);
let mut server_config =
crate::ServerConfig::with_single_cert(vec![self.0.cert.der().clone()], key).unwrap();
crate::ServerConfig::with_single_cert(vec![self.cert.cert.der().clone()], key).unwrap();
server_config.transport_config(transport_config.clone());

let mut roots = rustls::RootCertStore::empty();
roots.add(self.0.cert.der().clone()).unwrap();
let mut endpoint = Endpoint::server(
server_config,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
roots.add(self.cert.cert.der().clone()).unwrap();
let mut endpoint = Endpoint::new(
EndpointConfig::default(),
Some(server_config),
UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(),
Arc::new(TokioRuntime),
)
.unwrap();
let mut client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
Expand Down

0 comments on commit 45132ae

Please sign in to comment.