diff --git a/core/src/register_data.rs b/core/src/register_data.rs index ac7b421..28ac522 100644 --- a/core/src/register_data.rs +++ b/core/src/register_data.rs @@ -95,27 +95,36 @@ impl RegisterData for ArrayRegisterD } } -impl FromIterator for ArrayRegisterData +impl ArrayRegisterData where RB: funty::Integral, RB::Bytes: for<'a> TryFrom<&'a [u8]>, { - fn from_iter>(iter: T) -> Self { - // Get the iterator. We assume that it is in the same format as the bytes function outputs + /// Try to build a [ArrayRegisterData] from an [IntoIterator] + pub fn try_from_iter>( + iter: I, + ) -> Result { + use RegisterDataFromIterError::*; + // Get the iterator. let mut iter = iter.into_iter(); - assert_eq!( - iter.next().unwrap(), - REGISTER_DATA_IDENTIFIER, - "The given iterator is not for register data" - ); + match iter.next() { + Some(REGISTER_DATA_IDENTIFIER) => {} + Some(id) => return Err(InvalidIdentifier(id)), + None => return Err(NotEnoughItems), + } // First the starting number is encoded - let starting_register_number = - u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let starting_register_number = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); // Second is how many registers there are - let register_count = u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let register_count = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); // Create the buffer we're storing the registers in let mut registers = ArrayVec::new(); @@ -123,28 +132,48 @@ where // We process everything byte-by-byte generically so every register has an unknown length // So we need to store the bytes temporarily until we have enough to fully read the bytes as a register let register_size = core::mem::size_of::(); + + // Check that all register bytes will fit in `registers` + let num_register_bytes = register_count as usize * register_size; + if num_register_bytes > SIZE { + return Err(LengthTooBig(register_count, register_size)); + } + let mut register_bytes_buffer = ArrayVec::::new(); - for byte in (0..register_count as usize * register_size).map(|_| iter.next().unwrap()) { - register_bytes_buffer.push(byte); + for byte in (0..num_register_bytes).map(|_| iter.next().ok_or(NotEnoughItems)) { + let byte = byte?; + register_bytes_buffer.try_push(byte).map_err(|_| Corrupt)?; if register_bytes_buffer.len() == register_size { registers.push(RB::from_le_bytes( register_bytes_buffer .as_slice() .try_into() - .unwrap_or_else(|_| panic!()), + .map_err(|_| Corrupt)?, )); register_bytes_buffer = ArrayVec::new(); } } - assert!(register_bytes_buffer.is_empty()); + if !register_bytes_buffer.is_empty() { + return Err(Corrupt); + } - Self { + Ok(Self { starting_register_number, registers, - } + }) + } +} + +impl FromIterator for ArrayRegisterData +where + RB: funty::Integral, + RB::Bytes: for<'a> TryFrom<&'a [u8]>, +{ + fn from_iter>(iter: T) -> Self { + Self::try_from_iter(iter).unwrap() } } @@ -223,50 +252,76 @@ impl RegisterData for VecRegisterData { } #[cfg(feature = "std")] -impl FromIterator for VecRegisterData +impl VecRegisterData where RB: funty::Integral, RB::Bytes: for<'a> TryFrom<&'a [u8]>, { - fn from_iter>(iter: T) -> Self { + /// Try to build a [VecRegisterData] from an [IntoIterator] + pub fn try_from_iter>( + iter: I, + ) -> Result { + use RegisterDataFromIterError::*; + let mut iter = iter.into_iter(); - assert_eq!( - iter.next().unwrap(), - REGISTER_DATA_IDENTIFIER, - "The given iterator is not for register data" - ); + match iter.next() { + Some(REGISTER_DATA_IDENTIFIER) => {} + Some(id) => return Err(InvalidIdentifier(id)), + None => return Err(NotEnoughItems), + } - let starting_register_number = - u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let starting_register_number = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); - let register_count = u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let register_count = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); let mut registers = Vec::new(); let register_size = core::mem::size_of::(); let mut register_bytes_buffer = ArrayVec::::new(); - for byte in (0..register_count as usize * register_size).map(|_| iter.next().unwrap()) { - register_bytes_buffer.push(byte); + for byte in + (0..register_count as usize * register_size).map(|_| iter.next().ok_or(NotEnoughItems)) + { + let byte = byte?; + register_bytes_buffer.try_push(byte).map_err(|_| Corrupt)?; if register_bytes_buffer.len() == register_size { registers.push(RB::from_le_bytes( register_bytes_buffer .as_slice() .try_into() - .unwrap_or_else(|_| panic!()), + .map_err(|_| Corrupt)?, )); register_bytes_buffer.clear(); } } - assert!(register_bytes_buffer.is_empty()); + if !register_bytes_buffer.is_empty() { + return Err(Corrupt); + } - Self { + Ok(Self { starting_register_number, registers, - } + }) + } +} + +#[cfg(feature = "std")] +impl FromIterator for VecRegisterData +where + RB: funty::Integral, + RB::Bytes: for<'a> TryFrom<&'a [u8]>, +{ + fn from_iter>(iter: T) -> Self { + Self::try_from_iter(iter).unwrap() } } @@ -324,6 +379,35 @@ impl<'a, RB: funty::Integral> Iterator for RegisterDataBytesIterator<'a, RB> { impl<'a, RB: funty::Integral> ExactSizeIterator for RegisterDataBytesIterator<'a, RB> {} +#[derive(Debug)] +/// Specifies what went wrong building a [RegisterData] from an iterator +pub enum RegisterDataFromIterError { + /// The given iterator is not for a register set. + /// First item from iterator yielded invalid identifier. Expected [REGISTER_DATA_IDENTIFIER] + InvalidIdentifier(u8), + /// Iterator specified length too big for declared register set + LengthTooBig(u16, usize), + /// Iterator did not yield enough items to build register set + NotEnoughItems, + /// Iterator data is corrupt in some other way + Corrupt, +} + +impl core::fmt::Display for RegisterDataFromIterError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use RegisterDataFromIterError::*; + match self { + InvalidIdentifier(id) => write!(f, "Iterator is not for a register set. Started with {id}, expected {REGISTER_DATA_IDENTIFIER}"), + LengthTooBig(count, size) => write!(f, "Iterator specified length too big for register set: {len}", len = *count as usize * size), + NotEnoughItems => write!(f, "Iterator did not yield enough items to build register set"), + Corrupt => write!(f, "Iterator data is corrupt") + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RegisterDataFromIterError {} + #[cfg(test)] mod tests { use super::*;