diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 0403eb530..e4df5787d 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -457,11 +457,15 @@ impl Endpoint { fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId { loop { let cid = self.local_cid_generator.generate_cid(); + if cid.len() == 0 { + // Zero-length CID; nothing to track + debug_assert_eq!(self.local_cid_generator.cid_len(), 0); + return cid; + } if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) { e.insert(ch); break cid; } - assert!(self.local_cid_generator.cid_len() > 0); } } diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index c62b2bcca..802509cbd 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -10,7 +10,7 @@ use std::{ use crate::runtime::TokioRuntime; use bytes::Bytes; -use proto::crypto::rustls::QuicClientConfig; +use proto::{crypto::rustls::QuicClientConfig, RandomConnectionIdGenerator}; use rand::{rngs::StdRng, RngCore, SeedableRng}; use rustls::{ pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, @@ -266,12 +266,14 @@ fn endpoint_with_config(transport_config: TransportConfig) -> Endpoint { /// Constructs endpoints suitable for connecting to themselves and each other struct EndpointFactory { cert: rcgen::CertifiedKey, + endpoint_config: EndpointConfig, } impl EndpointFactory { fn new() -> Self { Self { cert: rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(), + endpoint_config: EndpointConfig::default(), } } @@ -289,7 +291,7 @@ impl EndpointFactory { let mut roots = rustls::RootCertStore::empty(); roots.add(self.cert.cert.der().clone()).unwrap(); let mut endpoint = Endpoint::new( - EndpointConfig::default(), + self.endpoint_config.clone(), Some(server_config), UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(), Arc::new(TokioRuntime), @@ -809,3 +811,54 @@ async fn two_datagram_readers() { assert!(*a == *b"one" || *b == *b"one"); assert!(*a == *b"two" || *b == *b"two"); } + +#[tokio::test] +async fn multiple_conns_with_zero_length_cids() { + let _guard = subscribe(); + let mut factory = EndpointFactory::new(); + factory + .endpoint_config + .cid_generator(|| Box::new(RandomConnectionIdGenerator::new(0))); + let server = { + let _guard = error_span!("server").entered(); + factory.endpoint() + }; + let server_addr = server.local_addr().unwrap(); + + let client1 = { + let _guard = error_span!("client1").entered(); + factory.endpoint() + }; + let client2 = { + let _guard = error_span!("client2").entered(); + factory.endpoint() + }; + + let client1 = async move { + let conn = client1 + .connect(server_addr, "localhost") + .unwrap() + .await + .unwrap(); + conn.closed().await; + } + .instrument(error_span!("client1")); + let client2 = async move { + let conn = client2 + .connect(server_addr, "localhost") + .unwrap() + .await + .unwrap(); + conn.closed().await; + } + .instrument(error_span!("client2")); + let server = async move { + let client1 = server.accept().await.unwrap().await.unwrap(); + let client2 = server.accept().await.unwrap().await.unwrap(); + // Both connections are now concurrently live. + client1.close(42u32.into(), &[]); + client2.close(42u32.into(), &[]); + } + .instrument(error_span!("server")); + tokio::join!(client1, client2, server); +}