From aef2d34d07afc0ca4e2e6bd3fb3003276b2f6ede Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 12:40:43 -0700 Subject: [PATCH 1/6] Fix the build --- src/lib.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d937325..51ea79c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,14 +21,11 @@ use core::ops::{ #[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; -#[cfg(feature = "borsh")] -use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; - #[cfg(all(feature = "borsh", not(feature = "std")))] -use alloc::{collections::BTreeMap, string::ToString}; +use alloc::{collections::BTreeMap, string::ToString, vec}; #[cfg(all(feature = "borsh", feature = "std"))] -use std::{collections::BTreeMap, string::ToString}; +use std::{collections::BTreeMap, string::ToString, vec}; #[cfg(feature = "schemars")] use schemars::JsonSchema; @@ -1073,10 +1070,10 @@ where // Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits. // Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway). #[cfg(feature = "borsh")] -impl BorshSerialize for UInt +impl borsh::BorshSerialize for UInt where Self: Number, - T: BorshSerialize + T: borsh::BorshSerialize + From + BitAnd + TryInto @@ -1104,9 +1101,9 @@ where #[cfg(feature = "borsh")] impl< - T: BorshDeserialize + core::cmp::PartialOrd< as Number>::UnderlyingType>, + T: borsh::BorshDeserialize + PartialOrd< as Number>::UnderlyingType>, const BITS: usize, - > BorshDeserialize for UInt + > borsh::BorshDeserialize for UInt where Self: Number, { @@ -1126,7 +1123,7 @@ where } #[cfg(feature = "borsh")] -impl BorshSchema for UInt { +impl borsh::BorshSchema for UInt { fn add_definitions_recursively( definitions: &mut BTreeMap, ) { From 2f638e775ee22051fd79c10600c4173525658c56 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 12:40:54 -0700 Subject: [PATCH 2/6] Bump dependencies --- Cargo.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ab12cb6..e1ce01e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,11 +32,11 @@ borsh = ["dep:borsh"] schemars = ["dep:schemars", "std"] [dependencies] -num-traits = { version = "0.2.17", default-features = false, optional = true } -defmt = { version = "0.3.5", optional = true } -serde = { version = "1.0", optional = true, default-features = false} +num-traits = { version = "0.2.19", default-features = false, optional = true } +defmt = { version = "0.3.8", optional = true } +serde = { version = "1.0", optional = true, default-features = false } borsh = { version = "1.5.1", optional = true, features = ["unstable__schema"], default-features = false } -schemars = { version = "0.8.1", optional = true, features = ["derive"], default-features = false } +schemars = { version = "0.8.21", optional = true, features = ["derive"], default-features = false } [dev-dependencies] serde_test = "1.0" From d4f6c7cc5a10e5722e76da8f56c3298a9d055e21 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 13:25:56 -0700 Subject: [PATCH 3/6] Rewrite borsh tests to catch more errors In particular, test at byte boundaries and try some unusual (large) uints --- src/lib.rs | 15 +++--- tests/tests.rs | 134 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 107 insertions(+), 42 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 51ea79c..7dc16bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1066,9 +1066,6 @@ where } } -// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde. -// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits. -// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway). #[cfg(feature = "borsh")] impl borsh::BorshSerialize for UInt where @@ -1084,16 +1081,16 @@ where { fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { let value = self.value(); - let length = (BITS + 7) / 8; - let mut bytes = 0; - let mask: T = u8::MAX.into(); - while bytes < length { - let le_byte: u8 = ((value >> (bytes << 3)) & mask) + let total_bytes = (BITS + 7) / 8; + let mut byte_count_written = 0; + let byte_mask: T = u8::MAX.into(); + while byte_count_written < total_bytes { + let le_byte: u8 = ((value >> (byte_count_written << 3)) & byte_mask) .try_into() .ok() .expect("we cut to u8 via mask"); writer.write(&[le_byte])?; - bytes += 1; + byte_count_written += 1; } Ok(()) } diff --git a/tests/tests.rs b/tests/tests.rs index 3687875..ac84078 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1912,42 +1912,110 @@ fn serde() { ); } -#[cfg(all(feature = "borsh", feature = "std"))] -#[test] -fn borsh() { +#[cfg(feature = "borsh")] +mod borsh_tests { + use arbitrary_int::{ + u1, u14, u15, u50, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt, + }; use borsh::schema::BorshSchemaContainer; - use borsh::{BorshDeserialize, BorshSerialize}; - let mut buf = Vec::new(); - let base_input: u8 = 42; - let input = u9::new(base_input.into()); - input.serialize(&mut buf).unwrap(); - let output = u9::deserialize(&mut buf.as_ref()).unwrap(); - let fits = u16::new(base_input.into()); - assert_eq!(buf, fits.to_le_bytes()); - assert_eq!(input, output); - - let input = u63::MAX; - let fits = u64::new(input.value()); - let mut buf = Vec::new(); - input.serialize(&mut buf).unwrap(); - let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap(); - assert_eq!(buf, fits.to_le_bytes()); - assert_eq!(input, output); - - let schema = BorshSchemaContainer::for_type::(); - match schema.get_definition("u9").expect("exists") { - borsh::schema::Definition::Primitive(2) => {} - _ => panic!("unexpected schema"), + use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; + use std::fmt::Debug; + + fn test_roundtrip( + input: T, + expected_buffer: &[u8], + ) { + let mut buf = Vec::new(); + + // Serialize and compare against expected + input.serialize(&mut buf).unwrap(); + assert_eq!(buf, expected_buffer); + + // Deserialize back and compare against input + let output = T::deserialize(&mut buf.as_ref()).unwrap(); + assert_eq!(input, output); + } + + #[test] + fn test_serialize_deserialize() { + // Run against plain u64 first (not an arbitrary_int) + test_roundtrip( + 0x12345678_9ABCDEF0u64, + &[0xF0, 0xDE, 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12], + ); + + // Now try various arbitrary ints + test_roundtrip(u1::new(0b0), &[0]); + test_roundtrip(u1::new(0b1), &[1]); + test_roundtrip(u6::new(0b101101), &[0b101101]); + test_roundtrip(u14::new(0b110101_11001101), &[0b11001101, 0b110101]); + test_roundtrip( + u72::new(0x36_01234567_89ABCDEF), + &[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01, 0x36], + ); + + // Pick a byte boundary (80; test one below and one above to ensure we get the right number + // of bytes) + test_roundtrip( + u79::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + test_roundtrip( + u80::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF], + ); + test_roundtrip( + u81::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, + ], + ); + + // Test actual u128 and arbitrary u128 (which is a legal one, though not a predefined) + test_roundtrip( + u128::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ], + ); + test_roundtrip( + UInt::::MAX, + &[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ], + ); + } + + fn verify_byte_count_in_schema(expected_byte_count: u8, name: &str) { + let schema = BorshSchemaContainer::for_type::(); + match schema.get_definition(name).expect("exists") { + borsh::schema::Definition::Primitive(byte_count) => { + assert_eq!(*byte_count, expected_byte_count); + } + _ => panic!("unexpected schema"), + } } - let input = u50::MAX; - let fits = u64::new(input.value()); - let mut buf = Vec::new(); - input.serialize(&mut buf).unwrap(); - assert!(buf.len() < fits.to_le_bytes().len()); - assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]); - let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap(); - assert_eq!(input, output); + #[test] + fn test_schema_byte_count() { + verify_byte_count_in_schema::(1, "u1"); + + verify_byte_count_in_schema::(1, "u7"); + + verify_byte_count_in_schema::>(1, "u8"); + verify_byte_count_in_schema::>(1, "u8"); + + verify_byte_count_in_schema::(2, "u9"); + + verify_byte_count_in_schema::(2, "u15"); + verify_byte_count_in_schema::>(2, "u15"); + + verify_byte_count_in_schema::(8, "u63"); + + verify_byte_count_in_schema::(9, "u65"); + } } #[cfg(feature = "schemars")] From aeb986ebfd99df90b2f78c685774133817715296 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 13:35:18 -0700 Subject: [PATCH 4/6] Replace vec! from deserialization with static array --- src/lib.rs | 10 +++++++--- tests/tests.rs | 4 +--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7dc16bf..95ec417 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,10 +22,10 @@ use core::ops::{ use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[cfg(all(feature = "borsh", not(feature = "std")))] -use alloc::{collections::BTreeMap, string::ToString, vec}; +use alloc::{collections::BTreeMap, string::ToString}; #[cfg(all(feature = "borsh", feature = "std"))] -use std::{collections::BTreeMap, string::ToString, vec}; +use std::{collections::BTreeMap, string::ToString}; #[cfg(feature = "schemars")] use schemars::JsonSchema; @@ -1105,7 +1105,11 @@ where Self: Number, { fn deserialize_reader(reader: &mut R) -> borsh::io::Result { - let mut buf = vec![0u8; core::mem::size_of::()]; + // Ideally, we'd want a buffer of size `BITS >> 3` or `size_of::`, but that's not possible + // with arrays. + // So instead we'll do a 16 byte buffer which handles the largest arbitrary-ints possible - + // not ideal, but still pretty small and better than going through an allocator. + let mut buf = [0u8; 16]; reader.read(&mut buf)?; let value = T::deserialize(&mut &buf[..])?; if value >= Self::MIN.value() && value <= Self::MAX.value() { diff --git a/tests/tests.rs b/tests/tests.rs index ac84078..9d8ffaf 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1914,9 +1914,7 @@ fn serde() { #[cfg(feature = "borsh")] mod borsh_tests { - use arbitrary_int::{ - u1, u14, u15, u50, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt, - }; + use arbitrary_int::{u1, u14, u15, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt}; use borsh::schema::BorshSchemaContainer; use borsh::{BorshDeserialize, BorshSchema, BorshSerialize}; use std::fmt::Debug; From b7ca18281793a07cceeb5c4ab368d0e949e53cf3 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 14:16:09 -0700 Subject: [PATCH 5/6] Simplify borsh serialize and deserialize --- src/lib.rs | 35 ++++++++++++++++++----------------- tests/tests.rs | 6 ++++++ 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 95ec417..8357560 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1080,18 +1080,11 @@ where Shr + TryInto + From + BitAnd, { fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { - let value = self.value(); - let total_bytes = (BITS + 7) / 8; - let mut byte_count_written = 0; - let byte_mask: T = u8::MAX.into(); - while byte_count_written < total_bytes { - let le_byte: u8 = ((value >> (byte_count_written << 3)) & byte_mask) - .try_into() - .ok() - .expect("we cut to u8 via mask"); - writer.write(&[le_byte])?; - byte_count_written += 1; - } + let serialized_byte_count = (BITS + 7) / 8; + let mut buffer = [0u8; 16]; + self.value.serialize(&mut &mut buffer[..])?; + writer.write(&buffer[0..serialized_byte_count])?; + Ok(()) } } @@ -1106,12 +1099,20 @@ where { fn deserialize_reader(reader: &mut R) -> borsh::io::Result { // Ideally, we'd want a buffer of size `BITS >> 3` or `size_of::`, but that's not possible - // with arrays. - // So instead we'll do a 16 byte buffer which handles the largest arbitrary-ints possible - - // not ideal, but still pretty small and better than going through an allocator. + // with arrays at present (feature(generic_const_exprs), once stable, will allow this). + // vec! would be an option, but an allocation is not expected at this level. + // Therefore, allocate a 16 byte buffer and take a slice out of it. + let serialized_byte_count = (BITS + 7) / 8; + let underlying_byte_count = core::mem::size_of::(); let mut buf = [0u8; 16]; - reader.read(&mut buf)?; - let value = T::deserialize(&mut &buf[..])?; + + // Read from the source, advancing cursor by the exact right number of bytes + reader.read(&mut buf[..serialized_byte_count])?; + + // Deserialize the underlying type. We have to pass in the correct number of bytes of the + // underlying type (or more, but let's be precise). The unused bytes are all still zero + let value = T::deserialize(&mut &buf[..underlying_byte_count])?; + if value >= Self::MIN.value() && value <= Self::MAX.value() { Ok(Self { value }) } else { diff --git a/tests/tests.rs b/tests/tests.rs index 9d8ffaf..e674586 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1929,9 +1929,15 @@ mod borsh_tests { input.serialize(&mut buf).unwrap(); assert_eq!(buf, expected_buffer); + // Add to the buffer a second time - this is a better test for the deserialization + // as it ensures we request the correct number of bytes + input.serialize(&mut buf).unwrap(); + // Deserialize back and compare against input let output = T::deserialize(&mut buf.as_ref()).unwrap(); + let output2 = T::deserialize(&mut &buf[buf.len() / 2..]).unwrap(); assert_eq!(input, output); + assert_eq!(input, output2); } #[test] From adb6f6f452008f008712a3ebf927db3c51b84a43 Mon Sep 17 00:00:00 2001 From: Daniel Lehmann Date: Tue, 23 Jul 2024 14:25:46 -0700 Subject: [PATCH 6/6] Simplify BorshSerialize trait boundsy --- src/lib.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8357560..8880f46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1070,14 +1070,7 @@ where impl borsh::BorshSerialize for UInt where Self: Number, - T: borsh::BorshSerialize - + From - + BitAnd - + TryInto - + Copy - + Shr, - as Number>::UnderlyingType: - Shr + TryInto + From + BitAnd, + T: borsh::BorshSerialize, { fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { let serialized_byte_count = (BITS + 7) / 8;