From dc180a444b35d00ec436f54ca5f74bd7fe295172 Mon Sep 17 00:00:00 2001 From: Koki Ueha Date: Sun, 29 Jan 2023 11:09:18 +0000 Subject: [PATCH] fix(rust, python): implement ser/de for BinaryChunked --- polars/polars-core/src/serde/chunked_array.rs | 5 ++++- polars/polars-core/src/serde/mod.rs | 21 +++++++++++++++++++ polars/polars-core/src/serde/series.rs | 10 +++++++++ py-polars/tests/unit/test_serde.py | 20 ++++++++++++++++++ 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/polars/polars-core/src/serde/chunked_array.rs b/polars/polars-core/src/serde/chunked_array.rs index fb617306cdd4..f0f2999bc8ec 100644 --- a/polars/polars-core/src/serde/chunked_array.rs +++ b/polars/polars-core/src/serde/chunked_array.rs @@ -120,6 +120,9 @@ impl_serialize!(Utf8Chunked); impl_serialize!(BooleanChunked); impl_serialize!(ListChunked); +#[cfg(feature = "dtype-binary")] +impl_serialize!(BinaryChunked); + #[cfg(feature = "dtype-categorical")] impl Serialize for CategoricalChunked { fn serialize( @@ -140,7 +143,7 @@ impl Serialize for CategoricalChunked { } } -#[cfg(feature = "dtype-categorical")] +#[cfg(feature = "dtype-struct")] impl Serialize for StructChunked { fn serialize( &self, diff --git a/polars/polars-core/src/serde/mod.rs b/polars/polars-core/src/serde/mod.rs index ea37ba3e5014..7c9ae54b2a54 100644 --- a/polars/polars-core/src/serde/mod.rs +++ b/polars/polars-core/src/serde/mod.rs @@ -21,6 +21,7 @@ enum DeDataType<'a> { Float32, Float64, Utf8, + Binary, Date, Datetime(TimeUnit, Option), Duration(TimeUnit), @@ -49,6 +50,8 @@ impl From<&DataType> for DeDataType<'_> { DataType::Boolean => DeDataType::Boolean, DataType::Null => DeDataType::Null, DataType::List(_) => DeDataType::List, + #[cfg(feature = "dtype-binary")] + DataType::Binary => DeDataType::Binary, #[cfg(feature = "object")] DataType::Object(s) => DeDataType::Object(s), #[cfg(feature = "dtype-struct")] @@ -128,6 +131,24 @@ mod test { let out = serde_json::from_reader::<_, DataFrame>(json.as_bytes()).unwrap(); // uses `DeserializeOwned` assert!(df.frame_equal_missing(&out)); } + + #[test] + #[cfg(feature = "dtype-binary")] + fn test_serde_binary_series_owned_bincode() { + let s1 = Series::new( + "foo", + &[ + vec![1u8, 2u8, 3u8], + vec![4u8, 5u8, 6u8, 7u8], + vec![8u8, 9u8], + ], + ); + let df = DataFrame::new(vec![s1]).unwrap(); + let bytes = bincode::serialize(&df).unwrap(); + let out = bincode::deserialize_from::<_, DataFrame>(bytes.as_slice()).unwrap(); + assert!(df.frame_equal_missing(&out)); + } + #[test] #[cfg(feature = "dtype-struct")] fn test_serde_struct_series_owned_json() { diff --git a/polars/polars-core/src/serde/series.rs b/polars/polars-core/src/serde/series.rs index 075c23584fa9..3a20655b656d 100644 --- a/polars/polars-core/src/serde/series.rs +++ b/polars/polars-core/src/serde/series.rs @@ -35,6 +35,11 @@ impl Serialize for Series { ca.serialize(serializer) } else { match self.dtype() { + #[cfg(feature = "dtype-binary")] + DataType::Binary => { + let ca = self.binary().unwrap(); + ca.serialize(serializer) + } #[cfg(feature = "dtype-struct")] DataType::Struct(_) => { let ca = self.struct_().unwrap(); @@ -201,6 +206,11 @@ impl<'de> Deserialize<'de> for Series { let values: Vec> = map.next_value()?; Ok(Series::new(&name, values)) } + #[cfg(feature = "dtype-binary")] + DeDataType::Binary => { + let values: Vec>> = map.next_value()?; + Ok(Series::new(&name, values)) + } #[cfg(feature = "dtype-struct")] DeDataType::Struct => { let values: Vec = map.next_value()?; diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 0d2c36759870..108aa812ec7c 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -61,3 +61,23 @@ def test_serde_duration() -> None: def test_serde_expression_5461() -> None: e = pl.col("a").sqrt() / pl.col("b").alias("c") assert pickle.loads(pickle.dumps(e)).meta == e.meta + + +def test_serde_binary() -> None: + data = pl.Series( + "binary_data", + [ + b"\xba\x9b\xca\xd3y\xcb\xc9#", + b"9\x04\xab\xe2\x11\xf3\x85", + b"\xb8\xcb\xc9^\\\xa9-\x94\xe0H\x9d ", + b"S\xbc:\xcb\xf0\xf5r\xfe\x18\xfeH", + b",\xf5)y\x00\xe5\xf7", + b"\xfd\xf6\xf1\xc2X\x0cn\xb9#", + b"\x06\xef\xa6\xa2\xb7", + b"@\xff\x95\xda\xff\xd2\x18", + ], + ) + assert_series_equal( + data, + pickle.loads(pickle.dumps(data)), + )