diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 3d61881d1..19c1f61e7 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -22,7 +22,7 @@ use crate::{ connection::{Connection, ConnectionError}, crypto::{self, Keys, UnsupportedVersion}, frame, - packet::{Header, Packet, PacketDecodeError, PacketNumber, PartialDecode}, + packet::{Header, Packet, PacketDecodeError, PacketNumber, PartialDecode, PlainInitialHeader}, shared::{ ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent, EndpointEventInner, IssuedCid, @@ -210,27 +210,40 @@ impl Endpoint { } }; - if let Some(version) = first_decode.initial_version() { + if let Some(header) = first_decode.initial_header() { if datagram_len < MIN_INITIAL_SIZE as usize { debug!("ignoring short initial for connection {}", dst_cid); return None; } - let crypto = match server_config - .crypto - .initial_keys(version, dst_cid, Side::Server) - { - Ok(keys) => keys, - Err(UnsupportedVersion) => { - // This probably indicates that the user set supported_versions incorrectly in - // `EndpointConfig`. - debug!( + let crypto = + match server_config + .crypto + .initial_keys(header.version, dst_cid, Side::Server) + { + Ok(keys) => keys, + Err(UnsupportedVersion) => { + // This probably indicates that the user set supported_versions incorrectly in + // `EndpointConfig`. + debug!( "ignoring initial packet version {:#x} unsupported by cryptographic layer", - version + header.version ); - return None; - } - }; + return None; + } + }; + + if let Err(reason) = self.early_validate_first_packet(header) { + return Some(DatagramEvent::Response(self.initial_close( + header.version, + addresses, + &crypto, + &header.src_cid, + reason, + buf, + ))); + } + return match first_decode.finish(Some(&*crypto.header.remote)) { Ok(packet) => { self.handle_first_packet(now, addresses, ecn, packet, remaining, &crypto, buf) @@ -436,36 +449,6 @@ impl Endpoint { let server_config = self.server_config.as_ref().unwrap().clone(); - if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full() - { - debug!("refusing connection"); - return Some(DatagramEvent::Response(self.initial_close( - version, - addresses, - crypto, - &src_cid, - TransportError::CONNECTION_REFUSED(""), - buf, - ))); - } - - if dst_cid.len() < 8 - && (!server_config.use_retry || dst_cid.len() != self.local_cid_generator.cid_len()) - { - debug!( - "rejecting connection due to invalid DCID length {}", - dst_cid.len() - ); - return Some(DatagramEvent::Response(self.initial_close( - version, - addresses, - crypto, - &src_cid, - TransportError::PROTOCOL_VIOLATION("invalid destination CID length"), - buf, - ))); - } - let (retry_src_cid, orig_dst_cid) = if server_config.use_retry { if token.is_empty() { // First Initial @@ -580,6 +563,39 @@ impl Endpoint { } } + /// Check if we should refuse a connection attempt regardless of the packet's contents + fn early_validate_first_packet( + &mut self, + header: &PlainInitialHeader, + ) -> Result<(), TransportError> { + let server_config = self.server_config.as_ref().unwrap(); + if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full() + { + debug!("refusing connection"); + return Err(TransportError::CONNECTION_REFUSED("")); + } + + // RFC9000 ยง7.2 dictates that initial (client-chosen) destination CIDs must be at least 8 + // bytes. If this is a Retry packet, then the length must instead match our usual CID + // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll + // also need to validate CID length for those after decoding the token. + if header.dst_cid.len() < 8 + && (!server_config.use_retry + || (!header.token_pos.is_empty() + && header.dst_cid.len() != self.local_cid_generator.cid_len())) + { + debug!( + "rejecting connection due to invalid DCID length {}", + header.dst_cid.len() + ); + return Err(TransportError::PROTOCOL_VIOLATION( + "invalid destination CID length", + )); + } + + Ok(()) + } + fn add_connection( &mut self, ch: ConnectionHandle, diff --git a/quinn-proto/src/packet.rs b/quinn-proto/src/packet.rs index b70eb98f7..d8f2611f7 100644 --- a/quinn-proto/src/packet.rs +++ b/quinn-proto/src/packet.rs @@ -60,11 +60,8 @@ impl PartialDecode { self.buf.get_ref() } - pub(crate) fn initial_version(&self) -> Option { - match self.plain_header { - PlainHeader::Initial { version, .. } => Some(version), - _ => None, - } + pub(crate) fn initial_header(&self) -> Option<&PlainInitialHeader> { + self.plain_header.as_initial() } pub(crate) fn has_long_header(&self) -> bool { @@ -119,13 +116,13 @@ impl PartialDecode { mut buf, } = self; - if let Initial { + if let Initial(PlainInitialHeader { dst_cid, src_cid, token_pos, version, .. - } = plain_header + }) = plain_header { let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?; let header_len = buf.position() as usize; @@ -481,13 +478,7 @@ impl PartialEncode { #[derive(Clone, Debug)] pub(crate) enum PlainHeader { - Initial { - dst_cid: ConnectionId, - src_cid: ConnectionId, - token_pos: Range, - len: u64, - version: u32, - }, + Initial(PlainInitialHeader), Long { ty: LongType, dst_cid: ConnectionId, @@ -512,10 +503,17 @@ pub(crate) enum PlainHeader { } impl PlainHeader { + pub(crate) fn as_initial(&self) -> Option<&PlainInitialHeader> { + match self { + Self::Initial(x) => Some(x), + _ => None, + } + } + fn dst_cid(&self) -> &ConnectionId { use self::PlainHeader::*; match self { - Initial { dst_cid, .. } => dst_cid, + Initial(header) => &header.dst_cid, Long { dst_cid, .. } => dst_cid, Retry { dst_cid, .. } => dst_cid, Short { dst_cid, .. } => dst_cid, @@ -526,7 +524,7 @@ impl PlainHeader { fn payload_len(&self) -> Option { use self::PlainHeader::*; match self { - Initial { len, .. } | Long { len, .. } => Some(*len), + Initial(PlainInitialHeader { len, .. }) | Long { len, .. } => Some(*len), _ => None, } } @@ -587,13 +585,13 @@ impl PlainHeader { buf.advance(token_len); let len = buf.get_var()?; - Ok(Self::Initial { + Ok(Self::Initial(PlainInitialHeader { dst_cid, src_cid, token_pos: token_start..token_start + token_len, len, version, - }) + })) } LongHeaderType::Retry => Ok(Self::Retry { dst_cid, @@ -612,6 +610,15 @@ impl PlainHeader { } } +#[derive(Clone, Debug)] +pub(crate) struct PlainInitialHeader { + pub(crate) dst_cid: ConnectionId, + pub(crate) src_cid: ConnectionId, + pub(crate) token_pos: Range, + pub(crate) len: u64, + pub(crate) version: u32, +} + // An encoded packet number #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum PacketNumber {