Skip to content

Commit

Permalink
feat: Add serde support for arrays (#543)
Browse files Browse the repository at this point in the history
Also adds tests for `ArrayRef` and `MapArray` serialization.
  • Loading branch information
bjchambers authored Jul 21, 2023
1 parent 1037a91 commit 58c7351
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/sparrow-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ static_init.workspace = true
tracing.workspace = true

[dev-dependencies]
bincode.workspace = true
insta.workspace = true
proptest.workspace = true

Expand Down
92 changes: 92 additions & 0 deletions crates/sparrow-arrow/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,51 @@ pub mod array_ref {
}
}

/// Provides serde for specific `Array` values.
///
/// Example:
///
/// ```rust
/// #[derive(serde::Serialize, serde::Deserialize)]
/// struct Foo {
/// #[serde(with = "sparrow_arrow::serde::array")]
/// array: MapArray
/// }
/// ```
pub mod array {
use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow_array::cast::downcast_array;
use arrow_array::Array;

pub fn serialize<T: Array + Clone + 'static, S>(
array: &T,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// We want to serialize the array using logic for RecordBatch.
// This requires we get an `ArrayRef` for it.
// This requires creating an owned array.
// This cloning should only do a shallow copy.
let array: T = array.clone();
let array: ArrayRef = Arc::new(array);
super::array_ref::serialize(&array, serializer)
}

pub fn deserialize<'de, T: From<arrow::array::ArrayData>, D>(
deserializer: D,
) -> Result<T, D::Error>
where
D: serde::Deserializer<'de>,
{
let array = super::array_ref::deserialize(deserializer)?;
Ok(downcast_array(array.as_ref()))
}
}

fn encode_batch(batch: &RecordBatch) -> Result<Vec<u8>, ArrowError> {
let c = Cursor::new(Vec::new());

Expand All @@ -112,3 +157,50 @@ fn decode_batch(bytes: Vec<u8>) -> Result<RecordBatch, ArrowError> {
let batches: Vec<_> = file_reader.into_iter().try_collect()?;
arrow::compute::concat_batches(&schema, &batches)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow_array::builder::{MapBuilder, StringBuilder, UInt32Builder};
use arrow_array::{ArrayRef, MapArray, UInt32Array};

#[derive(serde::Serialize, serde::Deserialize)]
struct ArrayRefTest {
#[serde(with = "crate::serde::array_ref")]
array: ArrayRef,
}

#[derive(serde::Serialize, serde::Deserialize)]
struct MapArrayTest {
#[serde(with = "crate::serde::array")]
array: MapArray,
}

#[test]
fn test_array_ref() {
let original = ArrayRefTest {
array: Arc::new(UInt32Array::from(vec![0, 1, 2])),
};
let copied = round_trip_bincode(&original);
assert_eq!(&original.array, &copied.array);
}

#[test]
fn test_map_array() {
let mut builder = MapBuilder::new(None, StringBuilder::new(), UInt32Builder::new());
builder.keys().append_value("hello");
builder.values().append_value(5);
builder.append(true).unwrap();
let array = builder.finish();

let original = MapArrayTest { array };
let copied = round_trip_bincode(&original);
assert_eq!(&original.array, &copied.array);
}

fn round_trip_bincode<T: serde::Serialize + serde::de::DeserializeOwned>(t: &T) -> T {
let serialized = bincode::serialize(t).unwrap();
bincode::deserialize(&serialized).unwrap()
}
}

0 comments on commit 58c7351

Please sign in to comment.