Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Bool8 #212

Merged
merged 9 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 11 additions & 27 deletions serde_arrow/src/arrow2_impl/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ use crate::{
internal::{
array_builder::ArrayBuilder,
arrow::Field,
deserialization::{
array_deserializer::ArrayDeserializer,
outer_sequence_deserializer::OuterSequenceDeserializer,
},
deserializer::Deserializer,
error::{fail, Result},
schema::{get_strategy_from_metadata, SerdeArrowSchema},
schema::SerdeArrowSchema,
serializer::Serializer,
},
};
Expand Down Expand Up @@ -153,14 +149,7 @@ impl<'de> Deserializer<'de> {
where
A: AsRef<dyn Array>,
{
let fields = fields
.iter()
.map(Field::try_from)
.collect::<Result<Vec<_>>>()?;
let arrays = arrays
.iter()
.map(|array| array.as_ref())
.collect::<Vec<_>>();
use crate::internal::arrow::ArrayView;

if fields.len() != arrays.len() {
fail!(
Expand All @@ -169,21 +158,16 @@ impl<'de> Deserializer<'de> {
arrays.len()
);
}
let len = arrays.first().map(|array| array.len()).unwrap_or_default();

let mut deserializers = Vec::new();
for (field, array) in std::iter::zip(fields, arrays) {
if array.len() != len {
fail!("arrays of different lengths are not supported");
}
let strategy = get_strategy_from_metadata(&field.metadata)?;
let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?;
deserializers.push((field.name.clone(), deserializer));
}

let deserializer = OuterSequenceDeserializer::new(deserializers, len);
let deserializer = Deserializer(deserializer);
let fields = fields
.iter()
.map(Field::try_from)
.collect::<Result<Vec<_>>>()?;
let views = arrays
.iter()
.map(|array| ArrayView::try_from(array.as_ref()))
.collect::<Result<Vec<_>>>()?;

Ok(deserializer)
Deserializer::new(&fields, views)
}
}
2 changes: 1 addition & 1 deletion serde_arrow/src/arrow2_impl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl TryFrom<Array> for ArrayRef {
let child: ArrayRef = child.try_into()?;
let field = field_from_array_and_meta(child.as_ref(), meta);

type_ids.push(type_id.try_into()?);
type_ids.push(type_id.into());
values.push(child);
fields.push(field);
}
Expand Down
4 changes: 2 additions & 2 deletions serde_arrow/src/arrow2_impl/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl TryFrom<&DataType> for ArrowDataType {
if *scale < 0 {
fail!("arrow2 does not support decimals with negative scale");
}
Ok(AT::Decimal((*precision).try_into()?, (*scale).try_into()?))
Ok(AT::Decimal((*precision).into(), (*scale).try_into()?))
}
T::Binary => Ok(AT::Binary),
T::LargeBinary => Ok(AT::LargeBinary),
Expand Down Expand Up @@ -266,7 +266,7 @@ impl TryFrom<&DataType> for ArrowDataType {

for (type_id, field) in in_fields {
fields.push(AF::try_from(field)?);
type_ids.push((*type_id).try_into()?);
type_ids.push((*type_id).into());
}
Ok(AT::Union(fields, Some(type_ids), (*mode).into()))
}
Expand Down
30 changes: 7 additions & 23 deletions serde_arrow/src/arrow_impl/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ use crate::{
},
internal::{
array_builder::ArrayBuilder,
deserialization::{
array_deserializer::ArrayDeserializer,
outer_sequence_deserializer::OuterSequenceDeserializer,
},
deserializer::Deserializer,
error::{fail, Result},
schema::{get_strategy_from_metadata, SerdeArrowSchema},
schema::SerdeArrowSchema,
serializer::Serializer,
},
};
Expand Down Expand Up @@ -241,11 +237,7 @@ impl<'de> Deserializer<'de> {
where
A: AsRef<dyn Array>,
{
let fields = fields_from_field_refs(fields)?;
let arrays = arrays
.iter()
.map(|array| array.as_ref())
.collect::<Vec<_>>();
use crate::internal::arrow::ArrayView;

if fields.len() != arrays.len() {
fail!(
Expand All @@ -254,23 +246,15 @@ impl<'de> Deserializer<'de> {
arrays.len()
);
}
let len = arrays.first().map(|array| array.len()).unwrap_or_default();

let mut deserializers = Vec::new();
for (field, array) in std::iter::zip(&fields, arrays) {
if array.len() != len {
fail!("arrays of different lengths are not supported");
}
let fields = fields_from_field_refs(fields)?;

let strategy = get_strategy_from_metadata(&field.metadata)?;
let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?;
deserializers.push((field.name.clone(), deserializer));
let mut views = Vec::new();
for array in arrays {
views.push(ArrayView::try_from(array.as_ref())?);
}

let deserializer = OuterSequenceDeserializer::new(deserializers, len);
let deserializer = Deserializer(deserializer);

Ok(deserializer)
Deserializer::new(&fields, views)
}

/// Construct a new deserializer from a record batch (*requires one of the
Expand Down
32 changes: 20 additions & 12 deletions serde_arrow/src/arrow_impl/type_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::_impl::arrow::{
use crate::internal::{
arrow::Field,
error::{Error, Result},
schema::extensions::FixedShapeTensorField,
schema::extensions::{Bool8Field, FixedShapeTensorField, VariableShapeTensorField},
};

impl From<ArrowError> for Error {
Expand All @@ -15,22 +15,30 @@ impl From<ArrowError> for Error {
}
}

impl TryFrom<&FixedShapeTensorField> for ArrowField {
type Error = Error;
macro_rules! impl_try_from_ext_type {
($ty:ty) => {
impl TryFrom<&$ty> for ArrowField {
type Error = Error;

fn try_from(value: &FixedShapeTensorField) -> Result<Self, Self::Error> {
Self::try_from(&Field::try_from(value)?)
}
}
fn try_from(value: &$ty) -> Result<Self, Self::Error> {
Self::try_from(&Field::try_from(value)?)
}
}

impl TryFrom<FixedShapeTensorField> for ArrowField {
type Error = Error;
impl TryFrom<$ty> for ArrowField {
type Error = Error;

fn try_from(value: FixedShapeTensorField) -> Result<Self, Self::Error> {
Self::try_from(&value)
}
fn try_from(value: $ty) -> Result<Self, Self::Error> {
Self::try_from(&value)
}
}
};
}

impl_try_from_ext_type!(Bool8Field);
impl_try_from_ext_type!(FixedShapeTensorField);
impl_try_from_ext_type!(VariableShapeTensorField);

pub fn fields_from_field_refs(fields: &[FieldRef]) -> Result<Vec<Field>> {
fields
.iter()
Expand Down
32 changes: 31 additions & 1 deletion serde_arrow/src/internal/deserializer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
use serde::de::Visitor;

use crate::internal::{
deserialization::outer_sequence_deserializer::OuterSequenceDeserializer,
arrow::{ArrayView, Field},
deserialization::{
array_deserializer::ArrayDeserializer,
outer_sequence_deserializer::OuterSequenceDeserializer,
},
error::{fail, Error, Result},
schema::get_strategy_from_metadata,
utils::array_view_ext::ArrayViewExt,
};

/// A structure to deserialize Arrow arrays into Rust objects
Expand All @@ -14,6 +20,30 @@ use crate::internal::{
#[cfg_attr(has_arrow2, doc = r"- [`Deserializer::from_arrow2`]")]
pub struct Deserializer<'de>(pub(crate) OuterSequenceDeserializer<'de>);

impl<'de> Deserializer<'de> {
pub(crate) fn new(fields: &[Field], views: Vec<ArrayView<'de>>) -> Result<Self> {
let len = match views.first() {
Some(view) => view.len(),
None => 0,
};

let mut deserializers = Vec::new();
for (field, view) in std::iter::zip(fields, views) {
if view.len() != len {
fail!("Cannot deserialize from arrays with different lengths");
}
let strategy = get_strategy_from_metadata(&field.metadata)?;
let deserializer = ArrayDeserializer::new(strategy.as_ref(), view)?;
deserializers.push((field.name.clone(), deserializer));
}

let deserializer = OuterSequenceDeserializer::new(deserializers, len);
let deserializer = Deserializer(deserializer);

Ok(deserializer)
}
}

impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> {
type Error = Error;

Expand Down
106 changes: 106 additions & 0 deletions serde_arrow/src/internal/schema/extensions/bool8_field.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::collections::HashMap;

use crate::internal::{
arrow::{DataType, Field},
error::{Error, Result},
schema::PrettyField,
};

/// A helper to construct new `Bool8` fields (`arrow.bool8`)
///
/// This extension type can be used with `overwrites` in schema tracing:
///
/// ```rust
/// # use serde_json::json;
/// # use serde_arrow::{Result, schema::{SerdeArrowSchema, SchemaLike, TracingOptions, ext::Bool8Field}};
/// # use serde::Deserialize;
/// # fn main() -> Result<()> {
/// ##[derive(Deserialize)]
/// struct Record {
/// int_field: i32,
/// nested: Nested,
/// }
///
/// ##[derive(Deserialize)]
/// struct Nested {
/// bool_field: bool,
/// }
///
/// let tracing_options = TracingOptions::default()
/// .overwrite("nested.bool_field", Bool8Field::new("bool_field"))?;
///
/// let schema = SerdeArrowSchema::from_type::<Record>(tracing_options)?;
/// # std::mem::drop(schema);
/// # Ok(())
/// # }
/// ```
///
/// It can also be converted to a `arrow` `Field` for manual schema manipulation.
///
pub struct Bool8Field {
name: String,
nullable: bool,
}

impl Bool8Field {
/// Construct a new non-nullable `Bool8Field`
pub fn new(name: &str) -> Self {
Self {
name: name.into(),
nullable: false,
}
}

/// Set the nullability of the field
pub fn nullable(mut self, value: bool) -> Self {
self.nullable = value;
self
}
}

impl TryFrom<&Bool8Field> for Field {
type Error = Error;

fn try_from(value: &Bool8Field) -> Result<Self> {
let mut metadata = HashMap::new();
metadata.insert("ARROW:extension:name".into(), "arrow.bool8".into());
metadata.insert("ARROW:extension:metadata".into(), String::new());

Ok(Field {
name: value.name.to_owned(),
nullable: value.nullable,
data_type: DataType::Int8,
metadata,
})
}
}

impl serde::ser::Serialize for Bool8Field {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::Error;
let field = Field::try_from(self).map_err(S::Error::custom)?;
PrettyField(&field).serialize(serializer)
}
}

#[test]
fn bool8_repr() -> crate::internal::error::PanicOnError<()> {
use serde_json::json;

let field = Bool8Field::new("hello");

let field = Field::try_from(&field)?;
let actual = serde_json::to_value(&PrettyField(&field))?;

let expected = json!({
"name": "hello",
"data_type": "I8",
"metadata": {
"ARROW:extension:name": "arrow.bool8",
"ARROW:extension:metadata": "",
},
});

assert_eq!(actual, expected);
Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::internal::{

use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr};

/// Easily construct a field for tensors with fixed shape
/// Easily construct a fixed shape tensor fields (`arrow.fixed_shape_tensor`)
///
/// See the [arrow docs][fixed-shape-tensor-docs] for details on the different
/// fields.
Expand Down Expand Up @@ -51,7 +51,7 @@ pub struct FixedShapeTensorField {
}

impl FixedShapeTensorField {
/// Construct a new instance
/// Construct a new non-nullable `FixedShapeTensorField`
///
/// Note the element parameter must serialize into a valid field definition
/// with the the name `"element"`. The field type can be any valid Arrow
Expand Down
2 changes: 2 additions & 0 deletions serde_arrow/src/internal/schema/extensions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod bool8_field;
mod fixed_shape_tensor_field;
mod utils;
mod variable_shape_tensor_field;

pub use bool8_field::Bool8Field;
pub use fixed_shape_tensor_field::FixedShapeTensorField;
pub use variable_shape_tensor_field::VariableShapeTensorField;
Loading