Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: initialize scratch buffer #752

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions crates/hyperion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,26 +383,21 @@ impl HyperionCore {
/// A scratch buffer for intermediate operations. This will return an empty [`Vec`] when calling [`Scratch::obtain`].
#[derive(Debug)]
pub struct Scratch<A: Allocator = std::alloc::Global> {
inner: Vec<u8, A>,
inner: Box<[u8], A>,
}

impl Default for Scratch<std::alloc::Global> {
fn default() -> Self {
let inner = Vec::with_capacity(MAX_PACKET_SIZE);
Self { inner }
std::alloc::Global.into()
}
}

/// Nice for getting a buffer that can be used for intermediate work
///
/// # Safety
/// - every single time [`ScratchBuffer::obtain`] is called, the buffer will be cleared before returning
/// - the buffer has capacity of at least `MAX_PACKET_SIZE`
pub unsafe trait ScratchBuffer: sealed::Sealed + Debug {
pub trait ScratchBuffer: sealed::Sealed + Debug {
/// The type of the allocator the [`Vec`] uses.
type Allocator: Allocator;
/// Obtains a buffer that can be used for intermediate work.
fn obtain(&mut self) -> &mut Vec<u8, Self::Allocator>;
/// Obtains a buffer that can be used for intermediate work. The contents are unspecified.
fn obtain(&mut self) -> &mut [u8];
}

mod sealed {
Expand All @@ -411,20 +406,23 @@ mod sealed {

impl<A: Allocator + Debug> sealed::Sealed for Scratch<A> {}

unsafe impl<A: Allocator + Debug> ScratchBuffer for Scratch<A> {
impl<A: Allocator + Debug> ScratchBuffer for Scratch<A> {
type Allocator = A;

fn obtain(&mut self) -> &mut Vec<u8, Self::Allocator> {
self.inner.clear();
fn obtain(&mut self) -> &mut [u8] {
&mut self.inner
}
}

impl<A: Allocator> From<A> for Scratch<A> {
fn from(allocator: A) -> Self {
Self {
inner: Vec::with_capacity_in(MAX_PACKET_SIZE, allocator),
}
// A zeroed slice is allocated to avoid reading from uninitialized memory, which is UB.
// Allocating zeroed memory is usually very cheap, so there are minimal performance
// penalties from this.
let inner = Box::new_zeroed_slice_in(MAX_PACKET_SIZE, allocator);
// SAFETY: The box was initialized to zero, and u8 can be represented by zero
let inner = unsafe { inner.assume_init() };
Self { inner }
}
}

Expand Down
29 changes: 9 additions & 20 deletions crates/hyperion/src/net/encoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
use std::{
fmt::Debug,
io::{Cursor, Write},
mem::MaybeUninit,
};

use anyhow::ensure;
Expand Down Expand Up @@ -129,37 +128,27 @@ impl PacketEncoder {
if data_len > threshold {
let scratch = scratch.obtain();

debug_assert!(scratch.is_empty());

let data_slice = &mut slice
[usize::try_from(data_write_start)?..usize::try_from(end_data_position_exclusive)?];

{
// todo: I think this kinda safe maybe??? ... lol. well I know at least scratch is always large enough
let written = {
let scratch = scratch.spare_capacity_mut();
let scratch = unsafe { MaybeUninit::slice_assume_init_mut(scratch) };

let len = data_slice.len();
let span = tracing::trace_span!("zlib_compress", bytes = len);
let _enter = span.enter();
compressor.zlib_compress(data_slice, scratch)?
};
let written = {
let len = data_slice.len();
let span = tracing::trace_span!("zlib_compress", bytes = len);
let _enter = span.enter();
compressor.zlib_compress(data_slice, scratch).unwrap()
};

unsafe {
scratch.set_len(scratch.len() + written);
}
}
let compressed = &scratch[..written];

let data_len = VarInt(data_len as u32 as i32);

let packet_len = data_len.written_size() + scratch.len();
let packet_len = data_len.written_size() + compressed.len();
let packet_len = VarInt(packet_len as u32 as i32);

let mut write = Cursor::new(&mut slice[..]);
packet_len.encode(&mut write)?;
data_len.encode(&mut write)?;
write.write_all(scratch)?;
write.write_all(compressed)?;

let len = write.position();

Expand Down
Loading