diff --git a/block-buffer/src/lib.rs b/block-buffer/src/lib.rs index ad9a9009..c199062a 100644 --- a/block-buffer/src/lib.rs +++ b/block-buffer/src/lib.rs @@ -12,7 +12,7 @@ use array::{ typenum::{Add1, B1}, Array, ArraySize, }; -use core::{fmt, ops::Add, slice}; +use core::{fmt, mem::MaybeUninit, ops::Add, ptr, slice}; use crypto_common::{BlockSizeUser, BlockSizes}; #[cfg(feature = "zeroize")] @@ -59,9 +59,8 @@ impl fmt::Display for Error { } /// Buffer for block processing of data. -#[derive(Debug)] pub struct BlockBuffer { - buffer: Block, + buffer: MaybeUninit>, pos: K::Pos, } @@ -72,20 +71,29 @@ impl BlockSizeUser for BlockBuffer { impl Default for BlockBuffer { #[inline] fn default() -> Self { - Self { - buffer: Default::default(), - pos: Default::default(), - } + let mut buffer = MaybeUninit::uninit(); + let mut pos = Default::default(); + K::set_pos(&mut buffer, &mut pos, 0); + Self { buffer, pos } } } impl Clone for BlockBuffer { #[inline] fn clone(&self) -> Self { - Self { - buffer: self.buffer.clone(), - pos: self.pos.clone(), - } + // SAFETY: `BlockBuffer` does not implement `Drop` (i.e. it could be a `Copy` type), + // so we can safely clone it using `ptr::read`. + unsafe { ptr::read(self) } + } +} + +impl fmt::Debug for BlockBuffer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.debug_struct(K::NAME) + .field("pos", &self.get_pos()) + .field("block_size", &BS::USIZE) + .field("data", &self.get_data()) + .finish() } } @@ -104,13 +112,14 @@ impl BlockBuffer { /// Returns an error if slice length is not valid for used buffer kind. #[inline(always)] pub fn try_new(buf: &[u8]) -> Result { - let pos = buf.len(); - if !K::invariant(pos, BS::USIZE) { + if !K::invariant(buf.len(), BS::USIZE) { return Err(Error); } let mut res = Self::default(); - res.buffer[..pos].copy_from_slice(buf); - K::set_pos(&mut res.buffer, &mut res.pos, pos); + // SAFETY: we have checked that buffer length satisfies the buffer kind invariant + unsafe { + res.set_data_unchecked(buf); + } Ok(res) } @@ -131,16 +140,27 @@ impl BlockBuffer { // panic branches. Using `unreachable_unchecked` in `get_pos` // we convince compiler that `BlockSize - pos` never underflows. if K::invariant(n, rem) { - // double slicing allows to remove panic branches - self.buffer[pos..][..n].copy_from_slice(input); - self.set_pos_unchecked(pos + n); + // SAFETY: we have checked that length of `input` is smaller than + // number of remaining bytes in `buffer`, so we can safely write data + // into them and update cursor position. + unsafe { + let buf_ptr = self.buffer.as_mut_ptr().cast::().add(pos); + ptr::copy_nonoverlapping(input.as_ptr(), buf_ptr, input.len()); + self.set_pos_unchecked(pos + input.len()); + } return; } if pos != 0 { let (left, right) = input.split_at(rem); input = right; - self.buffer[pos..].copy_from_slice(left); - compress(slice::from_ref(&self.buffer)); + // SAFETY: length of `left` is equal to number of remaining bytes in `buffer`, + // so we can copy data into it and process `buffer` as fully initialized block. + let block = unsafe { + let buf_ptr = self.buffer.as_mut_ptr().cast::().add(pos); + ptr::copy_nonoverlapping(left.as_ptr(), buf_ptr, left.len()); + self.buffer.assume_init_ref() + }; + compress(slice::from_ref(block)); } let (blocks, leftover) = K::split_blocks(input); @@ -148,24 +168,29 @@ impl BlockBuffer { compress(blocks); } - let n = leftover.len(); - self.buffer[..n].copy_from_slice(leftover); - self.set_pos_unchecked(n); + // SAFETY: `leftover` is always smaller than block size, + // so it satisfies the method's safety requirements for all buffer kinds + unsafe { + self.set_data_unchecked(leftover); + } } /// Reset buffer by setting cursor position to zero. #[inline(always)] pub fn reset(&mut self) { - self.set_pos_unchecked(0); + // SAFETY: 0 is always valid position + unsafe { + self.set_pos_unchecked(0); + } } /// Pad remaining data with zeros and return resulting block. #[inline(always)] pub fn pad_with_zeros(&mut self) -> Block { - let pos = self.get_pos(); - let mut res = self.buffer.clone(); - res[pos..].iter_mut().for_each(|b| *b = 0); - self.set_pos_unchecked(0); + let mut res = Block::::default(); + let data = self.get_data(); + res[..data.len()].copy_from_slice(data); + self.reset(); res } @@ -186,7 +211,9 @@ impl BlockBuffer { /// Return slice of data stored inside the buffer. #[inline(always)] pub fn get_data(&self) -> &[u8] { - &self.buffer[..self.get_pos()] + // SAFETY: the `buffer` field is properly initialized up to `self.get_pos()`. + // `get_pos` never returns position bigger than buffer size. + unsafe { slice::from_raw_parts(self.buffer.as_ptr().cast(), self.get_pos()) } } /// Set buffer content and cursor position. @@ -196,8 +223,12 @@ impl BlockBuffer { #[inline] pub fn set(&mut self, buf: Block, pos: usize) { assert!(K::invariant(pos, BS::USIZE)); - self.buffer = buf; - self.set_pos_unchecked(pos); + self.buffer = MaybeUninit::new(buf); + // SAFETY: we have asserted that `pos` satisfies the invariant and + // the `buffer` field is fully initialized + unsafe { + self.set_pos_unchecked(pos); + } } /// Return size of the internal buffer in bytes. @@ -212,11 +243,32 @@ impl BlockBuffer { self.size() - self.get_pos() } + /// Set buffer position. + /// + /// # Safety + /// Bytes in the range of `0..pos` in the `buffer` field must be properly initialized. + /// + /// `pos` must satisfy invariant of buffer kind, i.e. for eager hashes it must be + /// strictly smaller than block size and for lazy hashes it must be smaller or equal + /// to block size. #[inline(always)] - fn set_pos_unchecked(&mut self, pos: usize) { + unsafe fn set_pos_unchecked(&mut self, pos: usize) { debug_assert!(K::invariant(pos, BS::USIZE)); K::set_pos(&mut self.buffer, &mut self.pos, pos) } + + /// Set buffer data. + /// + /// # Safety + /// Length of `buf` must satisfy invariant of buffer kind, i.e. for eager hashes it must be + /// strictly smaller than block size and for lazy hashes it must be smaller or equal + /// to block size. + #[inline(always)] + unsafe fn set_data_unchecked(&mut self, buf: &[u8]) { + self.set_pos_unchecked(buf.len()); + let dst_ptr: *mut u8 = self.buffer.as_mut_ptr().cast(); + ptr::copy_nonoverlapping(buf.as_ptr(), dst_ptr, buf.len()); + } } impl BlockBuffer { @@ -232,22 +284,20 @@ impl BlockBuffer { panic!("suffix is too long"); } let pos = self.get_pos(); - self.buffer[pos] = delim; - for b in &mut self.buffer[pos + 1..] { - *b = 0; - } + let mut buf = self.pad_with_zeros(); + buf[pos] = delim; let n = self.size() - suffix.len(); if self.size() - pos - 1 < suffix.len() { - compress(&self.buffer); - let mut block = Block::::default(); - block[n..].copy_from_slice(suffix); - compress(&block); + compress(&buf); + buf.fill(0); + buf[n..].copy_from_slice(suffix); + compress(&buf); } else { - self.buffer[n..].copy_from_slice(suffix); - compress(&self.buffer); + buf[n..].copy_from_slice(suffix); + compress(&buf); } - self.set_pos_unchecked(0) + self.reset(); } /// Pad message with 0x80, zeros and 64-bit message length using @@ -274,12 +324,10 @@ impl BlockBuffer { /// Serialize buffer into a byte array. #[inline] pub fn serialize(&self) -> Block { - let mut res = self.buffer.clone(); - let pos = self.get_pos(); - // zeroize "garbage" data - for b in res[pos..BS::USIZE - 1].iter_mut() { - *b = 0; - } + let mut res = Block::::default(); + let data = self.get_data(); + res[..data.len()].copy_from_slice(data); + res[BS::USIZE - 1] = data.len() as u8; res } @@ -294,7 +342,7 @@ impl BlockBuffer { return Err(Error); } Ok(Self { - buffer: buffer.clone(), + buffer: MaybeUninit::new(buffer.clone()), pos: Default::default(), }) } @@ -329,8 +377,11 @@ impl BlockBuffer { if buffer[1..][pos as usize..].iter().any(|&b| b != 0) { return Err(Error); } - let buffer = Array::clone_from_slice(&buffer[1..]); - Ok(Self { buffer, pos }) + let buf = Array::clone_from_slice(&buffer[1..]); + Ok(Self { + buffer: MaybeUninit::new(buf), + pos, + }) } } diff --git a/block-buffer/src/sealed.rs b/block-buffer/src/sealed.rs index dcdce823..d4e147a9 100644 --- a/block-buffer/src/sealed.rs +++ b/block-buffer/src/sealed.rs @@ -1,5 +1,7 @@ use crate::array::{Array, ArraySize}; -use core::slice; +use core::{mem::MaybeUninit, ptr, slice}; + +type Block = MaybeUninit>; /// Sealed trait for buffer kinds. pub trait Sealed { @@ -8,9 +10,11 @@ pub trait Sealed { #[cfg(feature = "zeroize")] type Pos: Default + Clone + zeroize::Zeroize; - fn get_pos(buf: &[u8], pos: &Self::Pos) -> usize; + const NAME: &'static str; + + fn get_pos(buf: &Block, pos: &Self::Pos) -> usize; - fn set_pos(buf_val: &mut [u8], pos: &mut Self::Pos, val: usize); + fn set_pos(buf: &mut Block, pos: &mut Self::Pos, val: usize); /// Invariant guaranteed by a buffer kind, i.e. with correct /// buffer code this function always returns true. @@ -22,14 +26,26 @@ pub trait Sealed { impl Sealed for super::Eager { type Pos = (); + const NAME: &'static str = "BlockBuffer"; - fn get_pos(buf: &[u8], _pos: &Self::Pos) -> usize { - buf[buf.len() - 1] as usize + fn get_pos(buf: &Block, _pos: &Self::Pos) -> usize { + // SAFETY: last byte in `buf` for eager hashes is always properly initialized + let pos = unsafe { + let buf_ptr = buf.as_ptr().cast::(); + let last_byte_ptr = buf_ptr.add(N::USIZE - 1); + ptr::read(last_byte_ptr) + }; + pos as usize } - fn set_pos(buf: &mut [u8], _pos: &mut Self::Pos, val: usize) { + fn set_pos(buf: &mut Block, _pos: &mut Self::Pos, val: usize) { debug_assert!(val <= u8::MAX as usize); - buf[buf.len() - 1] = val as u8; + // SAFETY: we write to the last byte of `buf` which is always safe + unsafe { + let buf_ptr = buf.as_mut_ptr().cast::(); + let last_byte_ptr = buf_ptr.add(N::USIZE - 1); + ptr::write(last_byte_ptr, val as u8); + } } #[inline(always)] @@ -42,8 +58,7 @@ impl Sealed for super::Eager { let nb = data.len() / N::USIZE; let blocks_len = nb * N::USIZE; let tail_len = data.len() - blocks_len; - // SAFETY: we guarantee that created slices do not point - // outside of `data` + // SAFETY: we guarantee that created slices do not point outside of `data` unsafe { let blocks_ptr = data.as_ptr() as *const Array; let tail_ptr = data.as_ptr().add(blocks_len); @@ -57,12 +72,13 @@ impl Sealed for super::Eager { impl Sealed for super::Lazy { type Pos = u8; + const NAME: &'static str = "BlockBuffer"; - fn get_pos(_buf_val: &[u8], pos: &Self::Pos) -> usize { + fn get_pos(_buf_val: &Block, pos: &Self::Pos) -> usize { *pos as usize } - fn set_pos(_buf_val: &mut [u8], pos: &mut Self::Pos, val: usize) { + fn set_pos(_: &mut Block, pos: &mut Self::Pos, val: usize) { debug_assert!(val <= u8::MAX as usize); *pos = val as u8; } @@ -84,8 +100,7 @@ impl Sealed for super::Lazy { (nb, data.len() - nb * N::USIZE) }; let blocks_len = nb * N::USIZE; - // SAFETY: we guarantee that created slices do not point - // outside of `data` + // SAFETY: we guarantee that created slices do not point outside of `data` unsafe { let blocks_ptr = data.as_ptr() as *const Array; let tail_ptr = data.as_ptr().add(blocks_len); diff --git a/block-buffer/tests/mod.rs b/block-buffer/tests/mod.rs index 6ee1712b..55563d0c 100644 --- a/block-buffer/tests/mod.rs +++ b/block-buffer/tests/mod.rs @@ -109,7 +109,6 @@ fn test_read() { } #[test] -#[rustfmt::skip] fn test_eager_paddings() { let mut buf_be = EagerBuffer::::new(&[0x42]); let mut buf_le = buf_be.clone(); @@ -119,8 +118,8 @@ fn test_eager_paddings() { buf_be.len64_padding_be(len, |block| out_be.extend(block)); buf_le.len64_padding_le(len, |block| out_le.extend(block)); - assert_eq!(out_be, hex!("42800000000000000001020304050607")); - assert_eq!(out_le, hex!("42800000000000000706050403020100")); + assert_eq!(out_be, hex!("4280000000000000 0001020304050607")); + assert_eq!(out_le, hex!("4280000000000000 0706050403020100")); let mut buf_be = EagerBuffer::::new(&[0x42]); let mut buf_le = buf_be.clone(); @@ -138,14 +137,20 @@ fn test_eager_paddings() { buf.len128_padding_be(len, |block| out.extend(block)); assert_eq!( out, - hex!("42800000000000000000000000000000000102030405060708090a0b0c0d0e0f"), + hex!( + "42800000000000000000000000000000" + "000102030405060708090a0b0c0d0e0f" + ), ); let mut buf = EagerBuffer::::new(&[0x42]); let mut out = Vec::::new(); let len = 0x0001_0203_0405_0607_0809_0a0b_0c0d_0e0f; buf.len128_padding_be(len, |block| out.extend(block)); - assert_eq!(out, hex!("4280000000000000000102030405060708090a0b0c0d0e0f")); + assert_eq!( + out, + hex!("4280000000000000 0001020304050607 08090a0b0c0d0e0f") + ); let mut buf = EagerBuffer::::new(&[0x42]); let mut out = Vec::::new();