Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform early first-packet validation before decryption #1789

Merged
merged 4 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 61 additions & 45 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
connection::{Connection, ConnectionError},
crypto::{self, Keys, UnsupportedVersion},
frame,
packet::{Header, Packet, PacketDecodeError, PacketNumber, PartialDecode},
packet::{Header, Packet, PacketDecodeError, PacketNumber, PartialDecode, PlainInitialHeader},
shared::{
ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent,
EndpointEventInner, IssuedCid,
Expand Down Expand Up @@ -210,27 +210,40 @@ impl Endpoint {
}
};

if let Some(version) = first_decode.initial_version() {
if let Some(header) = first_decode.initial_header() {
if datagram_len < MIN_INITIAL_SIZE as usize {
debug!("ignoring short initial for connection {}", dst_cid);
return None;
}

let crypto = match server_config
.crypto
.initial_keys(version, dst_cid, Side::Server)
{
Ok(keys) => keys,
Err(UnsupportedVersion) => {
// This probably indicates that the user set supported_versions incorrectly in
// `EndpointConfig`.
debug!(
let crypto =
match server_config
.crypto
.initial_keys(header.version, dst_cid, Side::Server)
{
Ok(keys) => keys,
Err(UnsupportedVersion) => {
// This probably indicates that the user set supported_versions incorrectly in
// `EndpointConfig`.
debug!(
"ignoring initial packet version {:#x} unsupported by cryptographic layer",
version
header.version
);
return None;
}
};
return None;
}
};

if let Err(reason) = self.early_validate_first_packet(header) {
return Some(DatagramEvent::Response(self.initial_close(
Ralith marked this conversation as resolved.
Show resolved Hide resolved
header.version,
addresses,
&crypto,
&header.src_cid,
reason,
buf,
)));
}

return match first_decode.finish(Some(&*crypto.header.remote)) {
Ok(packet) => {
self.handle_first_packet(now, addresses, ecn, packet, remaining, &crypto, buf)
Expand Down Expand Up @@ -436,36 +449,6 @@ impl Endpoint {

let server_config = self.server_config.as_ref().unwrap().clone();

if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full()
{
debug!("refusing connection");
return Some(DatagramEvent::Response(self.initial_close(
version,
addresses,
crypto,
&src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)));
}

if dst_cid.len() < 8
&& (!server_config.use_retry || dst_cid.len() != self.local_cid_generator.cid_len())
{
debug!(
"rejecting connection due to invalid DCID length {}",
dst_cid.len()
);
return Some(DatagramEvent::Response(self.initial_close(
version,
addresses,
crypto,
&src_cid,
TransportError::PROTOCOL_VIOLATION("invalid destination CID length"),
buf,
)));
}

let (retry_src_cid, orig_dst_cid) = if server_config.use_retry {
if token.is_empty() {
// First Initial
Expand Down Expand Up @@ -580,6 +563,39 @@ impl Endpoint {
}
}

/// Check if we should refuse a connection attempt regardless of the packet's contents
fn early_validate_first_packet(
&mut self,
header: &PlainInitialHeader,
) -> Result<(), TransportError> {
let server_config = self.server_config.as_ref().unwrap();
if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full()
{
debug!("refusing connection");
return Err(TransportError::CONNECTION_REFUSED(""));
}

// RFC9000 §7.2 dictates that initial (client-chosen) destination CIDs must be at least 8
// bytes. If this is a Retry packet, then the length must instead match our usual CID
// length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll
// also need to validate CID length for those after decoding the token.
if header.dst_cid.len() < 8
&& (!server_config.use_retry
|| (!header.token_pos.is_empty()
&& header.dst_cid.len() != self.local_cid_generator.cid_len()))
{
debug!(
"rejecting connection due to invalid DCID length {}",
header.dst_cid.len()
);
return Err(TransportError::PROTOCOL_VIOLATION(
"invalid destination CID length",
));
}

Ok(())
}

fn add_connection(
&mut self,
ch: ConnectionHandle,
Expand Down
43 changes: 25 additions & 18 deletions quinn-proto/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,8 @@ impl PartialDecode {
self.buf.get_ref()
}

pub(crate) fn initial_version(&self) -> Option<u32> {
match self.plain_header {
PlainHeader::Initial { version, .. } => Some(version),
_ => None,
}
pub(crate) fn initial_header(&self) -> Option<&PlainInitialHeader> {
self.plain_header.as_initial()
}

pub(crate) fn has_long_header(&self) -> bool {
Expand Down Expand Up @@ -119,13 +116,13 @@ impl PartialDecode {
mut buf,
} = self;

if let Initial {
if let Initial(PlainInitialHeader {
dst_cid,
src_cid,
token_pos,
version,
..
} = plain_header
}) = plain_header
{
let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
let header_len = buf.position() as usize;
Expand Down Expand Up @@ -481,13 +478,7 @@ impl PartialEncode {

#[derive(Clone, Debug)]
pub(crate) enum PlainHeader {
Initial {
dst_cid: ConnectionId,
src_cid: ConnectionId,
token_pos: Range<usize>,
len: u64,
version: u32,
},
Initial(PlainInitialHeader),
Long {
ty: LongType,
dst_cid: ConnectionId,
Expand All @@ -512,10 +503,17 @@ pub(crate) enum PlainHeader {
}

impl PlainHeader {
pub(crate) fn as_initial(&self) -> Option<&PlainInitialHeader> {
match self {
Self::Initial(x) => Some(x),
_ => None,
}
}

fn dst_cid(&self) -> &ConnectionId {
use self::PlainHeader::*;
match self {
Initial { dst_cid, .. } => dst_cid,
Initial(header) => &header.dst_cid,
Long { dst_cid, .. } => dst_cid,
Retry { dst_cid, .. } => dst_cid,
Short { dst_cid, .. } => dst_cid,
Expand All @@ -526,7 +524,7 @@ impl PlainHeader {
fn payload_len(&self) -> Option<u64> {
use self::PlainHeader::*;
match self {
Initial { len, .. } | Long { len, .. } => Some(*len),
Initial(PlainInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
_ => None,
}
}
Expand Down Expand Up @@ -587,13 +585,13 @@ impl PlainHeader {
buf.advance(token_len);

let len = buf.get_var()?;
Ok(Self::Initial {
Ok(Self::Initial(PlainInitialHeader {
dst_cid,
src_cid,
token_pos: token_start..token_start + token_len,
len,
version,
})
}))
}
LongHeaderType::Retry => Ok(Self::Retry {
dst_cid,
Expand All @@ -612,6 +610,15 @@ impl PlainHeader {
}
}

#[derive(Clone, Debug)]
pub(crate) struct PlainInitialHeader {
pub(crate) dst_cid: ConnectionId,
pub(crate) src_cid: ConnectionId,
pub(crate) token_pos: Range<usize>,
pub(crate) len: u64,
pub(crate) version: u32,
}

// An encoded packet number
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) enum PacketNumber {
Expand Down