Skip to content

Commit

Permalink
block-buffer: improve safety comments, use MaybeUninit for internal…
Browse files Browse the repository at this point in the history
… buffer (#1081)

Marks private `unchecked` methods as `unsafe` and documents their safety
requirements. Adds `SAFETY` comment for all `unsafe` uses. Use of
`copy_nonoverlapping` ensures that compiler will not generate
unreachable panic branches. Use of `MaybeUninit` removes unnecessary
initialization and helps to test that we do not read bytes which were
not written by us.

The code successfully passes MIRI tests, but I plan to test this
implementation more thoroughly later.
  • Loading branch information
newpavlov authored Jun 24, 2024
1 parent ba7ede1 commit cb18faa
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 70 deletions.
155 changes: 103 additions & 52 deletions block-buffer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use array::{
typenum::{Add1, B1},
Array, ArraySize,
};
use core::{fmt, ops::Add, slice};
use core::{fmt, mem::MaybeUninit, ops::Add, ptr, slice};
use crypto_common::{BlockSizeUser, BlockSizes};

#[cfg(feature = "zeroize")]
Expand Down Expand Up @@ -59,9 +59,8 @@ impl fmt::Display for Error {
}

/// Buffer for block processing of data.
#[derive(Debug)]
pub struct BlockBuffer<BS: BlockSizes, K: BufferKind> {
buffer: Block<Self>,
buffer: MaybeUninit<Block<Self>>,
pos: K::Pos,
}

Expand All @@ -72,20 +71,29 @@ impl<BS: BlockSizes, K: BufferKind> BlockSizeUser for BlockBuffer<BS, K> {
impl<BS: BlockSizes, K: BufferKind> Default for BlockBuffer<BS, K> {
#[inline]
fn default() -> Self {
Self {
buffer: Default::default(),
pos: Default::default(),
}
let mut buffer = MaybeUninit::uninit();
let mut pos = Default::default();
K::set_pos(&mut buffer, &mut pos, 0);
Self { buffer, pos }
}
}

impl<BS: BlockSizes, K: BufferKind> Clone for BlockBuffer<BS, K> {
#[inline]
fn clone(&self) -> Self {
Self {
buffer: self.buffer.clone(),
pos: self.pos.clone(),
}
// SAFETY: `BlockBuffer` does not implement `Drop` (i.e. it could be a `Copy` type),
// so we can safely clone it using `ptr::read`.
unsafe { ptr::read(self) }
}
}

impl<BS: BlockSizes, K: BufferKind> fmt::Debug for BlockBuffer<BS, K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct(K::NAME)
.field("pos", &self.get_pos())
.field("block_size", &BS::USIZE)
.field("data", &self.get_data())
.finish()
}
}

Expand All @@ -104,13 +112,14 @@ impl<BS: BlockSizes, K: BufferKind> BlockBuffer<BS, K> {
/// Returns an error if slice length is not valid for used buffer kind.
#[inline(always)]
pub fn try_new(buf: &[u8]) -> Result<Self, Error> {
let pos = buf.len();
if !K::invariant(pos, BS::USIZE) {
if !K::invariant(buf.len(), BS::USIZE) {
return Err(Error);
}
let mut res = Self::default();
res.buffer[..pos].copy_from_slice(buf);
K::set_pos(&mut res.buffer, &mut res.pos, pos);
// SAFETY: we have checked that buffer length satisfies the buffer kind invariant
unsafe {
res.set_data_unchecked(buf);
}
Ok(res)
}

Expand All @@ -131,41 +140,57 @@ impl<BS: BlockSizes, K: BufferKind> BlockBuffer<BS, K> {
// panic branches. Using `unreachable_unchecked` in `get_pos`
// we convince compiler that `BlockSize - pos` never underflows.
if K::invariant(n, rem) {
// double slicing allows to remove panic branches
self.buffer[pos..][..n].copy_from_slice(input);
self.set_pos_unchecked(pos + n);
// SAFETY: we have checked that length of `input` is smaller than
// number of remaining bytes in `buffer`, so we can safely write data
// into them and update cursor position.
unsafe {
let buf_ptr = self.buffer.as_mut_ptr().cast::<u8>().add(pos);
ptr::copy_nonoverlapping(input.as_ptr(), buf_ptr, input.len());
self.set_pos_unchecked(pos + input.len());
}
return;
}
if pos != 0 {
let (left, right) = input.split_at(rem);
input = right;
self.buffer[pos..].copy_from_slice(left);
compress(slice::from_ref(&self.buffer));
// SAFETY: length of `left` is equal to number of remaining bytes in `buffer`,
// so we can copy data into it and process `buffer` as fully initialized block.
let block = unsafe {
let buf_ptr = self.buffer.as_mut_ptr().cast::<u8>().add(pos);
ptr::copy_nonoverlapping(left.as_ptr(), buf_ptr, left.len());
self.buffer.assume_init_ref()
};
compress(slice::from_ref(block));
}

let (blocks, leftover) = K::split_blocks(input);
if !blocks.is_empty() {
compress(blocks);
}

let n = leftover.len();
self.buffer[..n].copy_from_slice(leftover);
self.set_pos_unchecked(n);
// SAFETY: `leftover` is always smaller than block size,
// so it satisfies the method's safety requirements for all buffer kinds
unsafe {
self.set_data_unchecked(leftover);
}
}

/// Reset buffer by setting cursor position to zero.
#[inline(always)]
pub fn reset(&mut self) {
self.set_pos_unchecked(0);
// SAFETY: 0 is always valid position
unsafe {
self.set_pos_unchecked(0);
}
}

/// Pad remaining data with zeros and return resulting block.
#[inline(always)]
pub fn pad_with_zeros(&mut self) -> Block<Self> {
let pos = self.get_pos();
let mut res = self.buffer.clone();
res[pos..].iter_mut().for_each(|b| *b = 0);
self.set_pos_unchecked(0);
let mut res = Block::<Self>::default();
let data = self.get_data();
res[..data.len()].copy_from_slice(data);
self.reset();
res
}

Expand All @@ -186,7 +211,9 @@ impl<BS: BlockSizes, K: BufferKind> BlockBuffer<BS, K> {
/// Return slice of data stored inside the buffer.
#[inline(always)]
pub fn get_data(&self) -> &[u8] {
&self.buffer[..self.get_pos()]
// SAFETY: the `buffer` field is properly initialized up to `self.get_pos()`.
// `get_pos` never returns position bigger than buffer size.
unsafe { slice::from_raw_parts(self.buffer.as_ptr().cast(), self.get_pos()) }
}

/// Set buffer content and cursor position.
Expand All @@ -196,8 +223,12 @@ impl<BS: BlockSizes, K: BufferKind> BlockBuffer<BS, K> {
#[inline]
pub fn set(&mut self, buf: Block<Self>, pos: usize) {
assert!(K::invariant(pos, BS::USIZE));
self.buffer = buf;
self.set_pos_unchecked(pos);
self.buffer = MaybeUninit::new(buf);
// SAFETY: we have asserted that `pos` satisfies the invariant and
// the `buffer` field is fully initialized
unsafe {
self.set_pos_unchecked(pos);
}
}

/// Return size of the internal buffer in bytes.
Expand All @@ -212,11 +243,32 @@ impl<BS: BlockSizes, K: BufferKind> BlockBuffer<BS, K> {
self.size() - self.get_pos()
}

/// Set buffer position.
///
/// # Safety
/// Bytes in the range of `0..pos` in the `buffer` field must be properly initialized.
///
/// `pos` must satisfy invariant of buffer kind, i.e. for eager hashes it must be
/// strictly smaller than block size and for lazy hashes it must be smaller or equal
/// to block size.
#[inline(always)]
fn set_pos_unchecked(&mut self, pos: usize) {
unsafe fn set_pos_unchecked(&mut self, pos: usize) {
debug_assert!(K::invariant(pos, BS::USIZE));
K::set_pos(&mut self.buffer, &mut self.pos, pos)
}

/// Set buffer data.
///
/// # Safety
/// Length of `buf` must satisfy invariant of buffer kind, i.e. for eager hashes it must be
/// strictly smaller than block size and for lazy hashes it must be smaller or equal
/// to block size.
#[inline(always)]
unsafe fn set_data_unchecked(&mut self, buf: &[u8]) {
self.set_pos_unchecked(buf.len());
let dst_ptr: *mut u8 = self.buffer.as_mut_ptr().cast();
ptr::copy_nonoverlapping(buf.as_ptr(), dst_ptr, buf.len());
}
}

impl<BS: BlockSizes> BlockBuffer<BS, Eager> {
Expand All @@ -232,22 +284,20 @@ impl<BS: BlockSizes> BlockBuffer<BS, Eager> {
panic!("suffix is too long");
}
let pos = self.get_pos();
self.buffer[pos] = delim;
for b in &mut self.buffer[pos + 1..] {
*b = 0;
}
let mut buf = self.pad_with_zeros();
buf[pos] = delim;

let n = self.size() - suffix.len();
if self.size() - pos - 1 < suffix.len() {
compress(&self.buffer);
let mut block = Block::<Self>::default();
block[n..].copy_from_slice(suffix);
compress(&block);
compress(&buf);
buf.fill(0);
buf[n..].copy_from_slice(suffix);
compress(&buf);
} else {
self.buffer[n..].copy_from_slice(suffix);
compress(&self.buffer);
buf[n..].copy_from_slice(suffix);
compress(&buf);
}
self.set_pos_unchecked(0)
self.reset();
}

/// Pad message with 0x80, zeros and 64-bit message length using
Expand All @@ -274,12 +324,10 @@ impl<BS: BlockSizes> BlockBuffer<BS, Eager> {
/// Serialize buffer into a byte array.
#[inline]
pub fn serialize(&self) -> Block<Self> {
let mut res = self.buffer.clone();
let pos = self.get_pos();
// zeroize "garbage" data
for b in res[pos..BS::USIZE - 1].iter_mut() {
*b = 0;
}
let mut res = Block::<Self>::default();
let data = self.get_data();
res[..data.len()].copy_from_slice(data);
res[BS::USIZE - 1] = data.len() as u8;
res
}

Expand All @@ -294,7 +342,7 @@ impl<BS: BlockSizes> BlockBuffer<BS, Eager> {
return Err(Error);
}
Ok(Self {
buffer: buffer.clone(),
buffer: MaybeUninit::new(buffer.clone()),
pos: Default::default(),
})
}
Expand Down Expand Up @@ -329,8 +377,11 @@ impl<BS: BlockSizes> BlockBuffer<BS, Lazy> {
if buffer[1..][pos as usize..].iter().any(|&b| b != 0) {
return Err(Error);
}
let buffer = Array::clone_from_slice(&buffer[1..]);
Ok(Self { buffer, pos })
let buf = Array::clone_from_slice(&buffer[1..]);
Ok(Self {
buffer: MaybeUninit::new(buf),
pos,
})
}
}

Expand Down
41 changes: 28 additions & 13 deletions block-buffer/src/sealed.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::array::{Array, ArraySize};
use core::slice;
use core::{mem::MaybeUninit, ptr, slice};

type Block<N> = MaybeUninit<Array<u8, N>>;

/// Sealed trait for buffer kinds.
pub trait Sealed {
Expand All @@ -8,9 +10,11 @@ pub trait Sealed {
#[cfg(feature = "zeroize")]
type Pos: Default + Clone + zeroize::Zeroize;

fn get_pos(buf: &[u8], pos: &Self::Pos) -> usize;
const NAME: &'static str;

fn get_pos<N: ArraySize>(buf: &Block<N>, pos: &Self::Pos) -> usize;

fn set_pos(buf_val: &mut [u8], pos: &mut Self::Pos, val: usize);
fn set_pos<N: ArraySize>(buf: &mut Block<N>, pos: &mut Self::Pos, val: usize);

/// Invariant guaranteed by a buffer kind, i.e. with correct
/// buffer code this function always returns true.
Expand All @@ -22,14 +26,26 @@ pub trait Sealed {

impl Sealed for super::Eager {
type Pos = ();
const NAME: &'static str = "BlockBuffer<Eager>";

fn get_pos(buf: &[u8], _pos: &Self::Pos) -> usize {
buf[buf.len() - 1] as usize
fn get_pos<N: ArraySize>(buf: &Block<N>, _pos: &Self::Pos) -> usize {
// SAFETY: last byte in `buf` for eager hashes is always properly initialized
let pos = unsafe {
let buf_ptr = buf.as_ptr().cast::<u8>();
let last_byte_ptr = buf_ptr.add(N::USIZE - 1);
ptr::read(last_byte_ptr)
};
pos as usize
}

fn set_pos(buf: &mut [u8], _pos: &mut Self::Pos, val: usize) {
fn set_pos<N: ArraySize>(buf: &mut Block<N>, _pos: &mut Self::Pos, val: usize) {
debug_assert!(val <= u8::MAX as usize);
buf[buf.len() - 1] = val as u8;
// SAFETY: we write to the last byte of `buf` which is always safe
unsafe {
let buf_ptr = buf.as_mut_ptr().cast::<u8>();
let last_byte_ptr = buf_ptr.add(N::USIZE - 1);
ptr::write(last_byte_ptr, val as u8);
}
}

#[inline(always)]
Expand All @@ -42,8 +58,7 @@ impl Sealed for super::Eager {
let nb = data.len() / N::USIZE;
let blocks_len = nb * N::USIZE;
let tail_len = data.len() - blocks_len;
// SAFETY: we guarantee that created slices do not point
// outside of `data`
// SAFETY: we guarantee that created slices do not point outside of `data`
unsafe {
let blocks_ptr = data.as_ptr() as *const Array<u8, N>;
let tail_ptr = data.as_ptr().add(blocks_len);
Expand All @@ -57,12 +72,13 @@ impl Sealed for super::Eager {

impl Sealed for super::Lazy {
type Pos = u8;
const NAME: &'static str = "BlockBuffer<Lazy>";

fn get_pos(_buf_val: &[u8], pos: &Self::Pos) -> usize {
fn get_pos<N: ArraySize>(_buf_val: &Block<N>, pos: &Self::Pos) -> usize {
*pos as usize
}

fn set_pos(_buf_val: &mut [u8], pos: &mut Self::Pos, val: usize) {
fn set_pos<N: ArraySize>(_: &mut Block<N>, pos: &mut Self::Pos, val: usize) {
debug_assert!(val <= u8::MAX as usize);
*pos = val as u8;
}
Expand All @@ -84,8 +100,7 @@ impl Sealed for super::Lazy {
(nb, data.len() - nb * N::USIZE)
};
let blocks_len = nb * N::USIZE;
// SAFETY: we guarantee that created slices do not point
// outside of `data`
// SAFETY: we guarantee that created slices do not point outside of `data`
unsafe {
let blocks_ptr = data.as_ptr() as *const Array<u8, N>;
let tail_ptr = data.as_ptr().add(blocks_len);
Expand Down
Loading

0 comments on commit cb18faa

Please sign in to comment.