Skip to content

Commit

Permalink
Use thiserror Errors instead of relying on anyhow
Browse files Browse the repository at this point in the history
This improves matching on particular errors when we need to handle
different conditions downstream.

It's still possible to convert a anyhow::Error to a DecodeError in this
change, but every other error this crate expsoes is now in a variant.
  • Loading branch information
miguelfrde committed Apr 22, 2024
1 parent 51ef451 commit f18501c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 66 deletions.
54 changes: 39 additions & 15 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,53 @@ impl From<anyhow::Error> for EncodeError {
}

#[derive(Debug, Error)]
#[error("Decode error occurred: {inner}")]
pub struct DecodeError {
inner: anyhow::Error,
pub enum DecodeError {
#[error("Invalid MAC address")]
InvalidMACAddress,

#[error("Invalid IPv6 address")]
InvalidIPv6Address,

#[error("Invalid string")]
Utf8Error(#[from] std::string::FromUtf8Error),

#[error("Invalid u8")]
InvalidU8,

#[error("Invalid u16")]
InvalidU16,

#[error("Invalid u32")]
InvalidU32,

#[error("Invalid u64")]
InvalidU64,

#[error("Invalid u128")]
InvalidU128,

#[error("Invalid i32")]
InvalidI32,

#[error("Invalid {name}: length {len} < {buffer_len}")]
InvalidBufferLength {
name: &'static str,
len: usize,
buffer_len: usize,
},

#[error(transparent)]
Other(#[from] anyhow::Error),
}

impl From<&'static str> for DecodeError {
fn from(msg: &'static str) -> Self {
DecodeError {
inner: anyhow!(msg),
}
DecodeError::Other(anyhow!(msg))
}
}

impl From<String> for DecodeError {
fn from(msg: String) -> Self {
DecodeError {
inner: anyhow!(msg),
}
}
}

impl From<anyhow::Error> for DecodeError {
fn from(inner: anyhow::Error) -> DecodeError {
DecodeError { inner }
DecodeError::Other(anyhow!(msg))
}
}
14 changes: 5 additions & 9 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,11 @@ macro_rules! buffer_check_length {
fn check_buffer_length(&self) -> Result<(), DecodeError> {
let len = self.buffer.as_ref().len();
if len < $buffer_len {
Err(format!(
concat!(
"invalid ",
stringify!($name),
": length {} < {}"
),
len, $buffer_len
)
.into())
Err(DecodeError::InvalidBufferLength {
name: stringify!($name),
len,
buffer_len: $buffer_len,
})
} else {
Ok(())
}
Expand Down
57 changes: 32 additions & 25 deletions src/nla.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
// SPDX-License-Identifier: MIT

use core::ops::Range;

use anyhow::Context;
use byteorder::{ByteOrder, NativeEndian};

use crate::{
traits::{Emitable, Parseable},
DecodeError,
};
use byteorder::{ByteOrder, NativeEndian};
use core::ops::Range;
use thiserror::Error;

/// Represent a multi-bytes field with a fixed size in a packet
type Field = Range<usize>;
Expand All @@ -25,6 +23,20 @@ pub const NLA_ALIGNTO: usize = 4;
/// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type)
pub const NLA_HEADER_SIZE: usize = 4;

#[derive(Debug, Error)]
pub enum NLAError {
#[error("buffer has length {buffer_len}, but an NLA header is {} bytes", TYPE.end)]
BufferTooSmall { buffer_len: usize },

#[error("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes")]
LengthMismatch { buffer_len: usize, nla_len: u16 },

#[error(
"NLA has invalid length: {nla_len} (should be at least {} bytes", TYPE.end
)]
InvalidLength { nla_len: u16 },
}

#[macro_export]
macro_rules! nla_align {
($len: expr) => {
Expand Down Expand Up @@ -52,33 +64,26 @@ impl<T: AsRef<[u8]>> NlaBuffer<T> {
NlaBuffer { buffer }
}

pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, DecodeError> {
pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, NLAError> {
let buffer = Self::new(buffer);
buffer.check_buffer_length().context("invalid NLA buffer")?;
buffer.check_buffer_length()?;
Ok(buffer)
}

pub fn check_buffer_length(&self) -> Result<(), DecodeError> {
pub fn check_buffer_length(&self) -> Result<(), NLAError> {
let len = self.buffer.as_ref().len();
if len < TYPE.end {
Err(format!(
"buffer has length {}, but an NLA header is {} bytes",
len, TYPE.end
)
.into())
Err(NLAError::BufferTooSmall { buffer_len: len }.into())
} else if len < self.length() as usize {
Err(format!(
"buffer has length: {}, but the NLA is {} bytes",
len,
self.length()
)
Err(NLAError::LengthMismatch {
buffer_len: len,
nla_len: self.length(),
}
.into())
} else if (self.length() as usize) < TYPE.end {
Err(format!(
"NLA has invalid length: {} (should be at least {} bytes",
self.length(),
TYPE.end,
)
Err(NLAError::InvalidLength {
nla_len: self.length(),
}
.into())
} else {
Ok(())
Expand Down Expand Up @@ -204,7 +209,9 @@ impl Nla for DefaultNla {
impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>>
for DefaultNla
{
fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, DecodeError> {
type Error = DecodeError;

fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, Self::Error> {
let mut kind = buf.kind();

if buf.network_byte_order_flag() {
Expand Down Expand Up @@ -314,7 +321,7 @@ impl<T> NlasIterator<T> {
impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator
for NlasIterator<&'buffer T>
{
type Item = Result<NlaBuffer<&'buffer [u8]>, DecodeError>;
type Item = Result<NlaBuffer<&'buffer [u8]>, NLAError>;

fn next(&mut self) -> Option<Self::Item> {
if self.position >= self.buffer.as_ref().len() {
Expand Down
25 changes: 12 additions & 13 deletions src/parsers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
};

use anyhow::Context;
use byteorder::{BigEndian, ByteOrder, NativeEndian};

use crate::DecodeError;

pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> {
if payload.len() != 6 {
return Err(format!("invalid MAC address: {payload:?}").into());
return Err(DecodeError::InvalidMACAddress);
}
let mut address: [u8; 6] = [0; 6];
for (i, byte) in payload.iter().enumerate() {
Expand All @@ -23,7 +22,7 @@ pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> {

pub fn parse_ipv6(payload: &[u8]) -> Result<[u8; 16], DecodeError> {
if payload.len() != 16 {
return Err(format!("invalid IPv6 address: {payload:?}").into());
return Err(DecodeError::InvalidIPv6Address);
}
let mut address: [u8; 16] = [0; 16];
for (i, byte) in payload.iter().enumerate() {
Expand Down Expand Up @@ -57,7 +56,7 @@ pub fn parse_ip(payload: &[u8]) -> Result<IpAddr, DecodeError> {
payload[15],
])
.into()),
_ => Err(format!("invalid IPv6 address: {payload:?}").into()),
_ => Err(DecodeError::InvalidIPv6Address),
}
}

Expand All @@ -71,62 +70,62 @@ pub fn parse_string(payload: &[u8]) -> Result<String, DecodeError> {
} else {
&payload[..payload.len()]
};
let s = String::from_utf8(slice.to_vec()).context("invalid string")?;
let s = String::from_utf8(slice.to_vec())?;
Ok(s)
}

pub fn parse_u8(payload: &[u8]) -> Result<u8, DecodeError> {
if payload.len() != 1 {
return Err(format!("invalid u8: {payload:?}").into());
return Err(DecodeError::InvalidU8);
}
Ok(payload[0])
}

pub fn parse_u32(payload: &[u8]) -> Result<u32, DecodeError> {
if payload.len() != size_of::<u32>() {
return Err(format!("invalid u32: {payload:?}").into());
return Err(DecodeError::InvalidU32);
}
Ok(NativeEndian::read_u32(payload))
}

pub fn parse_u64(payload: &[u8]) -> Result<u64, DecodeError> {
if payload.len() != size_of::<u64>() {
return Err(format!("invalid u64: {payload:?}").into());
return Err(DecodeError::InvalidU64);
}
Ok(NativeEndian::read_u64(payload))
}

pub fn parse_u128(payload: &[u8]) -> Result<u128, DecodeError> {
if payload.len() != size_of::<u128>() {
return Err(format!("invalid u128: {payload:?}").into());
return Err(DecodeError::InvalidU128);
}
Ok(NativeEndian::read_u128(payload))
}

pub fn parse_u16(payload: &[u8]) -> Result<u16, DecodeError> {
if payload.len() != size_of::<u16>() {
return Err(format!("invalid u16: {payload:?}").into());
return Err(DecodeError::InvalidU16);
}
Ok(NativeEndian::read_u16(payload))
}

pub fn parse_i32(payload: &[u8]) -> Result<i32, DecodeError> {
if payload.len() != 4 {
return Err(format!("invalid u32: {payload:?}").into());
return Err(DecodeError::InvalidI32);
}
Ok(NativeEndian::read_i32(payload))
}

pub fn parse_u16_be(payload: &[u8]) -> Result<u16, DecodeError> {
if payload.len() != size_of::<u16>() {
return Err(format!("invalid u16: {payload:?}").into());
return Err(DecodeError::InvalidU16);
}
Ok(BigEndian::read_u16(payload))
}

pub fn parse_u32_be(payload: &[u8]) -> Result<u32, DecodeError> {
if payload.len() != size_of::<u32>() {
return Err(format!("invalid u32: {payload:?}").into());
return Err(DecodeError::InvalidU32);
}
Ok(BigEndian::read_u32(payload))
}
10 changes: 6 additions & 4 deletions src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
// SPDX-License-Identifier: MIT

use crate::DecodeError;

/// A type that implements `Emitable` can be serialized.
pub trait Emitable {
/// Return the length of the serialized data.
Expand All @@ -26,8 +24,10 @@ where
Self: Sized,
T: ?Sized,
{
type Error;

/// Deserialize the current type.
fn parse(buf: &T) -> Result<Self, DecodeError>;
fn parse(buf: &T) -> Result<Self, Self::Error>;
}

/// A `Parseable` type can be used to deserialize data from the type `T` for
Expand All @@ -37,6 +37,8 @@ where
Self: Sized,
T: ?Sized,
{
type Error;

/// Deserialize the current type.
fn parse_with_param(buf: &T, params: P) -> Result<Self, DecodeError>;
fn parse_with_param(buf: &T, params: P) -> Result<Self, Self::Error>;
}

0 comments on commit f18501c

Please sign in to comment.