diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 060fde71..24b3db5f 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -1,5 +1,6 @@ use std::{ backtrace::{Backtrace, BacktraceStatus}, + collections::BTreeMap, convert::Infallible, }; @@ -21,6 +22,7 @@ pub type Result = std::result::Result; #[non_exhaustive] pub enum Error { Custom(CustomError), + Annotated(AnnotatedError), } impl Error { @@ -45,17 +47,52 @@ impl Error { } impl Error { + pub(crate) fn empty() -> Self { + Self::Custom(CustomError { + message: String::new(), + backtrace: Backtrace::disabled(), + cause: None, + }) + } + pub fn message(&self) -> &str { match self { Self::Custom(err) => &err.message, + Self::Annotated(err) => err.error.message(), } } pub fn backtrace(&self) -> &Backtrace { match self { Self::Custom(err) => &err.backtrace, + Self::Annotated(err) => &err.error.backtrace(), + } + } + + pub(crate) fn annotations(&self) -> Option<&BTreeMap> { + match self { + Self::Custom(_) => None, + Self::Annotated(err) => Some(&err.annotations), } } + + /// Ensure the error is annotated and return a mutable reference to the annotations + pub(crate) fn annotations_mut(&mut self) -> &mut BTreeMap { + if !matches!(self, Self::Annotated(_)) { + let mut this = Error::empty(); + std::mem::swap(self, &mut this); + + *self = Self::Annotated(AnnotatedError { + error: Box::new(this), + annotations: BTreeMap::new(), + }); + } + + let Self::Annotated(err) = self else { + unreachable!(); + }; + &mut err.annotations + } } pub struct CustomError { @@ -70,6 +107,17 @@ impl std::cmp::PartialEq for CustomError { } } +pub struct AnnotatedError { + pub(crate) error: Box, + pub(crate) annotations: BTreeMap, +} + +impl std::cmp::PartialEq for AnnotatedError { + fn eq(&self, other: &Self) -> bool { + self.error.eq(&other.error) && self.annotations == other.annotations + } +} + impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "<{self}>") @@ -78,14 +126,35 @@ impl std::fmt::Debug for Error { impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Error::Custom(e) => write!( - f, - "Error: {msg}\n{bt}", - msg = e.message, - bt = BacktraceDisplay(&e.backtrace), - ), + write!( + f, + "Error: {msg}{annotations}\n{bt}", + msg = self.message(), + annotations = AnnotationsDisplay(self.annotations()), + bt = BacktraceDisplay(self.backtrace()), + ) + } +} + +struct AnnotationsDisplay<'a>(Option<&'a BTreeMap>); + +impl<'a> std::fmt::Display for AnnotationsDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Some(annotations) = self.0 else { + return Ok(()); + }; + if annotations.is_empty() { + return Ok(()); + } + + write!(f, "(")?; + for (idx, (key, value)) in annotations.iter().enumerate() { + if idx != 0 { + write!(f, ", ")?; + } + write!(f, "{key}: {value:?}")?; } + write!(f, ")") } } diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 52438114..416ae157 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -25,6 +25,11 @@ use super::ArrayBuilder; pub trait SimpleSerializer: Sized { fn name(&self) -> &str; + // TODO: remove default + fn annotate_error(&self, err: Error) -> Error { + err + } + fn serialize_default(&mut self) -> Result<()> { fail!("serialize_default is not supported for {}", self.name()); } @@ -275,71 +280,105 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { type SerializeTupleVariant = Mut<'a, ArrayBuilder>; fn serialize_unit(self) -> Result<()> { - self.0.serialize_unit() + self.0 + .serialize_unit() + .map_err(|err| self.0.annotate_error(err)) } fn serialize_none(self) -> Result<()> { - self.0.serialize_none() + self.0 + .serialize_none() + .map_err(|err| self.0.annotate_error(err)) } fn serialize_some(self, value: &V) -> Result<()> { - self.0.serialize_some(value) + self.0 + .serialize_some(value) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_bool(self, v: bool) -> Result<()> { - self.0.serialize_bool(v) + self.0 + .serialize_bool(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_char(self, v: char) -> Result<()> { - self.0.serialize_char(v) + self.0 + .serialize_char(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u8(self, v: u8) -> Result<()> { - self.0.serialize_u8(v) + self.0 + .serialize_u8(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u16(self, v: u16) -> Result<()> { - self.0.serialize_u16(v) + self.0 + .serialize_u16(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u32(self, v: u32) -> Result<()> { - self.0.serialize_u32(v) + self.0 + .serialize_u32(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u64(self, v: u64) -> Result<()> { - self.0.serialize_u64(v) + self.0 + .serialize_u64(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i8(self, v: i8) -> Result<()> { - self.0.serialize_i8(v) + self.0 + .serialize_i8(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i16(self, v: i16) -> Result<()> { - self.0.serialize_i16(v) + self.0 + .serialize_i16(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i32(self, v: i32) -> Result<()> { - self.0.serialize_i32(v) + self.0 + .serialize_i32(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i64(self, v: i64) -> Result<()> { - self.0.serialize_i64(v) + self.0 + .serialize_i64(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_f32(self, v: f32) -> Result<()> { - self.0.serialize_f32(v) + self.0 + .serialize_f32(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_f64(self, v: f64) -> Result<()> { - self.0.serialize_f64(v) + self.0 + .serialize_f64(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_bytes(self, v: &[u8]) -> Result<()> { - self.0.serialize_bytes(v) + self.0 + .serialize_bytes(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_str(self, v: &str) -> Result<()> { - self.0.serialize_str(v) + self.0 + .serialize_str(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_newtype_struct( @@ -347,7 +386,9 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { name: &'static str, value: &V, ) -> Result<()> { - self.0.serialize_newtype_struct(name, value) + self.0 + .serialize_newtype_struct(name, value) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_newtype_variant( @@ -359,10 +400,13 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { ) -> Result<()> { self.0 .serialize_newtype_variant(name, variant_index, variant, value) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_unit_struct(self, name: &'static str) -> Result<()> { - self.0.serialize_unit_struct(name) + self.0 + .serialize_unit_struct(name) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_unit_variant( @@ -371,26 +415,36 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant_index: u32, variant: &'static str, ) -> Result<()> { - self.0.serialize_unit_variant(name, variant_index, variant) + self.0 + .serialize_unit_variant(name, variant_index, variant) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_map(self, len: Option) -> Result { - self.0.serialize_map_start(len)?; + self.0 + .serialize_map_start(len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } fn serialize_seq(self, len: Option) -> Result { - self.0.serialize_seq_start(len)?; + self.0 + .serialize_seq_start(len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } fn serialize_struct(self, name: &'static str, len: usize) -> Result { - self.0.serialize_struct_start(name, len)?; + self.0 + .serialize_struct_start(name, len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } fn serialize_tuple(self, len: usize) -> Result { - self.0.serialize_tuple_start(len)?; + self.0 + .serialize_tuple_start(len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } @@ -399,7 +453,9 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { name: &'static str, len: usize, ) -> Result { - self.0.serialize_tuple_struct_start(name, len)?; + self.0 + .serialize_tuple_struct_start(name, len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } @@ -410,10 +466,16 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant: &'static str, len: usize, ) -> Result { - let variant_builder = - self.0 - .serialize_struct_variant_start(name, variant_index, variant, len)?; - Ok(Mut(variant_builder)) + // cannot borrow self immutably, as the result will keep self.0 borrowed mutably + // TODO: figure out how to remove this hack + let annotations_error = self.0.annotate_error(Error::empty()); + match self + .0 + .serialize_struct_variant_start(name, variant_index, variant, len) + { + Ok(variant_builder) => Ok(Mut(variant_builder)), + Err(err) => Err(merge_annotations(err, annotations_error)), + } } fn serialize_tuple_variant( @@ -423,11 +485,32 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant: &'static str, len: usize, ) -> Result { - let variant_builder = - self.0 - .serialize_tuple_variant_start(name, variant_index, variant, len)?; - Ok(Mut(variant_builder)) + // cannot borrow self immutably, as the result will keep self.0 borrowed mutably + // TODO: figure out how to remove this hack + let annotations_error = self.0.annotate_error(Error::empty()); + match self + .0 + .serialize_tuple_variant_start(name, variant_index, variant, len) + { + Ok(variant_builder) => Ok(Mut(variant_builder)), + Err(err) => Err(merge_annotations(err, annotations_error)), + } + } +} + +fn merge_annotations(mut err: Error, mut annotations_err: Error) -> Error { + let extra_annotations = std::mem::take(annotations_err.annotations_mut()); + if extra_annotations.is_empty() { + return err; } + + let result_annotations = err.annotations_mut(); + for (key, value) in extra_annotations { + if !result_annotations.contains_key(&key) { + result_annotations.insert(key, value); + } + } + err } impl<'a, T: SimpleSerializer> SerializeMap for Mut<'a, T> { @@ -435,15 +518,21 @@ impl<'a, T: SimpleSerializer> SerializeMap for Mut<'a, T> { type Error = Error; fn serialize_key(&mut self, key: &V) -> Result<()> { - self.0.serialize_map_key(key) + self.0 + .serialize_map_key(key) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_value(&mut self, value: &V) -> Result<()> { - self.0.serialize_map_value(value) + self.0 + .serialize_map_value(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_map_end() + self.0 + .serialize_map_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -452,11 +541,15 @@ impl<'a, T: SimpleSerializer> SerializeSeq for Mut<'a, T> { type Error = Error; fn serialize_element(&mut self, value: &V) -> Result<()> { - self.0.serialize_seq_element(value) + self.0 + .serialize_seq_element(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_seq_end() + self.0 + .serialize_seq_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -469,11 +562,15 @@ impl<'a, T: SimpleSerializer> SerializeStruct for Mut<'a, T> { key: &'static str, value: &V, ) -> Result<()> { - self.0.serialize_struct_field(key, value) + self.0 + .serialize_struct_field(key, value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_struct_end() + self.0 + .serialize_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -482,11 +579,15 @@ impl<'a, T: SimpleSerializer> SerializeTuple for Mut<'a, T> { type Error = Error; fn serialize_element(&mut self, value: &V) -> Result<()> { - self.0.serialize_tuple_element(value) + self.0 + .serialize_tuple_element(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_tuple_end() + self.0 + .serialize_tuple_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -495,11 +596,15 @@ impl<'a, T: SimpleSerializer> SerializeTupleStruct for Mut<'a, T> { type Error = Error; fn serialize_field(&mut self, value: &V) -> Result<()> { - self.0.serialize_tuple_struct_field(value) + self.0 + .serialize_tuple_struct_field(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_tuple_struct_end() + self.0 + .serialize_tuple_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -512,11 +617,15 @@ impl<'a, T: SimpleSerializer> SerializeStructVariant for Mut<'a, T> { key: &'static str, value: &V, ) -> Result<()> { - self.0.serialize_struct_field(key, value) + self.0 + .serialize_struct_field(key, value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_struct_end() + self.0 + .serialize_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -525,10 +634,14 @@ impl<'a, T: SimpleSerializer> SerializeTupleVariant for Mut<'a, T> { type Error = Error; fn serialize_field(&mut self, value: &V) -> Result<()> { - self.0.serialize_tuple_struct_field(value) + self.0 + .serialize_tuple_struct_field(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_tuple_struct_end() + self.0 + .serialize_tuple_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index b811b8aa..69f78efd 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,8 +1,10 @@ use crate::internal::{ arrow::{Array, BytesArray}, error::{fail, Result}, - utils::array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - utils::Offset, + utils::{ + array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, + Offset, + }, }; use super::simple_serializer::SimpleSerializer; diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index e1c999c8..7730e5aa 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1 +1 @@ -mod push_validity; \ No newline at end of file +mod push_validity; diff --git a/serde_arrow/src/test/error_messages/push_validity.rs b/serde_arrow/src/test/error_messages/push_validity.rs index f7ca8863..77097aeb 100644 --- a/serde_arrow/src/test/error_messages/push_validity.rs +++ b/serde_arrow/src/test/error_messages/push_validity.rs @@ -1,21 +1,25 @@ use serde::Serialize; use serde_json::json; -use crate::internal::{array_builder::ArrayBuilder, error::PanicOnError, schema::{SchemaLike, SerdeArrowSchema}, testing::assert_error}; - +use crate::internal::{ + array_builder::ArrayBuilder, + error::PanicOnError, + schema::{SchemaLike, SerdeArrowSchema}, + testing::assert_error, +}; #[test] fn push_validity_issue_202() -> PanicOnError<()> { let schema = SerdeArrowSchema::from_value(&json!([ { - "name": "nested", - "data_type": "Struct", + "name": "nested", + "data_type": "Struct", "children": [ {"name": "field", "data_type": "U32"}, ], }, ]))?; - + #[derive(Serialize)] struct Record { nested: Nested, @@ -27,11 +31,15 @@ fn push_validity_issue_202() -> PanicOnError<()> { } let mut array_builder = ArrayBuilder::new(schema)?; - let res = array_builder.push(&Record { nested: Nested { field: Some(5) }}); + let res = array_builder.push(&Record { + nested: Nested { field: Some(5) }, + }); assert_eq!(res, Ok(())); - let res = array_builder.push(&Record { nested: Nested { field: None }}); + let res = array_builder.push(&Record { + nested: Nested { field: None }, + }); assert_error(&res, "field: \"nested.field\""); - + Ok(()) -} \ No newline at end of file +}