diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index 10a75f1f..6127d956 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -10,13 +10,9 @@ use crate::{ internal::{ array_builder::ArrayBuilder, arrow::Field, - deserialization::{ - array_deserializer::ArrayDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - }, deserializer::Deserializer, error::{fail, Result}, - schema::{get_strategy_from_metadata, SerdeArrowSchema}, + schema::SerdeArrowSchema, serializer::Serializer, }, }; @@ -153,14 +149,7 @@ impl<'de> Deserializer<'de> { where A: AsRef, { - let fields = fields - .iter() - .map(Field::try_from) - .collect::>>()?; - let arrays = arrays - .iter() - .map(|array| array.as_ref()) - .collect::>(); + use crate::internal::arrow::ArrayView; if fields.len() != arrays.len() { fail!( @@ -169,21 +158,16 @@ impl<'de> Deserializer<'de> { arrays.len() ); } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - - let mut deserializers = Vec::new(); - for (field, array) in std::iter::zip(fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } - let strategy = get_strategy_from_metadata(&field.metadata)?; - let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?; - deserializers.push((field.name.clone(), deserializer)); - } - let deserializer = OuterSequenceDeserializer::new(deserializers, len); - let deserializer = Deserializer(deserializer); + let fields = fields + .iter() + .map(Field::try_from) + .collect::>>()?; + let views = arrays + .iter() + .map(|array| ArrayView::try_from(array.as_ref())) + .collect::>>()?; - Ok(deserializer) + Deserializer::new(&fields, views) } } diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 1150a739..7df8f857 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -149,7 +149,7 @@ impl TryFrom for ArrayRef { let child: ArrayRef = child.try_into()?; let field = field_from_array_and_meta(child.as_ref(), meta); - type_ids.push(type_id.try_into()?); + type_ids.push(type_id.into()); values.push(child); fields.push(field); } diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index c08e2cc0..08aa8930 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -193,7 +193,7 @@ impl TryFrom<&DataType> for ArrowDataType { if *scale < 0 { fail!("arrow2 does not support decimals with negative scale"); } - Ok(AT::Decimal((*precision).try_into()?, (*scale).try_into()?)) + Ok(AT::Decimal((*precision).into(), (*scale).try_into()?)) } T::Binary => Ok(AT::Binary), T::LargeBinary => Ok(AT::LargeBinary), @@ -266,7 +266,7 @@ impl TryFrom<&DataType> for ArrowDataType { for (type_id, field) in in_fields { fields.push(AF::try_from(field)?); - type_ids.push((*type_id).try_into()?); + type_ids.push((*type_id).into()); } Ok(AT::Union(fields, Some(type_ids), (*mode).into())) } diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index 852a0198..92644fa7 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -10,13 +10,9 @@ use crate::{ }, internal::{ array_builder::ArrayBuilder, - deserialization::{ - array_deserializer::ArrayDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - }, deserializer::Deserializer, error::{fail, Result}, - schema::{get_strategy_from_metadata, SerdeArrowSchema}, + schema::SerdeArrowSchema, serializer::Serializer, }, }; @@ -241,11 +237,7 @@ impl<'de> Deserializer<'de> { where A: AsRef, { - let fields = fields_from_field_refs(fields)?; - let arrays = arrays - .iter() - .map(|array| array.as_ref()) - .collect::>(); + use crate::internal::arrow::ArrayView; if fields.len() != arrays.len() { fail!( @@ -254,23 +246,15 @@ impl<'de> Deserializer<'de> { arrays.len() ); } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - let mut deserializers = Vec::new(); - for (field, array) in std::iter::zip(&fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } + let fields = fields_from_field_refs(fields)?; - let strategy = get_strategy_from_metadata(&field.metadata)?; - let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?; - deserializers.push((field.name.clone(), deserializer)); + let mut views = Vec::new(); + for array in arrays { + views.push(ArrayView::try_from(array.as_ref())?); } - let deserializer = OuterSequenceDeserializer::new(deserializers, len); - let deserializer = Deserializer(deserializer); - - Ok(deserializer) + Deserializer::new(&fields, views) } /// Construct a new deserializer from a record batch (*requires one of the diff --git a/serde_arrow/src/arrow_impl/type_support.rs b/serde_arrow/src/arrow_impl/type_support.rs index 7f34f4c3..8330bb80 100644 --- a/serde_arrow/src/arrow_impl/type_support.rs +++ b/serde_arrow/src/arrow_impl/type_support.rs @@ -6,7 +6,7 @@ use crate::_impl::arrow::{ use crate::internal::{ arrow::Field, error::{Error, Result}, - schema::extensions::FixedShapeTensorField, + schema::extensions::{Bool8Field, FixedShapeTensorField, VariableShapeTensorField}, }; impl From for Error { @@ -15,22 +15,30 @@ impl From for Error { } } -impl TryFrom<&FixedShapeTensorField> for ArrowField { - type Error = Error; +macro_rules! impl_try_from_ext_type { + ($ty:ty) => { + impl TryFrom<&$ty> for ArrowField { + type Error = Error; - fn try_from(value: &FixedShapeTensorField) -> Result { - Self::try_from(&Field::try_from(value)?) - } -} + fn try_from(value: &$ty) -> Result { + Self::try_from(&Field::try_from(value)?) + } + } -impl TryFrom for ArrowField { - type Error = Error; + impl TryFrom<$ty> for ArrowField { + type Error = Error; - fn try_from(value: FixedShapeTensorField) -> Result { - Self::try_from(&value) - } + fn try_from(value: $ty) -> Result { + Self::try_from(&value) + } + } + }; } +impl_try_from_ext_type!(Bool8Field); +impl_try_from_ext_type!(FixedShapeTensorField); +impl_try_from_ext_type!(VariableShapeTensorField); + pub fn fields_from_field_refs(fields: &[FieldRef]) -> Result> { fields .iter() diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index b716c286..723615b3 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -1,8 +1,14 @@ use serde::de::Visitor; use crate::internal::{ - deserialization::outer_sequence_deserializer::OuterSequenceDeserializer, + arrow::{ArrayView, Field}, + deserialization::{ + array_deserializer::ArrayDeserializer, + outer_sequence_deserializer::OuterSequenceDeserializer, + }, error::{fail, Error, Result}, + schema::get_strategy_from_metadata, + utils::array_view_ext::ArrayViewExt, }; /// A structure to deserialize Arrow arrays into Rust objects @@ -14,6 +20,30 @@ use crate::internal::{ #[cfg_attr(has_arrow2, doc = r"- [`Deserializer::from_arrow2`]")] pub struct Deserializer<'de>(pub(crate) OuterSequenceDeserializer<'de>); +impl<'de> Deserializer<'de> { + pub(crate) fn new(fields: &[Field], views: Vec>) -> Result { + let len = match views.first() { + Some(view) => view.len(), + None => 0, + }; + + let mut deserializers = Vec::new(); + for (field, view) in std::iter::zip(fields, views) { + if view.len() != len { + fail!("Cannot deserialize from arrays with different lengths"); + } + let strategy = get_strategy_from_metadata(&field.metadata)?; + let deserializer = ArrayDeserializer::new(strategy.as_ref(), view)?; + deserializers.push((field.name.clone(), deserializer)); + } + + let deserializer = OuterSequenceDeserializer::new(deserializers, len); + let deserializer = Deserializer(deserializer); + + Ok(deserializer) + } +} + impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { type Error = Error; diff --git a/serde_arrow/src/internal/schema/extensions/bool8_field.rs b/serde_arrow/src/internal/schema/extensions/bool8_field.rs new file mode 100644 index 00000000..56ba81af --- /dev/null +++ b/serde_arrow/src/internal/schema/extensions/bool8_field.rs @@ -0,0 +1,106 @@ +use std::collections::HashMap; + +use crate::internal::{ + arrow::{DataType, Field}, + error::{Error, Result}, + schema::PrettyField, +}; + +/// A helper to construct new `Bool8` fields (`arrow.bool8`) +/// +/// This extension type can be used with `overwrites` in schema tracing: +/// +/// ```rust +/// # use serde_json::json; +/// # use serde_arrow::{Result, schema::{SerdeArrowSchema, SchemaLike, TracingOptions, ext::Bool8Field}}; +/// # use serde::Deserialize; +/// # fn main() -> Result<()> { +/// ##[derive(Deserialize)] +/// struct Record { +/// int_field: i32, +/// nested: Nested, +/// } +/// +/// ##[derive(Deserialize)] +/// struct Nested { +/// bool_field: bool, +/// } +/// +/// let tracing_options = TracingOptions::default() +/// .overwrite("nested.bool_field", Bool8Field::new("bool_field"))?; +/// +/// let schema = SerdeArrowSchema::from_type::(tracing_options)?; +/// # std::mem::drop(schema); +/// # Ok(()) +/// # } +/// ``` +/// +/// It can also be converted to a `arrow` `Field` for manual schema manipulation. +/// +pub struct Bool8Field { + name: String, + nullable: bool, +} + +impl Bool8Field { + /// Construct a new non-nullable `Bool8Field` + pub fn new(name: &str) -> Self { + Self { + name: name.into(), + nullable: false, + } + } + + /// Set the nullability of the field + pub fn nullable(mut self, value: bool) -> Self { + self.nullable = value; + self + } +} + +impl TryFrom<&Bool8Field> for Field { + type Error = Error; + + fn try_from(value: &Bool8Field) -> Result { + let mut metadata = HashMap::new(); + metadata.insert("ARROW:extension:name".into(), "arrow.bool8".into()); + metadata.insert("ARROW:extension:metadata".into(), String::new()); + + Ok(Field { + name: value.name.to_owned(), + nullable: value.nullable, + data_type: DataType::Int8, + metadata, + }) + } +} + +impl serde::ser::Serialize for Bool8Field { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error; + let field = Field::try_from(self).map_err(S::Error::custom)?; + PrettyField(&field).serialize(serializer) + } +} + +#[test] +fn bool8_repr() -> crate::internal::error::PanicOnError<()> { + use serde_json::json; + + let field = Bool8Field::new("hello"); + + let field = Field::try_from(&field)?; + let actual = serde_json::to_value(&PrettyField(&field))?; + + let expected = json!({ + "name": "hello", + "data_type": "I8", + "metadata": { + "ARROW:extension:name": "arrow.bool8", + "ARROW:extension:metadata": "", + }, + }); + + assert_eq!(actual, expected); + Ok(()) +} diff --git a/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs index c5239a55..b7f0ca22 100644 --- a/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs +++ b/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr}; -/// Easily construct a field for tensors with fixed shape +/// Easily construct a fixed shape tensor fields (`arrow.fixed_shape_tensor`) /// /// See the [arrow docs][fixed-shape-tensor-docs] for details on the different /// fields. @@ -51,7 +51,7 @@ pub struct FixedShapeTensorField { } impl FixedShapeTensorField { - /// Construct a new instance + /// Construct a new non-nullable `FixedShapeTensorField` /// /// Note the element parameter must serialize into a valid field definition /// with the the name `"element"`. The field type can be any valid Arrow diff --git a/serde_arrow/src/internal/schema/extensions/mod.rs b/serde_arrow/src/internal/schema/extensions/mod.rs index abf2db9d..fa879d7c 100644 --- a/serde_arrow/src/internal/schema/extensions/mod.rs +++ b/serde_arrow/src/internal/schema/extensions/mod.rs @@ -1,6 +1,8 @@ +mod bool8_field; mod fixed_shape_tensor_field; mod utils; mod variable_shape_tensor_field; +pub use bool8_field::Bool8Field; pub use fixed_shape_tensor_field::FixedShapeTensorField; pub use variable_shape_tensor_field::VariableShapeTensorField; diff --git a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs index 6f7849ab..3cc1936c 100644 --- a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs +++ b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr}; -/// Helper to build fields for tensors with variable shape +/// Helper to build variable shape tensor fields (`arrow.variable_shape_tensor`) /// /// See the [arrow docs][variable-shape-tensor-field-docs] for details on the /// different fields. @@ -26,6 +26,7 @@ pub struct VariableShapeTensorField { } impl VariableShapeTensorField { + /// Create a new non-nullable `VariableShapeTensorField` pub fn new(name: &str, element: impl serde::ser::Serialize, ndim: usize) -> Result { let element = transmute_field(element)?; if element.name != "element" { @@ -134,27 +135,28 @@ impl TryFrom<&VariableShapeTensorField> for Field { ); metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); - let mut fields = Vec::new(); - fields.push(Field { - name: String::from("data"), - data_type: DataType::List(Box::new(value.element.clone())), - nullable: false, - metadata: HashMap::new(), - }); - fields.push(Field { - name: String::from("shape"), - data_type: DataType::FixedSizeList( - Box::new(Field { - name: String::from("element"), - data_type: DataType::Int32, - nullable: false, - metadata: HashMap::new(), - }), - value.ndim.try_into()?, - ), - nullable: false, - metadata: HashMap::new(), - }); + let fields = vec![ + Field { + name: String::from("data"), + data_type: DataType::List(Box::new(value.element.clone())), + nullable: false, + metadata: HashMap::new(), + }, + Field { + name: String::from("shape"), + data_type: DataType::FixedSizeList( + Box::new(Field { + name: String::from("element"), + data_type: DataType::Int32, + nullable: false, + metadata: HashMap::new(), + }), + value.ndim.try_into()?, + ), + nullable: false, + metadata: HashMap::new(), + }, + ]; Ok(Field { name: value.name.clone(), diff --git a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs index 5c9b0fef..b0e8c007 100644 --- a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs @@ -5,9 +5,9 @@ use std::collections::HashMap; use serde::Deserialize; use serde_json::json; -use crate::{ - internal::testing::assert_error, +use crate::internal::{ schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, + testing::assert_error, }; #[test] diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 78ebba3c..8ff90182 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -135,22 +135,20 @@ pub trait SchemaLike: Sized + Sealed { /// fn from_value(value: &T) -> Result; - /// Determine the schema from the given record type. See [`TracingOptions`] - /// for customization options. + /// Determine the schema from the given record type. See [`TracingOptions`] for customization + /// options. /// - /// This approach requires the type `T` to implement - /// [`Deserialize`][serde::Deserialize]. As only type information is used, - /// it is not possible to detect data dependent properties. Examples of - /// unsupported features: + /// This approach requires the type `T` to implement [`Deserialize`][::serde::Deserialize]. As + /// only type information is used, it is not possible to detect data dependent properties. + /// Examples of unsupported features: /// /// - auto detection of date time strings /// - non self-describing types such as `serde_json::Value` /// - flattened structure (`#[serde(flatten)]`) - /// - types that require specific data to be deserialized, such as the - /// `DateTime` type of `chrono` or the `Uuid` type of the `uuid` package + /// - types that require specific data to be deserialized, such as the `DateTime` type of + /// `chrono` or the `Uuid` type of the `uuid` package /// - /// Consider using [`from_samples`][SchemaLike::from_samples] in these - /// cases. + /// Consider using [`from_samples`][SchemaLike::from_samples] in these cases. /// /// ```rust /// # #[cfg(has_arrow)] @@ -199,20 +197,17 @@ pub trait SchemaLike: Sized + Sealed { /// ``` fn from_type<'de, T: Deserialize<'de> + ?Sized>(options: TracingOptions) -> Result; - /// Determine the schema from samples. See [`TracingOptions`] for - /// customization options. + /// Determine the schema from samples. See [`TracingOptions`] for customization options. /// - /// This approach requires the type `T` to implement - /// [`Serialize`][serde::Serialize] and the samples to include all relevant - /// values. It uses only the information encoded in the samples to generate - /// the schema. Therefore, the following requirements must be met: + /// This approach requires the type `T` to implement [`Serialize`][::serde::Serialize] and the + /// samples to include all relevant values. It uses only the information encoded in the samples + /// to generate the schema. Therefore, the following requirements must be met: /// /// - at least one `Some` value for `Option<..>` fields /// - all variants of enum fields /// - at least one element for sequence fields (e.g., `Vec<..>`) - /// - at least one example for map types (e.g., `HashMap<.., ..>`). All - /// possible keys must be given, if [`options.map_as_struct == - /// true`][TracingOptions::map_as_struct]). + /// - at least one example for map types (e.g., `HashMap<.., ..>`). All possible keys must be + /// given, if [`options.map_as_struct == true`][TracingOptions::map_as_struct]). /// /// ```rust /// # #[cfg(has_arrow)] diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index f31cfeed..4d3937b6 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -4,8 +4,8 @@ use std::collections::HashMap; use serde::ser::{SerializeSeq, SerializeStruct}; -use crate::{ - internal::arrow::{DataType, Field}, +use crate::internal::{ + arrow::{DataType, Field}, schema::{SerdeArrowSchema, STRATEGY_KEY}, }; diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 861b386c..3f2a5e1a 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -20,7 +20,7 @@ const RECURSIVE_TYPE_WARNING: &str = fn default_dictionary_field(name: &str, nullable: bool) -> Field { Field { name: name.to_owned(), - nullable: nullable, + nullable, metadata: HashMap::new(), data_type: DataType::Dictionary( Box::new(DataType::UInt32), @@ -560,7 +560,7 @@ impl Tracer { (ty, ev) => fail!( "Cannot accept event {ev} for tracer of primitive type {ty}", ev = DataTypeDisplay(&ev), - ty = DataTypeDisplay(&ty), + ty = DataTypeDisplay(ty), ), }; tracer.item_type = item_type; diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 2bc1ea5a..8bb2adc5 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -3,13 +3,13 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, BytesArray}, error::Result, - utils::{Mut, Offset}, + utils::{ + array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, + Mut, Offset, + }, }; -use super::{ - array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 1614bcfe..c1d23683 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, BooleanArray}, error::Result, + utils::array_ext::{set_bit_buffer, set_validity, set_validity_default}, }; -use super::{ - array_ext::{set_bit_buffer, set_validity, set_validity_default}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct BoolBuilder(BooleanArray); diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 1aaa3079..d4feb160 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -3,12 +3,10 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Date32Builder(PrimitiveArray); diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 576229c3..7001b558 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, error::{fail, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Date64Builder { diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 1caf8802..2f851d47 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,13 +1,11 @@ use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, utils::decimal::{self, DecimalParser}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct DecimalBuilder { diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 17ecc99a..1b835a69 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct DurationBuilder { diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 2c82b46e..329d4fc2 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -3,13 +3,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; -use super::{ - array_ext::{ArrayExt, CountArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 89007756..1c222304 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -3,14 +3,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, CountArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index a9be17b2..f54ed275 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -3,13 +3,11 @@ use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, utils::Mut, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct FloatBuilder(PrimitiveArray); diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index acea49ff..f385667f 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, PrimitiveArray}, error::{Error, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct IntBuilder(PrimitiveArray); @@ -77,6 +75,11 @@ where self.0.push_scalar_none() } + fn serialize_bool(&mut self, v: bool) -> Result<()> { + let v: u8 = if v { 1 } else { 0 }; + self.0.push_scalar_value(I::try_from(v)?) + } + fn serialize_i8(&mut self, v: i8) -> Result<()> { self.0.push_scalar_value(I::try_from(v)?) } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 8a2bdad1..7578433e 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -3,14 +3,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, error::Result, + utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, utils::{Mut, Offset}, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 8c0499d1..737d1df6 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -3,13 +3,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct MapBuilder { diff --git a/serde_arrow/src/internal/serialization/mod.rs b/serde_arrow/src/internal/serialization/mod.rs index ac137b17..f6af48eb 100644 --- a/serde_arrow/src/internal/serialization/mod.rs +++ b/serde_arrow/src/internal/serialization/mod.rs @@ -1,7 +1,6 @@ //! A serialization implementation without the event model pub mod array_builder; -pub mod array_ext; pub mod binary_builder; pub mod bool_builder; pub mod date32_builder; diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index a76a6851..c9e837ae 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -114,7 +114,7 @@ impl OuterSequenceBuilder { build_builder(entry_field.as_ref())?, field.nullable, )?), - T::Struct(children) => A::Struct(build_struct(&children, field.nullable)?), + T::Struct(children) => A::Struct(build_struct(children, field.nullable)?), T::Dictionary(key, value, _) => { let key_field = Field { name: "key".to_string(), diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 264a7b86..68753f13 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -5,14 +5,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, CountArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; const UNKNOWN_KEY: usize = usize::MAX; diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index ed6dae67..550cc17a 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -3,12 +3,10 @@ use chrono::Timelike; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, error::{Error, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct TimeBuilder { diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index b6b6c0f3..c1c4b4bb 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -1,11 +1,8 @@ use serde::Serialize; -use crate::{ - internal::{ - arrow::{Array, NullArray}, - error::fail, - }, - Result, +use crate::internal::{ + arrow::{Array, NullArray}, + error::{fail, Result}, }; use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 788123cf..b811b8aa 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,13 +1,11 @@ use crate::internal::{ arrow::{Array, BytesArray}, error::{fail, Result}, + utils::array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, utils::Offset, }; -use super::{ - array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Utf8Builder(BytesArray); diff --git a/serde_arrow/src/internal/serialization/array_ext.rs b/serde_arrow/src/internal/utils/array_ext.rs similarity index 100% rename from serde_arrow/src/internal/serialization/array_ext.rs rename to serde_arrow/src/internal/utils/array_ext.rs diff --git a/serde_arrow/src/internal/utils/array_view_ext.rs b/serde_arrow/src/internal/utils/array_view_ext.rs new file mode 100644 index 00000000..df26ae13 --- /dev/null +++ b/serde_arrow/src/internal/utils/array_view_ext.rs @@ -0,0 +1,48 @@ +use crate::internal::arrow::ArrayView; + +pub trait ArrayViewExt { + fn len(&self) -> usize; +} + +impl<'a> ArrayViewExt for ArrayView<'a> { + fn len(&self) -> usize { + use ArrayView as V; + match self { + V::Null(view) => view.len, + V::Boolean(view) => view.len, + V::Int8(view) => view.values.len(), + V::Int16(view) => view.values.len(), + V::Int32(view) => view.values.len(), + V::Int64(view) => view.values.len(), + V::UInt8(view) => view.values.len(), + V::UInt16(view) => view.values.len(), + V::UInt32(view) => view.values.len(), + V::UInt64(view) => view.values.len(), + V::Float16(view) => view.values.len(), + V::Float32(view) => view.values.len(), + V::Float64(view) => view.values.len(), + V::Date32(view) => view.values.len(), + V::Date64(view) => view.values.len(), + V::Time32(view) => view.values.len(), + V::Time64(view) => view.values.len(), + V::Timestamp(view) => view.values.len(), + V::Duration(view) => view.values.len(), + V::Decimal128(view) => view.values.len(), + V::Utf8(view) => view.offsets.len().saturating_sub(1), + V::LargeUtf8(view) => view.offsets.len().saturating_sub(1), + V::Binary(view) => view.offsets.len().saturating_sub(1), + V::LargeBinary(view) => view.offsets.len().saturating_sub(1), + V::FixedSizeBinary(view) => match usize::try_from(view.n) { + Ok(n) if n > 0 => view.data.len() / n, + _ => 0, + }, + V::FixedSizeList(view) => view.len, + V::List(view) => view.offsets.len().saturating_sub(1), + V::LargeList(view) => view.offsets.len().saturating_sub(1), + V::DenseUnion(view) => view.types.len(), + V::Map(view) => view.offsets.len().saturating_sub(1), + V::Struct(view) => view.len, + V::Dictionary(view) => view.indices.len(), + } + } +} diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 6995f332..cee187fe 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -1,3 +1,5 @@ +pub mod array_ext; +pub mod array_view_ext; pub mod decimal; pub mod dsl; pub mod value; diff --git a/serde_arrow/src/internal/utils/value.rs b/serde_arrow/src/internal/utils/value.rs index 91bdb1e2..9f7eea00 100644 --- a/serde_arrow/src/internal/utils/value.rs +++ b/serde_arrow/src/internal/utils/value.rs @@ -1,7 +1,7 @@ //! Serialize values into a in-memory representation use serde::{de::DeserializeOwned, forward_to_deserialize_any, Serialize}; -use crate::{internal::error::fail, Error, Result}; +use crate::internal::error::{fail, Error, Result}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Variant(u32, &'static str); diff --git a/serde_arrow/src/lib.rs b/serde_arrow/src/lib.rs index b3a85ce5..1ce4d392 100644 --- a/serde_arrow/src/lib.rs +++ b/serde_arrow/src/lib.rs @@ -391,7 +391,7 @@ pub mod schema { /// [ext-docs]: https://arrow.apache.org/docs/format/CanonicalExtensions.html pub mod ext { pub use crate::internal::schema::extensions::{ - FixedShapeTensorField, VariableShapeTensorField, + Bool8Field, FixedShapeTensorField, VariableShapeTensorField, }; } } diff --git a/serde_arrow/src/test_with_arrow/impls/bool8.rs b/serde_arrow/src/test_with_arrow/impls/bool8.rs new file mode 100644 index 00000000..14fb52c2 --- /dev/null +++ b/serde_arrow/src/test_with_arrow/impls/bool8.rs @@ -0,0 +1,77 @@ +use serde::Deserialize; +use serde_json::json; + +use crate::internal::{ + arrow::{ArrayView, DataType, Field, PrimitiveArrayView}, + deserializer::Deserializer, + schema::{extensions::Bool8Field, TracingOptions}, + utils::{Item, Items}, +}; + +use super::utils::Test; + +#[test] +fn bool_as_int8() { + let items = &[Item(true), Item(false)]; + Test::new() + .with_schema(json!([{"name": "item", "data_type": "I8"}])) + .serialize(items) + .deserialize(items) + .check_nulls(&[&[false, false]]); +} + +#[test] +fn nullable_bool_as_int8() { + let items = &[Item(Some(true)), Item(None), Item(Some(false))]; + Test::new() + .with_schema(json!([{"name": "item", "data_type": "I8", "nullable": true}])) + .serialize(items) + .deserialize(items) + .check_nulls(&[&[false, true, false]]); +} + +// from the bool8 specs: false is denoted by the value 0. true can be specified using any non-zero +// value. Preferably 1. +#[test] +fn deserialize_from_not_01_ints() -> crate::internal::error::PanicOnError<()> { + let field = Field { + name: String::from("item"), + data_type: DataType::Int8, + nullable: false, + metadata: Default::default(), + }; + let view = ArrayView::Int8(PrimitiveArrayView { + validity: None, + values: &[0, -1, 2, 3, -31, 100, 0, 0], + }); + let deserializer = Deserializer::new(&[field], vec![view])?; + + let Items(actual) = Items::>::deserialize(deserializer)?; + let expected = vec![false, true, true, true, true, true, false, false]; + assert_eq!(actual, expected); + + Ok(()) +} + +#[test] +fn overwrites() -> crate::internal::error::PanicOnError<()> { + let tracing_options = TracingOptions::new().overwrite("item", Bool8Field::new("item"))?; + + let items = &[Item(true), Item(false)]; + Test::new() + .with_schema(json!([{ + "name": "item", + "data_type": "I8", + "metadata": { + "ARROW:extension:name": "arrow.bool8", + "ARROW:extension:metadata": "", + }, + }])) + .trace_schema_from_samples(&items, tracing_options.clone()) + .trace_schema_from_type::>(tracing_options) + .serialize(items) + .deserialize(items) + .check_nulls(&[&[false, false]]); + + Ok(()) +} diff --git a/serde_arrow/src/test_with_arrow/impls/mod.rs b/serde_arrow/src/test_with_arrow/impls/mod.rs index 90695d09..72db9255 100644 --- a/serde_arrow/src/test_with_arrow/impls/mod.rs +++ b/serde_arrow/src/test_with_arrow/impls/mod.rs @@ -1,5 +1,6 @@ mod utils; +mod bool8; mod bytes; mod chrono; mod dictionary;