diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 2ac95074..5e564447 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::TimeUnit, deserialization::{ array_deserializer::ArrayDeserializer, bool_deserializer::BoolDeserializer, @@ -32,7 +33,6 @@ use crate::_impl::arrow2::{ datatypes::{DataType, Field, UnionMode}, types::{f16, NativeType, Offset as ArrowOffset}, }; -use crate::internal::schema::GenericTimeUnit; impl<'de> Deserializer<'de> { /// Build a deserializer from `arrow2` arrays (*requires one of the @@ -215,7 +215,7 @@ pub fn build_date64_deserializer<'a>( Ok(Date64Deserializer::new( as_primitive_values(array)?, get_validity(array), - GenericTimeUnit::Millisecond, + TimeUnit::Millisecond, field.is_utc()?, ) .into()) diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 5f137994..07ab3843 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -1,12 +1,15 @@ use std::collections::HashMap; use crate::{ - _impl::arrow2::datatypes::{DataType, Field, IntegerType, TimeUnit, UnionMode}, + _impl::arrow2::datatypes::{ + DataType, Field, IntegerType, TimeUnit as ArrowTimeUnit, UnionMode, + }, internal::{ + arrow::TimeUnit, error::{error, fail, Error, Result}, schema::{ merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, - GenericField, GenericTimeUnit, SchemaLike, Sealed, SerdeArrowSchema, + GenericField, SchemaLike, Sealed, SerdeArrowSchema, }, }, }; @@ -67,7 +70,7 @@ impl TryFrom<&Field> for GenericField { type Error = Error; fn try_from(field: &Field) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; let metadata = field .metadata @@ -104,26 +107,26 @@ impl TryFrom<&Field> for GenericField { } T::Decimal128(*precision as u8, *scale as i8) } - DataType::Time32(TimeUnit::Second) => T::Time32(U::Second), - DataType::Time32(TimeUnit::Millisecond) => T::Time32(U::Millisecond), + DataType::Time32(ArrowTimeUnit::Second) => T::Time32(U::Second), + DataType::Time32(ArrowTimeUnit::Millisecond) => T::Time32(U::Millisecond), DataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - DataType::Time64(TimeUnit::Microsecond) => T::Time64(U::Microsecond), - DataType::Time64(TimeUnit::Nanosecond) => T::Time64(U::Nanosecond), + DataType::Time64(ArrowTimeUnit::Microsecond) => T::Time64(U::Microsecond), + DataType::Time64(ArrowTimeUnit::Nanosecond) => T::Time64(U::Nanosecond), DataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - DataType::Timestamp(TimeUnit::Second, tz) => T::Timestamp(U::Second, tz.clone()), - DataType::Timestamp(TimeUnit::Millisecond, tz) => { + DataType::Timestamp(ArrowTimeUnit::Second, tz) => T::Timestamp(U::Second, tz.clone()), + DataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => { T::Timestamp(U::Millisecond, tz.clone()) } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { + DataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => { T::Timestamp(U::Microsecond, tz.clone()) } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + DataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => { T::Timestamp(U::Nanosecond, tz.clone()) } - DataType::Duration(TimeUnit::Second) => T::Duration(U::Second), - DataType::Duration(TimeUnit::Millisecond) => T::Duration(U::Millisecond), - DataType::Duration(TimeUnit::Microsecond) => T::Duration(U::Microsecond), - DataType::Duration(TimeUnit::Nanosecond) => T::Duration(U::Nanosecond), + DataType::Duration(ArrowTimeUnit::Second) => T::Duration(U::Second), + DataType::Duration(ArrowTimeUnit::Millisecond) => T::Duration(U::Millisecond), + DataType::Duration(ArrowTimeUnit::Microsecond) => T::Duration(U::Microsecond), + DataType::Duration(ArrowTimeUnit::Nanosecond) => T::Duration(U::Nanosecond), DataType::List(field) => { children.push(GenericField::try_from(field.as_ref())?); T::List @@ -194,7 +197,7 @@ impl TryFrom<&GenericField> for Field { type Error = Error; fn try_from(value: &GenericField) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; let data_type = match &value.data_type { T::Null => DataType::Null, @@ -212,11 +215,11 @@ impl TryFrom<&GenericField> for Field { T::F64 => DataType::Float64, T::Date32 => DataType::Date32, T::Date64 => DataType::Date64, - T::Time32(U::Second) => DataType::Time32(TimeUnit::Second), - T::Time32(U::Millisecond) => DataType::Time32(TimeUnit::Millisecond), + T::Time32(U::Second) => DataType::Time32(ArrowTimeUnit::Second), + T::Time32(U::Millisecond) => DataType::Time32(ArrowTimeUnit::Millisecond), T::Time32(unit) => fail!("Invalid time unit {unit} for Time32"), - T::Time64(U::Microsecond) => DataType::Time64(TimeUnit::Microsecond), - T::Time64(U::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond), + T::Time64(U::Microsecond) => DataType::Time64(ArrowTimeUnit::Microsecond), + T::Time64(U::Nanosecond) => DataType::Time64(ArrowTimeUnit::Nanosecond), T::Time64(unit) => fail!("Invalid time unit {unit} for Time64"), T::Timestamp(unit, tz) => DataType::Timestamp((*unit).into(), tz.clone()), T::Duration(unit) => DataType::Duration((*unit).into()), @@ -306,13 +309,13 @@ impl TryFrom<&GenericField> for Field { } } -impl From for TimeUnit { - fn from(value: GenericTimeUnit) -> Self { +impl From for ArrowTimeUnit { + fn from(value: TimeUnit) -> Self { match value { - GenericTimeUnit::Second => Self::Second, - GenericTimeUnit::Millisecond => Self::Millisecond, - GenericTimeUnit::Microsecond => Self::Microsecond, - GenericTimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Second => Self::Second, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Nanosecond => Self::Nanosecond, } } } diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 788d33ef..2ff608bd 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::TimeUnit, deserialization::{ array_deserializer::ArrayDeserializer, binary_deserializer::BinaryDeserializer, @@ -22,7 +23,7 @@ use crate::internal::{ }, deserializer::Deserializer, error::{fail, Result}, - schema::{GenericDataType, GenericField, GenericTimeUnit}, + schema::{GenericDataType, GenericField}, utils::Offset, }; @@ -124,7 +125,7 @@ pub fn build_array_deserializer<'a>( field: &GenericField, array: &'a dyn Array, ) -> Result> { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; match &field.data_type { T::Null => Ok(NullDeserializer.into()), T::Bool => build_bool_deserializer(field, array), @@ -284,7 +285,7 @@ pub fn build_date64_deserializer<'a>( Ok(Date64Deserializer::new( as_primitive_values::(array)?, get_validity(array), - GenericTimeUnit::Millisecond, + TimeUnit::Millisecond, field.is_utc()?, ) .into()) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 5d637400..242440c1 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -1,12 +1,13 @@ use std::sync::Arc; use crate::{ - _impl::arrow::datatypes::{DataType, Field, FieldRef, TimeUnit, UnionMode}, + _impl::arrow::datatypes::{DataType, Field, FieldRef, TimeUnit as ArrowTimeUnit, UnionMode}, internal::{ + arrow::TimeUnit, error::{error, fail, Error, Result}, schema::{ merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, - GenericField, GenericTimeUnit, SchemaLike, Sealed, SerdeArrowSchema, + GenericField, SchemaLike, Sealed, SerdeArrowSchema, }, }, }; @@ -123,7 +124,7 @@ impl TryFrom<&DataType> for GenericDataType { type Error = Error; fn try_from(value: &DataType) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; match value { DataType::Boolean => Ok(T::Bool), DataType::Null => Ok(T::Null), @@ -143,31 +144,31 @@ impl TryFrom<&DataType> for GenericDataType { DataType::Date32 => Ok(T::Date32), DataType::Date64 => Ok(T::Date64), DataType::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), - DataType::Time32(TimeUnit::Second) => Ok(T::Time32(U::Second)), - DataType::Time32(TimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), + DataType::Time32(ArrowTimeUnit::Second) => Ok(T::Time32(U::Second)), + DataType::Time32(ArrowTimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), DataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - DataType::Time64(TimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), - DataType::Time64(TimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), + DataType::Time64(ArrowTimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), + DataType::Time64(ArrowTimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), DataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - DataType::Timestamp(TimeUnit::Second, tz) => { + DataType::Timestamp(ArrowTimeUnit::Second, tz) => { Ok(T::Timestamp(U::Second, tz.as_ref().map(|s| s.to_string()))) } - DataType::Timestamp(TimeUnit::Millisecond, tz) => Ok(T::Timestamp( + DataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => Ok(T::Timestamp( U::Millisecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Timestamp(TimeUnit::Microsecond, tz) => Ok(T::Timestamp( + DataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => Ok(T::Timestamp( U::Microsecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Timestamp(TimeUnit::Nanosecond, tz) => Ok(T::Timestamp( + DataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => Ok(T::Timestamp( U::Nanosecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Duration(TimeUnit::Second) => Ok(T::Duration(U::Second)), - DataType::Duration(TimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), - DataType::Duration(TimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), - DataType::Duration(TimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), + DataType::Duration(ArrowTimeUnit::Second) => Ok(T::Duration(U::Second)), + DataType::Duration(ArrowTimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), + DataType::Duration(ArrowTimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), + DataType::Duration(ArrowTimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), DataType::Binary => Ok(T::Binary), DataType::LargeBinary => Ok(T::LargeBinary), DataType::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), @@ -253,7 +254,7 @@ impl TryFrom<&GenericField> for Field { type Error = Error; fn try_from(value: &GenericField) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; let data_type = match &value.data_type { T::Null => DataType::Null, @@ -354,11 +355,11 @@ impl TryFrom<&GenericField> for Field { DataType::Dictionary(Box::new(key_type), Box::new(val_field.data_type().clone())) } - T::Time32(U::Second) => DataType::Time32(TimeUnit::Second), - T::Time32(U::Millisecond) => DataType::Time32(TimeUnit::Millisecond), + T::Time32(U::Second) => DataType::Time32(ArrowTimeUnit::Second), + T::Time32(U::Millisecond) => DataType::Time32(ArrowTimeUnit::Millisecond), T::Time32(unit) => fail!("invalid time unit {unit} for Time32"), - T::Time64(U::Microsecond) => DataType::Time64(TimeUnit::Microsecond), - T::Time64(U::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond), + T::Time64(U::Microsecond) => DataType::Time64(ArrowTimeUnit::Microsecond), + T::Time64(U::Nanosecond) => DataType::Time64(ArrowTimeUnit::Nanosecond), T::Time64(unit) => fail!("invalid time unit {unit} for Time64"), T::Timestamp(unit, tz) => { DataType::Timestamp((*unit).into(), tz.clone().map(|s| s.into())) @@ -376,13 +377,39 @@ impl TryFrom<&GenericField> for Field { } } -impl From for TimeUnit { - fn from(value: GenericTimeUnit) -> Self { +impl From for ArrowTimeUnit { + fn from(value: TimeUnit) -> Self { match value { - GenericTimeUnit::Second => Self::Second, - GenericTimeUnit::Millisecond => Self::Millisecond, - GenericTimeUnit::Microsecond => Self::Microsecond, - GenericTimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Second => Self::Second, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Nanosecond => Self::Nanosecond, + } + } +} + +impl TryFrom for DataType { + type Error = Error; + + fn try_from(value: crate::internal::arrow::DataType) -> Result { + use {crate::internal::arrow::DataType as DT, DataType as ArrowDT}; + + match value { + DT::Int8 => Ok(ArrowDT::Int8), + DT::Int16 => Ok(ArrowDT::Int16), + DT::Int32 => Ok(ArrowDT::Int32), + DT::Int64 => Ok(ArrowDT::Int64), + DT::UInt8 => Ok(ArrowDT::UInt8), + DT::UInt16 => Ok(ArrowDT::UInt16), + DT::UInt32 => Ok(ArrowDT::UInt32), + DT::UInt64 => Ok(ArrowDT::UInt64), + DT::Float16 => Ok(ArrowDT::Float16), + DT::Float32 => Ok(ArrowDT::Float32), + DT::Float64 => Ok(ArrowDT::Float64), + dt => fail!( + "{} not supported", + crate::internal::arrow::BaseDataTypeDisplay(&dt) + ), } } } diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index e07fe32d..d47c823c 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -1,22 +1,22 @@ #![allow(missing_docs)] use std::sync::Arc; +use half::f16; + use crate::{ _impl::arrow::{ array::{make_array, Array, ArrayData, ArrayRef, NullArray, RecordBatch}, buffer::{Buffer, ScalarBuffer}, datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Field, FieldRef, Float16Type, Schema, + UnionMode, }, }, internal::{ - error::{fail, Result}, + arrow::FieldMeta, + error::{fail, Error, Result}, schema::{GenericField, SerdeArrowSchema}, - serialization::{ - utils::{MutableBitBuffer, MutableOffsetBuffer}, - ArrayBuilder, OuterSequenceBuilder, - }, - utils::Offset, + serialization::{ArrayBuilder, OuterSequenceBuilder}, }, }; @@ -60,289 +60,256 @@ impl OuterSequenceBuilder { } fn build_array(builder: ArrayBuilder) -> Result { - let data = build_array_data(builder)?; + let data = builder.into_array()?.try_into()?; Ok(make_array(data)) } -fn build_array_data(builder: ArrayBuilder) -> Result { - use {ArrayBuilder as A, DataType as T}; - match builder { - A::Null(builder) => Ok(NullArray::new(builder.count).into_data()), - A::UnknownVariant(_) => Ok(NullArray::new(0).into_data()), - A::Bool(builder) => build_array_data_primitive_with_len( - T::Boolean, - builder.buffer.len(), - builder.buffer.buffer, - builder.validity, - ), - A::I8(builder) => build_array_data_primitive(T::Int8, builder.buffer, builder.validity), - A::I16(builder) => build_array_data_primitive(T::Int16, builder.buffer, builder.validity), - A::I32(builder) => build_array_data_primitive(T::Int32, builder.buffer, builder.validity), - A::I64(builder) => build_array_data_primitive(T::Int64, builder.buffer, builder.validity), - A::U8(builder) => build_array_data_primitive(T::UInt8, builder.buffer, builder.validity), - A::U16(builder) => build_array_data_primitive(T::UInt16, builder.buffer, builder.validity), - A::U32(builder) => build_array_data_primitive(T::UInt32, builder.buffer, builder.validity), - A::U64(builder) => build_array_data_primitive(T::UInt64, builder.buffer, builder.validity), - A::F16(builder) => build_array_data_primitive( - T::Float16, - builder - .buffer - .into_iter() - .map(|v| ::Native::from_bits(v.to_bits())) - .collect(), - builder.validity, - ), - A::F32(builder) => build_array_data_primitive(T::Float32, builder.buffer, builder.validity), - A::F64(builder) => build_array_data_primitive(T::Float64, builder.buffer, builder.validity), - A::Date32(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), - A::Date64(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), - A::Time32(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), - A::Time64(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), - A::Duration(builder) => build_array_data_primitive( - T::Duration(builder.unit.into()), - builder.buffer, - builder.validity, - ), - A::Decimal128(builder) => build_array_data_primitive( - T::Decimal128(builder.precision, builder.scale), - builder.buffer, - builder.validity, - ), - A::Utf8(builder) => build_array_data_utf8( - T::Utf8, - builder.offsets.offsets, - builder.buffer, - builder.validity, - ), - A::LargeUtf8(builder) => build_array_data_utf8( - T::LargeUtf8, - builder.offsets.offsets, - builder.buffer, - builder.validity, - ), - A::LargeList(builder) => build_array_data_list( - T::LargeList(Arc::new(Field::try_from(&builder.field)?)), - builder.offsets.offsets.len() - 1, - builder.offsets.offsets, - build_array_data(*builder.element)?, - builder.validity, - ), - A::List(builder) => build_array_data_list( - T::List(Arc::new(Field::try_from(&builder.field)?)), - builder.offsets.offsets.len() - 1, - builder.offsets.offsets, - build_array_data(*builder.element)?, - builder.validity, - ), - A::FixedSizedList(builder) => { - let data_type = T::FixedSizeList( - Arc::new(Field::try_from(&builder.field)?), - builder.n.try_into()?, - ); - let child_data = build_array_data(*builder.element)?; - let validity = if let Some(validity) = builder.validity { - Some(Buffer::from(validity.buffer)) - } else { - None - }; +impl TryFrom for ArrayData { + type Error = Error; - Ok(ArrayData::builder(data_type) - .len(builder.num_elements) - .null_bit_buffer(validity) - .add_child_data(child_data) - .build()?) - } - A::Binary(builder) => { - build_array_data_binary(T::Binary, builder.offsets, builder.buffer, builder.validity) - } - A::LargeBinary(builder) => build_array_data_binary( - T::LargeBinary, - builder.offsets, - builder.buffer, - builder.validity, - ), - A::FixedSizeBinary(builder) => { - let data_buffer = ScalarBuffer::from(builder.buffer).into_inner(); - let validity = if let Some(validity) = builder.validity { - Some(Buffer::from(validity.buffer)) - } else { - None - }; + fn try_from(value: crate::internal::arrow::Array) -> Result { + use {crate::internal::arrow::Array as A, DataType as T}; + type ArrowF16 = ::Native; - Ok( - ArrayData::builder(T::FixedSizeBinary(builder.n.try_into()?)) - .len(builder.len) - .null_bit_buffer(validity) - .add_buffer(data_buffer) - .build()?, - ) + fn f16_to_f16(v: f16) -> ArrowF16 { + ArrowF16::from_bits(v.to_bits()) } - A::Struct(builder) => { - let mut data = Vec::new(); - for (_, field) in builder.named_fields { - data.push(build_array_data(field)?); + + match value { + A::Null(arr) => Ok(NullArray::new(arr.len).into_data()), + A::Boolean(arr) => Ok(ArrayData::try_new( + T::Boolean, + // NOTE: use the explicit len + arr.len, + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?), + A::Int8(arr) => primitive_into_data(T::Int8, arr.validity, arr.values), + A::Int16(arr) => primitive_into_data(T::Int16, arr.validity, arr.values), + A::Int32(arr) => primitive_into_data(T::Int32, arr.validity, arr.values), + A::Int64(arr) => primitive_into_data(T::Int64, arr.validity, arr.values), + A::UInt8(arr) => primitive_into_data(T::UInt8, arr.validity, arr.values), + A::UInt16(arr) => primitive_into_data(T::UInt16, arr.validity, arr.values), + A::UInt32(arr) => primitive_into_data(T::UInt32, arr.validity, arr.values), + A::UInt64(arr) => primitive_into_data(T::UInt64, arr.validity, arr.values), + A::Float16(arr) => primitive_into_data( + T::Float16, + arr.validity, + arr.values.into_iter().map(f16_to_f16).collect(), + ), + A::Float32(arr) => primitive_into_data(T::Float32, arr.validity, arr.values), + A::Float64(arr) => primitive_into_data(T::Float64, arr.validity, arr.values), + A::Date32(arr) => primitive_into_data(T::Date32, arr.validity, arr.values), + A::Date64(arr) => primitive_into_data(T::Date64, arr.validity, arr.values), + A::Timestamp(arr) => primitive_into_data( + T::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), + arr.validity, + arr.values, + ), + A::Time32(arr) => { + primitive_into_data(T::Time32(arr.unit.into()), arr.validity, arr.values) } + A::Time64(arr) => { + primitive_into_data(T::Time64(arr.unit.into()), arr.validity, arr.values) + } + A::Duration(arr) => { + primitive_into_data(T::Duration(arr.unit.into()), arr.validity, arr.values) + } + A::Decimal128(arr) => primitive_into_data( + T::Decimal128(arr.precision, arr.scale), + arr.validity, + arr.values, + ), + A::Utf8(arr) => bytes_into_data(T::Utf8, arr.offsets, arr.data, arr.validity), + A::LargeUtf8(arr) => bytes_into_data(T::LargeUtf8, arr.offsets, arr.data, arr.validity), + A::Binary(arr) => bytes_into_data(T::Binary, arr.offsets, arr.data, arr.validity), + A::LargeBinary(arr) => { + bytes_into_data(T::LargeBinary, arr.offsets, arr.data, arr.validity) + } + A::Struct(arr) => { + let mut fields = Vec::new(); + let mut data = Vec::new(); - let (validity, len) = if let Some(validity) = builder.validity { - (Some(Buffer::from(validity.buffer)), validity.len) - } else { - if data.is_empty() { - fail!("cannot built non-nullable structs without fields"); + for (field, meta) in arr.fields { + let child: ArrayData = field.try_into()?; + let field = Field::new(meta.name, child.data_type().clone(), meta.nullable) + .with_metadata(meta.metadata); + fields.push(Arc::new(field)); + data.push(child); } - (None, data[0].len()) - }; + let data_type = T::Struct(fields.into()); - let fields = builder - .fields - .iter() - .map(Field::try_from) - .collect::>>()?; - let data_type = T::Struct(fields.into()); + Ok(ArrayData::builder(data_type) + .len(arr.len) + .null_bit_buffer(arr.validity.map(Buffer::from)) + .child_data(data) + .build()?) + } + A::List(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + list_into_data( + T::List(Arc::new(field)), + arr.offsets.len().saturating_sub(1), + arr.offsets, + child, + arr.validity, + ) + } + A::LargeList(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + list_into_data( + T::LargeList(Arc::new(field)), + arr.offsets.len().saturating_sub(1), + arr.offsets, + child, + arr.validity, + ) + } + A::FixedSizeList(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + if (child.len() % usize::try_from(arr.n)?) != 0 { + fail!( + "Invalid FixedSizeList: number of child elements ({}) not divisible by n ({})", + child.len(), + arr.n, + ); + } + let field = field_from_data_and_meta(&child, arr.meta); + Ok(ArrayData::try_new( + T::FixedSizeList(Arc::new(field), arr.n), + child.len() / usize::try_from(arr.n)?, + arr.validity.map(Buffer::from), + 0, + vec![], + vec![child], + )?) + } + A::FixedSizeBinary(arr) => { + if (arr.data.len() % usize::try_from(arr.n)?) != 0 { + fail!( + "Invalid FixedSizeBinary: number of child elements ({}) not divisible by n ({})", + arr.data.len(), + arr.n, + ); + } + Ok(ArrayData::try_new( + T::FixedSizeBinary(arr.n), + arr.data.len() / usize::try_from(arr.n)?, + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.data).into_inner()], + vec![], + )?) + } + A::Dictionary(arr) => { + let indices: ArrayData = (*arr.indices).try_into()?; + let values: ArrayData = (*arr.values).try_into()?; + let data_type = T::Dictionary( + Box::new(indices.data_type().clone()), + Box::new(values.data_type().clone()), + ); - Ok(ArrayData::builder(data_type) - .len(len) - .null_bit_buffer(validity) - .child_data(data) - .build()?) - } - A::Map(builder) => Ok(ArrayData::builder(T::Map( - Arc::new(Field::try_from(&builder.entry_field)?), - false, - )) - .len(builder.offsets.offsets.len() - 1) - .add_buffer(ScalarBuffer::from(builder.offsets.offsets).into_inner()) - .add_child_data(build_array_data(*builder.entry)?) - .null_bit_buffer(builder.validity.map(|b| Buffer::from(b.buffer))) - .build()?), - A::DictionaryUtf8(builder) => { - let indices = build_array_data(*builder.indices)?; - let values = build_array_data(*builder.values)?; - let data_type = Field::try_from(&builder.field)?.data_type().clone(); + Ok(indices + .into_builder() + .data_type(data_type) + .child_data(vec![values]) + .build()?) + } + A::Map(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + Ok(ArrayData::try_new( + T::Map(Arc::new(field), false), + arr.offsets.len().saturating_sub(1), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.offsets).into_inner()], + vec![child], + )?) + } + A::DenseUnion(arr) => { + let mut fields = Vec::new(); + let mut child_data = Vec::new(); - Ok(indices - .into_builder() - .data_type(data_type) - .child_data(vec![values]) - .build()?) - } - A::Union(builder) => { - let data_type = Field::try_from(&builder.field)?.data_type().clone(); - let children = builder - .fields - .into_iter() - .map(build_array_data) - .collect::>>()?; - let len = builder.types.len(); + for (idx, (array, meta)) in arr.fields.into_iter().enumerate() { + let child: ArrayData = array.try_into()?; + let field = field_from_data_and_meta(&child, meta); - Ok(ArrayData::builder(data_type) - .len(len) - .add_buffer(Buffer::from_vec(builder.types)) - .add_buffer(Buffer::from_vec(builder.offsets)) - .child_data(children) - .build()?) + fields.push((idx as i8, Arc::new(field))); + child_data.push(child); + } + + Ok(ArrayData::try_new( + DataType::Union(fields.into_iter().collect(), UnionMode::Dense), + arr.types.len(), + None, + 0, + vec![ + ScalarBuffer::from(arr.types).into_inner(), + ScalarBuffer::from(arr.offsets).into_inner(), + ], + child_data, + )?) + } } } } -fn build_array_data_primitive( - data_type: DataType, - data: Vec, - validity: Option, -) -> Result { - let len = data.len(); - build_array_data_primitive_with_len(data_type, len, data, validity) +fn field_from_data_and_meta(data: &ArrayData, meta: FieldMeta) -> Field { + Field::new(meta.name, data.data_type().clone(), meta.nullable).with_metadata(meta.metadata) } -fn build_array_data_primitive_with_len( +fn primitive_into_data( data_type: DataType, - len: usize, - data: Vec, - validity: Option, + validity: Option>, + values: Vec, ) -> Result { Ok(ArrayData::try_new( data_type, - len, - validity.map(|b| Buffer::from(b.buffer)), + values.len(), + validity.map(Buffer::from), 0, - vec![ScalarBuffer::from(data).into_inner()], + vec![ScalarBuffer::from(values).into_inner()], vec![], )?) } -fn build_array_data_utf8( +fn bytes_into_data( data_type: DataType, offsets: Vec, data: Vec, - validity: Option, + validity: Option>, ) -> Result { - let values_len = offsets.len() - 1; - - let offsets = ScalarBuffer::from(offsets).into_inner(); - let data = ScalarBuffer::from(data).into_inner(); - let validity = validity.map(|b| Buffer::from(b.buffer)); - Ok(ArrayData::try_new( data_type, - values_len, - validity, + offsets.len().saturating_sub(1), + validity.map(Buffer::from), 0, - vec![offsets, data], + vec![ + ScalarBuffer::from(offsets).into_inner(), + ScalarBuffer::from(data).into_inner(), + ], vec![], )?) } -fn build_array_data_binary( - data_type: DataType, - offsets: MutableOffsetBuffer, - data: Vec, - validity: Option, -) -> Result { - let len = offsets.len(); - let offset_buffer = ScalarBuffer::from(offsets.offsets).into_inner(); - let data_buffer = ScalarBuffer::from(data).into_inner(); - let validity = if let Some(validity) = validity { - Some(Buffer::from(validity.buffer)) - } else { - None - }; - Ok(ArrayData::builder(data_type) - .len(len) - .null_bit_buffer(validity) - .add_buffer(offset_buffer) - .add_buffer(data_buffer) - .build()?) -} - -fn build_array_data_list( +fn list_into_data( data_type: DataType, len: usize, offsets: Vec, child_data: ArrayData, - validity: Option, + validity: Option>, ) -> Result { - let offset_buffer = ScalarBuffer::from(offsets).into_inner(); - let validity = validity.map(|b| Buffer::from(b.buffer)); - - Ok(ArrayData::builder(data_type) - .len(len) - .add_buffer(offset_buffer) - .add_child_data(child_data) - .null_bit_buffer(validity) - .build()?) + Ok(ArrayData::try_new( + data_type, + len, + validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(offsets).into_inner()], + vec![child_data], + )?) } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs new file mode 100644 index 00000000..6c349f96 --- /dev/null +++ b/serde_arrow/src/internal/arrow/array.rs @@ -0,0 +1,143 @@ +//! Owned versions of the different array types +use std::collections::HashMap; + +use half::f16; + +use crate::internal::arrow::data_type::TimeUnit; + +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum Array { + Null(NullArray), + Boolean(BooleanArray), + Int8(PrimitiveArray), + Int16(PrimitiveArray), + Int32(PrimitiveArray), + Int64(PrimitiveArray), + UInt8(PrimitiveArray), + UInt16(PrimitiveArray), + UInt32(PrimitiveArray), + UInt64(PrimitiveArray), + Float16(PrimitiveArray), + Float32(PrimitiveArray), + Float64(PrimitiveArray), + Date32(PrimitiveArray), + Date64(PrimitiveArray), + Time32(TimeArray), + Time64(TimeArray), + Timestamp(TimestampArray), + Duration(TimeArray), + Utf8(BytesArray), + LargeUtf8(BytesArray), + Binary(BytesArray), + LargeBinary(BytesArray), + FixedSizeBinary(FixedSizeBinaryArray), + Decimal128(DecimalArray), + Struct(StructArray), + List(ListArray), + LargeList(ListArray), + FixedSizeList(FixedSizeListArray), + Dictionary(DictionaryArray), + Map(ListArray), + DenseUnion(DenseUnionArray), +} + +#[derive(Clone, Debug)] +pub struct NullArray { + pub len: usize, +} + +#[derive(Clone, Debug)] +pub struct BooleanArray { + // Note: len is required to know how many bits of values are used + pub len: usize, + pub validity: Option>, + pub values: Vec, +} + +#[derive(Clone, Debug)] +pub struct PrimitiveArray { + pub validity: Option>, + pub values: Vec, +} + +#[derive(Debug, Clone)] +pub struct TimeArray { + pub unit: TimeUnit, + pub validity: Option>, + pub values: Vec, +} + +#[derive(Debug, Clone)] + +pub struct TimestampArray { + pub unit: TimeUnit, + pub timezone: Option, + pub validity: Option>, + pub values: Vec, +} + +#[derive(Clone, Debug)] +pub struct StructArray { + pub len: usize, + pub validity: Option>, + pub fields: Vec<(Array, FieldMeta)>, +} + +#[derive(Clone, Debug)] +pub struct FieldMeta { + pub name: String, + pub nullable: bool, + pub metadata: HashMap, +} + +#[derive(Clone, Debug)] +pub struct ListArray { + pub validity: Option>, + pub offsets: Vec, + pub meta: FieldMeta, + pub element: Box, +} + +#[derive(Clone, Debug)] +pub struct FixedSizeListArray { + pub n: i32, + pub validity: Option>, + pub meta: FieldMeta, + pub element: Box, +} + +#[derive(Clone, Debug)] +pub struct BytesArray { + pub validity: Option>, + pub offsets: Vec, + pub data: Vec, +} + +#[derive(Clone, Debug)] +pub struct FixedSizeBinaryArray { + pub n: i32, + pub validity: Option>, + pub data: Vec, +} + +#[derive(Clone, Debug)] +pub struct DecimalArray { + pub precision: u8, + pub scale: i8, + pub validity: Option>, + pub values: Vec, +} + +#[derive(Clone, Debug)] +pub struct DictionaryArray { + pub indices: Box, + pub values: Box, +} + +#[derive(Clone, Debug)] +pub struct DenseUnionArray { + pub types: Vec, + pub offsets: Vec, + pub fields: Vec<(Array, FieldMeta)>, +} diff --git a/serde_arrow/src/internal/arrow/array_view.rs b/serde_arrow/src/internal/arrow/array_view.rs new file mode 100644 index 00000000..81cc7eb4 --- /dev/null +++ b/serde_arrow/src/internal/arrow/array_view.rs @@ -0,0 +1,79 @@ +#![allow(dead_code, unused)] +use half::f16; + +use crate::internal::arrow::data_type::TimeUnit; + +pub enum ArrayView<'a> { + Null(NullArrayView), + Boolean(BooleanArrayView<'a>), + Int8(PrimitiveArrayView<'a, i8>), + Int16(PrimitiveArrayView<'a, i16>), + Int32(PrimitiveArrayView<'a, i32>), + Int64(PrimitiveArrayView<'a, i64>), + UInt8(PrimitiveArrayView<'a, u8>), + UInt16(PrimitiveArrayView<'a, u16>), + UInt32(PrimitiveArrayView<'a, u32>), + UInt64(PrimitiveArrayView<'a, u64>), + Float16(PrimitiveArrayView<'a, f16>), + Float32(PrimitiveArrayView<'a, f32>), + Float64(PrimitiveArrayView<'a, f64>), + Date32(PrimitiveArrayView<'a, i32>), + Date64(PrimitiveArrayView<'a, i64>), + Time32(TimeArrayView<'a, i32>), + Time64(TimeArrayView<'a, i64>), + Utf8(Utf8ArrayView<'a, i32>), + LargeUtf8(Utf8ArrayView<'a, i64>), + Binary(Utf8ArrayView<'a, i32>), + LargeBinary(Utf8ArrayView<'a, i64>), + Decimal128(PrimitiveArrayView<'a, i128>), + Struct(StructArrayView<'a>), + List(ListArrayView<'a, i32>), + LargeList(ListArrayView<'a, i64>), +} + +pub struct NullArrayView { + pub len: usize, +} + +#[derive(Debug, Clone, Copy)] +pub struct BitsWithOffset<'a> { + pub offset: usize, + pub data: &'a [u8], +} + +pub struct BooleanArrayView<'a> { + pub len: usize, + pub validity: Option>, + pub values: BitsWithOffset<'a>, +} + +pub struct PrimitiveArrayView<'a, T> { + pub validity: Option>, + pub values: &'a [T], +} + +pub struct TimeArrayView<'a, T> { + pub unit: TimeUnit, + pub validity: Option>, + pub values: &'a [T], +} + +pub struct StructArrayView<'a> { + pub len: usize, + pub validity: Option>, + pub fields: Vec>, +} + +pub struct ListArrayView<'a, O> { + pub len: usize, + pub validity: Option<&'a [u8]>, + pub offsets: &'a [O], + pub element: Box>, +} + +pub struct Utf8ArrayView<'a, O> { + pub len: usize, + pub validity: Option<&'a [u8]>, + pub offsets: &'a [O], + pub data: &'a [u8], +} diff --git a/serde_arrow/src/internal/arrow/data_type.rs b/serde_arrow/src/internal/arrow/data_type.rs new file mode 100644 index 00000000..a732e466 --- /dev/null +++ b/serde_arrow/src/internal/arrow/data_type.rs @@ -0,0 +1,111 @@ +use std::{collections::HashMap, sync::Arc}; + +use serde::{Deserialize, Serialize}; + +use crate::internal::error::{fail, Error, Result}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Field { + pub name: String, + pub data_type: DataType, + pub metadata: HashMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum DataType { + Null, + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Utf8, + LargeUtf8, + Binary, + LargeBinary, + Date32, + Date64, + Timestamp(TimeUnit, Option>), + Time32(TimeUnit), + Time64(TimeUnit), + Decimal128, + Struct(Vec), + List(Box), + LargeList(Box), +} + +pub struct BaseDataTypeDisplay<'a>(pub &'a DataType); + +impl<'a> std::fmt::Display for BaseDataTypeDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + DataType::Null => write!(f, "Null"), + DataType::Boolean => write!(f, "Boolean"), + DataType::Int8 => write!(f, "Int8"), + DataType::Int16 => write!(f, "Int16"), + DataType::Int32 => write!(f, "Int32"), + DataType::Int64 => write!(f, "Int64"), + DataType::UInt8 => write!(f, "UInt8"), + DataType::UInt16 => write!(f, "UInt16"), + DataType::UInt32 => write!(f, "UInt32"), + DataType::UInt64 => write!(f, "UInt64"), + DataType::Float16 => write!(f, "Float16"), + DataType::Float32 => write!(f, "Float32"), + DataType::Float64 => write!(f, "Float64"), + DataType::Utf8 => write!(f, "Utf8"), + DataType::LargeUtf8 => write!(f, "LargeUtf8"), + DataType::Binary => write!(f, "Binary"), + DataType::LargeBinary => write!(f, "LargeBinary"), + DataType::Date32 => write!(f, "Date32"), + DataType::Date64 => write!(f, "Date64"), + DataType::Timestamp(_, _) => write!(f, "Timestamp"), + DataType::Time32(_) => write!(f, "Time32"), + DataType::Time64(_) => write!(f, "Time64"), + DataType::Decimal128 => write!(f, "Decimal128"), + DataType::Struct(_) => write!(f, "Struct"), + DataType::List(_) => write!(f, "List"), + DataType::LargeList(_) => write!(f, "LargeList"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] +pub enum TimeUnit { + Second, + Millisecond, + Microsecond, + Nanosecond, +} + +impl std::fmt::Display for TimeUnit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TimeUnit::Second => write!(f, "Second"), + TimeUnit::Millisecond => write!(f, "Millisecond"), + TimeUnit::Microsecond => write!(f, "Microsecond"), + TimeUnit::Nanosecond => write!(f, "Nanosecond"), + } + } +} + +impl std::str::FromStr for TimeUnit { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "Second" => Ok(Self::Second), + "Millisecond" => Ok(Self::Millisecond), + "Microsecond" => Ok(Self::Microsecond), + "Nanosecond" => Ok(Self::Nanosecond), + s => fail!("Invalid time unit {s}"), + } + } +} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs new file mode 100644 index 00000000..17048a17 --- /dev/null +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -0,0 +1,12 @@ +//! A common arrow abstraction to simplify conversion between different arrow +//! implementations +mod array; +mod array_view; +mod data_type; + +pub use array::{ + Array, BooleanArray, BytesArray, DecimalArray, DenseUnionArray, DictionaryArray, FieldMeta, + FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, + TimeArray, TimestampArray, +}; +pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index bc1b6a11..bd905560 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -2,8 +2,8 @@ use chrono::DateTime; use serde::de::Visitor; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, - schema::GenericTimeUnit, utils::Mut, }; @@ -12,13 +12,13 @@ use super::{ utils::{ArrayBufferIterator, BitBuffer}, }; -pub struct Date64Deserializer<'a>(ArrayBufferIterator<'a, i64>, GenericTimeUnit, bool); +pub struct Date64Deserializer<'a>(ArrayBufferIterator<'a, i64>, TimeUnit, bool); impl<'a> Date64Deserializer<'a> { pub fn new( buffer: &'a [i64], validity: Option>, - unit: GenericTimeUnit, + unit: TimeUnit, is_utc: bool, ) -> Self { Self(ArrayBufferIterator::new(buffer, validity), unit, is_utc) @@ -26,10 +26,10 @@ impl<'a> Date64Deserializer<'a> { pub fn get_string_repr(&self, ts: i64) -> Result { let Some(date_time) = (match self.1 { - GenericTimeUnit::Second => DateTime::from_timestamp(ts, 0), - GenericTimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), - GenericTimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), - GenericTimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), + TimeUnit::Second => DateTime::from_timestamp(ts, 0), + TimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), + TimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), + TimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), }) else { fail!("Unsupported timestamp value: {ts}"); }; diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index 05f9775e..9c755fbb 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -2,8 +2,8 @@ use chrono::NaiveTime; use serde::de::Visitor; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, - schema::GenericTimeUnit, utils::Mut, }; @@ -16,8 +16,13 @@ use super::{ pub struct TimeDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>, i64, i64); impl<'a, T: Integer> TimeDeserializer<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>, unit: GenericTimeUnit) -> Self { - let (seconds_factor, nanoseconds_factor) = unit.get_factors(); + pub fn new(buffer: &'a [T], validity: Option>, unit: TimeUnit) -> Self { + let (seconds_factor, nanoseconds_factor) = match unit { + TimeUnit::Nanosecond => (1_000_000_000, 1), + TimeUnit::Microsecond => (1_000_000, 1_000), + TimeUnit::Millisecond => (1_000, 1_000_000), + TimeUnit::Second => (1, 1_000_000_000), + }; Self( ArrayBufferIterator::new(buffer, validity), diff --git a/serde_arrow/src/internal/mod.rs b/serde_arrow/src/internal/mod.rs index 47a1c1d6..1c7188d6 100644 --- a/serde_arrow/src/internal/mod.rs +++ b/serde_arrow/src/internal/mod.rs @@ -1,4 +1,5 @@ pub mod array_builder; +pub mod arrow; pub mod deserialization; pub mod deserializer; pub mod error; diff --git a/serde_arrow/src/internal/schema/data_type.rs b/serde_arrow/src/internal/schema/data_type.rs index bfd70025..0e9e9aa8 100644 --- a/serde_arrow/src/internal/schema/data_type.rs +++ b/serde_arrow/src/internal/schema/data_type.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use crate::internal::{ + arrow::TimeUnit, error::{fail, Error, Result}, utils::dsl::Term, }; @@ -25,9 +26,9 @@ pub enum GenericDataType { LargeUtf8, Date32, Date64, - Time32(GenericTimeUnit), - Time64(GenericTimeUnit), - Duration(GenericTimeUnit), + Time32(TimeUnit), + Time64(TimeUnit), + Duration(TimeUnit), Struct, List, LargeList, @@ -38,7 +39,7 @@ pub enum GenericDataType { Union, Map, Dictionary, - Timestamp(GenericTimeUnit, Option), + Timestamp(TimeUnit, Option), Decimal128(u8, i8), } @@ -123,7 +124,7 @@ impl std::str::FromStr for GenericDataType { ("Map", []) => T::Map, ("Dictionary", []) => T::Dictionary, ("Timestamp", [unit, timezone]) => { - let unit: GenericTimeUnit = unit.as_ident()?.parse()?; + let unit: TimeUnit = unit.as_ident()?.parse()?; let timezone = timezone .as_option()? .map(|term| term.as_string()) @@ -158,47 +159,3 @@ impl From for GenericDataTypeString { Self(value.to_string()) } } - -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] -pub enum GenericTimeUnit { - Second, - Millisecond, - Microsecond, - Nanosecond, -} - -impl GenericTimeUnit { - pub fn get_factors(&self) -> (i64, i64) { - match self { - GenericTimeUnit::Nanosecond => (1_000_000_000, 1), - GenericTimeUnit::Microsecond => (1_000_000, 1_000), - GenericTimeUnit::Millisecond => (1_000, 1_000_000), - GenericTimeUnit::Second => (1, 1_000_000_000), - } - } -} - -impl std::fmt::Display for GenericTimeUnit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - GenericTimeUnit::Second => write!(f, "Second"), - GenericTimeUnit::Millisecond => write!(f, "Millisecond"), - GenericTimeUnit::Microsecond => write!(f, "Microsecond"), - GenericTimeUnit::Nanosecond => write!(f, "Nanosecond"), - } - } -} - -impl std::str::FromStr for GenericTimeUnit { - type Err = Error; - - fn from_str(s: &str) -> Result { - match s { - "Second" => Ok(Self::Second), - "Millisecond" => Ok(Self::Millisecond), - "Microsecond" => Ok(Self::Microsecond), - "Nanosecond" => Ok(Self::Nanosecond), - s => fail!("Invalid time unit {s}"), - } - } -} diff --git a/serde_arrow/src/internal/schema/deserialization.rs b/serde_arrow/src/internal/schema/deserialization.rs index 1a95783c..cb3e909f 100644 --- a/serde_arrow/src/internal/schema/deserialization.rs +++ b/serde_arrow/src/internal/schema/deserialization.rs @@ -6,10 +6,11 @@ use std::{collections::HashMap, str::FromStr}; use serde::{de::Visitor, Deserialize}; use crate::internal::{ + arrow::TimeUnit, error::{fail, Error, Result}, schema::{ merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, GenericField, - GenericTimeUnit, SerdeArrowSchema, Strategy, + SerdeArrowSchema, Strategy, }, }; @@ -40,7 +41,7 @@ pub enum ArrowTimeUnit { Nanosecond, } -impl From for GenericTimeUnit { +impl From for TimeUnit { fn from(value: ArrowTimeUnit) -> Self { match value { ArrowTimeUnit::Second => Self::Second, diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index fdec33da..3af0ea09 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -14,13 +14,14 @@ mod test; use std::collections::HashMap; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, utils::value, }; use serde::{Deserialize, Serialize}; -pub use data_type::{GenericDataType, GenericTimeUnit}; +pub use data_type::GenericDataType; pub use strategy::{ merge_strategy_with_metadata, split_strategy_from_metadata, Strategy, STRATEGY_KEY, }; @@ -492,7 +493,7 @@ impl GenericField { } if !matches!( self.data_type, - GenericDataType::Time32(GenericTimeUnit::Second | GenericTimeUnit::Millisecond) + GenericDataType::Time32(TimeUnit::Second | TimeUnit::Millisecond) ) { fail!("Time32 field must have Second or Millisecond unit"); } @@ -511,7 +512,7 @@ impl GenericField { } if !matches!( self.data_type, - GenericDataType::Time64(GenericTimeUnit::Microsecond | GenericTimeUnit::Nanosecond) + GenericDataType::Time64(TimeUnit::Microsecond | TimeUnit::Nanosecond) ) { fail!("Time64 field must have Microsecond or Nanosecond unit"); } diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index e976e567..8e19822c 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -1,6 +1,7 @@ use serde_json::json; use crate::internal::{ + arrow::TimeUnit, schema::{GenericDataType, GenericField, SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, testing::{assert_error, hash_map}, }; @@ -104,7 +105,7 @@ fn date64_with_strategy() { #[test] fn timestamp_second_serialization() { - let dt = super::GenericDataType::Timestamp(super::GenericTimeUnit::Second, None); + let dt = super::GenericDataType::Timestamp(TimeUnit::Second, None); let s = serde_json::to_string(&dt).unwrap(); assert_eq!(s, r#""Timestamp(Second, None)""#); @@ -115,10 +116,7 @@ fn timestamp_second_serialization() { #[test] fn timestamp_second_utc_serialization() { - let dt = super::GenericDataType::Timestamp( - super::GenericTimeUnit::Second, - Some(String::from("Utc")), - ); + let dt = super::GenericDataType::Timestamp(TimeUnit::Second, Some(String::from("Utc"))); let s = serde_json::to_string(&dt).unwrap(); assert_eq!(s, r#""Timestamp(Second, Some(\"Utc\"))""#); @@ -129,7 +127,7 @@ fn timestamp_second_utc_serialization() { #[test] fn test_date32() { - use super::GenericDataType as DT; + use GenericDataType as DT; assert_eq!(DT::Date32.to_string(), "Date32"); assert_eq!("Date32".parse::
().unwrap(), DT::Date32); @@ -137,7 +135,7 @@ fn test_date32() { #[test] fn time64_data_type_format() { - use super::{GenericDataType as DT, GenericTimeUnit as TU}; + use {GenericDataType as DT, TimeUnit as TU}; for (dt, s) in [ (DT::Time64(TU::Microsecond), "Time64(Microsecond)"), diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 68e74e14..fc73c70d 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -1,7 +1,7 @@ use half::f16; use serde::Serialize; -use crate::internal::error::Result; +use crate::internal::{arrow::Array, error::Result}; use super::{ binary_builder::BinaryBuilder, bool_builder::BoolBuilder, date32_builder::Date32Builder, @@ -131,10 +131,15 @@ impl ArrayBuilder { pub fn is_nullable(&self) -> bool { dispatch!(self, Self(builder) => builder.is_nullable()) } + + pub fn into_array(self) -> Result { + dispatch!(self, Self(builder) => builder.into_array()) + } } impl ArrayBuilder { /// Take the contained array builder, while leaving structure intact + // TODO: use ArrayBuilder as return type for the impls and use dispatch here pub fn take(&mut self) -> ArrayBuilder { match self { Self::Null(builder) => Self::Null(builder.take()), diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index d1dcac33..192195ac 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::{Array, BytesArray}, error::Result, utils::{Mut, Offset}, }; @@ -39,6 +40,26 @@ impl BinaryBuilder { } } +impl BinaryBuilder { + pub fn into_array(self) -> Result { + Ok(Array::Binary(BytesArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + })) + } +} + +impl BinaryBuilder { + pub fn into_array(self) -> Result { + Ok(Array::LargeBinary(BytesArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + })) + } +} + impl BinaryBuilder { fn start(&mut self) -> Result<()> { push_validity(&mut self.validity, true) diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 805e76a2..a884e341 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -1,4 +1,7 @@ -use crate::internal::error::Result; +use crate::internal::{ + arrow::{Array, BooleanArray}, + error::Result, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -26,6 +29,14 @@ impl BoolBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::Boolean(BooleanArray { + len: self.buffer.len, + validity: self.validity.map(|v| v.buffer), + values: self.buffer.buffer, + })) + } } impl SimpleSerializer for BoolBuilder { diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 3eee20f2..8537c5df 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -1,6 +1,10 @@ use chrono::{NaiveDate, NaiveDateTime}; -use crate::internal::{error::Result, schema::GenericField}; +use crate::internal::{ + arrow::{Array, PrimitiveArray}, + error::Result, + schema::GenericField, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -31,6 +35,13 @@ impl Date32Builder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::Date32(PrimitiveArray { + validity: self.validity.map(|validity| validity.buffer), + values: self.buffer, + })) + } } impl SimpleSerializer for Date32Builder { diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index fe17903a..9668a199 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,6 +1,7 @@ use crate::internal::{ + arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, error::{Error, Result}, - schema::{GenericDataType, GenericField, GenericTimeUnit}, + schema::{GenericDataType, GenericField}, }; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -35,6 +36,22 @@ impl Date64Builder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + if let GenericDataType::Timestamp(unit, timezone) = self.field.data_type { + Ok(Array::Timestamp(TimestampArray { + unit, + timezone, + validity: self.validity.map(|validity| validity.buffer), + values: self.buffer, + })) + } else { + Ok(Array::Date64(PrimitiveArray { + validity: self.validity.map(|validity| validity.buffer), + values: self.buffer, + })) + } + } } impl SimpleSerializer for Date64Builder { @@ -64,14 +81,14 @@ impl SimpleSerializer for Date64Builder { }; let timestamp = match self.field.data_type { - GenericDataType::Timestamp(GenericTimeUnit::Nanosecond, _) => { + GenericDataType::Timestamp(TimeUnit::Nanosecond, _) => { date_time .timestamp_nanos_opt() .ok_or_else(|| Error::custom(format!("Timestamp '{v}' cannot be converted to nanoseconds. The dates that can be represented as nanoseconds are between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804.")))? }, - GenericDataType::Timestamp(GenericTimeUnit::Microsecond, _) => date_time.timestamp_micros(), - GenericDataType::Timestamp(GenericTimeUnit::Millisecond, _) => date_time.timestamp_millis(), - GenericDataType::Timestamp(GenericTimeUnit::Second, _) => date_time.timestamp(), + GenericDataType::Timestamp(TimeUnit::Microsecond, _) => date_time.timestamp_micros(), + GenericDataType::Timestamp(TimeUnit::Millisecond, _) => date_time.timestamp_millis(), + GenericDataType::Timestamp(TimeUnit::Second, _) => date_time.timestamp(), _ => date_time.timestamp_millis(), }; diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 830e893f..5bb48de3 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::{Array, DecimalArray}, error::Result, utils::decimal::{self, DecimalParser}, }; @@ -44,6 +45,15 @@ impl DecimalBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::Decimal128(DecimalArray { + precision: self.precision, + scale: self.scale, + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + })) + } } impl SimpleSerializer for DecimalBuilder { diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index a3c64ae5..cbbbd028 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use serde::Serialize; use crate::internal::{ + arrow::{Array, DictionaryArray}, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -40,6 +41,13 @@ impl DictionaryUtf8Builder { pub fn is_nullable(&self) -> bool { self.indices.is_nullable() } + + pub fn into_array(self) -> Result { + Ok(Array::Dictionary(DictionaryArray { + indices: Box::new((*self.indices).into_array()?), + values: Box::new((*self.values).into_array()?), + })) + } } impl SimpleSerializer for DictionaryUtf8Builder { diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index dc807ba8..8b43845a 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,16 +1,19 @@ -use crate::internal::{error::Result, schema::GenericTimeUnit}; +use crate::internal::{ + arrow::{Array, TimeArray, TimeUnit}, + error::Result, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; #[derive(Debug, Clone)] pub struct DurationBuilder { - pub unit: GenericTimeUnit, + pub unit: TimeUnit, pub validity: Option, pub buffer: Vec, } impl DurationBuilder { - pub fn new(unit: GenericTimeUnit, is_nullable: bool) -> Self { + pub fn new(unit: TimeUnit, is_nullable: bool) -> Self { Self { unit, validity: is_nullable.then(MutableBitBuffer::default), @@ -29,6 +32,14 @@ impl DurationBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::Duration(TimeArray { + unit: self.unit, + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + })) + } } impl SimpleSerializer for DurationBuilder { diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 1ead7495..06fc3193 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::{Array, FixedSizeBinaryArray}, error::{fail, Result}, utils::Mut, }; @@ -41,6 +42,14 @@ impl FixedSizeBinaryBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::FixedSizeBinary(FixedSizeBinaryArray { + n: self.n.try_into()?, + validity: self.validity.map(|v| v.buffer), + data: self.buffer, + })) + } } impl FixedSizeBinaryBuilder { diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 22d1ee49..06754e27 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::{Array, FixedSizeListArray}, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -8,7 +9,9 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}, + utils::{ + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer, + }, }; #[derive(Debug, Clone)] @@ -48,6 +51,15 @@ impl FixedSizeListBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::FixedSizeList(FixedSizeListArray { + n: self.n.try_into()?, + meta: meta_from_field(self.field)?, + validity: self.validity.map(|v| v.buffer), + element: Box::new((*self.element).into_array()?), + })) + } } impl FixedSizeListBuilder { diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index a07cf66f..6381f905 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -1,6 +1,10 @@ use half::f16; -use crate::internal::{error::Result, utils::Mut}; +use crate::internal::{ + arrow::{Array, PrimitiveArray}, + error::Result, + utils::Mut, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -36,6 +40,23 @@ impl FloatBuilder { } } +macro_rules! impl_into_array { + ($ty:ty, $var:ident) => { + impl FloatBuilder<$ty> { + pub fn into_array(self) -> Result { + Ok(Array::$var(PrimitiveArray { + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + })) + } + } + }; +} + +impl_into_array!(f16, Float16); +impl_into_array!(f32, Float32); +impl_into_array!(f64, Float64); + impl SimpleSerializer for FloatBuilder { fn name(&self) -> &str { "FloatBuilder" diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 57516d3a..391a1f41 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -1,4 +1,7 @@ -use crate::internal::error::{Error, Result}; +use crate::internal::{ + arrow::{Array, PrimitiveArray}, + error::{Error, Result}, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -28,6 +31,28 @@ impl IntBuilder { } } +macro_rules! impl_into_array { + ($ty:ty, $var:ident) => { + impl IntBuilder<$ty> { + pub fn into_array(self) -> Result { + Ok(Array::$var(PrimitiveArray { + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + })) + } + } + }; +} + +impl_into_array!(i8, Int8); +impl_into_array!(i16, Int16); +impl_into_array!(i32, Int32); +impl_into_array!(i64, Int64); +impl_into_array!(u8, UInt8); +impl_into_array!(u16, UInt16); +impl_into_array!(u32, UInt32); +impl_into_array!(u64, UInt64); + impl SimpleSerializer for IntBuilder where I: Default diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index ec112b47..784e2f22 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::{Array, ListArray}, error::Result, schema::GenericField, utils::{Mut, Offset}, @@ -9,8 +10,8 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, utils::{ - push_validity, push_validity_default, MutableBitBuffer, MutableOffsetBuffer, - SimpleSerializer, + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, + MutableOffsetBuffer, SimpleSerializer, }, }; @@ -47,6 +48,28 @@ impl ListBuilder { } } +impl ListBuilder { + pub fn into_array(self) -> Result { + Ok(Array::List(ListArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + element: Box::new(self.element.into_array()?), + meta: meta_from_field(self.field)?, + })) + } +} + +impl ListBuilder { + pub fn into_array(self) -> Result { + Ok(Array::LargeList(ListArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + element: Box::new(self.element.into_array()?), + meta: meta_from_field(self.field)?, + })) + } +} + impl ListBuilder { fn start(&mut self) -> Result<()> { push_validity(&mut self.validity, true) diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index bae4b05f..6679219d 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -1,12 +1,16 @@ use serde::Serialize; -use crate::internal::{error::Result, schema::GenericField}; +use crate::internal::{ + arrow::{Array, ListArray}, + error::Result, + schema::GenericField, +}; use super::{ array_builder::ArrayBuilder, utils::{ - push_validity, push_validity_default, MutableBitBuffer, MutableOffsetBuffer, - SimpleSerializer, + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, + MutableOffsetBuffer, SimpleSerializer, }, }; @@ -40,6 +44,15 @@ impl MapBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + Ok(Array::Map(ListArray { + meta: meta_from_field(self.entry_field)?, + element: Box::new((*self.entry).into_array()?), + validity: self.validity.map(|v| v.buffer), + offsets: self.offsets.offsets, + })) + } } impl SimpleSerializer for MapBuilder { diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index 5a7e3122..eb02acdb 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -1,4 +1,7 @@ -use crate::Result; +use crate::internal::{ + arrow::{Array, NullArray}, + error::Result, +}; use super::utils::SimpleSerializer; @@ -21,6 +24,10 @@ impl NullBuilder { pub fn is_nullable(&self) -> bool { true } + + pub fn into_array(self) -> Result { + Ok(Array::Null(NullArray { len: self.count })) + } } impl SimpleSerializer for NullBuilder { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index e815c534..98e66ab5 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -1,8 +1,9 @@ use serde::Serialize; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, - schema::{GenericDataType, GenericField, GenericTimeUnit, SerdeArrowSchema, Strategy}, + schema::{GenericDataType, GenericField, SerdeArrowSchema, Strategy}, serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, @@ -78,16 +79,13 @@ impl OuterSequenceBuilder { Some(tz) => fail!("Timezone {tz} is not supported"), }, T::Time32(unit) => { - if !matches!(unit, GenericTimeUnit::Second | GenericTimeUnit::Millisecond) { + if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { fail!("Only timestamps with second or millisecond unit are supported"); } A::Time32(TimeBuilder::new(field.clone(), field.nullable, *unit)) } T::Time64(unit) => { - if !matches!( - unit, - GenericTimeUnit::Nanosecond | GenericTimeUnit::Microsecond - ) { + if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { fail!("Only timestamps with nanosecond or microsecond unit are supported"); } A::Time64(TimeBuilder::new(field.clone(), field.nullable, *unit)) diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index d48bb3eb..2753c03e 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use serde::Serialize; use crate::internal::{ + arrow::{Array, StructArray}, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -10,13 +11,16 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}, + utils::{ + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer, + }, }; const UNKNOWN_KEY: usize = usize::MAX; #[derive(Debug, Clone)] pub struct StructBuilder { + // TODO: clean this up pub fields: Vec, pub validity: Option, pub named_fields: Vec<(String, ArrayBuilder)>, @@ -24,6 +28,7 @@ pub struct StructBuilder { pub seen: Vec, pub next: usize, pub index: BTreeMap, + pub len: usize, } impl StructBuilder { @@ -32,15 +37,11 @@ impl StructBuilder { named_fields: Vec<(String, ArrayBuilder)>, is_nullable: bool, ) -> Result { - let mut index = BTreeMap::new(); - let cached_names = vec![None; named_fields.len()]; - let seen = vec![false; named_fields.len()]; - let next = 0; - if fields.len() != named_fields.len() { fail!("mismatched number of fields and builders"); } + let mut index = BTreeMap::new(); for (idx, (name, _)) in named_fields.iter().enumerate() { if index.contains_key(name) { fail!("Duplicate field {name}"); @@ -50,12 +51,13 @@ impl StructBuilder { Ok(Self { fields, + seen: vec![false; named_fields.len()], + cached_names: vec![None; named_fields.len()], validity: is_nullable.then(MutableBitBuffer::default), named_fields, - cached_names, - seen, - next, + next: 0, index, + len: 0, }) } @@ -74,6 +76,7 @@ impl StructBuilder { ), seen: std::mem::replace(&mut self.seen, vec![false; self.named_fields.len()]), next: std::mem::take(&mut self.next), + len: std::mem::take(&mut self.len), index: self.index.clone(), } } @@ -81,11 +84,27 @@ impl StructBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Result { + let mut fields = Vec::new(); + for (field, (_, builder)) in self.fields.into_iter().zip(self.named_fields) { + let meta = meta_from_field(field)?; + let array = builder.into_array()?; + fields.push((array, meta)); + } + + Ok(Array::Struct(StructArray { + len: self.len, + validity: self.validity.map(|b| b.buffer), + fields, + })) + } } impl StructBuilder { fn start(&mut self) -> Result<()> { push_validity(&mut self.validity, true)?; + self.len += 1; self.reset(); Ok(()) } @@ -130,6 +149,7 @@ impl SimpleSerializer for StructBuilder { fn serialize_default(&mut self) -> Result<()> { push_validity_default(&mut self.validity); + self.len += 1; for (_, field) in &mut self.named_fields { field.serialize_default()?; } @@ -139,11 +159,11 @@ impl SimpleSerializer for StructBuilder { fn serialize_none(&mut self) -> Result<()> { push_validity(&mut self.validity, false)?; + self.len += 1; for (_, field) in &mut self.named_fields { field.serialize_default()?; } - Ok(()) } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 198a519a..71ea9fa1 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -1,8 +1,9 @@ use chrono::Timelike; use crate::internal::{ + arrow::{Array, TimeArray, TimeUnit}, error::{Error, Result}, - schema::{GenericField, GenericTimeUnit}, + schema::GenericField, }; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -12,20 +13,16 @@ pub struct TimeBuilder { pub field: GenericField, pub validity: Option, pub buffer: Vec, - pub seconds_factor: i64, - pub nanoseconds_factor: i64, + pub unit: TimeUnit, } impl TimeBuilder { - pub fn new(field: GenericField, nullable: bool, unit: GenericTimeUnit) -> Self { - let (seconds_factor, nanoseconds_factor) = unit.get_factors(); - + pub fn new(field: GenericField, nullable: bool, unit: TimeUnit) -> Self { Self { field, validity: nullable.then(MutableBitBuffer::default), buffer: Vec::new(), - seconds_factor, - nanoseconds_factor, + unit, } } @@ -34,8 +31,7 @@ impl TimeBuilder { field: self.field.clone(), validity: self.validity.as_mut().map(std::mem::take), buffer: std::mem::take(&mut self.buffer), - seconds_factor: self.seconds_factor, - nanoseconds_factor: self.nanoseconds_factor, + unit: self.unit, } } @@ -44,6 +40,26 @@ impl TimeBuilder { } } +impl TimeBuilder { + pub fn into_array(self) -> Result { + Ok(Array::Time32(TimeArray { + unit: self.unit, + validity: self.validity.map(|v| v.buffer), + values: self.buffer, + })) + } +} + +impl TimeBuilder { + pub fn into_array(self) -> Result { + Ok(Array::Time64(TimeArray { + unit: self.unit, + validity: self.validity.map(|v| v.buffer), + values: self.buffer, + })) + } +} + impl SimpleSerializer for TimeBuilder where I: TryFrom + TryFrom + Default, @@ -67,10 +83,17 @@ where } fn serialize_str(&mut self, v: &str) -> Result<()> { + let (seconds_factor, nanoseconds_factor) = match self.unit { + TimeUnit::Nanosecond => (1_000_000_000, 1), + TimeUnit::Microsecond => (1_000_000, 1_000), + TimeUnit::Millisecond => (1_000, 1_000_000), + TimeUnit::Second => (1, 1_000_000_000), + }; + use chrono::naive::NaiveTime; let time = v.parse::()?; - let timestamp = time.num_seconds_from_midnight() as i64 * self.seconds_factor - + time.nanosecond() as i64 / self.nanoseconds_factor; + let timestamp = time.num_seconds_from_midnight() as i64 * seconds_factor + + time.nanosecond() as i64 / nanoseconds_factor; push_validity(&mut self.validity, true)?; self.buffer.push(timestamp.try_into()?); diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index f3885ad6..720e3dc8 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -1,10 +1,14 @@ use crate::internal::{ + arrow::{Array, DenseUnionArray}, error::{fail, Result}, schema::GenericField, utils::Mut, }; -use super::{utils::SimpleSerializer, ArrayBuilder}; +use super::{ + utils::{meta_from_field, SimpleSerializer}, + ArrayBuilder, +}; #[derive(Debug, Clone)] pub struct UnionBuilder { @@ -39,6 +43,21 @@ impl UnionBuilder { pub fn is_nullable(&self) -> bool { false } + + pub fn into_array(self) -> Result { + let mut fields = Vec::new(); + for (field, builder) in self.field.children.into_iter().zip(self.fields) { + let meta = meta_from_field(field)?; + let array = builder.into_array()?; + fields.push((array, meta)); + } + + Ok(Array::DenseUnion(DenseUnionArray { + types: self.types, + offsets: self.offsets, + fields, + })) + } } impl UnionBuilder { diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 4ae7e653..845574de 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -1,6 +1,12 @@ use serde::Serialize; -use crate::{internal::error::fail, Result}; +use crate::{ + internal::{ + arrow::{Array, NullArray}, + error::fail, + }, + Result, +}; use super::{utils::SimpleSerializer, ArrayBuilder}; @@ -15,6 +21,10 @@ impl UnknownVariantBuilder { pub fn is_nullable(&self) -> bool { false } + + pub fn into_array(self) -> Result { + Ok(Array::Null(NullArray { len: 0 })) + } } impl SimpleSerializer for UnknownVariantBuilder { diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 3940df16..0ebd6051 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::{Array, BytesArray}, error::{fail, Result}, utils::Offset, }; @@ -36,6 +37,26 @@ impl Utf8Builder { } } +impl Utf8Builder { + pub fn into_array(self) -> Result { + Ok(Array::Utf8(BytesArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + })) + } +} + +impl Utf8Builder { + pub fn into_array(self) -> Result { + Ok(Array::LargeUtf8(BytesArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + })) + } +} + impl SimpleSerializer for Utf8Builder { fn name(&self) -> &str { "Utf8Builder" diff --git a/serde_arrow/src/internal/serialization/utils.rs b/serde_arrow/src/internal/serialization/utils.rs index 398ad714..4a65d9ac 100644 --- a/serde_arrow/src/internal/serialization/utils.rs +++ b/serde_arrow/src/internal/serialization/utils.rs @@ -7,12 +7,22 @@ use serde::{ }; use crate::internal::{ + arrow::FieldMeta, error::{fail, Error, Result}, + schema::{merge_strategy_with_metadata, GenericField}, utils::{Mut, Offset}, }; use super::ArrayBuilder; +pub fn meta_from_field(field: GenericField) -> Result { + Ok(FieldMeta { + name: field.name, + nullable: field.nullable, + metadata: merge_strategy_with_metadata(field.metadata, field.strategy)?, + }) +} + #[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct MutableBitBuffer { pub(crate) buffer: Vec, diff --git a/x.py b/x.py index 9a106aa1..cbcfc488 100644 --- a/x.py +++ b/x.py @@ -169,13 +169,13 @@ def format(): @cmd(help="Run the linters") -@arg("--fast", action="store_true") -def check(fast=False): +@arg("--all", action="store_true") +def check(all=False): check_cargo_toml() _sh(f"cargo check --features {default_features}") _sh(f"cargo clippy --features {default_features}") - if not fast: + if all: for arrow2_feature in (*all_arrow2_features, *all_arrow_features): _sh(f"cargo check --features {arrow2_feature}")