From 6de87b015008222d01844f51f2e26b75e4e231bf Mon Sep 17 00:00:00 2001 From: John-John Tedro Date: Mon, 15 Apr 2024 18:21:19 +0200 Subject: [PATCH] Fix array encoding --- .../musli-descriptive/src/integer_encoding.rs | 2 +- crates/musli-wire/src/de.rs | 4 +- crates/musli/src/de/mod.rs | 21 +++ crates/musli/src/en/mod.rs | 12 ++ crates/musli/src/fixed.rs | 123 ++++++++++++++++++ crates/musli/src/impls/mod.rs | 34 ++++- crates/musli/src/lib.rs | 1 + tests/tests/primitives.rs | 70 ++++++++-- 8 files changed, 247 insertions(+), 20 deletions(-) create mode 100644 crates/musli/src/fixed.rs diff --git a/crates/musli-descriptive/src/integer_encoding.rs b/crates/musli-descriptive/src/integer_encoding.rs index 49da41074..0d25f2e58 100644 --- a/crates/musli-descriptive/src/integer_encoding.rs +++ b/crates/musli-descriptive/src/integer_encoding.rs @@ -41,7 +41,7 @@ where Ok(value) } - NumberKind::Unsigned => Ok(value), + NumberKind::Unsigned | NumberKind::Float => Ok(value), kind => Err(cx.message(format_args!( "Expected signed or unsigned number, got {:?}", kind diff --git a/crates/musli-wire/src/de.rs b/crates/musli-wire/src/de.rs index 536a37769..4c3b3188c 100644 --- a/crates/musli-wire/src/de.rs +++ b/crates/musli-wire/src/de.rs @@ -150,7 +150,9 @@ where } else { musli_utils::int::decode_usize::<_, _, OPT>(self.cx, self.reader.borrow_mut())? }), - _ => Err(self.cx.marked_message(start, "Expected prefix")), + kind => Err(self + .cx + .marked_message(start, format_args!("Expected prefix, but got {kind:?}"))), } } } diff --git a/crates/musli/src/de/mod.rs b/crates/musli/src/de/mod.rs index f0fc7fc46..1e7e1ee40 100644 --- a/crates/musli/src/de/mod.rs +++ b/crates/musli/src/de/mod.rs @@ -17,66 +17,87 @@ //! ``` mod skip; +#[doc(inline)] pub use self::skip::Skip; mod as_decoder; +#[doc(inline)] pub use self::as_decoder::AsDecoder; mod decode; +#[doc(inline)] pub use self::decode::{Decode, TraceDecode}; mod decode_unsized; +#[doc(inline)] pub use self::decode_unsized::DecodeUnsized; mod decode_unsized_bytes; +#[doc(inline)] pub use self::decode_unsized_bytes::DecodeUnsizedBytes; mod decode_bytes; +#[doc(inline)] pub use self::decode_bytes::DecodeBytes; mod decoder; +#[doc(inline)] pub use self::decoder::Decoder; mod map_decoder; +#[doc(inline)] pub use self::map_decoder::MapDecoder; mod map_entries_decoder; +#[doc(inline)] pub use self::map_entries_decoder::MapEntriesDecoder; mod map_entry_decoder; +#[doc(inline)] pub use self::map_entry_decoder::MapEntryDecoder; mod number_visitor; +#[doc(inline)] pub use self::number_visitor::NumberVisitor; mod pack_decoder; +#[doc(inline)] pub use self::pack_decoder::PackDecoder; mod tuple_decoder; +#[doc(inline)] pub use self::tuple_decoder::TupleDecoder; mod sequence_decoder; +#[doc(inline)] pub use self::sequence_decoder::SequenceDecoder; mod struct_decoder; +#[doc(inline)] pub use self::struct_decoder::StructDecoder; mod struct_field_decoder; +#[doc(inline)] pub use self::struct_field_decoder::StructFieldDecoder; mod struct_fields_decoder; +#[doc(inline)] pub use self::struct_fields_decoder::StructFieldsDecoder; mod size_hint; +#[doc(inline)] pub use self::size_hint::SizeHint; mod value_visitor; +#[doc(inline)] pub use self::value_visitor::ValueVisitor; mod variant_decoder; +#[doc(inline)] pub use self::variant_decoder::VariantDecoder; mod visitor; +#[doc(inline)] pub use self::visitor::Visitor; use crate::mode::DefaultMode; diff --git a/crates/musli/src/en/mod.rs b/crates/musli/src/en/mod.rs index f0bc0f3a8..77305c2cf 100644 --- a/crates/musli/src/en/mod.rs +++ b/crates/musli/src/en/mod.rs @@ -17,37 +17,49 @@ //! ``` mod encode; +#[doc(inline)] pub use self::encode::{Encode, TraceEncode}; mod encode_bytes; +#[doc(inline)] pub use self::encode_bytes::EncodeBytes; mod encoder; +#[doc(inline)] pub use self::encoder::Encoder; mod sequence_encoder; +#[doc(inline)] pub use self::sequence_encoder::SequenceEncoder; mod tuple_encoder; +#[doc(inline)] pub use self::tuple_encoder::TupleEncoder; mod pack_encoder; +#[doc(inline)] pub use self::pack_encoder::PackEncoder; mod map_encoder; +#[doc(inline)] pub use self::map_encoder::MapEncoder; mod map_entry_encoder; +#[doc(inline)] pub use self::map_entry_encoder::MapEntryEncoder; mod map_entries_encoder; +#[doc(inline)] pub use self::map_entries_encoder::MapEntriesEncoder; mod struct_encoder; +#[doc(inline)] pub use self::struct_encoder::StructEncoder; mod struct_field_encoder; +#[doc(inline)] pub use self::struct_field_encoder::StructFieldEncoder; mod variant_encoder; +#[doc(inline)] pub use self::variant_encoder::VariantEncoder; diff --git a/crates/musli/src/fixed.rs b/crates/musli/src/fixed.rs new file mode 100644 index 000000000..eadb7861c --- /dev/null +++ b/crates/musli/src/fixed.rs @@ -0,0 +1,123 @@ +use core::fmt; +use core::mem::{self, ManuallyDrop, MaybeUninit}; +use core::ops::{Deref, DerefMut}; +use core::ptr; +use core::slice; + +/// An error raised when we are at capacity. +#[derive(Debug)] +#[non_exhaustive] +pub(crate) struct CapacityError; + +impl fmt::Display for CapacityError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Out of capacity when constructing array") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for CapacityError {} + +/// A fixed capacity vector allocated on the stack. +pub(crate) struct FixedVec { + data: [MaybeUninit; N], + len: usize, +} + +impl FixedVec { + /// Construct a new empty fixed vector. + pub(crate) const fn new() -> FixedVec { + unsafe { + FixedVec { + data: MaybeUninit::uninit().assume_init(), + len: 0, + } + } + } + + #[inline] + pub(crate) fn as_ptr(&self) -> *const T { + self.data.as_ptr() as *const T + } + + #[inline] + pub(crate) fn as_mut_ptr(&mut self) -> *mut T { + self.data.as_mut_ptr() as *mut T + } + + #[inline] + pub(crate) fn as_slice(&self) -> &[T] { + unsafe { slice::from_raw_parts(self.as_ptr(), self.len) } + } + + #[inline] + pub(crate) fn as_mut_slice(&mut self) -> &mut [T] { + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) } + } + + /// Try to push an element onto the fixed vector. + pub(crate) fn try_push(&mut self, element: T) -> Result<(), CapacityError> { + if self.len >= N { + return Err(CapacityError); + } + + unsafe { + ptr::write(self.as_mut_ptr().wrapping_add(self.len), element); + self.len += 1; + } + + Ok(()) + } + + pub(crate) fn clear(&mut self) { + if self.len == 0 { + return; + } + + let len = mem::take(&mut self.len); + + if mem::needs_drop::() { + unsafe { + let tail = slice::from_raw_parts_mut(self.as_mut_ptr(), len); + ptr::drop_in_place(tail); + } + } + } + + pub(crate) fn into_inner(self) -> [T; N] { + assert!( + self.len == N, + "into_inner: length mismatch, expected {N} but got {}", + self.len + ); + + // SAFETY: We've asserted that the length is initialized just above. + unsafe { + let this = ManuallyDrop::new(self); + ptr::read(this.data.as_ptr() as *const [T; N]) + } + } +} + +impl Deref for FixedVec { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_slice() + } +} + +impl DerefMut for FixedVec { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut_slice() + } +} + +impl Drop for FixedVec { + #[inline] + fn drop(&mut self) { + self.clear() + } +} diff --git a/crates/musli/src/impls/mod.rs b/crates/musli/src/impls/mod.rs index 9fa8995cd..d2298f7dc 100644 --- a/crates/musli/src/impls/mod.rs +++ b/crates/musli/src/impls/mod.rs @@ -15,7 +15,8 @@ use core::num::{ use core::{fmt, marker}; use crate::de::{ - Decode, DecodeBytes, DecodeUnsized, DecodeUnsizedBytes, Decoder, ValueVisitor, VariantDecoder, + Decode, DecodeBytes, DecodeUnsized, DecodeUnsizedBytes, Decoder, SequenceDecoder, ValueVisitor, + VariantDecoder, }; use crate::en::{Encode, EncodeBytes, Encoder, SequenceEncoder, VariantEncoder}; use crate::hint::SequenceHint; @@ -176,13 +177,36 @@ where } } -impl<'de, M, const N: usize> Decode<'de, M> for [u8; N] { +impl<'de, M, T, const N: usize> Decode<'de, M> for [T; N] +where + T: Decode<'de, M>, +{ #[inline] - fn decode(_: &D::Cx, decoder: D) -> Result + fn decode(cx: &D::Cx, decoder: D) -> Result where - D: Decoder<'de>, + D: Decoder<'de, Mode = M>, { - decoder.decode_array() + let mark = cx.mark(); + + decoder.decode_sequence(|seq| { + let mut array = crate::fixed::FixedVec::new(); + + while let Some(item) = seq.decode_next()? { + array.try_push(item.decode()?).map_err(cx.map())?; + } + + if array.len() != N { + return Err(cx.marked_message( + mark, + format_args!( + "Array with length {} does not have the expected {N} number of elements", + array.len() + ), + )); + } + + Ok(array.into_inner()) + }) } } diff --git a/crates/musli/src/lib.rs b/crates/musli/src/lib.rs index fa96743ee..31e3eb20c 100644 --- a/crates/musli/src/lib.rs +++ b/crates/musli/src/lib.rs @@ -422,6 +422,7 @@ pub mod de; pub mod derives; pub mod en; mod expecting; +mod fixed; pub mod hint; mod impls; mod internal; diff --git a/tests/tests/primitives.rs b/tests/tests/primitives.rs index 6f7504049..2663829cd 100644 --- a/tests/tests/primitives.rs +++ b/tests/tests/primitives.rs @@ -1,13 +1,13 @@ #![cfg(feature = "test")] -use musli::compat::{Bytes, Sequence}; +use musli::compat::Sequence; use musli::{Decode, Encode}; #[derive(Debug, PartialEq, Encode, Decode)] pub struct Inner; #[derive(Debug, PartialEq, Encode, Decode)] -pub struct Numbers { +pub struct Primitives { pub bool_field: bool, pub char_field: char, pub u8_field: u8, @@ -20,18 +20,18 @@ pub struct Numbers { pub i32_field: i32, pub i64_field: i64, pub i128_field: i128, + pub f32_field: f32, + pub f64_field: f64, pub usize_field: usize, pub isize_field: isize, - pub empty_array_field: Bytes<[u8; 0]>, pub empty_tuple: (), - pub empty_sequence: Sequence<()>, } #[test] -fn primitives_max() { +fn primitives() { tests::rt!( full, - Numbers { + Primitives { bool_field: true, char_field: char::MAX, u8_field: u8::MAX, @@ -44,20 +44,17 @@ fn primitives_max() { i32_field: i32::MAX, i64_field: i64::MAX, i128_field: i128::MAX, + f32_field: f32::MAX, + f64_field: f64::MAX, usize_field: usize::MAX, isize_field: isize::MAX, - empty_array_field: Bytes([]), empty_tuple: (), - empty_sequence: Sequence(()), } ); -} -#[test] -fn primitives_min() { tests::rt!( full, - Numbers { + Primitives { bool_field: false, char_field: '\u{0000}', u8_field: u8::MIN, @@ -70,10 +67,57 @@ fn primitives_min() { i32_field: i32::MIN, i64_field: i64::MIN, i128_field: i128::MIN, + f32_field: f32::MIN, + f64_field: f64::MIN, usize_field: usize::MIN, isize_field: isize::MIN, - empty_array_field: Bytes([]), empty_tuple: (), + } + ); +} + +#[derive(Debug, PartialEq, Encode, Decode)] +pub struct Arrays { + pub empty_array_field: [u8; 0], + pub ten_array: [i32; 10], +} + +#[test] +fn arrays() { + tests::rt!( + full, + Arrays { + empty_array_field: [], + ten_array: [i32::MIN; 10], + } + ); + + tests::rt!( + full, + Arrays { + empty_array_field: [], + ten_array: [i32::MAX; 10], + } + ); +} + +#[derive(Debug, PartialEq, Encode, Decode)] +pub struct Sequences { + pub empty_sequence: Sequence<()>, +} + +#[test] +fn sequences() { + tests::rt!( + full, + Sequences { + empty_sequence: Sequence(()), + } + ); + + tests::rt!( + full, + Sequences { empty_sequence: Sequence(()), } );