Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoder improvements #2259

Merged
merged 14 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
*~
/.vscode/
/lcov.info
/mutants.out*/
/target/
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions neqo-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ qlog = { workspace = true }
windows = { version = "0.58", default-features = false, features = ["Win32_Media"] }

[dev-dependencies]
criterion = { version = "0.5", default-features = false }
neqo-crypto = { path = "../neqo-crypto" }
test-fixture = { path = "../test-fixture" }
regex = { workspace = true }

Expand All @@ -38,3 +40,7 @@ build-fuzzing-corpus = ["hex"]
[lib]
# See https://github.com/bheisler/criterion.rs/blob/master/book/src/faq.md#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options
bench = false

[[bench]]
name = "decoder"
harness = false
48 changes: 48 additions & 0 deletions neqo-common/benches/decoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use neqo_common::Decoder;
use neqo_crypto::{init, randomize};

fn randomize_buffer(n: usize, mask: u8) -> Vec<u8> {
let mut buf = vec![0; n];
for chunk in buf.chunks_mut(1024) {
// NSS doesn't like randomizing larger buffers, so chunk them up.
randomize(chunk);
}
for x in &mut buf[..] {
*x &= mask;
}
buf
}

fn decoder(c: &mut Criterion, count: usize, mask: u8) {
c.bench_function(&format!("decode {count} bytes, mask {mask:x}"), |b| {
b.iter_batched_ref(
|| randomize_buffer(count, mask),
|buf| {
let mut dec = Decoder::new(&buf[..]);
while black_box(dec.decode_varint()).is_some() {
// Do nothing;
}
},
criterion::BatchSize::SmallInput,
);
});
}

fn benchmark_decoder(c: &mut Criterion) {
init().unwrap();
for mask in [0xff, 0x7f, 0x3f] {
for count in [10, 15, 20] {
decoder(c, 1 << count, mask);
}
}
}

criterion_group!(benches, benchmark_decoder);
criterion_main!(benches);
84 changes: 48 additions & 36 deletions neqo-common/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<'a> Decoder<'a> {
/// Skip a vector. Panics if there isn't enough space.
/// Only use this for tests because we panic rather than reporting a result.
pub fn skip_vec(&mut self, n: usize) {
let len = self.decode_uint(n);
let len = self.decode_n(n);
self.skip_inner(len);
}

Expand All @@ -62,16 +62,6 @@ impl<'a> Decoder<'a> {
self.skip_inner(len);
}

/// Decodes (reads) a single byte.
pub fn decode_byte(&mut self) -> Option<u8> {
if self.remaining() < 1 {
return None;
}
let b = self.buf[self.offset];
self.offset += 1;
Some(b)
}

/// Provides the next byte without moving the read position.
#[must_use]
pub const fn peek_byte(&self) -> Option<u8> {
Expand All @@ -92,33 +82,43 @@ impl<'a> Decoder<'a> {
Some(res)
}

/// Decodes an unsigned integer of length 1..=8.
///
/// # Panics
///
/// This panics if `n` is not in the range `1..=8`.
pub fn decode_uint(&mut self, n: usize) -> Option<u64> {
assert!(n > 0 && n <= 8);
#[inline]
pub(crate) fn decode_n(&mut self, n: usize) -> Option<u64> {
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
debug_assert!(n > 0 && n <= 8);
if self.remaining() < n {
return None;
}
let mut v = 0_u64;
for i in 0..n {
let b = self.buf[self.offset + i];
v = v << 8 | u64::from(b);
}
self.offset += n;
Some(v)
Some(if n == 1 {
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
let v = u64::from(self.buf[self.offset]);
self.offset += 1;
v
} else {
let mut buf = [0; 8];
buf[8 - n..].copy_from_slice(&self.buf[self.offset..self.offset + n]);
self.offset += n;
u64::from_be_bytes(buf)
})
}

/// Decodes a big-endian, unsigned integer value into the target type.
/// This returns `None` if there is not enough data remaining
/// or if the conversion to the identified type fails.
/// Conversion is via `u64`, so failures are impossible for
/// unsigned integer types: `u8`, `u16`, `u32`, or `u64`.
/// Signed types will fail if the high bit is set.
pub fn decode_uint<T: TryFrom<u64>>(&mut self) -> Option<T> {
let v = self.decode_n(size_of::<T>());
v.and_then(|v| T::try_from(v).ok())
}

/// Decodes a QUIC varint.
pub fn decode_varint(&mut self) -> Option<u64> {
let b1 = self.decode_byte()?;
let b1 = self.decode_n(1)?;
match b1 >> 6 {
0 => Some(u64::from(b1 & 0x3f)),
1 => Some((u64::from(b1 & 0x3f) << 8) | self.decode_uint(1)?),
2 => Some((u64::from(b1 & 0x3f) << 24) | self.decode_uint(3)?),
3 => Some((u64::from(b1 & 0x3f) << 56) | self.decode_uint(7)?),
0 => Some(b1),
1 => Some((b1 & 0x3f) << 8 | self.decode_n(1)?),
2 => Some((b1 & 0x3f) << 24 | self.decode_n(3)?),
3 => Some((b1 & 0x3f) << 56 | self.decode_n(7)?),
_ => unreachable!(),
}
}
Expand All @@ -143,7 +143,7 @@ impl<'a> Decoder<'a> {

/// Decodes a TLS-style length-prefixed buffer.
pub fn decode_vec(&mut self, n: usize) -> Option<&'a [u8]> {
let len = self.decode_uint(n);
let len = self.decode_n(n);
self.decode_checked(len)
}

Expand Down Expand Up @@ -481,16 +481,28 @@ mod tests {
let enc = Encoder::from_hex("0123");
let mut dec = enc.as_decoder();

assert_eq!(dec.decode_byte().unwrap(), 0x01);
assert_eq!(dec.decode_byte().unwrap(), 0x23);
assert!(dec.decode_byte().is_none());
assert_eq!(dec.decode_uint::<u8>().unwrap(), 0x01);
assert_eq!(dec.decode_uint::<u8>().unwrap(), 0x23);
assert!(dec.decode_uint::<u8>().is_none());
}

#[test]
fn peek_byte() {
let enc = Encoder::from_hex("01");
let mut dec = enc.as_decoder();

assert_eq!(dec.offset(), 0);
assert_eq!(dec.peek_byte().unwrap(), 0x01);
dec.skip(1);
assert_eq!(dec.offset(), 1);
assert!(dec.peek_byte().is_none());
}

#[test]
fn decode_byte_short() {
let enc = Encoder::from_hex("");
let mut dec = enc.as_decoder();
assert!(dec.decode_byte().is_none());
assert!(dec.decode_uint::<u8>().is_none());
}

#[test]
Expand All @@ -501,7 +513,7 @@ mod tests {
assert!(dec.decode(2).is_none());

let mut dec = Decoder::from(&[]);
assert_eq!(dec.decode_remainder().len(), 0);
assert!(dec.decode_remainder().is_empty());
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions neqo-common/src/incrdecoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ impl IncrementalDecoderUint {
if amount < 8 {
self.v <<= amount * 8;
}
self.v |= dv.decode_uint(amount).unwrap();
self.v |= dv.decode_n(amount).unwrap();
*r -= amount;
if *r == 0 {
Some(self.v)
} else {
None
}
} else {
let (v, remaining) = dv.decode_byte().map_or_else(
let (v, remaining) = dv.decode_uint::<u8>().map_or_else(
|| unreachable!(),
|b| {
(
Expand Down
3 changes: 2 additions & 1 deletion neqo-http3/src/frames/hframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ impl HFrame {
Self::PriorityUpdateRequest { .. } => H3_FRAME_TYPE_PRIORITY_UPDATE_REQUEST,
Self::PriorityUpdatePush { .. } => H3_FRAME_TYPE_PRIORITY_UPDATE_PUSH,
Self::Grease => {
HFrameType(Decoder::from(&random::<7>()).decode_uint(7).unwrap() * 0x1f + 0x21)
let r = Decoder::from(&random::<8>()).decode_uint::<u64>().unwrap();
HFrameType((r >> 5) * 0x1f + 0x21)
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions neqo-http3/src/frames/wtframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ impl FrameDecoder<Self> for WebTransportFrame {
if frame_len > WT_FRAME_CLOSE_MAX_MESSAGE_SIZE + 4 {
return Err(Error::HttpMessageError);
}
let error =
u32::try_from(dec.decode_uint(4).ok_or(Error::HttpMessageError)?).unwrap();
let error = dec.decode_uint().ok_or(Error::HttpMessageError)?;
let Ok(message) = String::from_utf8(dec.decode_remainder().to_vec()) else {
return Err(Error::HttpMessageError);
};
Expand Down
4 changes: 2 additions & 2 deletions neqo-transport/src/addr_valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ impl AddressValidation {
let peer_addr = Self::encode_aad(peer_address, retry);
let data = self.self_encrypt.open(peer_addr.as_ref(), token).ok()?;
let mut dec = Decoder::new(&data);
match dec.decode_uint(4) {
match dec.decode_uint::<u32>() {
Some(d) => {
let end = self.start_time + Duration::from_millis(d);
let end = self.start_time + Duration::from_millis(u64::from(d));
if end < now {
qtrace!("Expired token: {:?} vs. {:?}", end, now);
return None;
Expand Down
7 changes: 4 additions & 3 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,10 @@ impl Connection {
);
let mut dec = Decoder::from(token.as_ref());

let version = Version::try_from(u32::try_from(
dec.decode_uint(4).ok_or(Error::InvalidResumptionToken)?,
)?)?;
let version = Version::try_from(
dec.decode_uint::<WireVersion>()
.ok_or(Error::InvalidResumptionToken)?,
)?;
qtrace!([self], " version {:?}", version);
if !self.conn_params.get_versions().all().contains(&version) {
return Err(Error::DisabledVersion);
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ impl<'a> Frame<'a> {
return Err(Error::FrameEncodingError);
}
let delay = dv(dec)?;
let ignore_order = match d(dec.decode_uint(1))? {
let ignore_order = match d(dec.decode_uint::<u8>())? {
0 => false,
1 => true,
_ => return Err(Error::FrameEncodingError),
Expand Down
6 changes: 3 additions & 3 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ impl<'a> PublicPacket<'a> {
#[allow(clippy::similar_names)]
pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
let mut decoder = Decoder::new(data);
let first = Self::opt(decoder.decode_byte())?;
let first = Self::opt(decoder.decode_uint::<u8>())?;

if first & 0x80 == PACKET_BIT_SHORT {
// Conveniently, this also guarantees that there is enough space
Expand Down Expand Up @@ -638,7 +638,7 @@ impl<'a> PublicPacket<'a> {
}

// Generic long header.
let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
let version = Self::opt(decoder.decode_uint())?;
let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);

Expand Down Expand Up @@ -893,7 +893,7 @@ impl<'a> PublicPacket<'a> {
let mut decoder = Decoder::new(&self.data[self.header_len..]);
let mut res = Vec::new();
while decoder.remaining() > 0 {
let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
let version = Self::opt(decoder.decode_uint::<WireVersion>())?;
res.push(version);
}
Ok(res)
Expand Down
11 changes: 5 additions & 6 deletions neqo-transport/src/tparams.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl TransportParameter {
fn decode_preferred_address(d: &mut Decoder) -> Res<Self> {
// IPv4 address (maybe)
let v4ip = Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?)?);
let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?)?;
let v4port = d.decode_uint::<u16>().ok_or(Error::NoMoreData)?;
// Can't have non-zero IP and zero port, or vice versa.
if v4ip.is_unspecified() ^ (v4port == 0) {
return Err(Error::TransportParameterError);
Expand All @@ -200,7 +200,7 @@ impl TransportParameter {
let v6ip = Ipv6Addr::from(<[u8; 16]>::try_from(
d.decode(16).ok_or(Error::NoMoreData)?,
)?);
let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?)?;
let v6port = d.decode_uint().ok_or(Error::NoMoreData)?;
if v6ip.is_unspecified() ^ (v6port == 0) {
return Err(Error::TransportParameterError);
}
Expand Down Expand Up @@ -229,11 +229,11 @@ impl TransportParameter {

fn decode_versions(dec: &mut Decoder) -> Res<Self> {
fn dv(dec: &mut Decoder) -> Res<WireVersion> {
let v = dec.decode_uint(4).ok_or(Error::NoMoreData)?;
let v = dec.decode_uint::<WireVersion>().ok_or(Error::NoMoreData)?;
if v == 0 {
Err(Error::TransportParameterError)
} else {
Ok(WireVersion::try_from(v)?)
Ok(v)
}
}

Expand Down Expand Up @@ -457,8 +457,7 @@ impl TransportParameters {
let rbuf = random::<4>();
let mut other = Vec::with_capacity(versions.all().len() + 1);
let mut dec = Decoder::new(&rbuf);
let grease =
(u32::try_from(dec.decode_uint(4).unwrap()).unwrap()) & 0xf0f0_f0f0 | 0x0a0a_0a0a;
let grease = dec.decode_uint::<u32>().unwrap() & 0xf0f0_f0f0 | 0x0a0a_0a0a;
other.push(grease);
for &v in versions.all() {
if role == Role::Client && !versions.initial().is_compatible(v) {
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ mod tests {
assert_eq!(stats.ack, 1);

let mut dec = builder.as_decoder();
_ = dec.decode_byte().unwrap(); // Skip the short header.
_ = dec.decode_uint::<u8>().unwrap(); // Skip the short header.
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
let frame = Frame::decode(&mut dec).unwrap();
if let Frame::Ack { ack_ranges, .. } = frame {
assert_eq!(ack_ranges.len(), 0);
Expand Down
Loading
Loading