Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the deserializer compile time #350

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 178 additions & 70 deletions rmp-serde/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ macro_rules! depth_count(
}
);

/// Inspired by serde. We have our own error type and can save some compile time by avoiding `?`.
macro_rules! tri {
($expr:expr) => {
match $expr {
Ok(val) => val,
Err(err) => return Err(Error::from(err)),
}
};
}

impl error::Error for Error {
#[cold]
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
Expand Down Expand Up @@ -340,60 +350,169 @@ fn read_128_buf<'de, R: ReadSlice<'de>>(rd: &mut R, len: u8) -> Result<i128, Err
if len != 16 {
return Err(Error::LengthMismatch(16));
}
let buf = match read_bin_data(rd, 16)? {
let buf = match read_bin_content(rd, 16)? {
Reference::Borrowed(buf) => buf,
Reference::Copied(buf) => buf,
};
Ok(i128::from_be_bytes(buf.try_into().map_err(|_| Error::LengthMismatch(16))?))
}

fn read_str_data<'de, V, R>(rd: &mut R, len: u32, visitor: V) -> Result<V::Value, Error>
where V: Visitor<'de>, R: ReadSlice<'de>
fn read_str_len<'de, R>(rd: &mut R, marker: Marker) -> Result<u32, Error>
where R: ReadSlice<'de>{
match marker {
Marker::FixStr(len) => Ok(len.into()),
Marker::Str8 => read_u8(rd).map(u32::from),
Marker::Str16 => read_u16(rd).map(u32::from),
Marker::Str32 => read_u32(rd),
_ => Err(Error::TypeMismatch(Marker::Reserved)),
}
}

enum StrData<'de, 'r> {
Str(&'r str),
StrError(Utf8Error, &'r [u8]),
BorrowedStr(&'de str),
BorrowedStrError(Utf8Error, &'de [u8]),
}

fn read_str_data<'de, 'r, R>(rd: &'r mut R, marker: Marker) -> Result<StrData<'de, 'r>, Error>
where R: ReadSlice<'de>
{
match read_bin_data(rd, len)? {
let len = tri!(read_str_len(rd, marker));
match read_bin_content(rd, len)? {
Reference::Borrowed(buf) => {
match str::from_utf8(buf) {
Ok(s) => visitor.visit_borrowed_str(s),
Ok(s) => Ok(StrData::BorrowedStr(s)),
Err(err) => {
// Allow to unpack invalid UTF-8 bytes into a byte array.
match visitor.visit_borrowed_bytes::<Error>(buf) {
Ok(buf) => Ok(buf),
Err(..) => Err(Error::Utf8Error(err)),
}
Ok(StrData::BorrowedStrError(err, buf))
}
}
}
Reference::Copied(buf) => {
match str::from_utf8(buf) {
Ok(s) => visitor.visit_str(s),
Ok(s) => Ok(StrData::Str(s)),
Err(err) => {
// Allow to unpack invalid UTF-8 bytes into a byte array.
match visitor.visit_bytes::<Error>(buf) {
Ok(buf) => Ok(buf),
Err(..) => Err(Error::Utf8Error(err)),
}
Ok(StrData::StrError(err, buf))
}
}
}
}
}

fn read_bin_data<'a, 'de, R: ReadSlice<'de>>(rd: &'a mut R, len: u32) -> Result<Reference<'de, 'a, [u8]>, Error> {
rd.read_slice(len as usize).map_err(Error::InvalidDataRead)
fn visit_str_data<'de, V>(visitor: V, data: StrData<'de, '_>) -> Result<V::Value, Error> where V: Visitor<'de> {
match data {
StrData::Str(s) => visitor.visit_str(s),
StrData::StrError(err, buf) => {
// Allow to unpack invalid UTF-8 bytes into a byte array.
match visitor.visit_bytes::<Error>(buf) {
Ok(buf) => Ok(buf),
Err(..) => Err(Error::Utf8Error(err)),
}
},
StrData::BorrowedStr(s) => visitor.visit_borrowed_str(s),
StrData::BorrowedStrError(err, buf) => {
// Allow to unpack invalid UTF-8 bytes into a byte array.
match visitor.visit_borrowed_bytes::<Error>(buf) {
Ok(buf) => Ok(buf),
Err(..) => Err(Error::Utf8Error(err)),
}
},
}
}

fn read_array_len<'de, R>(rd: &mut R, marker: Marker) -> Result<u32, Error>
where R: ReadSlice<'de> {
match marker {
Marker::FixArray(len) => Ok(len.into()),
Marker::Array16 => read_u16(rd).map(u32::from),
Marker::Array32 => read_u32(rd),
_ => Err(Error::TypeMismatch(Marker::Reserved)),
}
}

fn read_bin_len<'de, R>(rd: &mut R, marker: Marker) -> Result<u32, Error>
where R: ReadSlice<'de>{
match marker {
Marker::Bin8 => read_u8(rd).map(u32::from),
Marker::Bin16 => read_u16(rd).map(u32::from),
Marker::Bin32 => read_u32(rd),
_ => Err(Error::TypeMismatch(Marker::Reserved)),
}
}

fn read_map_len<'de, R>(rd: &mut R, marker: Marker) -> Result<u32, Error>
where R: ReadSlice<'de> {
match marker {
Marker::FixMap(len) => Ok(len.into()),
Marker::Map16 => read_u16(rd).map(u32::from),
Marker::Map32 => read_u32(rd),
_ => Err(Error::TypeMismatch(Marker::Reserved)),
}
}

#[inline(never)]
fn read_bin_data<'a, 'de, R: ReadSlice<'de>>(rd: &'a mut R, marker: Marker) -> Result<Reference<'de, 'a, [u8]>, Error> {
let len = tri!(read_bin_len(rd, marker));
read_bin_content(rd, len)
}

fn read_bin_content<'a, 'de, R: ReadSlice<'de>>(rd: &'a mut R, len: u32) -> Result<Reference<'de, 'a, [u8]>, Error> {
match rd.read_slice(len as usize) {
Ok(b) => Ok(b),
Err(e) => Err(Error::InvalidDataRead(e)),
}
}

fn read_u8<R: Read>(rd: &mut R) -> Result<u8, Error> {
byteorder::ReadBytesExt::read_u8(rd).map_err(Error::InvalidDataRead)
match byteorder::ReadBytesExt::read_u8(rd) {
Ok(v) => Ok(v),
Err(e) => Err(Error::InvalidDataRead(e)),
}
}

fn read_u16<R: Read>(rd: &mut R) -> Result<u16, Error> {
rd.read_u16::<byteorder::BigEndian>()
.map_err(Error::InvalidDataRead)
match rd.read_u16::<byteorder::BigEndian>() {
Ok(v) => Ok(v),
Err(e) => Err(Error::InvalidDataRead(e)),
}
}

fn read_u32<R: Read>(rd: &mut R) -> Result<u32, Error> {
rd.read_u32::<byteorder::BigEndian>()
.map_err(Error::InvalidDataRead)
match rd.read_u32::<byteorder::BigEndian>() {
Ok(v) => Ok(v),
Err(e) => Err(Error::InvalidDataRead(e)),
}
}

enum AnyNumber {
U8(u8),
U16(u16),
U32(u32),
U64(u64),
I8(i8),
I16(i16),
I32(i32),
I64(i64),
F32(f32),
F64(f64),
}

#[inline(never)]
fn read_any_number<'de, R>(rd: &mut R, marker: Marker) -> Result<AnyNumber, Error>
where R: ReadSlice<'de> {
match marker {
Marker::U8 => Ok(AnyNumber::U8(rd.read_data_u8()?)),
Marker::U16 => Ok(AnyNumber::U16(rd.read_data_u16()?)),
Marker::U32 => Ok(AnyNumber::U32(rd.read_data_u32()?)),
Marker::U64 => Ok(AnyNumber::U64(rd.read_data_u64()?)),
Marker::I8 => Ok(AnyNumber::I8(rd.read_data_i8()?)),
Marker::I16 => Ok(AnyNumber::I16(rd.read_data_i16()?)),
Marker::I32 => Ok(AnyNumber::I32(rd.read_data_i32()?)),
Marker::I64 => Ok(AnyNumber::I64(rd.read_data_i64()?)),
Marker::F32 => Ok(AnyNumber::F32(rd.read_data_f32()?)),
Marker::F64 => Ok(AnyNumber::F64(rd.read_data_f64()?)),
other_marker => Err(Error::TypeMismatch(other_marker)),
}
}

fn ext_len<R: Read>(rd: &mut R, marker: Marker) -> Result<u32, Error> {
Expand Down Expand Up @@ -511,23 +630,36 @@ fn any_num<'de, R: ReadSlice<'de>, V: Visitor<'de>>(rd: &mut R, visitor: V, mark
Marker::False => visitor.visit_bool(marker == Marker::True),
Marker::FixPos(val) => visitor.visit_u8(val),
Marker::FixNeg(val) => visitor.visit_i8(val),
Marker::U8 => visitor.visit_u8(rd.read_data_u8()?),
Marker::U16 => visitor.visit_u16(rd.read_data_u16()?),
Marker::U32 => visitor.visit_u32(rd.read_data_u32()?),
Marker::U64 => visitor.visit_u64(rd.read_data_u64()?),
Marker::I8 => visitor.visit_i8(rd.read_data_i8()?),
Marker::I16 => visitor.visit_i16(rd.read_data_i16()?),
Marker::I32 => visitor.visit_i32(rd.read_data_i32()?),
Marker::I64 => visitor.visit_i64(rd.read_data_i64()?),
Marker::F32 => visitor.visit_f32(rd.read_data_f32()?),
Marker::F64 => visitor.visit_f64(rd.read_data_f64()?),
Marker::U8 |
Marker::U16 |
Marker::U32 |
Marker::U64 |
Marker::I8 |
Marker::I16 |
Marker::I32 |
Marker::I64 |
Marker::F32 |
Marker::F64 => {
match tri!(read_any_number(rd, marker)) {
AnyNumber::U8(n) => visitor.visit_u8(n),
AnyNumber::U16(n) => visitor.visit_u16(n),
AnyNumber::U32(n) => visitor.visit_u32(n),
AnyNumber::U64(n) => visitor.visit_u64(n),
AnyNumber::I8(n) => visitor.visit_i8(n),
AnyNumber::I16(n) => visitor.visit_i16(n),
AnyNumber::I32(n) => visitor.visit_i32(n),
AnyNumber::I64(n) => visitor.visit_i64(n),
AnyNumber::F32(n) => visitor.visit_f32(n),
AnyNumber::F64(n) => visitor.visit_f64(n),
}
}
other_marker => Err(Error::TypeMismatch(other_marker)),
}
}

impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer<R, C> {
fn any_inner<V: Visitor<'de>>(&mut self, visitor: V, allow_bytes: bool) -> Result<V::Value, Error> {
let marker = self.take_or_read_marker()?;
let marker = tri!(self.take_or_read_marker());
match marker {
Marker::Null |
Marker::True |
Expand All @@ -545,28 +677,16 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer<R, C> {
Marker::F32 |
Marker::F64 => any_num(&mut self.rd, visitor, marker),
Marker::FixStr(_) | Marker::Str8 | Marker::Str16 | Marker::Str32 => {
let len = match marker {
Marker::FixStr(len) => Ok(len.into()),
Marker::Str8 => read_u8(&mut self.rd).map(u32::from),
Marker::Str16 => read_u16(&mut self.rd).map(u32::from),
Marker::Str32 => read_u32(&mut self.rd),
_ => return Err(Error::TypeMismatch(Marker::Reserved)),
}?;
read_str_data(&mut self.rd, len, visitor)
let data = tri!(read_str_data(&mut self.rd, marker));
visit_str_data(visitor, data)
}
Marker::FixArray(_) |
Marker::Array16 |
Marker::Array32 => {
let len = match marker {
Marker::FixArray(len) => len.into(),
Marker::Array16 => read_u16(&mut self.rd)?.into(),
Marker::Array32 => read_u32(&mut self.rd)?,
_ => return Err(Error::TypeMismatch(Marker::Reserved)),
};

let len = tri!(read_array_len(&mut self.rd, marker));
depth_count!(self.depth, {
let mut seq = SeqAccess::new(self, len);
let res = visitor.visit_seq(&mut seq)?;
let res = tri!(visitor.visit_seq(&mut seq));
match seq.left {
0 => Ok(res),
excess => Err(Error::LengthMismatch(len - excess)),
Expand All @@ -576,30 +696,18 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer<R, C> {
Marker::FixMap(_) |
Marker::Map16 |
Marker::Map32 => {
let len = match marker {
Marker::FixMap(len) => len.into(),
Marker::Map16 => read_u16(&mut self.rd)?.into(),
Marker::Map32 => read_u32(&mut self.rd)?,
_ => return Err(Error::TypeMismatch(Marker::Reserved)),
};

let len = tri!(read_map_len(&mut self.rd, marker));
depth_count!(self.depth, {
let mut seq = MapAccess::new(self, len);
let res = visitor.visit_map(&mut seq)?;
let res = tri!(visitor.visit_map(&mut seq));
match seq.left {
0 => Ok(res),
excess => Err(Error::LengthMismatch(len - excess)),
}
})
}
Marker::Bin8 | Marker::Bin16 | Marker::Bin32 => {
let len = match marker {
Marker::Bin8 => read_u8(&mut self.rd).map(u32::from),
Marker::Bin16 => read_u16(&mut self.rd).map(u32::from),
Marker::Bin32 => read_u32(&mut self.rd),
_ => return Err(Error::TypeMismatch(Marker::Reserved)),
}?;
match read_bin_data(&mut self.rd, len)? {
match tri!(read_bin_data(&mut self.rd, marker)) {
Reference::Borrowed(buf) if allow_bytes => visitor.visit_borrowed_bytes(buf),
Reference::Copied(buf) if allow_bytes => visitor.visit_bytes(buf),
Reference::Borrowed(buf) | Reference::Copied(buf) => {
Expand All @@ -615,7 +723,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer<R, C> {
Marker::Ext8 |
Marker::Ext16 |
Marker::Ext32 => {
let len = ext_len(&mut self.rd, marker)?;
let len = tri!(ext_len(&mut self.rd, marker));
depth_count!(self.depth, visitor.visit_newtype_struct(ExtDeserializer::new(self, len)))
}
Marker::Reserved => Err(Error::TypeMismatch(Marker::Reserved)),
Expand Down Expand Up @@ -833,7 +941,7 @@ impl<'de, 'a, R: ReadSlice<'de> + 'a, C: SerializerConfig> de::SeqAccess<'de> fo
{
if self.left > 0 {
self.left -= 1;
Ok(Some(seed.deserialize(&mut *self.de)?))
Ok(Some(tri!(seed.deserialize(&mut *self.de))))
} else {
Ok(None)
}
Expand Down Expand Up @@ -906,7 +1014,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> de::EnumAccess<'de>
where
V: de::DeserializeSeed<'de>,
{
let variant = seed.deserialize(&mut *self.de)?;
let variant = tri!(seed.deserialize(&mut *self.de));
Ok((variant, self))
}
}
Expand Down Expand Up @@ -973,7 +1081,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> de::EnumAccess<'de> for Varian
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self), Error>
where V: de::DeserializeSeed<'de>,
{
Ok((seed.deserialize(&mut *self.de)?, self))
Ok((tri!(seed.deserialize(&mut *self.de)), self))
}
}

Expand All @@ -982,7 +1090,7 @@ impl<'de, R: ReadSlice<'de>, C: SerializerConfig> de::VariantAccess<'de> for Var

#[inline]
fn unit_variant(self) -> Result<(), Error> {
decode::read_nil(&mut self.de.rd)?;
tri!(decode::read_nil(&mut self.de.rd));
Ok(())
}

Expand Down
5 changes: 4 additions & 1 deletion rmp/src/decode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ pub trait RmpRead: sealed::Sealed {
#[inline]
#[doc(hidden)]
fn read_data_u8(&mut self) -> Result<u8, ValueReadError<Self::Error>> {
self.read_u8().map_err(ValueReadError::InvalidDataRead)
match self.read_u8() {
Ok(v) => Ok(v),
Err(e) => Err(ValueReadError::InvalidDataRead(e)),
}
}
/// Read a single (signed) byte from this stream.
#[inline]
Expand Down