Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero-length connection IDs #1883

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use crate::{
},
token::TokenDecodeError,
transport_parameters::{PreferredAddress, TransportParameters},
ResetToken, RetryToken, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE,
MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
ResetToken, RetryToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU,
MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
};

/// The main entry point to the library
Expand Down Expand Up @@ -457,11 +457,15 @@ impl Endpoint {
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
loop {
let cid = self.local_cid_generator.generate_cid();
if cid.len() == 0 {
// Zero-length CID; nothing to track
debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
return cid;
}
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
e.insert(ch);
break cid;
}
assert!(self.local_cid_generator.cid_len() > 0);
}
}

Expand Down Expand Up @@ -814,6 +818,10 @@ impl Endpoint {
) -> Connection {
let mut rng_seed = [0; 32];
self.rng.fill_bytes(&mut rng_seed);
let side = match server_config.is_some() {
true => Side::Server,
false => Side::Client,
};
let conn = Connection::new(
self.config.clone(),
server_config,
Expand Down Expand Up @@ -854,7 +862,7 @@ impl Endpoint {
});
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");

self.index.insert_conn(addresses, loc_cid, ch);
self.index.insert_conn(addresses, loc_cid, ch, side);

conn
}
Expand Down Expand Up @@ -913,7 +921,8 @@ impl Endpoint {
// Not all connections have known reset tokens
debug_assert!(x >= self.index.connection_reset_tokens.0.len());
// Not all connections have unique remotes, and 0-length CIDs might not be in use.
debug_assert!(x >= self.index.connection_remotes.len());
debug_assert!(x >= self.index.incoming_connection_remotes.len());
debug_assert!(x >= self.index.outgoing_connection_remotes.len());
x
}

Expand Down Expand Up @@ -978,10 +987,19 @@ struct ConnectionIndex {
///
/// Uses a cheaper hash function since keys are locally created
connection_ids: FxHashMap<ConnectionId, ConnectionHandle>,
/// Identifies connections with zero-length CIDs
/// Identifies incoming connections with zero-length CIDs
///
/// Uses a standard `HashMap` to protect against hash collision attacks.
incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
/// Identifies outgoing connections with zero-length CIDs
///
/// We don't yet support explicit source addresses for client connections, and zero-length CIDs
/// require a unique four-tuple, so at most one client connection with zero-length local CIDs
/// may be established per remote. We must omit the local address from the key because we don't
/// necessarily know what address we're sending from, and hence receiving at.
///
/// Uses a standard `HashMap` to protect against hash collision attacks.
connection_remotes: HashMap<FourTuple, ConnectionHandle>,
outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
/// Reset tokens provided by the peer for the CID each connection is currently sending to
///
/// Incoming stateless resets do not have correct CIDs, so we need this to identify the correct
Expand Down Expand Up @@ -1014,11 +1032,19 @@ impl ConnectionIndex {
addresses: FourTuple,
dst_cid: ConnectionId,
connection: ConnectionHandle,
side: Side,
) {
match dst_cid.len() {
0 => {
self.connection_remotes.insert(addresses, connection);
}
0 => match side {
Side::Server => {
self.incoming_connection_remotes
.insert(addresses, connection);
}
Side::Client => {
self.outgoing_connection_remotes
.insert(addresses.remote, connection);
}
},
_ => {
self.connection_ids.insert(dst_cid, connection);
}
Expand All @@ -1038,7 +1064,9 @@ impl ConnectionIndex {
for cid in conn.loc_cids.values() {
self.connection_ids.remove(cid);
}
self.connection_remotes.remove(&conn.addresses);
self.incoming_connection_remotes.remove(&conn.addresses);
self.outgoing_connection_remotes
.remove(&conn.addresses.remote);
if let Some((remote, token)) = conn.reset_token {
self.connection_reset_tokens.remove(remote, token);
}
Expand All @@ -1057,7 +1085,10 @@ impl ConnectionIndex {
}
}
if datagram.dst_cid().len() == 0 {
if let Some(&ch) = self.connection_remotes.get(addresses) {
if let Some(&ch) = self.incoming_connection_remotes.get(addresses) {
return Some(RouteDatagramTo::Connection(ch));
}
if let Some(&ch) = self.outgoing_connection_remotes.get(&addresses.remote) {
return Some(RouteDatagramTo::Connection(ch));
}
}
Expand Down
79 changes: 69 additions & 10 deletions quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{

use crate::runtime::TokioRuntime;
use bytes::Bytes;
use proto::crypto::rustls::QuicClientConfig;
use proto::{crypto::rustls::QuicClientConfig, RandomConnectionIdGenerator};
use rand::{rngs::StdRng, RngCore, SeedableRng};
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
Expand All @@ -24,7 +24,7 @@ use tracing::{error_span, info};
use tracing_futures::Instrument as _;
use tracing_subscriber::EnvFilter;

use super::{ClientConfig, Endpoint, RecvStream, SendStream, TransportConfig};
use super::{ClientConfig, Endpoint, EndpointConfig, RecvStream, SendStream, TransportConfig};

#[test]
fn handshake_timeout() {
Expand Down Expand Up @@ -264,29 +264,37 @@ fn endpoint_with_config(transport_config: TransportConfig) -> Endpoint {
}

/// Constructs endpoints suitable for connecting to themselves and each other
struct EndpointFactory(rcgen::CertifiedKey);
struct EndpointFactory {
cert: rcgen::CertifiedKey,
Ralith marked this conversation as resolved.
Show resolved Hide resolved
endpoint_config: EndpointConfig,
}

impl EndpointFactory {
fn new() -> Self {
Self(rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap())
Self {
cert: rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(),
endpoint_config: EndpointConfig::default(),
}
}

fn endpoint(&self) -> Endpoint {
self.endpoint_with_config(TransportConfig::default())
}

fn endpoint_with_config(&self, transport_config: TransportConfig) -> Endpoint {
let key = PrivateKeyDer::Pkcs8(self.0.key_pair.serialize_der().into());
let key = PrivateKeyDer::Pkcs8(self.cert.key_pair.serialize_der().into());
let transport_config = Arc::new(transport_config);
let mut server_config =
crate::ServerConfig::with_single_cert(vec![self.0.cert.der().clone()], key).unwrap();
crate::ServerConfig::with_single_cert(vec![self.cert.cert.der().clone()], key).unwrap();
server_config.transport_config(transport_config.clone());

let mut roots = rustls::RootCertStore::empty();
roots.add(self.0.cert.der().clone()).unwrap();
let mut endpoint = Endpoint::server(
server_config,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
roots.add(self.cert.cert.der().clone()).unwrap();
let mut endpoint = Endpoint::new(
self.endpoint_config.clone(),
Some(server_config),
UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(),
Arc::new(TokioRuntime),
)
.unwrap();
let mut client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
Expand Down Expand Up @@ -803,3 +811,54 @@ async fn two_datagram_readers() {
assert!(*a == *b"one" || *b == *b"one");
assert!(*a == *b"two" || *b == *b"two");
}

#[tokio::test]
async fn multiple_conns_with_zero_length_cids() {
let _guard = subscribe();
let mut factory = EndpointFactory::new();
factory
.endpoint_config
.cid_generator(|| Box::new(RandomConnectionIdGenerator::new(0)));
let server = {
let _guard = error_span!("server").entered();
factory.endpoint()
};
let server_addr = server.local_addr().unwrap();

let client1 = {
let _guard = error_span!("client1").entered();
factory.endpoint()
};
let client2 = {
let _guard = error_span!("client2").entered();
factory.endpoint()
};

let client1 = async move {
let conn = client1
.connect(server_addr, "localhost")
.unwrap()
.await
.unwrap();
conn.closed().await;
}
.instrument(error_span!("client1"));
let client2 = async move {
let conn = client2
.connect(server_addr, "localhost")
.unwrap()
.await
.unwrap();
conn.closed().await;
}
.instrument(error_span!("client2"));
let server = async move {
let client1 = server.accept().await.unwrap().await.unwrap();
let client2 = server.accept().await.unwrap().await.unwrap();
// Both connections are now concurrently live.
client1.close(42u32.into(), &[]);
client2.close(42u32.into(), &[]);
}
.instrument(error_span!("server"));
tokio::join!(client1, client2, server);
}
Loading