Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine the borsh implementation #45

Merged
merged 6 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ borsh = ["dep:borsh"]
schemars = ["dep:schemars", "std"]

[dependencies]
num-traits = { version = "0.2.17", default-features = false, optional = true }
defmt = { version = "0.3.5", optional = true }
serde = { version = "1.0", optional = true, default-features = false}
num-traits = { version = "0.2.19", default-features = false, optional = true }
defmt = { version = "0.3.8", optional = true }
serde = { version = "1.0", optional = true, default-features = false }
borsh = { version = "1.5.1", optional = true, features = ["unstable__schema"], default-features = false }
schemars = { version = "0.8.1", optional = true, features = ["derive"], default-features = false }
schemars = { version = "0.8.21", optional = true, features = ["derive"], default-features = false }

[dev-dependencies]
serde_test = "1.0"
58 changes: 25 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ use core::ops::{
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};

#[cfg(feature = "borsh")]
use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};

#[cfg(all(feature = "borsh", not(feature = "std")))]
use alloc::{collections::BTreeMap, string::ToString};

Expand Down Expand Up @@ -1069,51 +1066,46 @@ where
}
}

// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde.
// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits.
// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway).
#[cfg(feature = "borsh")]
impl<T, const BITS: usize> BorshSerialize for UInt<T, BITS>
impl<T, const BITS: usize> borsh::BorshSerialize for UInt<T, BITS>
where
Self: Number,
T: BorshSerialize
+ From<u8>
+ BitAnd<T, Output = T>
+ TryInto<u8>
+ Copy
+ Shr<usize, Output = T>,
<UInt<T, BITS> as Number>::UnderlyingType:
Shr<usize, Output = T> + TryInto<u8> + From<u8> + BitAnd<T>,
T: borsh::BorshSerialize,
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
let value = self.value();
let length = (BITS + 7) / 8;
let mut bytes = 0;
let mask: T = u8::MAX.into();
while bytes < length {
let le_byte: u8 = ((value >> (bytes << 3)) & mask)
.try_into()
.ok()
.expect("we cut to u8 via mask");
writer.write(&[le_byte])?;
bytes += 1;
}
let serialized_byte_count = (BITS + 7) / 8;
let mut buffer = [0u8; 16];
self.value.serialize(&mut &mut buffer[..])?;
writer.write(&buffer[0..serialized_byte_count])?;

Ok(())
}
}

#[cfg(feature = "borsh")]
impl<
T: BorshDeserialize + core::cmp::PartialOrd<<UInt<T, BITS> as Number>::UnderlyingType>,
T: borsh::BorshDeserialize + PartialOrd<<UInt<T, BITS> as Number>::UnderlyingType>,
const BITS: usize,
> BorshDeserialize for UInt<T, BITS>
> borsh::BorshDeserialize for UInt<T, BITS>
where
Self: Number,
{
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
let mut buf = vec![0u8; core::mem::size_of::<T>()];
reader.read(&mut buf)?;
let value = T::deserialize(&mut &buf[..])?;
// Ideally, we'd want a buffer of size `BITS >> 3` or `size_of::<T>`, but that's not possible
// with arrays at present (feature(generic_const_exprs), once stable, will allow this).
// vec! would be an option, but an allocation is not expected at this level.
// Therefore, allocate a 16 byte buffer and take a slice out of it.
let serialized_byte_count = (BITS + 7) / 8;
let underlying_byte_count = core::mem::size_of::<T>();
let mut buf = [0u8; 16];
danlehmann marked this conversation as resolved.
Show resolved Hide resolved

// Read from the source, advancing cursor by the exact right number of bytes
reader.read(&mut buf[..serialized_byte_count])?;

// Deserialize the underlying type. We have to pass in the correct number of bytes of the
// underlying type (or more, but let's be precise). The unused bytes are all still zero
let value = T::deserialize(&mut &buf[..underlying_byte_count])?;

if value >= Self::MIN.value() && value <= Self::MAX.value() {
Ok(Self { value })
} else {
Expand All @@ -1126,7 +1118,7 @@ where
}

#[cfg(feature = "borsh")]
impl<T, const BITS: usize> BorshSchema for UInt<T, BITS> {
impl<T, const BITS: usize> borsh::BorshSchema for UInt<T, BITS> {
fn add_definitions_recursively(
definitions: &mut BTreeMap<borsh::schema::Declaration, borsh::schema::Definition>,
) {
Expand Down
138 changes: 105 additions & 33 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1912,42 +1912,114 @@ fn serde() {
);
}

#[cfg(all(feature = "borsh", feature = "std"))]
#[test]
fn borsh() {
#[cfg(feature = "borsh")]
mod borsh_tests {
use arbitrary_int::{u1, u14, u15, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt};
use borsh::schema::BorshSchemaContainer;
use borsh::{BorshDeserialize, BorshSerialize};
let mut buf = Vec::new();
let base_input: u8 = 42;
let input = u9::new(base_input.into());
input.serialize(&mut buf).unwrap();
let output = u9::deserialize(&mut buf.as_ref()).unwrap();
let fits = u16::new(base_input.into());
assert_eq!(buf, fits.to_le_bytes());
assert_eq!(input, output);

let input = u63::MAX;
let fits = u64::new(input.value());
let mut buf = Vec::new();
input.serialize(&mut buf).unwrap();
let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(buf, fits.to_le_bytes());
assert_eq!(input, output);

let schema = BorshSchemaContainer::for_type::<u9>();
match schema.get_definition("u9").expect("exists") {
borsh::schema::Definition::Primitive(2) => {}
_ => panic!("unexpected schema"),
use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
use std::fmt::Debug;

fn test_roundtrip<T: Number + BorshSerialize + BorshDeserialize + PartialEq + Eq + Debug>(
input: T,
expected_buffer: &[u8],
) {
let mut buf = Vec::new();

// Serialize and compare against expected
input.serialize(&mut buf).unwrap();
assert_eq!(buf, expected_buffer);

// Add to the buffer a second time - this is a better test for the deserialization
// as it ensures we request the correct number of bytes
input.serialize(&mut buf).unwrap();

// Deserialize back and compare against input
let output = T::deserialize(&mut buf.as_ref()).unwrap();
let output2 = T::deserialize(&mut &buf[buf.len() / 2..]).unwrap();
assert_eq!(input, output);
assert_eq!(input, output2);
}

#[test]
fn test_serialize_deserialize() {
// Run against plain u64 first (not an arbitrary_int)
test_roundtrip(
0x12345678_9ABCDEF0u64,
&[0xF0, 0xDE, 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12],
);

// Now try various arbitrary ints
test_roundtrip(u1::new(0b0), &[0]);
test_roundtrip(u1::new(0b1), &[1]);
test_roundtrip(u6::new(0b101101), &[0b101101]);
test_roundtrip(u14::new(0b110101_11001101), &[0b11001101, 0b110101]);
test_roundtrip(
u72::new(0x36_01234567_89ABCDEF),
&[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01, 0x36],
);

// Pick a byte boundary (80; test one below and one above to ensure we get the right number
// of bytes)
test_roundtrip(
u79::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
test_roundtrip(
u80::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
);
test_roundtrip(
u81::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
],
);

// Test actual u128 and arbitrary u128 (which is a legal one, though not a predefined)
test_roundtrip(
u128::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF,
],
);
test_roundtrip(
UInt::<u128, 128>::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF,
],
);
}

fn verify_byte_count_in_schema<T: BorshSchema + ?Sized>(expected_byte_count: u8, name: &str) {
let schema = BorshSchemaContainer::for_type::<T>();
match schema.get_definition(name).expect("exists") {
borsh::schema::Definition::Primitive(byte_count) => {
assert_eq!(*byte_count, expected_byte_count);
}
_ => panic!("unexpected schema"),
}
}

let input = u50::MAX;
let fits = u64::new(input.value());
let mut buf = Vec::new();
input.serialize(&mut buf).unwrap();
assert!(buf.len() < fits.to_le_bytes().len());
assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]);
let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(input, output);
#[test]
fn test_schema_byte_count() {
verify_byte_count_in_schema::<u1>(1, "u1");

verify_byte_count_in_schema::<u7>(1, "u7");

verify_byte_count_in_schema::<UInt<u8, 8>>(1, "u8");
verify_byte_count_in_schema::<UInt<u32, 8>>(1, "u8");

verify_byte_count_in_schema::<u9>(2, "u9");

verify_byte_count_in_schema::<u15>(2, "u15");
verify_byte_count_in_schema::<UInt<u128, 15>>(2, "u15");

verify_byte_count_in_schema::<u63>(8, "u63");

verify_byte_count_in_schema::<u65>(9, "u65");
}
}

#[cfg(feature = "schemars")]
Expand Down
Loading