Skip to content

Commit

Permalink
fix: initialize scratch buffer (#752)
Browse files Browse the repository at this point in the history
Previously, the code took a reference to uninitialized memory, which is
undefined behavior. Allocating zeroed memory is relatively cheap, so the
scratch buffer allocates zeroed memory to initialize the memory.

Co-authored-by: Andrew Gazelka <andrew.gazelka@gmail.com>
  • Loading branch information
TestingPlant and andrewgazelka authored Dec 18, 2024
1 parent 3aac1a2 commit 7afedab
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 36 deletions.
30 changes: 14 additions & 16 deletions crates/hyperion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,26 +383,21 @@ impl HyperionCore {
/// A scratch buffer for intermediate operations. This will return an empty [`Vec`] when calling [`Scratch::obtain`].
#[derive(Debug)]
pub struct Scratch<A: Allocator = std::alloc::Global> {
inner: Vec<u8, A>,
inner: Box<[u8], A>,
}

impl Default for Scratch<std::alloc::Global> {
fn default() -> Self {
let inner = Vec::with_capacity(MAX_PACKET_SIZE);
Self { inner }
std::alloc::Global.into()
}
}

/// Nice for getting a buffer that can be used for intermediate work
///
/// # Safety
/// - every single time [`ScratchBuffer::obtain`] is called, the buffer will be cleared before returning
/// - the buffer has capacity of at least `MAX_PACKET_SIZE`
pub unsafe trait ScratchBuffer: sealed::Sealed + Debug {
pub trait ScratchBuffer: sealed::Sealed + Debug {
/// The type of the allocator the [`Vec`] uses.
type Allocator: Allocator;
/// Obtains a buffer that can be used for intermediate work.
fn obtain(&mut self) -> &mut Vec<u8, Self::Allocator>;
/// Obtains a buffer that can be used for intermediate work. The contents are unspecified.
fn obtain(&mut self) -> &mut [u8];
}

mod sealed {
Expand All @@ -411,20 +406,23 @@ mod sealed {

impl<A: Allocator + Debug> sealed::Sealed for Scratch<A> {}

unsafe impl<A: Allocator + Debug> ScratchBuffer for Scratch<A> {
impl<A: Allocator + Debug> ScratchBuffer for Scratch<A> {
type Allocator = A;

fn obtain(&mut self) -> &mut Vec<u8, Self::Allocator> {
self.inner.clear();
fn obtain(&mut self) -> &mut [u8] {
&mut self.inner
}
}

impl<A: Allocator> From<A> for Scratch<A> {
fn from(allocator: A) -> Self {
Self {
inner: Vec::with_capacity_in(MAX_PACKET_SIZE, allocator),
}
// A zeroed slice is allocated to avoid reading from uninitialized memory, which is UB.
// Allocating zeroed memory is usually very cheap, so there are minimal performance
// penalties from this.
let inner = Box::new_zeroed_slice_in(MAX_PACKET_SIZE, allocator);
// SAFETY: The box was initialized to zero, and u8 can be represented by zero
let inner = unsafe { inner.assume_init() };
Self { inner }
}
}

Expand Down
29 changes: 9 additions & 20 deletions crates/hyperion/src/net/encoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
use std::{
fmt::Debug,
io::{Cursor, Write},
mem::MaybeUninit,
};

use anyhow::ensure;
Expand Down Expand Up @@ -129,37 +128,27 @@ impl PacketEncoder {
if data_len > threshold {
let scratch = scratch.obtain();

debug_assert!(scratch.is_empty());

let data_slice = &mut slice
[usize::try_from(data_write_start)?..usize::try_from(end_data_position_exclusive)?];

{
// todo: I think this kinda safe maybe??? ... lol. well I know at least scratch is always large enough
let written = {
let scratch = scratch.spare_capacity_mut();
let scratch = unsafe { MaybeUninit::slice_assume_init_mut(scratch) };

let len = data_slice.len();
let span = tracing::trace_span!("zlib_compress", bytes = len);
let _enter = span.enter();
compressor.zlib_compress(data_slice, scratch)?
};
let written = {
let len = data_slice.len();
let span = tracing::trace_span!("zlib_compress", bytes = len);
let _enter = span.enter();
compressor.zlib_compress(data_slice, scratch).unwrap()
};

unsafe {
scratch.set_len(scratch.len() + written);
}
}
let compressed = &scratch[..written];

let data_len = VarInt(data_len as u32 as i32);

let packet_len = data_len.written_size() + scratch.len();
let packet_len = data_len.written_size() + compressed.len();
let packet_len = VarInt(packet_len as u32 as i32);

let mut write = Cursor::new(&mut slice[..]);
packet_len.encode(&mut write)?;
data_len.encode(&mut write)?;
write.write_all(scratch)?;
write.write_all(compressed)?;

let len = write.position();

Expand Down

0 comments on commit 7afedab

Please sign in to comment.