From 52db15eb77e4a5d1118a50072a8850fa1febbb2e Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Tue, 9 Jul 2024 00:36:23 +0200 Subject: [PATCH 1/2] Unsafe improvements: core `parquet` crate. --- parquet/src/bloom_filter/mod.rs | 3 ++- parquet/src/data_type.rs | 32 ++++++++++++++++++++++---------- parquet/src/util/bit_util.rs | 16 ++++++++++++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/parquet/src/bloom_filter/mod.rs b/parquet/src/bloom_filter/mod.rs index d2acdcd0b2b8..a8d68d4b6442 100644 --- a/parquet/src/bloom_filter/mod.rs +++ b/parquet/src/bloom_filter/mod.rs @@ -134,7 +134,8 @@ impl Block { #[inline] fn to_ne_bytes(self) -> [u8; 32] { - unsafe { std::mem::transmute(self) } + // SAFETY: [u32; 8] and [u8; 32] have the same size and neither has invalid bit patterns. + unsafe { std::mem::transmute(self.0) } } #[inline] diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs index 5bcd2062ca59..b85a75cfd410 100644 --- a/parquet/src/data_type.rs +++ b/parquet/src/data_type.rs @@ -468,6 +468,8 @@ macro_rules! gen_as_bytes { impl AsBytes for $source_ty { #[allow(clippy::size_of_in_element_count)] fn as_bytes(&self) -> &[u8] { + // SAFETY: macro is only used with primitive types that have no padding, so the + // resulting slice always refers to initialized memory. unsafe { std::slice::from_raw_parts( self as *const $source_ty as *const u8, @@ -481,6 +483,8 @@ macro_rules! gen_as_bytes { #[inline] #[allow(clippy::size_of_in_element_count)] fn slice_as_bytes(self_: &[Self]) -> &[u8] { + // SAFETY: macro is only used with primitive types that have no padding, so the + // resulting slice always refers to initialized memory. unsafe { std::slice::from_raw_parts( self_.as_ptr() as *const u8, @@ -492,10 +496,15 @@ macro_rules! gen_as_bytes { #[inline] #[allow(clippy::size_of_in_element_count)] unsafe fn slice_as_bytes_mut(self_: &mut [Self]) -> &mut [u8] { - std::slice::from_raw_parts_mut( - self_.as_mut_ptr() as *mut u8, - std::mem::size_of_val(self_), - ) + // SAFETY: macro is only used with primitive types that have no padding, so the + // resulting slice always refers to initialized memory. Moreover, self has no + // invalid bit patterns, so all writes to the resulting slice will be valid. + unsafe { + std::slice::from_raw_parts_mut( + self_.as_mut_ptr() as *mut u8, + std::mem::size_of_val(self_), + ) + } } } }; @@ -534,12 +543,15 @@ unimplemented_slice_as_bytes!(FixedLenByteArray); impl AsBytes for bool { fn as_bytes(&self) -> &[u8] { + // SAFETY: a bool is guaranteed to be either 0x00 or 0x01 in memory, so the memory is + // valid. unsafe { std::slice::from_raw_parts(self as *const bool as *const u8, 1) } } } impl AsBytes for Int96 { fn as_bytes(&self) -> &[u8] { + // SAFETY: Int96::data is a &[u32; 3]. unsafe { std::slice::from_raw_parts(self.data() as *const [u32] as *const u8, 12) } } } @@ -718,6 +730,7 @@ pub(crate) mod private { #[inline] fn encode(values: &[Self], writer: &mut W, _: &mut BitWriter) -> Result<()> { + // SAFETY: Self is one of i32, i64, f32, f64, which have no padding. let raw = unsafe { std::slice::from_raw_parts( values.as_ptr() as *const u8, @@ -747,9 +760,10 @@ pub(crate) mod private { return Err(eof_err!("Not enough bytes to decode")); } - // SAFETY: Raw types should be as per the standard rust bit-vectors - unsafe { - let raw_buffer = &mut Self::slice_as_bytes_mut(buffer)[..bytes_to_decode]; + { + // SAFETY: Self has no invalid bit patterns, so writing to the slice + // obtained with slice_as_bytes_mut is always safe. + let raw_buffer = &mut unsafe { Self::slice_as_bytes_mut(buffer) }[..bytes_to_decode]; raw_buffer.copy_from_slice(data.slice( decoder.start..decoder.start + bytes_to_decode ).as_ref()); @@ -810,9 +824,7 @@ pub(crate) mod private { _: &mut BitWriter, ) -> Result<()> { for value in values { - let raw = unsafe { - std::slice::from_raw_parts(value.data() as *const [u32] as *const u8, 12) - }; + let raw = SliceAsBytes::slice_as_bytes(value.data()); writer.write_all(raw)?; } Ok(()) diff --git a/parquet/src/util/bit_util.rs b/parquet/src/util/bit_util.rs index eaaf3ee10279..29df74e30669 100644 --- a/parquet/src/util/bit_util.rs +++ b/parquet/src/util/bit_util.rs @@ -435,6 +435,10 @@ impl BitReader { /// This function panics if /// - `num_bits` is larger than the bit-capacity of `T` /// + // FIXME: soundness issue - this method can be used to write arbitrary bytes to any + // T. A possible fix would be to make `FromBytes` an unsafe trait (or to use a + // separate marker trait) which requires all bit patterns of T to be valid (note that this is + // not the case for `T` = `bool`). pub fn get_batch(&mut self, batch: &mut [T], num_bits: usize) -> usize { assert!(num_bits <= size_of::() * 8); @@ -461,6 +465,9 @@ impl BitReader { match size_of::() { 1 => { let ptr = batch.as_mut_ptr() as *mut u8; + // SAFETY: batch is properly aligned and sized. Caller guarantees that T + // can be safely seen as a slice of bytes through FromBytes bound + // (FIXME: not actually true right now) let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 8 { let out_slice = (&mut out[i..i + 8]).try_into().unwrap(); @@ -471,6 +478,9 @@ impl BitReader { } 2 => { let ptr = batch.as_mut_ptr() as *mut u16; + // SAFETY: batch is properly aligned and sized. Caller guarantees that T + // can be safely seen as a slice of bytes through FromBytes bound + // (FIXME: not actually true right now) let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 16 { let out_slice = (&mut out[i..i + 16]).try_into().unwrap(); @@ -481,6 +491,9 @@ impl BitReader { } 4 => { let ptr = batch.as_mut_ptr() as *mut u32; + // SAFETY: batch is properly aligned and sized. Caller guarantees that T + // can be safely seen as a slice of bytes through FromBytes bound + // (FIXME: not actually true right now) let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 32 { let out_slice = (&mut out[i..i + 32]).try_into().unwrap(); @@ -491,6 +504,9 @@ impl BitReader { } 8 => { let ptr = batch.as_mut_ptr() as *mut u64; + // SAFETY: batch is properly aligned and sized. Caller guarantees that T + // can be safely seen as a slice of bytes through FromBytes bound + // (FIXME: not actually true right now) let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 64 { let out_slice = (&mut out[i..i + 64]).try_into().unwrap(); From 290a0d041860ec5c9ea5b7b79911cd9d943962b8 Mon Sep 17 00:00:00 2001 From: Luca Versari Date: Tue, 9 Jul 2024 11:44:48 +0200 Subject: [PATCH 2/2] Make FromBytes an unsafe trait. --- parquet/src/util/bit_util.rs | 61 +++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/parquet/src/util/bit_util.rs b/parquet/src/util/bit_util.rs index 29df74e30669..adbf45014c9d 100644 --- a/parquet/src/util/bit_util.rs +++ b/parquet/src/util/bit_util.rs @@ -42,7 +42,11 @@ fn array_from_slice(bs: &[u8]) -> Result<[u8; N]> { } } -pub trait FromBytes: Sized { +/// # Safety +/// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s, +/// must be valid, unless BIT_CAPACITY is 0. +pub unsafe trait FromBytes: Sized { + const BIT_CAPACITY: usize; type Buffer: AsMut<[u8]> + Default; fn try_from_le_slice(b: &[u8]) -> Result; fn from_le_bytes(bs: Self::Buffer) -> Self; @@ -51,7 +55,9 @@ pub trait FromBytes: Sized { macro_rules! from_le_bytes { ($($ty: ty),*) => { $( - impl FromBytes for $ty { + // SAFETY: this macro is used for types for which all bit patterns are valid. + unsafe impl FromBytes for $ty { + const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8; type Buffer = [u8; size_of::()]; fn try_from_le_slice(b: &[u8]) -> Result { Ok(Self::from_le_bytes(array_from_slice(b)?)) @@ -66,7 +72,9 @@ macro_rules! from_le_bytes { from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 } -impl FromBytes for bool { +// SAFETY: the 0000000x bit pattern is always valid for `bool`. +unsafe impl FromBytes for bool { + const BIT_CAPACITY: usize = 1; type Buffer = [u8; 1]; fn try_from_le_slice(b: &[u8]) -> Result { @@ -77,7 +85,9 @@ impl FromBytes for bool { } } -impl FromBytes for Int96 { +// SAFETY: BIT_CAPACITY is 0. +unsafe impl FromBytes for Int96 { + const BIT_CAPACITY: usize = 0; type Buffer = [u8; 12]; fn try_from_le_slice(b: &[u8]) -> Result { @@ -95,7 +105,9 @@ impl FromBytes for Int96 { } } -impl FromBytes for ByteArray { +// SAFETY: BIT_CAPACITY is 0. +unsafe impl FromBytes for ByteArray { + const BIT_CAPACITY: usize = 0; type Buffer = Vec; fn try_from_le_slice(b: &[u8]) -> Result { @@ -106,7 +118,9 @@ impl FromBytes for ByteArray { } } -impl FromBytes for FixedLenByteArray { +// SAFETY: BIT_CAPACITY is 0. +unsafe impl FromBytes for FixedLenByteArray { + const BIT_CAPACITY: usize = 0; type Buffer = Vec; fn try_from_le_slice(b: &[u8]) -> Result { @@ -435,10 +449,6 @@ impl BitReader { /// This function panics if /// - `num_bits` is larger than the bit-capacity of `T` /// - // FIXME: soundness issue - this method can be used to write arbitrary bytes to any - // T. A possible fix would be to make `FromBytes` an unsafe trait (or to use a - // separate marker trait) which requires all bit patterns of T to be valid (note that this is - // not the case for `T` = `bool`). pub fn get_batch(&mut self, batch: &mut [T], num_bits: usize) -> usize { assert!(num_bits <= size_of::() * 8); @@ -461,13 +471,17 @@ impl BitReader { } } + assert_ne!(T::BIT_CAPACITY, 0); + assert!(num_bits <= T::BIT_CAPACITY); + // Read directly into output buffer match size_of::() { 1 => { let ptr = batch.as_mut_ptr() as *mut u8; - // SAFETY: batch is properly aligned and sized. Caller guarantees that T - // can be safely seen as a slice of bytes through FromBytes bound - // (FIXME: not actually true right now) + // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns + // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, + // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we + // checked that num_bits <= T::BIT_CAPACITY. let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 8 { let out_slice = (&mut out[i..i + 8]).try_into().unwrap(); @@ -478,9 +492,10 @@ impl BitReader { } 2 => { let ptr = batch.as_mut_ptr() as *mut u16; - // SAFETY: batch is properly aligned and sized. Caller guarantees that T - // can be safely seen as a slice of bytes through FromBytes bound - // (FIXME: not actually true right now) + // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns + // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, + // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we + // checked that num_bits <= T::BIT_CAPACITY. let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 16 { let out_slice = (&mut out[i..i + 16]).try_into().unwrap(); @@ -491,9 +506,10 @@ impl BitReader { } 4 => { let ptr = batch.as_mut_ptr() as *mut u32; - // SAFETY: batch is properly aligned and sized. Caller guarantees that T - // can be safely seen as a slice of bytes through FromBytes bound - // (FIXME: not actually true right now) + // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns + // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, + // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we + // checked that num_bits <= T::BIT_CAPACITY. let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 32 { let out_slice = (&mut out[i..i + 32]).try_into().unwrap(); @@ -504,9 +520,10 @@ impl BitReader { } 8 => { let ptr = batch.as_mut_ptr() as *mut u64; - // SAFETY: batch is properly aligned and sized. Caller guarantees that T - // can be safely seen as a slice of bytes through FromBytes bound - // (FIXME: not actually true right now) + // SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns + // in which only the lowest T::BIT_CAPACITY bits of T are set are valid, + // unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we + // checked that num_bits <= T::BIT_CAPACITY. let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; while values_to_read - i >= 64 { let out_slice = (&mut out[i..i + 64]).try_into().unwrap();