From 3a22612154d8a40796f8247fe8ed2986246322be Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 19 Jun 2024 20:46:33 +0200 Subject: [PATCH 001/178] Fix the deps of example and integration test --- Cargo.lock | 162 ++++++++++++----------------------- example/Cargo.toml | 6 +- integration_tests/Cargo.toml | 6 +- 3 files changed, 62 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e036c5c4..73d47c08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,35 +76,35 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" +checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" dependencies = [ "arrow-arith", - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-cast 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-cast", "arrow-csv", - "arrow-data 51.0.0", + "arrow-data 52.0.0", "arrow-ipc", - "arrow-json 51.0.0", + "arrow-json", "arrow-ord", "arrow-row", - "arrow-schema 51.0.0", - "arrow-select 51.0.0", + "arrow-schema 52.0.0", + "arrow-select", "arrow-string", ] [[package]] name = "arrow-arith" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" +checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-data 52.0.0", + "arrow-schema 52.0.0", "chrono", "half 2.2.1", "num", @@ -534,26 +534,6 @@ dependencies = [ "num", ] -[[package]] -name = "arrow-cast" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" -dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", - "arrow-select 51.0.0", - "atoi", - "base64", - "chrono", - "half 2.2.1", - "lexical-core", - "num", - "ryu", -] - [[package]] name = "arrow-cast" version = "52.0.0" @@ -564,7 +544,7 @@ dependencies = [ "arrow-buffer 52.0.0", "arrow-data 52.0.0", "arrow-schema 52.0.0", - "arrow-select 52.0.0", + "arrow-select", "atoi", "base64", "chrono", @@ -576,15 +556,15 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95cbcba196b862270bf2a5edb75927380a7f3a163622c61d40cbba416a6305f2" +checksum = "cea5068bef430a86690059665e40034625ec323ffa4dd21972048eebb0127adc" dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-cast 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-cast", + "arrow-data 52.0.0", + "arrow-schema 52.0.0", "chrono", "csv", "csv-core", @@ -787,38 +767,18 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" +checksum = "ffc68f6523970aa6f7ce1dc9a33a7d9284cfb9af77d4ad3e617dbe5d79cc6ec8" dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-cast 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-cast", + "arrow-data 52.0.0", + "arrow-schema 52.0.0", "flatbuffers", ] -[[package]] -name = "arrow-json" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaafb5714d4e59feae964714d724f880511500e3569cc2a94d02456b403a2a49" -dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-cast 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", - "chrono", - "half 2.2.1", - "indexmap", - "lexical-core", - "num", - "serde", - "serde_json", -] - [[package]] name = "arrow-json" version = "52.0.0" @@ -827,7 +787,7 @@ checksum = "2041380f94bd6437ab648e6c2085a045e45a0c44f91a1b9a4fe3fed3d379bfb1" dependencies = [ "arrow-array 52.0.0", "arrow-buffer 52.0.0", - "arrow-cast 52.0.0", + "arrow-cast", "arrow-data 52.0.0", "arrow-schema 52.0.0", "chrono", @@ -841,30 +801,30 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" +checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", - "arrow-select 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-data 52.0.0", + "arrow-schema 52.0.0", + "arrow-select", "half 2.2.1", "num", ] [[package]] name = "arrow-row" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" +checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" dependencies = [ "ahash 0.8.3", - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-data 52.0.0", + "arrow-schema 52.0.0", "half 2.2.1", "hashbrown 0.14.0", ] @@ -1013,20 +973,6 @@ dependencies = [ "serde", ] -[[package]] -name = "arrow-select" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" -dependencies = [ - "ahash 0.8.3", - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", - "num", -] - [[package]] name = "arrow-select" version = "52.0.0" @@ -1043,15 +989,15 @@ dependencies = [ [[package]] name = "arrow-string" -version = "51.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" +checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" dependencies = [ - "arrow-array 51.0.0", - "arrow-buffer 51.0.0", - "arrow-data 51.0.0", - "arrow-schema 51.0.0", - "arrow-select 51.0.0", + "arrow-array 52.0.0", + "arrow-buffer 52.0.0", + "arrow-data 52.0.0", + "arrow-schema 52.0.0", + "arrow-select", "memchr", "num", "regex", @@ -1564,9 +1510,9 @@ dependencies = [ [[package]] name = "flatbuffers" -version = "23.5.26" +version = "24.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" dependencies = [ "bitflags", "rustc_version", @@ -2334,7 +2280,7 @@ dependencies = [ "arrow-data 50.0.0", "arrow-data 51.0.0", "arrow-data 52.0.0", - "arrow-json 52.0.0", + "arrow-json", "arrow-schema 37.0.0", "arrow-schema 38.0.0", "arrow-schema 39.0.0", diff --git a/example/Cargo.toml b/example/Cargo.toml index 231a4e23..95731707 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -7,9 +7,11 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -arrow = {version = "51.0", features = ["ipc"] } +# arrow-version:replace: arrow = {{ version = "52.0", features = [{version}] }} +arrow = {version = "52.0", features = ["ipc"] } chrono = { version = "0.4", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } -serde_arrow = { path = "../serde_arrow", features = ["arrow-51"] } +# arrow-version:replace: serde_arrow = {{ path = "../serde_arrow", features = ["arrow-{version}"] }} +serde_arrow = { path = "../serde_arrow", features = ["arrow-52"] } diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index 00d2122c..fd398ca3 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -4,10 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] -arrow = {version = "51.0", features = ["ipc"] } +# arrow-version:replace: arrow = {{ version = "52.0", features = [{version}] }} +arrow = {version = "52.0", features = ["ipc"] } chrono = { version = "0.4", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1" -serde_arrow = { path = "../serde_arrow", features = ["arrow-51"] } +# arrow-version:replace: serde_arrow = {{ path = "../serde_arrow", features = ["arrow-{version}"] }} +serde_arrow = { path = "../serde_arrow", features = ["arrow-52"] } From 3c9ee4975716921bbdc97b88ff51eb78ab800f61 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 19 Jun 2024 21:08:21 +0200 Subject: [PATCH 002/178] Remove deprecated API --- Changes.md | 6 ++ serde_arrow/src/arrow2_impl/api.rs | 89 +------------------ serde_arrow/src/arrow2_impl/schema.rs | 9 +- serde_arrow/src/arrow_impl/api.rs | 92 +------------------- serde_arrow/src/arrow_impl/schema.rs | 9 +- serde_arrow/src/internal/schema/data_type.rs | 9 +- serde_arrow/src/lib.rs | 12 --- 7 files changed, 15 insertions(+), 211 deletions(-) diff --git a/Changes.md b/Changes.md index a33b7bb2..9aabb455 100644 --- a/Changes.md +++ b/Changes.md @@ -1,5 +1,11 @@ # Change log +## 0.12 + +- Remove `serde_arrow::schema::Schema` +- Remove `serde_arrow::ArrowBuilder` and `serde_arrow::Arrow2Builder` +- Use `impl serde::Serialize` instead of `&(impl serde::Serialize + ?Sized)` + ## 0.11.6 - Add `arrow=52` support diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index cf95b246..bd997275 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -8,97 +8,10 @@ use crate::{ _impl::arrow2::{array::Array, datatypes::Field}, internal::{ array_builder::ArrayBuilder, deserializer::Deserializer, error::Result, - schema::SerdeArrowSchema, serializer::Serializer, + serializer::Serializer, }, }; -/// Build arrow2 arrays record by record (*requires one of the `arrow2-*` -/// features*) -/// -/// The given items should be records (e.g., structs). To serialize items -/// encoding single values consider the [`Items`][crate::utils::Items] and -/// [`Item`][crate::utils::Item] wrappers. -/// -/// Example: -/// -/// ```rust -/// # fn main() -> serde_arrow::Result<()> { -/// # use serde_arrow::_impl::arrow2 as arrow2; -/// use arrow2::datatypes::{DataType, Field}; -/// use serde::Serialize; -/// use serde_arrow::Arrow2Builder; -/// -/// ##[derive(Serialize)] -/// struct Record { -/// a: Option, -/// b: u64, -/// } -/// -/// let mut builder = Arrow2Builder::new(&[ -/// Field::new("a", DataType::Float32, true), -/// Field::new("b", DataType::UInt64, false), -/// ])?; -/// -/// builder.push(&Record { a: Some(1.0), b: 2})?; -/// builder.push(&Record { a: Some(3.0), b: 4})?; -/// builder.push(&Record { a: Some(5.0), b: 5})?; -/// -/// builder.extend(&[ -/// Record { a: Some(6.0), b: 7}, -/// Record { a: Some(8.0), b: 9}, -/// Record { a: Some(10.0), b: 11}, -/// ])?; -/// -/// let arrays = builder.build_arrays()?; -/// # -/// # assert_eq!(arrays.len(), 2); -/// # assert_eq!(arrays[0].len(), 6); -/// # Ok(()) -/// # } -/// ``` -#[deprecated = "`Arrow2Builder` is deprecated. Use `ArrayBuilder` instead"] -pub struct Arrow2Builder(ArrayBuilder); - -#[allow(deprecated)] -impl std::fmt::Debug for Arrow2Builder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Arrow2Builder<...>") - } -} - -#[allow(deprecated)] -impl Arrow2Builder { - /// Build a new Arrow2Builder for the given fields - /// - /// This method may fail when unsupported data types are encountered in the - /// given fields. - /// - pub fn new(fields: &[Field]) -> Result { - let schema = SerdeArrowSchema::from_arrow2_fields(fields)?; - Ok(Self(ArrayBuilder::new(schema)?)) - } - - /// Add a single record to the arrays - /// - pub fn push(&mut self, item: &T) -> Result<()> { - self.0.push(item) - } - - /// Add multiple records to the arrays - /// - pub fn extend(&mut self, items: &T) -> Result<()> { - self.0.extend(items) - } - - /// Build the arrays from the rows pushed to far. - /// - /// This operation will reset the underlying buffers and start a new batch. - /// - pub fn build_arrays(&mut self) -> Result>> { - self.0.to_arrow2() - } -} - /// Build arrow2 arrays from the given items (*requires one of the `arrow2-*` /// features*) /// diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 7b730ad2..f5a156b5 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -18,8 +18,7 @@ impl SerdeArrowSchema { Self::try_from(fields) } - /// This method is deprecated. Use - /// [`to_arrow2_fields`][SerdeArrowSchema::to_arrow2_fields] instead: + /// Build a vec of fields from a Schema object /// /// ```rust /// # fn main() -> serde_arrow::_impl::PanicOnError<()> { @@ -33,12 +32,6 @@ impl SerdeArrowSchema { /// # Ok(()) /// # } /// ``` - #[deprecated = "The method `get_arrow2_fields` is deprecated. Use `to_arrow2_fields` instead"] - pub fn get_arrow2_fields(&self) -> Result> { - Vec::::try_from(self) - } - - /// Build a vec of fields from a Schema object pub fn to_arrow2_fields(&self) -> Result> { Vec::::try_from(self) } diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index 50f6cf2b..ad0815c4 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -14,94 +14,6 @@ use crate::{ }, }; -/// Build arrow arrays record by record (*requires one of the `arrow-*` -/// features*) -/// -/// The given items should be records (e.g., structs). To serialize items -/// encoding single values consider the [`Items`][crate::utils::Items] and -/// [`Item`][crate::utils::Item] wrappers. -/// -/// Example: -/// -/// ```rust -/// # fn main() -> serde_arrow::Result<()> { -/// # use serde_arrow::_impl::arrow as arrow; -/// use arrow::datatypes::{DataType, Field}; -/// use serde::Serialize; -/// use serde_arrow::ArrowBuilder; -/// -/// ##[derive(Serialize)] -/// struct Record { -/// a: Option, -/// b: u64, -/// } -/// -/// let mut builder = ArrowBuilder::new(&[ -/// Field::new("a", DataType::Float32, true), -/// Field::new("b", DataType::UInt64, false), -/// ])?; -/// -/// builder.push(&Record { a: Some(1.0), b: 2})?; -/// builder.push(&Record { a: Some(3.0), b: 4})?; -/// builder.push(&Record { a: Some(5.0), b: 5})?; -/// -/// builder.extend(&[ -/// Record { a: Some(6.0), b: 7}, -/// Record { a: Some(8.0), b: 9}, -/// Record { a: Some(10.0), b: 11}, -/// ])?; -/// -/// let arrays = builder.build_arrays()?; -/// # -/// # assert_eq!(arrays.len(), 2); -/// # assert_eq!(arrays[0].len(), 6); -/// # Ok(()) -/// # } -/// ``` -#[deprecated = "`ArrowBuilder` is deprecated. Use `ArrayBuilder` instead"] -pub struct ArrowBuilder(ArrayBuilder); - -#[allow(deprecated)] -impl std::fmt::Debug for ArrowBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ArrowBuilder<...>") - } -} - -#[allow(deprecated)] -impl ArrowBuilder { - /// Build a new ArrowBuilder for the given fields - /// - /// This method may fail when unsupported data types are encountered in the - /// given fields. - /// - pub fn new(fields: &[Field]) -> Result { - Ok(Self(ArrayBuilder::new(SerdeArrowSchema::try_from( - fields, - )?)?)) - } - - /// Add a single record to the arrays - /// - pub fn push(&mut self, item: &T) -> Result<()> { - self.0.push(item) - } - - /// Add multiple records to the arrays - /// - pub fn extend(&mut self, items: &T) -> Result<()> { - self.0.extend(items) - } - - /// Build the arrays from the rows pushed to far. - /// - /// This operation will reset the underlying buffers and start a new batch. - /// - pub fn build_arrays(&mut self) -> Result> { - self.0.to_arrow() - } -} - /// Build arrow arrays from the given items (*requires one of the `arrow-*` /// features*) /// @@ -109,7 +21,7 @@ impl ArrowBuilder { /// structs). To serialize items encoding single values consider the /// [`Items`][crate::utils::Items] wrapper. /// -/// To build arrays record by record use [`ArrowBuilder`]. To construct a record +/// To build arrays record by record use [`ArrayBuilder`]. To construct a record /// batch, consider using [`to_record_batch`]. /// /// Example: @@ -193,7 +105,7 @@ where /// structs). To serialize items encoding single values consider the /// [`Items`][crate::utils::Items] wrapper. /// -/// To build arrays record by record use [`ArrowBuilder`]. +/// To build arrays record by record use [`ArrayBuilder`]. /// /// Example: /// diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 16784a64..0f872e27 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -18,8 +18,7 @@ impl SerdeArrowSchema { Self::try_from(fields) } - /// This method is deprecated. Use - /// [`to_arrow_fields`][SerdeArrowSchema::to_arrow_fields] instead: + /// Build a vec of fields from a Schema object /// /// ```rust /// # fn main() -> serde_arrow::_impl::PanicOnError<()> { @@ -33,12 +32,6 @@ impl SerdeArrowSchema { /// # Ok(()) /// # } /// ``` - #[deprecated = "The method `get_arrow_fields` is deprecated. Use `to_arrow_fields` instead"] - pub fn get_arrow_fields(&self) -> Result> { - Vec::::try_from(self) - } - - /// Build a vec of fields from a Schema object pub fn to_arrow_fields(&self) -> Result> { Vec::::try_from(self) } diff --git a/serde_arrow/src/internal/schema/data_type.rs b/serde_arrow/src/internal/schema/data_type.rs index a384f6fb..bfd70025 100644 --- a/serde_arrow/src/internal/schema/data_type.rs +++ b/serde_arrow/src/internal/schema/data_type.rs @@ -194,11 +194,10 @@ impl std::str::FromStr for GenericTimeUnit { fn from_str(s: &str) -> Result { match s { - // TODO: remove plural forms (were incorrectly supported in 0.11.x) - "Second" | "Seconds" => Ok(Self::Second), - "Millisecond" | "Milliseconds" => Ok(Self::Millisecond), - "Microsecond" | "Microseconds" => Ok(Self::Microsecond), - "Nanosecond" | "Nanoseconds" => Ok(Self::Nanosecond), + "Second" => Ok(Self::Second), + "Millisecond" => Ok(Self::Millisecond), + "Microsecond" => Ok(Self::Microsecond), + "Nanosecond" => Ok(Self::Nanosecond), s => fail!("Invalid time unit {s}"), } } diff --git a/serde_arrow/src/lib.rs b/serde_arrow/src/lib.rs index 64f6f8a9..44be12f0 100644 --- a/serde_arrow/src/lib.rs +++ b/serde_arrow/src/lib.rs @@ -325,20 +325,12 @@ mod arrow_impl; #[cfg(has_arrow)] pub use arrow_impl::api::{from_arrow, from_record_batch, to_arrow, to_record_batch}; -#[cfg(has_arrow)] -#[allow(deprecated)] -pub use arrow_impl::api::ArrowBuilder; - #[cfg(has_arrow2)] mod arrow2_impl; #[cfg(has_arrow2)] pub use arrow2_impl::api::{from_arrow2, to_arrow2}; -#[cfg(has_arrow2)] -#[allow(deprecated)] -pub use arrow2_impl::api::Arrow2Builder; - #[deny(missing_docs)] /// Helpers that may be useful when using `serde_arrow` pub mod utils { @@ -395,10 +387,6 @@ pub mod schema { Overwrites, SchemaLike, SerdeArrowSchema, Strategy, TracingOptions, STRATEGY_KEY, }; - /// Renamed to [`SerdeArrowSchema`] - #[deprecated = "serde_arrow::schema::Schema is deprecated. Use serde_arrow::schema::SerdeArrowSchema instead"] - pub type Schema = SerdeArrowSchema; - /// Support for [canonical extension types][ext-docs]. This module is experimental without semver guarantees. /// /// [ext-docs]: https://arrow.apache.org/docs/format/CanonicalExtensions.html From 5f08a069fbb1c8ce283808f944b351ba121b7507 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 19 Jun 2024 21:38:50 +0200 Subject: [PATCH 003/178] Use &[FieldRef] instead of &[Field] in arrow API --- Changes.md | 1 + serde_arrow/benches/groups/impls.rs | 11 +-- serde_arrow/src/_impl/docs/defs.rs | 4 +- serde_arrow/src/_impl/docs/quickstart.rs | 17 ++-- serde_arrow/src/arrow2_impl/api.rs | 2 +- serde_arrow/src/arrow_impl/api.rs | 20 ++--- serde_arrow/src/arrow_impl/deserialization.rs | 5 +- serde_arrow/src/arrow_impl/serialization.rs | 6 +- .../src/test_with_arrow/impls/union.rs | 84 +++++-------------- .../src/test_with_arrow/impls/utils.rs | 13 +-- .../issue_90_top_level_nulls_in_structs.rs | 6 +- .../src/test_with_arrow/items_wrapper.rs | 11 +-- .../serializer_deserializer.rs | 12 +-- 13 files changed, 69 insertions(+), 123 deletions(-) diff --git a/Changes.md b/Changes.md index 9aabb455..606c01cd 100644 --- a/Changes.md +++ b/Changes.md @@ -5,6 +5,7 @@ - Remove `serde_arrow::schema::Schema` - Remove `serde_arrow::ArrowBuilder` and `serde_arrow::Arrow2Builder` - Use `impl serde::Serialize` instead of `&(impl serde::Serialize + ?Sized)` +- Use `&[FieldRef]` instead of `&[Field]` in arrow APIs ## 0.11.6 diff --git a/serde_arrow/benches/groups/impls.rs b/serde_arrow/benches/groups/impls.rs index d8a044f8..1d5e2a80 100644 --- a/serde_arrow/benches/groups/impls.rs +++ b/serde_arrow/benches/groups/impls.rs @@ -13,6 +13,7 @@ macro_rules! define_benchmark { ) => { pub fn benchmark_serialize(c: &mut criterion::Criterion) { use serde_arrow::schema::{SerdeArrowSchema, SchemaLike}; + use serde_arrow::_impl::arrow::datatypes::FieldRef; for n in [$($n),*] { let mut group = c.benchmark_group(format!("{}_serialize({})", stringify!($name), n)); @@ -33,7 +34,7 @@ macro_rules! define_benchmark { .map(|_| <$ty>::random(&mut rng)) .collect::>(); let schema = SerdeArrowSchema::from_samples(&items, Default::default()).unwrap(); - let arrow_fields = schema.to_arrow_fields().unwrap(); + let arrow_fields = Vec::::try_from(&schema).unwrap(); let arrow2_fields = schema.to_arrow2_fields().unwrap(); #[allow(unused)] @@ -93,10 +94,10 @@ pub mod serde_arrow_arrow { use serde::Serialize; use serde_arrow::{ Result, - _impl::arrow::{array::ArrayRef, datatypes::Field}, + _impl::arrow::{array::ArrayRef, datatypes::FieldRef}, }; - pub fn serialize(fields: &[Field], items: &T) -> Result> + pub fn serialize(fields: &[FieldRef], items: &T) -> Result> where T: Serialize + ?Sized, { @@ -132,10 +133,10 @@ pub mod arrow { use serde_arrow::{ Error, Result, - _impl::arrow::{array::ArrayRef, datatypes::Field}, + _impl::arrow::{array::ArrayRef, datatypes::FieldRef}, }; - pub fn serialize(fields: &[Field], items: &[T]) -> Result> + pub fn serialize(fields: &[FieldRef], items: &[T]) -> Result> where T: Serialize, { diff --git a/serde_arrow/src/_impl/docs/defs.rs b/serde_arrow/src/_impl/docs/defs.rs index 295ea335..3320c2fa 100644 --- a/serde_arrow/src/_impl/docs/defs.rs +++ b/serde_arrow/src/_impl/docs/defs.rs @@ -22,12 +22,12 @@ pub fn example_record_batch() -> crate::_impl::arrow::array::RecordBatch { } #[cfg(has_arrow)] -pub fn example_arrow_arrays() -> (Vec, Vec) { +pub fn example_arrow_arrays() -> (Vec, Vec) { use crate::schema::{SchemaLike, TracingOptions}; let items = example_records(); - let fields = Vec::::from_type::(TracingOptions::default()).unwrap(); + let fields = Vec::::from_type::(TracingOptions::default()).unwrap(); let arrays = crate::to_arrow(&fields, &items).unwrap(); (fields, arrays) diff --git a/serde_arrow/src/_impl/docs/quickstart.rs b/serde_arrow/src/_impl/docs/quickstart.rs index 08800dc5..280a9a8b 100644 --- a/serde_arrow/src/_impl/docs/quickstart.rs +++ b/serde_arrow/src/_impl/docs/quickstart.rs @@ -16,7 +16,7 @@ //! # #[cfg(has_arrow)] //! # fn main() { //! # use serde_arrow::_impl::arrow as arrow; -//! use arrow::datatypes::{DataType, Field}; +//! use arrow::datatypes::{DataType, FieldRef}; //! use serde_arrow::{ //! schema::{SchemaLike, Strategy, TracingOptions}, //! utils::{Item, Items}, @@ -37,7 +37,7 @@ //! ```rust //! # #[cfg(has_arrow)] //! # fn main() -> serde_arrow::_impl::PanicOnError<()> { -//! # use serde_arrow::_impl::arrow::datatypes::{DataType, Field}; +//! # use serde_arrow::_impl::arrow::datatypes::{DataType, FieldRef}; //! # use serde_arrow::{schema::{SchemaLike, TracingOptions}, utils::Item}; //! use chrono::NaiveDateTime; //! @@ -46,7 +46,7 @@ //! // ... //! ]; //! -//! let fields = Vec::::from_samples(items, TracingOptions::default())?; +//! let fields = Vec::::from_samples(items, TracingOptions::default())?; //! assert_eq!(fields[0].data_type(), &DataType::LargeUtf8); //! # Ok(()) //! # } @@ -74,6 +74,7 @@ //! ```rust //! # #[cfg(has_arrow)] //! # fn main() -> serde_arrow::_impl::PanicOnError<()> { +//! # use std::sync::Arc; //! # use serde_arrow::_impl::arrow::datatypes::{DataType, Field}; //! # use serde_arrow::utils::Item; //! let records: &[Item] = &[ @@ -81,7 +82,7 @@ //! Item(9 * 60 * 60 * 24 * 1000), //! ]; //! -//! let fields = vec![Field::new("item", DataType::Date64, false)]; +//! let fields = vec![Arc::new(Field::new("item", DataType::Date64, false))]; //! let arrays = serde_arrow::to_arrow(&fields, records)?; //! # Ok(()) //! # } @@ -95,7 +96,7 @@ //! ```rust //! # #[cfg(has_arrow)] //! # fn main() -> serde_arrow::_impl::PanicOnError<()> { -//! # use serde_arrow::_impl::arrow::datatypes::Field; +//! # use serde_arrow::_impl::arrow::datatypes::FieldRef; //! # use serde_arrow::{schema::SchemaLike, utils::Item}; //! use std::str::FromStr; //! @@ -107,7 +108,7 @@ //! Item(BigDecimal::from_str("4.56").unwrap()), //! ]; //! -//! let fields = Vec::::from_value(&json!([ +//! let fields = Vec::::from_value(&json!([ //! {"name": "item", "data_type": "Decimal128(5, 2)"}, //! ]))?; //! @@ -145,10 +146,10 @@ //! ```rust //! # #[cfg(has_arrow)] //! # fn main() -> serde_arrow::_impl::PanicOnError<()> { -//! # use serde_arrow::_impl::arrow::datatypes::Field; +//! # use serde_arrow::_impl::arrow::datatypes::FieldRef; //! # use serde_arrow::{schema::{SchemaLike, TracingOptions}, utils::Item}; //! let items = &[Item("foo"), Item("bar")]; -//! let fields = Vec::::from_samples( +//! let fields = Vec::::from_samples( //! items, //! TracingOptions::default().string_dictionary_encoding(true), //! )?; diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index bd997275..c0352f8e 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -19,7 +19,7 @@ use crate::{ /// structs). To serialize items encoding single values consider the /// [`Items`][crate::utils::Items] wrapper. /// -/// To build arrays record by record use [`Arrow2Builder`]. +/// To build arrays record by record use [`ArrayBuilder`]. /// /// ```rust /// # fn main() -> serde_arrow::Result<()> { diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index ad0815c4..c73e906c 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -1,12 +1,10 @@ #![deny(missing_docs)] -use std::borrow::Cow; - use serde::{Deserialize, Serialize}; use crate::{ _impl::arrow::{ array::{Array, ArrayRef, RecordBatch}, - datatypes::{Field, FieldRef}, + datatypes::FieldRef, }, internal::{ array_builder::ArrayBuilder, deserializer::Deserializer, error::Result, @@ -29,7 +27,7 @@ use crate::{ /// ```rust /// # fn main() -> serde_arrow::Result<()> { /// # use serde_arrow::_impl::arrow; -/// use arrow::datatypes::Field; +/// use arrow::datatypes::FieldRef; /// use serde::{Serialize, Deserialize}; /// use serde_arrow::schema::{SchemaLike, TracingOptions}; /// @@ -44,7 +42,7 @@ use crate::{ /// // ... /// ]; /// -/// let fields = Vec::::from_type::(TracingOptions::default())?; +/// let fields = Vec::::from_type::(TracingOptions::default())?; /// let arrays = serde_arrow::to_arrow(&fields, &items)?; /// # /// # assert_eq!(arrays.len(), 2); @@ -52,7 +50,7 @@ use crate::{ /// # } /// ``` /// -pub fn to_arrow(fields: &[Field], items: &T) -> Result> { +pub fn to_arrow(fields: &[FieldRef], items: &T) -> Result> { let builder = ArrayBuilder::new(SerdeArrowSchema::try_from(fields)?)?; items .serialize(Serializer::new(builder))? @@ -70,7 +68,7 @@ pub fn to_arrow(fields: &[Field], items: &T) -> Result serde_arrow::Result<()> { /// # use serde_arrow::_impl::arrow; -/// use arrow::datatypes::Field; +/// use arrow::datatypes::FieldRef; /// use serde::{Deserialize, Serialize}; /// use serde_arrow::schema::{SchemaLike, TracingOptions}; /// @@ -82,20 +80,18 @@ pub fn to_arrow(fields: &[Field], items: &T) -> Result::from_type::(TracingOptions::default())?; +/// let fields = Vec::::from_type::(TracingOptions::default())?; /// let items: Vec = serde_arrow::from_arrow(&fields, &arrays)?; /// # Ok(()) /// # } /// ``` /// -pub fn from_arrow<'de, T, A>(fields: &[Field], arrays: &'de [A]) -> Result +pub fn from_arrow<'de, T, A>(fields: &[FieldRef], arrays: &'de [A]) -> Result where T: Deserialize<'de>, A: AsRef, { - let fields = fields.iter().map(Cow::Borrowed).collect::>(); - let deserializer = Deserializer::from_arrow(&fields, arrays)?; - T::deserialize(deserializer) + T::deserialize(Deserializer::from_arrow(fields, arrays)?) } /// Build a record batch from the given items (*requires one of the `arrow-*` diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 82dc6116..788d33ef 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -35,7 +35,7 @@ use crate::_impl::arrow::{ datatypes::{ ArrowDictionaryKeyType, ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Field, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, + DurationSecondType, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, @@ -69,9 +69,8 @@ impl<'de> Deserializer<'de> { /// # Ok(()) /// # } /// ``` - pub fn from_arrow(fields: &[F], arrays: &'de [A]) -> Result + pub fn from_arrow(fields: &[FieldRef], arrays: &'de [A]) -> Result where - F: AsRef, A: AsRef, { let fields = fields diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index b149fddb..0ca01941 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -5,7 +5,9 @@ use crate::{ _impl::arrow::{ array::{make_array, Array, ArrayData, ArrayRef, NullArray, RecordBatch}, buffer::{Buffer, ScalarBuffer}, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field, Float16Type, Schema}, + datatypes::{ + ArrowNativeType, ArrowPrimitiveType, DataType, Field, FieldRef, Float16Type, Schema, + }, }, internal::{ error::{fail, Result}, @@ -22,7 +24,7 @@ use crate::{ impl crate::internal::array_builder::ArrayBuilder { /// Build an ArrayBuilder from `arrow` fields (*requires one of the /// `arrow-*` features*) - pub fn from_arrow>(fields: &[F]) -> Result { + pub fn from_arrow(fields: &[FieldRef]) -> Result { let fields = fields .iter() .map(|f| GenericField::try_from(f.as_ref())) diff --git a/serde_arrow/src/test_with_arrow/impls/union.rs b/serde_arrow/src/test_with_arrow/impls/union.rs index 46571d1b..c16b5f3b 100644 --- a/serde_arrow/src/test_with_arrow/impls/union.rs +++ b/serde_arrow/src/test_with_arrow/impls/union.rs @@ -3,8 +3,8 @@ use serde_json::json; use crate::{ internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions}, - utils::Item, + schema::{SchemaLike, Strategy, TracingOptions}, + utils::{Item, Items}, }; use super::utils::Test; @@ -302,68 +302,28 @@ fn enums_union() { .deserialize(&values); } -macro_rules! test_generic { - ( - $(#[ignore = $ignore:literal])? - fn $name:ident() { - $($stmt:stmt)* - } - ) => { - #[allow(unused)] - mod $name { - use crate::{ - schema::{SchemaLike, TracingOptions}, - utils::{Items, Item} - }; - use crate::internal::schema::{GenericField, GenericDataType}; - - mod arrow { - use super::*; - use crate::{to_arrow, from_arrow}; - use crate::_impl::arrow::datatypes::Field; - - $(#[ignore = $ignore])? - #[test] - fn test() { - $($stmt)* - } - } - mod arrow2 { - use super::*; - use crate::{to_arrow2 as to_arrow, from_arrow2 as from_arrow}; - use crate::_impl::arrow2::datatypes::Field; - - $(#[ignore = $ignore])? - #[test] - fn test() { - $($stmt)* - } - } - } - }; -} +#[test] +fn missing_union_variants() { + use crate::_impl::arrow::datatypes::FieldRef; + + use crate::internal::testing::assert_error; + use crate::schema::TracingOptions; + use serde::{Deserialize, Serialize}; -test_generic!( - fn missing_union_variants() { - use crate::internal::testing::assert_error; - use crate::schema::TracingOptions; - use serde::{Deserialize, Serialize}; - - #[derive(Serialize, Deserialize, Debug, PartialEq)] - enum U { - A, - B, - C, - } - - let tracing_options = TracingOptions::default().allow_null_fields(true); - let fields = Vec::::from_samples(&Items(&[U::A, U::C]), tracing_options).unwrap(); - - // NOTE: variant B was never encountered during tracing - let res = to_arrow(&fields, &Items(&[U::A, U::B, U::C])); - assert_error(&res, "Serialization failed: an unknown variant"); + #[derive(Serialize, Deserialize, Debug, PartialEq)] + enum U { + A, + B, + C, } -); + + let tracing_options = TracingOptions::default().allow_null_fields(true); + let fields = Vec::::from_samples(&Items(&[U::A, U::C]), tracing_options).unwrap(); + + // NOTE: variant B was never encountered during tracing + let res = crate::to_arrow(&fields, &Items(&[U::A, U::B, U::C])); + assert_error(&res, "Serialization failed: an unknown variant"); +} #[test] fn fieldless_unions_as_dictionary() { diff --git a/serde_arrow/src/test_with_arrow/impls/utils.rs b/serde_arrow/src/test_with_arrow/impls/utils.rs index a48cd381..84f1aeb0 100644 --- a/serde_arrow/src/test_with_arrow/impls/utils.rs +++ b/serde_arrow/src/test_with_arrow/impls/utils.rs @@ -16,7 +16,7 @@ pub struct Arrays { #[derive(Default)] pub struct Fields { - pub arrow: Option>, + pub arrow: Option>, pub arrow2: Option>, } @@ -77,11 +77,10 @@ impl Test { } impl Test { - pub fn get_arrow_fields(&self) -> Cow<'_, Vec> { + pub fn get_arrow_fields(&self) -> Cow<'_, Vec> { match self.schema.as_ref() { Some(schema) => Cow::Owned( - schema - .to_arrow_fields() + Vec::::try_from(schema) .expect("Cannot covert schema to arrow fields"), ), None => Cow::Borrowed( @@ -146,10 +145,6 @@ impl Test { pub fn try_serialize_arrow(&mut self, items: &T) -> Result<()> { let fields = self.get_arrow_fields().to_vec(); - let field_refs = fields - .iter() - .map(|f| Arc::new(f.clone())) - .collect::>(); let arrays = crate::to_arrow(&fields, items)?; assert_eq!(fields.len(), arrays.len()); @@ -168,7 +163,7 @@ impl Test { self.arrays.arrow = Some(arrays); - let mut builder = crate::ArrayBuilder::from_arrow(&field_refs)?; + let mut builder = crate::ArrayBuilder::from_arrow(&fields)?; builder.extend(items)?; let arrays = builder.to_arrow()?; assert_eq!(self.arrays.arrow, Some(arrays)); diff --git a/serde_arrow/src/test_with_arrow/issue_90_top_level_nulls_in_structs.rs b/serde_arrow/src/test_with_arrow/issue_90_top_level_nulls_in_structs.rs index 6365ce5b..76030bf5 100644 --- a/serde_arrow/src/test_with_arrow/issue_90_top_level_nulls_in_structs.rs +++ b/serde_arrow/src/test_with_arrow/issue_90_top_level_nulls_in_structs.rs @@ -9,7 +9,7 @@ use crate::{ schema::{SchemaLike, TracingOptions}, }; -use crate::_impl::arrow::{_raw::schema::Schema, array::RecordBatch, datatypes::Field}; +use crate::_impl::arrow::{_raw::schema::Schema, array::RecordBatch, datatypes::FieldRef}; #[derive(Deserialize, Serialize, Debug, PartialEq, Clone)] pub struct Distribution { @@ -40,7 +40,7 @@ fn example() -> PanicOnError<()> { VectorMetric { distribution: None }, ]; - let fields = Vec::::from_type::(TracingOptions::default())?; + let fields = Vec::::from_type::(TracingOptions::default())?; let arrays = serde_arrow::to_arrow(&fields, &metrics)?; let batch = RecordBatch::try_new(Arc::new(Schema::new(fields.clone())), arrays.clone())?; @@ -55,7 +55,7 @@ fn example() -> PanicOnError<()> { #[test] fn example_top_level_none() -> PanicOnError<()> { // top-level options are not supported if fields are are extracted - let res = Vec::::from_type::>(TracingOptions::default()); + let res = Vec::::from_type::>(TracingOptions::default()); assert!(res.is_err()); Ok(()) } diff --git a/serde_arrow/src/test_with_arrow/items_wrapper.rs b/serde_arrow/src/test_with_arrow/items_wrapper.rs index 60d4b637..38835a6b 100644 --- a/serde_arrow/src/test_with_arrow/items_wrapper.rs +++ b/serde_arrow/src/test_with_arrow/items_wrapper.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use crate::{ self as serde_arrow, - _impl::arrow::{_raw::schema::Schema, array::RecordBatch, datatypes::Field}, + _impl::arrow::{_raw::schema::Schema, array::RecordBatch, datatypes::FieldRef}, internal::error::PanicOnError, schema::{SchemaLike, TracingOptions}, utils::{Item, Items}, @@ -10,14 +10,11 @@ use crate::{ #[test] fn example() -> PanicOnError<()> { - use serde_arrow::schema::SerdeArrowSchema; - let items: Vec = vec![1, 2, 3, 4, 5]; - let fields_from_type: Vec = - SerdeArrowSchema::from_type::>(TracingOptions::default())?.try_into()?; - let fields_from_samples: Vec = - SerdeArrowSchema::from_samples(&Items(&items), TracingOptions::default())?.try_into()?; + let fields_from_type = Vec::::from_type::>(TracingOptions::default())?; + let fields_from_samples = + Vec::::from_samples(&Items(&items), TracingOptions::default())?; assert_eq!(fields_from_type, fields_from_samples); let fields = fields_from_type; diff --git a/serde_arrow/src/test_with_arrow/serializer_deserializer.rs b/serde_arrow/src/test_with_arrow/serializer_deserializer.rs index fe4b6af0..c6009738 100644 --- a/serde_arrow/src/test_with_arrow/serializer_deserializer.rs +++ b/serde_arrow/src/test_with_arrow/serializer_deserializer.rs @@ -1,10 +1,7 @@ use serde::{Deserialize, Serialize}; use crate::{ - _impl::arrow::{ - array::ArrayRef, - datatypes::{Field, FieldRef}, - }, + _impl::arrow::{array::ArrayRef, datatypes::FieldRef}, schema::{SchemaLike, TracingOptions}, ArrayBuilder, Deserializer, }; @@ -27,7 +24,7 @@ enum Enum<'a> { TupleVariant(Record, Record), } -fn serialize(fields: &[impl AsRef], items: &I) -> Vec { +fn serialize(fields: &[FieldRef], items: &I) -> Vec { let builder = ArrayBuilder::from_arrow(&fields).unwrap(); items .serialize(crate::Serializer::new(builder)) @@ -37,10 +34,7 @@ fn serialize(fields: &[impl AsRef], items: &I) -> .unwrap() } -fn deserialize<'de, I: Deserialize<'de>>( - fields: &[impl AsRef], - arrays: &'de [ArrayRef], -) -> I { +fn deserialize<'de, I: Deserialize<'de>>(fields: &[FieldRef], arrays: &'de [ArrayRef]) -> I { I::deserialize(Deserializer::<'de>::from_arrow(fields, arrays).unwrap()).unwrap() } From d797a697e08a814a8ad7ee59eb9a4eb08257da96 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 19 Jun 2024 21:57:32 +0200 Subject: [PATCH 004/178] Remove from_arrow_fields / to_arrow_fields in SerdeArrowSchema --- Changes.md | 2 + serde_arrow/benches/groups/impls.rs | 4 +- serde_arrow/src/arrow2_impl/schema.rs | 32 +--------- serde_arrow/src/arrow_impl/schema.rs | 60 +++++-------------- serde_arrow/src/arrow_impl/serialization.rs | 2 +- .../src/test_with_arrow/impls/utils.rs | 3 +- 6 files changed, 24 insertions(+), 79 deletions(-) diff --git a/Changes.md b/Changes.md index 606c01cd..5f866886 100644 --- a/Changes.md +++ b/Changes.md @@ -6,6 +6,8 @@ - Remove `serde_arrow::ArrowBuilder` and `serde_arrow::Arrow2Builder` - Use `impl serde::Serialize` instead of `&(impl serde::Serialize + ?Sized)` - Use `&[FieldRef]` instead of `&[Field]` in arrow APIs +- Remove `from_arrow_fields` / `to_arrow_fields` for `SerdeArrowSchema`, use the + `TryFrom` conversions to convert between fields and `SerdeArrowSchema` ## 0.11.6 diff --git a/serde_arrow/benches/groups/impls.rs b/serde_arrow/benches/groups/impls.rs index 1d5e2a80..c8210479 100644 --- a/serde_arrow/benches/groups/impls.rs +++ b/serde_arrow/benches/groups/impls.rs @@ -13,7 +13,7 @@ macro_rules! define_benchmark { ) => { pub fn benchmark_serialize(c: &mut criterion::Criterion) { use serde_arrow::schema::{SerdeArrowSchema, SchemaLike}; - use serde_arrow::_impl::arrow::datatypes::FieldRef; + use serde_arrow::_impl::{arrow::datatypes::FieldRef, arrow2::datatypes::Field as Arrow2Field}; for n in [$($n),*] { let mut group = c.benchmark_group(format!("{}_serialize({})", stringify!($name), n)); @@ -35,7 +35,7 @@ macro_rules! define_benchmark { .collect::>(); let schema = SerdeArrowSchema::from_samples(&items, Default::default()).unwrap(); let arrow_fields = Vec::::try_from(&schema).unwrap(); - let arrow2_fields = schema.to_arrow2_fields().unwrap(); + let arrow2_fields = Vec::::try_from(&schema).unwrap(); #[allow(unused)] let bench_serde_arrow = true; diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index f5a156b5..5f137994 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -11,32 +11,6 @@ use crate::{ }, }; -/// Support for arrow2 types (*requires one of the `arrow2-*` features*) -impl SerdeArrowSchema { - /// Build a new Schema object from fields - pub fn from_arrow2_fields(fields: &[Field]) -> Result { - Self::try_from(fields) - } - - /// Build a vec of fields from a Schema object - /// - /// ```rust - /// # fn main() -> serde_arrow::_impl::PanicOnError<()> { - /// # use serde_arrow::schema::{SerdeArrowSchema, SchemaLike, TracingOptions}; - /// # #[derive(serde::Deserialize)] - /// # struct Item { a: u32 } - /// # let schema = SerdeArrowSchema::from_type::(TracingOptions::default()).unwrap(); - /// # let fields = - /// schema.to_arrow2_fields()? - /// # ; - /// # Ok(()) - /// # } - /// ``` - pub fn to_arrow2_fields(&self) -> Result> { - Vec::::try_from(self) - } -} - impl TryFrom for Vec { type Error = Error; @@ -72,20 +46,20 @@ impl Sealed for Vec {} /// `arrow2-*` features*) impl SchemaLike for Vec { fn from_value(value: &T) -> Result { - SerdeArrowSchema::from_value(value)?.to_arrow2_fields() + SerdeArrowSchema::from_value(value)?.try_into() } fn from_type<'de, T: serde::Deserialize<'de> + ?Sized>( options: crate::schema::TracingOptions, ) -> Result { - SerdeArrowSchema::from_type::(options)?.to_arrow2_fields() + SerdeArrowSchema::from_type::(options)?.try_into() } fn from_samples( samples: &T, options: crate::schema::TracingOptions, ) -> Result { - SerdeArrowSchema::from_samples(samples, options)?.to_arrow2_fields() + SerdeArrowSchema::from_samples(samples, options)?.try_into() } } diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 0f872e27..5d637400 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -11,37 +11,11 @@ use crate::{ }, }; -/// Support for arrow types (*requires one of the `arrow-*` features*) -impl SerdeArrowSchema { - /// Build a new Schema object from fields - pub fn from_arrow_fields(fields: &[Field]) -> Result { - Self::try_from(fields) - } - - /// Build a vec of fields from a Schema object - /// - /// ```rust - /// # fn main() -> serde_arrow::_impl::PanicOnError<()> { - /// # use serde_arrow::schema::{SerdeArrowSchema, SchemaLike, TracingOptions}; - /// # #[derive(serde::Deserialize)] - /// # struct Item { a: u32 } - /// # let schema = SerdeArrowSchema::from_type::(TracingOptions::default()).unwrap(); - /// # let fields = - /// schema.to_arrow_fields()? - /// # ; - /// # Ok(()) - /// # } - /// ``` - pub fn to_arrow_fields(&self) -> Result> { - Vec::::try_from(self) - } -} - impl TryFrom for Vec { type Error = Error; fn try_from(value: SerdeArrowSchema) -> Result { - value.fields.iter().map(Field::try_from).collect() + (&value).try_into() } } @@ -53,6 +27,14 @@ impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { } } +impl TryFrom for Vec { + type Error = Error; + + fn try_from(value: SerdeArrowSchema) -> Result { + (&value).try_into() + } +} + impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { type Error = Error; @@ -97,20 +79,20 @@ impl Sealed for Vec {} /// `arrow-*` features*) impl SchemaLike for Vec { fn from_value(value: &T) -> Result { - SerdeArrowSchema::from_value(value)?.to_arrow_fields() + SerdeArrowSchema::from_value(value)?.try_into() } fn from_type<'de, T: serde::Deserialize<'de> + ?Sized>( options: crate::schema::TracingOptions, ) -> Result { - SerdeArrowSchema::from_type::(options)?.to_arrow_fields() + SerdeArrowSchema::from_type::(options)?.try_into() } fn from_samples( samples: &T, options: crate::schema::TracingOptions, ) -> Result { - SerdeArrowSchema::from_samples(samples, options)?.to_arrow_fields() + SerdeArrowSchema::from_samples(samples, options)?.try_into() } } @@ -120,32 +102,20 @@ impl Sealed for Vec {} /// `arrow-*` features*) impl SchemaLike for Vec { fn from_value(value: &T) -> Result { - Ok(SerdeArrowSchema::from_value(value)? - .to_arrow_fields()? - .into_iter() - .map(Arc::new) - .collect()) + SerdeArrowSchema::from_value(value)?.try_into() } fn from_type<'de, T: serde::Deserialize<'de> + ?Sized>( options: crate::schema::TracingOptions, ) -> Result { - Ok(SerdeArrowSchema::from_type::(options)? - .to_arrow_fields()? - .into_iter() - .map(Arc::new) - .collect()) + SerdeArrowSchema::from_type::(options)?.try_into() } fn from_samples( samples: &T, options: crate::schema::TracingOptions, ) -> Result { - Ok(SerdeArrowSchema::from_samples(samples, options)? - .to_arrow_fields()? - .into_iter() - .map(Arc::new) - .collect()) + SerdeArrowSchema::from_samples(samples, options)?.try_into() } } diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 0ca01941..e07fe32d 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -42,7 +42,7 @@ impl crate::internal::array_builder::ArrayBuilder { /// `arrow-*` features*) pub fn to_record_batch(&mut self) -> Result { let arrays = self.builder.build_arrow()?; - let fields = self.schema.to_arrow_fields()?; + let fields = Vec::::try_from(&self.schema)?; let schema = Schema::new(fields); Ok(RecordBatch::try_new(Arc::new(schema), arrays)?) } diff --git a/serde_arrow/src/test_with_arrow/impls/utils.rs b/serde_arrow/src/test_with_arrow/impls/utils.rs index 84f1aeb0..4721f8f2 100644 --- a/serde_arrow/src/test_with_arrow/impls/utils.rs +++ b/serde_arrow/src/test_with_arrow/impls/utils.rs @@ -95,8 +95,7 @@ impl Test { pub fn get_arrow2_fields(&self) -> Cow<'_, Vec> { match self.schema.as_ref() { Some(schema) => Cow::Owned( - schema - .to_arrow2_fields() + Vec::::try_from(schema) .expect("Cannot covert schema to arrow fields"), ), None => Cow::Borrowed( From 8525e5f295752cb990e48e900d14cdf2bfd47d14 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 19 Jun 2024 22:00:26 +0200 Subject: [PATCH 005/178] Update docs --- serde_arrow/src/internal/schema/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 469e2224..10981ece 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -281,6 +281,9 @@ pub trait SchemaLike: Sized + Sealed { } /// A collection of fields as understood by `serde_arrow` +/// +/// It can be converted from / to arrow or arrow2 fields. +/// #[derive(Default, Debug, PartialEq, Clone, Serialize)] pub struct SerdeArrowSchema { pub(crate) fields: Vec, From 01b496f8ae9b9ca96a169dc6e17d252cf88991e9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 19 Jun 2024 22:07:10 +0200 Subject: [PATCH 006/178] Remove SerdeArrowSchema::new(), Overwrites::new() --- Changes.md | 1 + serde_arrow/src/internal/schema/mod.rs | 7 ------- serde_arrow/src/internal/schema/test.rs | 10 +++++----- serde_arrow/src/internal/schema/tracing_options.rs | 7 ------- 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/Changes.md b/Changes.md index 5f866886..c4588156 100644 --- a/Changes.md +++ b/Changes.md @@ -8,6 +8,7 @@ - Use `&[FieldRef]` instead of `&[Field]` in arrow APIs - Remove `from_arrow_fields` / `to_arrow_fields` for `SerdeArrowSchema`, use the `TryFrom` conversions to convert between fields and `SerdeArrowSchema` +- Remove `SerdeArrowSchema::new()`, `Overwrites::new()` ## 0.11.6 diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 10981ece..fdec33da 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -289,13 +289,6 @@ pub struct SerdeArrowSchema { pub(crate) fields: Vec, } -impl SerdeArrowSchema { - /// Return a new schema without any fields - pub fn new() -> Self { - Self::default() - } -} - impl Sealed for SerdeArrowSchema {} impl SchemaLike for SerdeArrowSchema { diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index a8e72bca..e976e567 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -14,7 +14,7 @@ impl SerdeArrowSchema { #[test] fn example() { - let schema = SerdeArrowSchema::new() + let schema = SerdeArrowSchema::default() .with_field(GenericField::new("foo", GenericDataType::U8, false)) .with_field(GenericField::new("bar", GenericDataType::Utf8, false)); @@ -30,7 +30,7 @@ fn example() { #[test] fn example_without_wrapper() { - let expected = SerdeArrowSchema::new() + let expected = SerdeArrowSchema::default() .with_field(GenericField::new("foo", GenericDataType::U8, false)) .with_field(GenericField::new("bar", GenericDataType::Utf8, false)); @@ -41,7 +41,7 @@ fn example_without_wrapper() { #[test] fn list() { - let schema = SerdeArrowSchema::new().with_field( + let schema = SerdeArrowSchema::default().with_field( GenericField::new("value", GenericDataType::List, false).with_child(GenericField::new( "element", GenericDataType::I32, @@ -69,7 +69,7 @@ fn doc_schema() { "#; let actual: SerdeArrowSchema = serde_json::from_str(&schema).unwrap(); - let expected = SerdeArrowSchema::new() + let expected = SerdeArrowSchema::default() .with_field(GenericField::new("foo", GenericDataType::U8, false)) .with_field(GenericField::new("bar", GenericDataType::Utf8, false)); @@ -78,7 +78,7 @@ fn doc_schema() { #[test] fn date64_with_strategy() { - let schema = SerdeArrowSchema::new().with_field( + let schema = SerdeArrowSchema::default().with_field( GenericField::new("item", GenericDataType::Date64, false) .with_strategy(Strategy::NaiveStrAsDate64), ); diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index a61797bb..1acad363 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -300,10 +300,3 @@ impl TracingOptions { /// An opaque mapping of field paths to field definitions #[derive(Debug, Clone, Default, PartialEq)] pub struct Overwrites(pub(crate) HashMap); - -impl Overwrites { - /// Create a new empty instance - pub fn new() -> Self { - Self::default() - } -} From 6e268f72715710a5f894082cb4c7a584842c6adc Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 19:05:42 +0200 Subject: [PATCH 007/178] Start to implement a common array abstraction --- .../src/arrow2_impl/deserialization.rs | 4 +- serde_arrow/src/arrow2_impl/schema.rs | 55 ++--- serde_arrow/src/arrow_impl/deserialization.rs | 7 +- serde_arrow/src/arrow_impl/schema.rs | 79 ++++--- serde_arrow/src/arrow_impl/serialization.rs | 97 ++++++--- serde_arrow/src/internal/arrow/array.rs | 196 ++++++++++++++++++ serde_arrow/src/internal/arrow/array_view.rs | 177 ++++++++++++++++ serde_arrow/src/internal/arrow/data_type.rs | 111 ++++++++++ serde_arrow/src/internal/arrow/mod.rs | 14 ++ .../deserialization/date64_deserializer.rs | 14 +- .../deserialization/time_deserializer.rs | 11 +- serde_arrow/src/internal/mod.rs | 1 + serde_arrow/src/internal/schema/data_type.rs | 55 +---- .../src/internal/schema/deserialization.rs | 5 +- serde_arrow/src/internal/schema/mod.rs | 7 +- serde_arrow/src/internal/schema/test.rs | 12 +- .../internal/serialization/array_builder.rs | 7 +- .../internal/serialization/binary_builder.rs | 5 + .../internal/serialization/bool_builder.rs | 13 +- .../internal/serialization/date32_builder.rs | 6 +- .../internal/serialization/date64_builder.rs | 15 +- .../internal/serialization/decimal_builder.rs | 5 + .../serialization/dictionary_utf8_builder.rs | 5 + .../serialization/duration_builder.rs | 13 +- .../fixed_size_binary_builder.rs | 5 + .../serialization/fixed_size_list_builder.rs | 5 + .../internal/serialization/float_builder.rs | 23 +- .../src/internal/serialization/int_builder.rs | 27 ++- .../internal/serialization/list_builder.rs | 5 + .../src/internal/serialization/map_builder.rs | 6 +- .../internal/serialization/null_builder.rs | 9 +- .../serialization/outer_sequence_builder.rs | 10 +- .../internal/serialization/struct_builder.rs | 5 + .../internal/serialization/time_builder.rs | 16 +- .../internal/serialization/union_builder.rs | 5 + .../serialization/unknown_variant_builder.rs | 12 +- .../internal/serialization/utf8_builder.rs | 5 + 37 files changed, 866 insertions(+), 181 deletions(-) create mode 100644 serde_arrow/src/internal/arrow/array.rs create mode 100644 serde_arrow/src/internal/arrow/array_view.rs create mode 100644 serde_arrow/src/internal/arrow/data_type.rs create mode 100644 serde_arrow/src/internal/arrow/mod.rs diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 2ac95074..5e564447 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::TimeUnit, deserialization::{ array_deserializer::ArrayDeserializer, bool_deserializer::BoolDeserializer, @@ -32,7 +33,6 @@ use crate::_impl::arrow2::{ datatypes::{DataType, Field, UnionMode}, types::{f16, NativeType, Offset as ArrowOffset}, }; -use crate::internal::schema::GenericTimeUnit; impl<'de> Deserializer<'de> { /// Build a deserializer from `arrow2` arrays (*requires one of the @@ -215,7 +215,7 @@ pub fn build_date64_deserializer<'a>( Ok(Date64Deserializer::new( as_primitive_values(array)?, get_validity(array), - GenericTimeUnit::Millisecond, + TimeUnit::Millisecond, field.is_utc()?, ) .into()) diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 5f137994..07ab3843 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -1,12 +1,15 @@ use std::collections::HashMap; use crate::{ - _impl::arrow2::datatypes::{DataType, Field, IntegerType, TimeUnit, UnionMode}, + _impl::arrow2::datatypes::{ + DataType, Field, IntegerType, TimeUnit as ArrowTimeUnit, UnionMode, + }, internal::{ + arrow::TimeUnit, error::{error, fail, Error, Result}, schema::{ merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, - GenericField, GenericTimeUnit, SchemaLike, Sealed, SerdeArrowSchema, + GenericField, SchemaLike, Sealed, SerdeArrowSchema, }, }, }; @@ -67,7 +70,7 @@ impl TryFrom<&Field> for GenericField { type Error = Error; fn try_from(field: &Field) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; let metadata = field .metadata @@ -104,26 +107,26 @@ impl TryFrom<&Field> for GenericField { } T::Decimal128(*precision as u8, *scale as i8) } - DataType::Time32(TimeUnit::Second) => T::Time32(U::Second), - DataType::Time32(TimeUnit::Millisecond) => T::Time32(U::Millisecond), + DataType::Time32(ArrowTimeUnit::Second) => T::Time32(U::Second), + DataType::Time32(ArrowTimeUnit::Millisecond) => T::Time32(U::Millisecond), DataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - DataType::Time64(TimeUnit::Microsecond) => T::Time64(U::Microsecond), - DataType::Time64(TimeUnit::Nanosecond) => T::Time64(U::Nanosecond), + DataType::Time64(ArrowTimeUnit::Microsecond) => T::Time64(U::Microsecond), + DataType::Time64(ArrowTimeUnit::Nanosecond) => T::Time64(U::Nanosecond), DataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - DataType::Timestamp(TimeUnit::Second, tz) => T::Timestamp(U::Second, tz.clone()), - DataType::Timestamp(TimeUnit::Millisecond, tz) => { + DataType::Timestamp(ArrowTimeUnit::Second, tz) => T::Timestamp(U::Second, tz.clone()), + DataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => { T::Timestamp(U::Millisecond, tz.clone()) } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { + DataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => { T::Timestamp(U::Microsecond, tz.clone()) } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + DataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => { T::Timestamp(U::Nanosecond, tz.clone()) } - DataType::Duration(TimeUnit::Second) => T::Duration(U::Second), - DataType::Duration(TimeUnit::Millisecond) => T::Duration(U::Millisecond), - DataType::Duration(TimeUnit::Microsecond) => T::Duration(U::Microsecond), - DataType::Duration(TimeUnit::Nanosecond) => T::Duration(U::Nanosecond), + DataType::Duration(ArrowTimeUnit::Second) => T::Duration(U::Second), + DataType::Duration(ArrowTimeUnit::Millisecond) => T::Duration(U::Millisecond), + DataType::Duration(ArrowTimeUnit::Microsecond) => T::Duration(U::Microsecond), + DataType::Duration(ArrowTimeUnit::Nanosecond) => T::Duration(U::Nanosecond), DataType::List(field) => { children.push(GenericField::try_from(field.as_ref())?); T::List @@ -194,7 +197,7 @@ impl TryFrom<&GenericField> for Field { type Error = Error; fn try_from(value: &GenericField) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; let data_type = match &value.data_type { T::Null => DataType::Null, @@ -212,11 +215,11 @@ impl TryFrom<&GenericField> for Field { T::F64 => DataType::Float64, T::Date32 => DataType::Date32, T::Date64 => DataType::Date64, - T::Time32(U::Second) => DataType::Time32(TimeUnit::Second), - T::Time32(U::Millisecond) => DataType::Time32(TimeUnit::Millisecond), + T::Time32(U::Second) => DataType::Time32(ArrowTimeUnit::Second), + T::Time32(U::Millisecond) => DataType::Time32(ArrowTimeUnit::Millisecond), T::Time32(unit) => fail!("Invalid time unit {unit} for Time32"), - T::Time64(U::Microsecond) => DataType::Time64(TimeUnit::Microsecond), - T::Time64(U::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond), + T::Time64(U::Microsecond) => DataType::Time64(ArrowTimeUnit::Microsecond), + T::Time64(U::Nanosecond) => DataType::Time64(ArrowTimeUnit::Nanosecond), T::Time64(unit) => fail!("Invalid time unit {unit} for Time64"), T::Timestamp(unit, tz) => DataType::Timestamp((*unit).into(), tz.clone()), T::Duration(unit) => DataType::Duration((*unit).into()), @@ -306,13 +309,13 @@ impl TryFrom<&GenericField> for Field { } } -impl From for TimeUnit { - fn from(value: GenericTimeUnit) -> Self { +impl From for ArrowTimeUnit { + fn from(value: TimeUnit) -> Self { match value { - GenericTimeUnit::Second => Self::Second, - GenericTimeUnit::Millisecond => Self::Millisecond, - GenericTimeUnit::Microsecond => Self::Microsecond, - GenericTimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Second => Self::Second, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Nanosecond => Self::Nanosecond, } } } diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 788d33ef..2ff608bd 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::TimeUnit, deserialization::{ array_deserializer::ArrayDeserializer, binary_deserializer::BinaryDeserializer, @@ -22,7 +23,7 @@ use crate::internal::{ }, deserializer::Deserializer, error::{fail, Result}, - schema::{GenericDataType, GenericField, GenericTimeUnit}, + schema::{GenericDataType, GenericField}, utils::Offset, }; @@ -124,7 +125,7 @@ pub fn build_array_deserializer<'a>( field: &GenericField, array: &'a dyn Array, ) -> Result> { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; match &field.data_type { T::Null => Ok(NullDeserializer.into()), T::Bool => build_bool_deserializer(field, array), @@ -284,7 +285,7 @@ pub fn build_date64_deserializer<'a>( Ok(Date64Deserializer::new( as_primitive_values::(array)?, get_validity(array), - GenericTimeUnit::Millisecond, + TimeUnit::Millisecond, field.is_utc()?, ) .into()) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 5d637400..242440c1 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -1,12 +1,13 @@ use std::sync::Arc; use crate::{ - _impl::arrow::datatypes::{DataType, Field, FieldRef, TimeUnit, UnionMode}, + _impl::arrow::datatypes::{DataType, Field, FieldRef, TimeUnit as ArrowTimeUnit, UnionMode}, internal::{ + arrow::TimeUnit, error::{error, fail, Error, Result}, schema::{ merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, - GenericField, GenericTimeUnit, SchemaLike, Sealed, SerdeArrowSchema, + GenericField, SchemaLike, Sealed, SerdeArrowSchema, }, }, }; @@ -123,7 +124,7 @@ impl TryFrom<&DataType> for GenericDataType { type Error = Error; fn try_from(value: &DataType) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; match value { DataType::Boolean => Ok(T::Bool), DataType::Null => Ok(T::Null), @@ -143,31 +144,31 @@ impl TryFrom<&DataType> for GenericDataType { DataType::Date32 => Ok(T::Date32), DataType::Date64 => Ok(T::Date64), DataType::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), - DataType::Time32(TimeUnit::Second) => Ok(T::Time32(U::Second)), - DataType::Time32(TimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), + DataType::Time32(ArrowTimeUnit::Second) => Ok(T::Time32(U::Second)), + DataType::Time32(ArrowTimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), DataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - DataType::Time64(TimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), - DataType::Time64(TimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), + DataType::Time64(ArrowTimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), + DataType::Time64(ArrowTimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), DataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - DataType::Timestamp(TimeUnit::Second, tz) => { + DataType::Timestamp(ArrowTimeUnit::Second, tz) => { Ok(T::Timestamp(U::Second, tz.as_ref().map(|s| s.to_string()))) } - DataType::Timestamp(TimeUnit::Millisecond, tz) => Ok(T::Timestamp( + DataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => Ok(T::Timestamp( U::Millisecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Timestamp(TimeUnit::Microsecond, tz) => Ok(T::Timestamp( + DataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => Ok(T::Timestamp( U::Microsecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Timestamp(TimeUnit::Nanosecond, tz) => Ok(T::Timestamp( + DataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => Ok(T::Timestamp( U::Nanosecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Duration(TimeUnit::Second) => Ok(T::Duration(U::Second)), - DataType::Duration(TimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), - DataType::Duration(TimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), - DataType::Duration(TimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), + DataType::Duration(ArrowTimeUnit::Second) => Ok(T::Duration(U::Second)), + DataType::Duration(ArrowTimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), + DataType::Duration(ArrowTimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), + DataType::Duration(ArrowTimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), DataType::Binary => Ok(T::Binary), DataType::LargeBinary => Ok(T::LargeBinary), DataType::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), @@ -253,7 +254,7 @@ impl TryFrom<&GenericField> for Field { type Error = Error; fn try_from(value: &GenericField) -> Result { - use {GenericDataType as T, GenericTimeUnit as U}; + use {GenericDataType as T, TimeUnit as U}; let data_type = match &value.data_type { T::Null => DataType::Null, @@ -354,11 +355,11 @@ impl TryFrom<&GenericField> for Field { DataType::Dictionary(Box::new(key_type), Box::new(val_field.data_type().clone())) } - T::Time32(U::Second) => DataType::Time32(TimeUnit::Second), - T::Time32(U::Millisecond) => DataType::Time32(TimeUnit::Millisecond), + T::Time32(U::Second) => DataType::Time32(ArrowTimeUnit::Second), + T::Time32(U::Millisecond) => DataType::Time32(ArrowTimeUnit::Millisecond), T::Time32(unit) => fail!("invalid time unit {unit} for Time32"), - T::Time64(U::Microsecond) => DataType::Time64(TimeUnit::Microsecond), - T::Time64(U::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond), + T::Time64(U::Microsecond) => DataType::Time64(ArrowTimeUnit::Microsecond), + T::Time64(U::Nanosecond) => DataType::Time64(ArrowTimeUnit::Nanosecond), T::Time64(unit) => fail!("invalid time unit {unit} for Time64"), T::Timestamp(unit, tz) => { DataType::Timestamp((*unit).into(), tz.clone().map(|s| s.into())) @@ -376,13 +377,39 @@ impl TryFrom<&GenericField> for Field { } } -impl From for TimeUnit { - fn from(value: GenericTimeUnit) -> Self { +impl From for ArrowTimeUnit { + fn from(value: TimeUnit) -> Self { match value { - GenericTimeUnit::Second => Self::Second, - GenericTimeUnit::Millisecond => Self::Millisecond, - GenericTimeUnit::Microsecond => Self::Microsecond, - GenericTimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Second => Self::Second, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Nanosecond => Self::Nanosecond, + } + } +} + +impl TryFrom for DataType { + type Error = Error; + + fn try_from(value: crate::internal::arrow::DataType) -> Result { + use {crate::internal::arrow::DataType as DT, DataType as ArrowDT}; + + match value { + DT::Int8 => Ok(ArrowDT::Int8), + DT::Int16 => Ok(ArrowDT::Int16), + DT::Int32 => Ok(ArrowDT::Int32), + DT::Int64 => Ok(ArrowDT::Int64), + DT::UInt8 => Ok(ArrowDT::UInt8), + DT::UInt16 => Ok(ArrowDT::UInt16), + DT::UInt32 => Ok(ArrowDT::UInt32), + DT::UInt64 => Ok(ArrowDT::UInt64), + DT::Float16 => Ok(ArrowDT::Float16), + DT::Float32 => Ok(ArrowDT::Float32), + DT::Float64 => Ok(ArrowDT::Float64), + dt => fail!( + "{} not supported", + crate::internal::arrow::BaseDataTypeDisplay(&dt) + ), } } } diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index e07fe32d..21639855 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -1,6 +1,8 @@ #![allow(missing_docs)] use std::sync::Arc; +use half::f16; + use crate::{ _impl::arrow::{ array::{make_array, Array, ArrayData, ArrayRef, NullArray, RecordBatch}, @@ -10,7 +12,7 @@ use crate::{ }, }, internal::{ - error::{fail, Result}, + error::{fail, Error, Result}, schema::{GenericField, SerdeArrowSchema}, serialization::{ utils::{MutableBitBuffer, MutableOffsetBuffer}, @@ -66,34 +68,22 @@ fn build_array(builder: ArrayBuilder) -> Result { fn build_array_data(builder: ArrayBuilder) -> Result { use {ArrayBuilder as A, DataType as T}; + match builder { - A::Null(builder) => Ok(NullArray::new(builder.count).into_data()), - A::UnknownVariant(_) => Ok(NullArray::new(0).into_data()), - A::Bool(builder) => build_array_data_primitive_with_len( - T::Boolean, - builder.buffer.len(), - builder.buffer.buffer, - builder.validity, - ), - A::I8(builder) => build_array_data_primitive(T::Int8, builder.buffer, builder.validity), - A::I16(builder) => build_array_data_primitive(T::Int16, builder.buffer, builder.validity), - A::I32(builder) => build_array_data_primitive(T::Int32, builder.buffer, builder.validity), - A::I64(builder) => build_array_data_primitive(T::Int64, builder.buffer, builder.validity), - A::U8(builder) => build_array_data_primitive(T::UInt8, builder.buffer, builder.validity), - A::U16(builder) => build_array_data_primitive(T::UInt16, builder.buffer, builder.validity), - A::U32(builder) => build_array_data_primitive(T::UInt32, builder.buffer, builder.validity), - A::U64(builder) => build_array_data_primitive(T::UInt64, builder.buffer, builder.validity), - A::F16(builder) => build_array_data_primitive( - T::Float16, - builder - .buffer - .into_iter() - .map(|v| ::Native::from_bits(v.to_bits())) - .collect(), - builder.validity, - ), - A::F32(builder) => build_array_data_primitive(T::Float32, builder.buffer, builder.validity), - A::F64(builder) => build_array_data_primitive(T::Float64, builder.buffer, builder.validity), + builder @ (A::UnknownVariant(_) + | A::Null(_) + | A::Bool(_) + | A::I8(_) + | A::I16(_) + | A::I32(_) + | A::I64(_) + | A::U8(_) + | A::U16(_) + | A::U32(_) + | A::U64(_) + | A::F16(_) + | A::F32(_) + | A::F64(_)) => builder.into_array().try_into(), A::Date32(builder) => build_array_data_primitive( Field::try_from(&builder.field)?.data_type().clone(), builder.buffer, @@ -260,6 +250,57 @@ fn build_array_data(builder: ArrayBuilder) -> Result { } } +impl TryFrom for ArrayData { + type Error = Error; + + fn try_from(value: crate::internal::arrow::Array) -> Result { + use {crate::internal::arrow::Array as A, DataType as ArrowT}; + type ArrowF16 = ::Native; + + fn f16_to_f16(v: f16) -> ArrowF16 { + ArrowF16::from_bits(v.to_bits()) + } + + match value { + A::Null(arr) => Ok(NullArray::new(arr.len).into_data()), + A::Boolean(arr) => Ok(ArrayData::try_new( + ArrowT::Boolean, + arr.len, + arr.validity.map(|buffer| Buffer::from(buffer)), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?), + A::Int8(arr) => primitive_into_data(ArrowT::Int8, arr), + A::Int16(arr) => primitive_into_data(ArrowT::Int16, arr), + A::Int32(arr) => primitive_into_data(ArrowT::Int32, arr), + A::Int64(arr) => primitive_into_data(ArrowT::Int64, arr), + A::UInt8(arr) => primitive_into_data(ArrowT::UInt8, arr), + A::UInt16(arr) => primitive_into_data(ArrowT::UInt16, arr), + A::UInt32(arr) => primitive_into_data(ArrowT::UInt32, arr), + A::UInt64(arr) => primitive_into_data(ArrowT::UInt64, arr), + A::Float16(arr) => primitive_into_data(ArrowT::Float16, arr.map_values(f16_to_f16)), + A::Float32(arr) => primitive_into_data(ArrowT::Float32, arr), + A::Float64(arr) => primitive_into_data(ArrowT::Float64, arr), + array => fail!("{:?} not implemented", array), + } + } +} + +fn primitive_into_data( + data_type: DataType, + array: crate::internal::arrow::PrimitiveArray, +) -> Result { + Ok(ArrayData::try_new( + data_type, + array.values.len(), + array.validity.map(|buffer| Buffer::from(buffer)), + 0, + vec![ScalarBuffer::from(array.values).into_inner()], + vec![], + )?) +} + fn build_array_data_primitive( data_type: DataType, data: Vec, diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs new file mode 100644 index 00000000..4f979f0e --- /dev/null +++ b/serde_arrow/src/internal/arrow/array.rs @@ -0,0 +1,196 @@ +//! Owned versions of the different array types +use half::f16; + +use crate::internal::arrow::{ + array_view::{ + ArrayView, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, + StructArrayView, Utf8ArrayView, + }, + data_type::TimeUnit, +}; + +use super::array_view::TimeArrayView; + +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum Array { + Null(NullArray), + Boolean(BooleanArray), + Int8(PrimitiveArray), + Int16(PrimitiveArray), + Int32(PrimitiveArray), + Int64(PrimitiveArray), + UInt8(PrimitiveArray), + UInt16(PrimitiveArray), + UInt32(PrimitiveArray), + UInt64(PrimitiveArray), + Float16(PrimitiveArray), + Float32(PrimitiveArray), + Float64(PrimitiveArray), + Date32(PrimitiveArray), + Date64(PrimitiveArray), + Time32(TimeArray), + Time64(TimeArray), + Utf8(Utf8Array), + LargeUtf8(Utf8Array), + Binary(Utf8Array), + LargeBinary(Utf8Array), + Decimal128(PrimitiveArray), + Struct(StructArray), + List(ListArray), + LargeList(ListArray), +} + +impl Array { + pub fn view(&self) -> ArrayView { + match self { + Self::Null(array) => ArrayView::Null(array.view()), + Self::Boolean(array) => ArrayView::Boolean(array.view()), + Self::Int8(array) => ArrayView::Int8(array.view()), + Self::Int16(array) => ArrayView::Int16(array.view()), + Self::Int32(array) => ArrayView::Int32(array.view()), + Self::Int64(array) => ArrayView::Int64(array.view()), + Self::UInt8(array) => ArrayView::UInt8(array.view()), + Self::UInt16(array) => ArrayView::UInt16(array.view()), + Self::UInt32(array) => ArrayView::UInt32(array.view()), + Self::UInt64(array) => ArrayView::UInt64(array.view()), + Self::Float16(array) => ArrayView::Float16(array.view()), + Self::Float32(array) => ArrayView::Float32(array.view()), + Self::Float64(array) => ArrayView::Float64(array.view()), + Self::Date32(array) => ArrayView::Date32(array.view()), + Self::Date64(array) => ArrayView::Date64(array.view()), + Self::Time32(array) => ArrayView::Time32(array.view()), + Self::Time64(array) => ArrayView::Time64(array.view()), + Self::Utf8(array) => ArrayView::Utf8(array.view()), + Self::LargeUtf8(array) => ArrayView::LargeUtf8(array.view()), + Self::Binary(array) => ArrayView::Binary(array.view()), + Self::LargeBinary(array) => ArrayView::LargeBinary(array.view()), + Self::Decimal128(array) => ArrayView::Decimal128(array.view()), + Self::Struct(array) => ArrayView::Struct(array.view()), + Self::List(array) => ArrayView::List(array.view()), + Self::LargeList(array) => ArrayView::LargeList(array.view()), + } + } +} + +#[derive(Clone, Debug)] +pub struct NullArray { + pub len: usize, +} + +impl NullArray { + pub fn view(&self) -> NullArrayView { + NullArrayView { len: self.len } + } +} + +#[derive(Clone, Debug)] +pub struct BooleanArray { + pub len: usize, + pub validity: Option>, + pub values: Vec, +} + +impl BooleanArray { + pub fn view(&self) -> BooleanArrayView { + BooleanArrayView { + len: self.len, + validity: self.validity.as_deref(), + values: &self.values, + } + } +} + +#[derive(Clone, Debug)] +pub struct PrimitiveArray { + pub validity: Option>, + pub values: Vec, +} + +impl PrimitiveArray { + pub fn view(&self) -> PrimitiveArrayView { + PrimitiveArrayView { + validity: self.validity.as_deref(), + values: &self.values, + } + } + + pub fn map_values(self, func: impl Fn(T) -> R) -> PrimitiveArray { + PrimitiveArray { + validity: self.validity, + values: self.values.into_iter().map(func).collect(), + } + } +} + +#[derive(Debug, Clone)] +pub struct TimeArray { + pub unit: TimeUnit, + pub validity: Option>, + pub values: Vec, +} + +impl TimeArray { + pub fn view(&self) -> TimeArrayView { + TimeArrayView { + unit: self.unit, + validity: self.validity.as_deref(), + values: &self.values, + } + } +} + +#[derive(Clone, Debug)] +pub struct StructArray { + pub len: usize, + pub validity: Option>, + pub fields: Vec, +} + +impl StructArray { + pub fn view(&self) -> StructArrayView { + StructArrayView { + len: self.len, + validity: self.validity.as_deref(), + fields: self.fields.iter().map(|f| f.view()).collect(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ListArray { + pub len: usize, + pub validity: Option>, + pub offsets: Vec, + pub element: Box, +} + +impl ListArray { + pub fn view(&self) -> ListArrayView { + ListArrayView { + len: self.len, + validity: self.validity.as_deref(), + offsets: &self.offsets, + element: Box::new(self.element.view()), + } + } +} + +#[derive(Clone, Debug)] +pub struct Utf8Array { + pub len: usize, + pub validity: Option>, + pub offsets: Vec, + pub data: Vec, +} + +impl Utf8Array { + pub fn view(&self) -> Utf8ArrayView { + Utf8ArrayView { + len: self.len, + validity: self.validity.as_deref(), + offsets: &self.offsets, + data: &self.data, + } + } +} diff --git a/serde_arrow/src/internal/arrow/array_view.rs b/serde_arrow/src/internal/arrow/array_view.rs new file mode 100644 index 00000000..40f80bb9 --- /dev/null +++ b/serde_arrow/src/internal/arrow/array_view.rs @@ -0,0 +1,177 @@ +use half::f16; + +use crate::internal::arrow::{ + array::{ + Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, + Utf8Array, + }, + data_type::TimeUnit, +}; + +pub enum ArrayView<'a> { + Null(NullArrayView), + Boolean(BooleanArrayView<'a>), + Int8(PrimitiveArrayView<'a, i8>), + Int16(PrimitiveArrayView<'a, i16>), + Int32(PrimitiveArrayView<'a, i32>), + Int64(PrimitiveArrayView<'a, i64>), + UInt8(PrimitiveArrayView<'a, u8>), + UInt16(PrimitiveArrayView<'a, u16>), + UInt32(PrimitiveArrayView<'a, u32>), + UInt64(PrimitiveArrayView<'a, u64>), + Float16(PrimitiveArrayView<'a, f16>), + Float32(PrimitiveArrayView<'a, f32>), + Float64(PrimitiveArrayView<'a, f64>), + Date32(PrimitiveArrayView<'a, i32>), + Date64(PrimitiveArrayView<'a, i64>), + Time32(TimeArrayView<'a, i32>), + Time64(TimeArrayView<'a, i64>), + Utf8(Utf8ArrayView<'a, i32>), + LargeUtf8(Utf8ArrayView<'a, i64>), + Binary(Utf8ArrayView<'a, i32>), + LargeBinary(Utf8ArrayView<'a, i64>), + Decimal128(PrimitiveArrayView<'a, i128>), + Struct(StructArrayView<'a>), + List(ListArrayView<'a, i32>), + LargeList(ListArrayView<'a, i64>), +} + +impl<'a> ArrayView<'a> { + pub fn array(&self) -> Array { + match self { + Self::Null(view) => Array::Null(view.array()), + Self::Boolean(view) => Array::Boolean(view.array()), + Self::Int8(view) => Array::Int8(view.array()), + Self::Int16(view) => Array::Int16(view.array()), + Self::Int32(view) => Array::Int32(view.array()), + Self::Int64(view) => Array::Int64(view.array()), + Self::UInt8(view) => Array::UInt8(view.array()), + Self::UInt16(view) => Array::UInt16(view.array()), + Self::UInt32(view) => Array::UInt32(view.array()), + Self::UInt64(view) => Array::UInt64(view.array()), + Self::Float16(view) => Array::Float16(view.array()), + Self::Float32(view) => Array::Float32(view.array()), + Self::Float64(view) => Array::Float64(view.array()), + Self::Date32(view) => Array::Date32(view.array()), + Self::Date64(view) => Array::Date64(view.array()), + Self::Time32(view) => Array::Time32(view.array()), + Self::Time64(view) => Array::Time64(view.array()), + Self::Utf8(view) => Array::Utf8(view.array()), + Self::Binary(view) => Array::Binary(view.array()), + Self::LargeBinary(view) => Array::LargeBinary(view.array()), + Self::LargeUtf8(view) => Array::LargeUtf8(view.array()), + Self::Decimal128(view) => Array::Decimal128(view.array()), + Self::Struct(view) => Array::Struct(view.array()), + Self::List(view) => Array::List(view.array()), + Self::LargeList(view) => Array::LargeList(view.array()), + } + } +} + +pub struct NullArrayView { + pub len: usize, +} + +impl NullArrayView { + pub fn array(&self) -> NullArray { + NullArray { len: self.len } + } +} + +pub struct BooleanArrayView<'a> { + pub len: usize, + pub validity: Option<&'a [u8]>, + pub values: &'a [u8], +} + +impl<'a> BooleanArrayView<'a> { + pub fn array(&self) -> BooleanArray { + BooleanArray { + len: self.len, + validity: self.validity.map(<[_]>::to_vec), + values: self.values.to_owned(), + } + } +} + +pub struct PrimitiveArrayView<'a, T> { + pub validity: Option<&'a [u8]>, + pub values: &'a [T], +} + +impl<'a, T: Clone> PrimitiveArrayView<'a, T> { + pub fn array(&self) -> PrimitiveArray { + PrimitiveArray { + validity: self.validity.map(<[_]>::to_vec), + values: self.values.to_vec(), + } + } +} + +pub struct TimeArrayView<'a, T> { + pub unit: TimeUnit, + pub validity: Option<&'a [u8]>, + pub values: &'a [T], +} + +impl<'a, T: Clone> TimeArrayView<'a, T> { + pub fn array(&self) -> TimeArray { + TimeArray { + unit: self.unit, + validity: self.validity.map(<[_]>::to_vec), + values: self.values.to_vec(), + } + } +} + +pub struct StructArrayView<'a> { + pub len: usize, + pub validity: Option<&'a [u8]>, + pub fields: Vec>, +} + +impl<'a> StructArrayView<'a> { + pub fn array(&self) -> StructArray { + StructArray { + len: self.len, + validity: self.validity.map(<[_]>::to_vec), + fields: self.fields.iter().map(|f| f.array()).collect(), + } + } +} + +pub struct ListArrayView<'a, O> { + pub len: usize, + pub validity: Option<&'a [u8]>, + pub offsets: &'a [O], + pub element: Box>, +} + +impl<'a, O: Clone> ListArrayView<'a, O> { + pub fn array(&self) -> ListArray { + ListArray { + len: self.len, + validity: self.validity.map(<[_]>::to_vec), + offsets: self.offsets.to_vec(), + element: Box::new(self.element.array()), + } + } +} + +pub struct Utf8ArrayView<'a, O> { + pub len: usize, + pub validity: Option<&'a [u8]>, + pub offsets: &'a [O], + pub data: &'a [u8], +} + +impl<'a, O: Clone> Utf8ArrayView<'a, O> { + pub fn array(&self) -> Utf8Array { + Utf8Array { + len: self.len, + validity: self.validity.map(<[_]>::to_vec), + offsets: self.offsets.to_vec(), + data: self.data.to_vec(), + } + } +} diff --git a/serde_arrow/src/internal/arrow/data_type.rs b/serde_arrow/src/internal/arrow/data_type.rs new file mode 100644 index 00000000..a732e466 --- /dev/null +++ b/serde_arrow/src/internal/arrow/data_type.rs @@ -0,0 +1,111 @@ +use std::{collections::HashMap, sync::Arc}; + +use serde::{Deserialize, Serialize}; + +use crate::internal::error::{fail, Error, Result}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Field { + pub name: String, + pub data_type: DataType, + pub metadata: HashMap, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum DataType { + Null, + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float16, + Float32, + Float64, + Utf8, + LargeUtf8, + Binary, + LargeBinary, + Date32, + Date64, + Timestamp(TimeUnit, Option>), + Time32(TimeUnit), + Time64(TimeUnit), + Decimal128, + Struct(Vec), + List(Box), + LargeList(Box), +} + +pub struct BaseDataTypeDisplay<'a>(pub &'a DataType); + +impl<'a> std::fmt::Display for BaseDataTypeDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + DataType::Null => write!(f, "Null"), + DataType::Boolean => write!(f, "Boolean"), + DataType::Int8 => write!(f, "Int8"), + DataType::Int16 => write!(f, "Int16"), + DataType::Int32 => write!(f, "Int32"), + DataType::Int64 => write!(f, "Int64"), + DataType::UInt8 => write!(f, "UInt8"), + DataType::UInt16 => write!(f, "UInt16"), + DataType::UInt32 => write!(f, "UInt32"), + DataType::UInt64 => write!(f, "UInt64"), + DataType::Float16 => write!(f, "Float16"), + DataType::Float32 => write!(f, "Float32"), + DataType::Float64 => write!(f, "Float64"), + DataType::Utf8 => write!(f, "Utf8"), + DataType::LargeUtf8 => write!(f, "LargeUtf8"), + DataType::Binary => write!(f, "Binary"), + DataType::LargeBinary => write!(f, "LargeBinary"), + DataType::Date32 => write!(f, "Date32"), + DataType::Date64 => write!(f, "Date64"), + DataType::Timestamp(_, _) => write!(f, "Timestamp"), + DataType::Time32(_) => write!(f, "Time32"), + DataType::Time64(_) => write!(f, "Time64"), + DataType::Decimal128 => write!(f, "Decimal128"), + DataType::Struct(_) => write!(f, "Struct"), + DataType::List(_) => write!(f, "List"), + DataType::LargeList(_) => write!(f, "LargeList"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] +pub enum TimeUnit { + Second, + Millisecond, + Microsecond, + Nanosecond, +} + +impl std::fmt::Display for TimeUnit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TimeUnit::Second => write!(f, "Second"), + TimeUnit::Millisecond => write!(f, "Millisecond"), + TimeUnit::Microsecond => write!(f, "Microsecond"), + TimeUnit::Nanosecond => write!(f, "Nanosecond"), + } + } +} + +impl std::str::FromStr for TimeUnit { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "Second" => Ok(Self::Second), + "Millisecond" => Ok(Self::Millisecond), + "Microsecond" => Ok(Self::Microsecond), + "Nanosecond" => Ok(Self::Nanosecond), + s => fail!("Invalid time unit {s}"), + } + } +} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs new file mode 100644 index 00000000..086d3e08 --- /dev/null +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -0,0 +1,14 @@ +//! A common arrow abstraction to simplify conversion between different arrow +//! implementations +mod array; +mod array_view; +mod data_type; + +pub use array::{ + Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, +}; +pub use array_view::{ + ArrayView, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, + Utf8ArrayView, +}; +pub use data_type::{BaseDataTypeDisplay, DataType, Field, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index bc1b6a11..bd905560 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -2,8 +2,8 @@ use chrono::DateTime; use serde::de::Visitor; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, - schema::GenericTimeUnit, utils::Mut, }; @@ -12,13 +12,13 @@ use super::{ utils::{ArrayBufferIterator, BitBuffer}, }; -pub struct Date64Deserializer<'a>(ArrayBufferIterator<'a, i64>, GenericTimeUnit, bool); +pub struct Date64Deserializer<'a>(ArrayBufferIterator<'a, i64>, TimeUnit, bool); impl<'a> Date64Deserializer<'a> { pub fn new( buffer: &'a [i64], validity: Option>, - unit: GenericTimeUnit, + unit: TimeUnit, is_utc: bool, ) -> Self { Self(ArrayBufferIterator::new(buffer, validity), unit, is_utc) @@ -26,10 +26,10 @@ impl<'a> Date64Deserializer<'a> { pub fn get_string_repr(&self, ts: i64) -> Result { let Some(date_time) = (match self.1 { - GenericTimeUnit::Second => DateTime::from_timestamp(ts, 0), - GenericTimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), - GenericTimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), - GenericTimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), + TimeUnit::Second => DateTime::from_timestamp(ts, 0), + TimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), + TimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), + TimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), }) else { fail!("Unsupported timestamp value: {ts}"); }; diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index 05f9775e..9c755fbb 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -2,8 +2,8 @@ use chrono::NaiveTime; use serde::de::Visitor; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, - schema::GenericTimeUnit, utils::Mut, }; @@ -16,8 +16,13 @@ use super::{ pub struct TimeDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>, i64, i64); impl<'a, T: Integer> TimeDeserializer<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>, unit: GenericTimeUnit) -> Self { - let (seconds_factor, nanoseconds_factor) = unit.get_factors(); + pub fn new(buffer: &'a [T], validity: Option>, unit: TimeUnit) -> Self { + let (seconds_factor, nanoseconds_factor) = match unit { + TimeUnit::Nanosecond => (1_000_000_000, 1), + TimeUnit::Microsecond => (1_000_000, 1_000), + TimeUnit::Millisecond => (1_000, 1_000_000), + TimeUnit::Second => (1, 1_000_000_000), + }; Self( ArrayBufferIterator::new(buffer, validity), diff --git a/serde_arrow/src/internal/mod.rs b/serde_arrow/src/internal/mod.rs index 47a1c1d6..1c7188d6 100644 --- a/serde_arrow/src/internal/mod.rs +++ b/serde_arrow/src/internal/mod.rs @@ -1,4 +1,5 @@ pub mod array_builder; +pub mod arrow; pub mod deserialization; pub mod deserializer; pub mod error; diff --git a/serde_arrow/src/internal/schema/data_type.rs b/serde_arrow/src/internal/schema/data_type.rs index bfd70025..0e9e9aa8 100644 --- a/serde_arrow/src/internal/schema/data_type.rs +++ b/serde_arrow/src/internal/schema/data_type.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use crate::internal::{ + arrow::TimeUnit, error::{fail, Error, Result}, utils::dsl::Term, }; @@ -25,9 +26,9 @@ pub enum GenericDataType { LargeUtf8, Date32, Date64, - Time32(GenericTimeUnit), - Time64(GenericTimeUnit), - Duration(GenericTimeUnit), + Time32(TimeUnit), + Time64(TimeUnit), + Duration(TimeUnit), Struct, List, LargeList, @@ -38,7 +39,7 @@ pub enum GenericDataType { Union, Map, Dictionary, - Timestamp(GenericTimeUnit, Option), + Timestamp(TimeUnit, Option), Decimal128(u8, i8), } @@ -123,7 +124,7 @@ impl std::str::FromStr for GenericDataType { ("Map", []) => T::Map, ("Dictionary", []) => T::Dictionary, ("Timestamp", [unit, timezone]) => { - let unit: GenericTimeUnit = unit.as_ident()?.parse()?; + let unit: TimeUnit = unit.as_ident()?.parse()?; let timezone = timezone .as_option()? .map(|term| term.as_string()) @@ -158,47 +159,3 @@ impl From for GenericDataTypeString { Self(value.to_string()) } } - -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] -pub enum GenericTimeUnit { - Second, - Millisecond, - Microsecond, - Nanosecond, -} - -impl GenericTimeUnit { - pub fn get_factors(&self) -> (i64, i64) { - match self { - GenericTimeUnit::Nanosecond => (1_000_000_000, 1), - GenericTimeUnit::Microsecond => (1_000_000, 1_000), - GenericTimeUnit::Millisecond => (1_000, 1_000_000), - GenericTimeUnit::Second => (1, 1_000_000_000), - } - } -} - -impl std::fmt::Display for GenericTimeUnit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - GenericTimeUnit::Second => write!(f, "Second"), - GenericTimeUnit::Millisecond => write!(f, "Millisecond"), - GenericTimeUnit::Microsecond => write!(f, "Microsecond"), - GenericTimeUnit::Nanosecond => write!(f, "Nanosecond"), - } - } -} - -impl std::str::FromStr for GenericTimeUnit { - type Err = Error; - - fn from_str(s: &str) -> Result { - match s { - "Second" => Ok(Self::Second), - "Millisecond" => Ok(Self::Millisecond), - "Microsecond" => Ok(Self::Microsecond), - "Nanosecond" => Ok(Self::Nanosecond), - s => fail!("Invalid time unit {s}"), - } - } -} diff --git a/serde_arrow/src/internal/schema/deserialization.rs b/serde_arrow/src/internal/schema/deserialization.rs index 1a95783c..cb3e909f 100644 --- a/serde_arrow/src/internal/schema/deserialization.rs +++ b/serde_arrow/src/internal/schema/deserialization.rs @@ -6,10 +6,11 @@ use std::{collections::HashMap, str::FromStr}; use serde::{de::Visitor, Deserialize}; use crate::internal::{ + arrow::TimeUnit, error::{fail, Error, Result}, schema::{ merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, GenericField, - GenericTimeUnit, SerdeArrowSchema, Strategy, + SerdeArrowSchema, Strategy, }, }; @@ -40,7 +41,7 @@ pub enum ArrowTimeUnit { Nanosecond, } -impl From for GenericTimeUnit { +impl From for TimeUnit { fn from(value: ArrowTimeUnit) -> Self { match value { ArrowTimeUnit::Second => Self::Second, diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index fdec33da..3af0ea09 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -14,13 +14,14 @@ mod test; use std::collections::HashMap; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, utils::value, }; use serde::{Deserialize, Serialize}; -pub use data_type::{GenericDataType, GenericTimeUnit}; +pub use data_type::GenericDataType; pub use strategy::{ merge_strategy_with_metadata, split_strategy_from_metadata, Strategy, STRATEGY_KEY, }; @@ -492,7 +493,7 @@ impl GenericField { } if !matches!( self.data_type, - GenericDataType::Time32(GenericTimeUnit::Second | GenericTimeUnit::Millisecond) + GenericDataType::Time32(TimeUnit::Second | TimeUnit::Millisecond) ) { fail!("Time32 field must have Second or Millisecond unit"); } @@ -511,7 +512,7 @@ impl GenericField { } if !matches!( self.data_type, - GenericDataType::Time64(GenericTimeUnit::Microsecond | GenericTimeUnit::Nanosecond) + GenericDataType::Time64(TimeUnit::Microsecond | TimeUnit::Nanosecond) ) { fail!("Time64 field must have Microsecond or Nanosecond unit"); } diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index e976e567..8e19822c 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -1,6 +1,7 @@ use serde_json::json; use crate::internal::{ + arrow::TimeUnit, schema::{GenericDataType, GenericField, SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, testing::{assert_error, hash_map}, }; @@ -104,7 +105,7 @@ fn date64_with_strategy() { #[test] fn timestamp_second_serialization() { - let dt = super::GenericDataType::Timestamp(super::GenericTimeUnit::Second, None); + let dt = super::GenericDataType::Timestamp(TimeUnit::Second, None); let s = serde_json::to_string(&dt).unwrap(); assert_eq!(s, r#""Timestamp(Second, None)""#); @@ -115,10 +116,7 @@ fn timestamp_second_serialization() { #[test] fn timestamp_second_utc_serialization() { - let dt = super::GenericDataType::Timestamp( - super::GenericTimeUnit::Second, - Some(String::from("Utc")), - ); + let dt = super::GenericDataType::Timestamp(TimeUnit::Second, Some(String::from("Utc"))); let s = serde_json::to_string(&dt).unwrap(); assert_eq!(s, r#""Timestamp(Second, Some(\"Utc\"))""#); @@ -129,7 +127,7 @@ fn timestamp_second_utc_serialization() { #[test] fn test_date32() { - use super::GenericDataType as DT; + use GenericDataType as DT; assert_eq!(DT::Date32.to_string(), "Date32"); assert_eq!("Date32".parse::
().unwrap(), DT::Date32); @@ -137,7 +135,7 @@ fn test_date32() { #[test] fn time64_data_type_format() { - use super::{GenericDataType as DT, GenericTimeUnit as TU}; + use {GenericDataType as DT, TimeUnit as TU}; for (dt, s) in [ (DT::Time64(TU::Microsecond), "Time64(Microsecond)"), diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 68e74e14..faa3a6a3 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -1,7 +1,7 @@ use half::f16; use serde::Serialize; -use crate::internal::error::Result; +use crate::internal::{arrow::Array, error::Result}; use super::{ binary_builder::BinaryBuilder, bool_builder::BoolBuilder, date32_builder::Date32Builder, @@ -131,10 +131,15 @@ impl ArrayBuilder { pub fn is_nullable(&self) -> bool { dispatch!(self, Self(builder) => builder.is_nullable()) } + + pub fn into_array(self) -> Array { + dispatch!(self, Self(builder) => builder.into_array()) + } } impl ArrayBuilder { /// Take the contained array builder, while leaving structure intact + // TODO: use ArrayBuilder as return type for the impls and use dispatch here pub fn take(&mut self) -> ArrayBuilder { match self { Self::Null(builder) => Self::Null(builder.take()), diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index d1dcac33..49bdcdff 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::Array, error::Result, utils::{Mut, Offset}, }; @@ -37,6 +38,10 @@ impl BinaryBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl BinaryBuilder { diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 805e76a2..0670a637 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -1,4 +1,7 @@ -use crate::internal::error::Result; +use crate::internal::{ + arrow::{Array, BooleanArray}, + error::Result, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -26,6 +29,14 @@ impl BoolBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + Array::Boolean(BooleanArray { + len: self.buffer.len, + validity: self.validity.map(|v| v.buffer), + values: self.buffer.buffer, + }) + } } impl SimpleSerializer for BoolBuilder { diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 3eee20f2..e0e8d9b0 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -1,6 +1,6 @@ use chrono::{NaiveDate, NaiveDateTime}; -use crate::internal::{error::Result, schema::GenericField}; +use crate::internal::{arrow::Array, error::Result, schema::GenericField}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -31,6 +31,10 @@ impl Date32Builder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for Date32Builder { diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index fe17903a..9952cd68 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,6 +1,7 @@ use crate::internal::{ + arrow::{Array, TimeUnit}, error::{Error, Result}, - schema::{GenericDataType, GenericField, GenericTimeUnit}, + schema::{GenericDataType, GenericField}, }; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -35,6 +36,10 @@ impl Date64Builder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for Date64Builder { @@ -64,14 +69,14 @@ impl SimpleSerializer for Date64Builder { }; let timestamp = match self.field.data_type { - GenericDataType::Timestamp(GenericTimeUnit::Nanosecond, _) => { + GenericDataType::Timestamp(TimeUnit::Nanosecond, _) => { date_time .timestamp_nanos_opt() .ok_or_else(|| Error::custom(format!("Timestamp '{v}' cannot be converted to nanoseconds. The dates that can be represented as nanoseconds are between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804.")))? }, - GenericDataType::Timestamp(GenericTimeUnit::Microsecond, _) => date_time.timestamp_micros(), - GenericDataType::Timestamp(GenericTimeUnit::Millisecond, _) => date_time.timestamp_millis(), - GenericDataType::Timestamp(GenericTimeUnit::Second, _) => date_time.timestamp(), + GenericDataType::Timestamp(TimeUnit::Microsecond, _) => date_time.timestamp_micros(), + GenericDataType::Timestamp(TimeUnit::Millisecond, _) => date_time.timestamp_millis(), + GenericDataType::Timestamp(TimeUnit::Second, _) => date_time.timestamp(), _ => date_time.timestamp_millis(), }; diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 830e893f..e3444cca 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::Array, error::Result, utils::decimal::{self, DecimalParser}, }; @@ -44,6 +45,10 @@ impl DecimalBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for DecimalBuilder { diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index a3c64ae5..b7f068d3 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use serde::Serialize; use crate::internal::{ + arrow::Array, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -40,6 +41,10 @@ impl DictionaryUtf8Builder { pub fn is_nullable(&self) -> bool { self.indices.is_nullable() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for DictionaryUtf8Builder { diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index dc807ba8..4514fd45 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,16 +1,19 @@ -use crate::internal::{error::Result, schema::GenericTimeUnit}; +use crate::internal::{ + arrow::{Array, TimeUnit}, + error::Result, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; #[derive(Debug, Clone)] pub struct DurationBuilder { - pub unit: GenericTimeUnit, + pub unit: TimeUnit, pub validity: Option, pub buffer: Vec, } impl DurationBuilder { - pub fn new(unit: GenericTimeUnit, is_nullable: bool) -> Self { + pub fn new(unit: TimeUnit, is_nullable: bool) -> Self { Self { unit, validity: is_nullable.then(MutableBitBuffer::default), @@ -29,6 +32,10 @@ impl DurationBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for DurationBuilder { diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 1ead7495..f8960459 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::Array, error::{fail, Result}, utils::Mut, }; @@ -41,6 +42,10 @@ impl FixedSizeBinaryBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl FixedSizeBinaryBuilder { diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 22d1ee49..157a8ba6 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::Array, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -48,6 +49,10 @@ impl FixedSizeListBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl FixedSizeListBuilder { diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index a07cf66f..1cc9007f 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -1,6 +1,10 @@ use half::f16; -use crate::internal::{error::Result, utils::Mut}; +use crate::internal::{ + arrow::{Array, DataType, PrimitiveArray}, + error::Result, + utils::Mut, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -36,6 +40,23 @@ impl FloatBuilder { } } +macro_rules! impl_into_array { + ($ty:ty, $var:ident) => { + impl FloatBuilder<$ty> { + pub fn into_array(self) -> Array { + Array::$var(PrimitiveArray { + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + }) + } + } + }; +} + +impl_into_array!(f16, Float16); +impl_into_array!(f32, Float32); +impl_into_array!(f64, Float64); + impl SimpleSerializer for FloatBuilder { fn name(&self) -> &str { "FloatBuilder" diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 57516d3a..bafa0aec 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -1,4 +1,7 @@ -use crate::internal::error::{Error, Result}; +use crate::internal::{ + arrow::{Array, DataType, PrimitiveArray}, + error::{Error, Result}, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -28,6 +31,28 @@ impl IntBuilder { } } +macro_rules! impl_into_array { + ($ty:ty, $var:ident) => { + impl IntBuilder<$ty> { + pub fn into_array(self) -> Array { + Array::$var(PrimitiveArray { + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + }) + } + } + }; +} + +impl_into_array!(i8, Int8); +impl_into_array!(i16, Int16); +impl_into_array!(i32, Int32); +impl_into_array!(i64, Int64); +impl_into_array!(u8, UInt8); +impl_into_array!(u16, UInt16); +impl_into_array!(u32, UInt32); +impl_into_array!(u64, UInt64); + impl SimpleSerializer for IntBuilder where I: Default diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index ec112b47..c981947b 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -1,6 +1,7 @@ use serde::Serialize; use crate::internal::{ + arrow::Array, error::Result, schema::GenericField, utils::{Mut, Offset}, @@ -45,6 +46,10 @@ impl ListBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl ListBuilder { diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index bae4b05f..e950553a 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -1,6 +1,6 @@ use serde::Serialize; -use crate::internal::{error::Result, schema::GenericField}; +use crate::internal::{arrow::Array, error::Result, schema::GenericField}; use super::{ array_builder::ArrayBuilder, @@ -40,6 +40,10 @@ impl MapBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for MapBuilder { diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index 5a7e3122..f574250d 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -1,4 +1,7 @@ -use crate::Result; +use crate::internal::{ + arrow::{Array, NullArray}, + error::Result, +}; use super::utils::SimpleSerializer; @@ -21,6 +24,10 @@ impl NullBuilder { pub fn is_nullable(&self) -> bool { true } + + pub fn into_array(self) -> Array { + Array::Null(NullArray { len: self.count }) + } } impl SimpleSerializer for NullBuilder { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index e815c534..98e66ab5 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -1,8 +1,9 @@ use serde::Serialize; use crate::internal::{ + arrow::TimeUnit, error::{fail, Result}, - schema::{GenericDataType, GenericField, GenericTimeUnit, SerdeArrowSchema, Strategy}, + schema::{GenericDataType, GenericField, SerdeArrowSchema, Strategy}, serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, @@ -78,16 +79,13 @@ impl OuterSequenceBuilder { Some(tz) => fail!("Timezone {tz} is not supported"), }, T::Time32(unit) => { - if !matches!(unit, GenericTimeUnit::Second | GenericTimeUnit::Millisecond) { + if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { fail!("Only timestamps with second or millisecond unit are supported"); } A::Time32(TimeBuilder::new(field.clone(), field.nullable, *unit)) } T::Time64(unit) => { - if !matches!( - unit, - GenericTimeUnit::Nanosecond | GenericTimeUnit::Microsecond - ) { + if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { fail!("Only timestamps with nanosecond or microsecond unit are supported"); } A::Time64(TimeBuilder::new(field.clone(), field.nullable, *unit)) diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index d48bb3eb..8b12b1b2 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use serde::Serialize; use crate::internal::{ + arrow::Array, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -81,6 +82,10 @@ impl StructBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl StructBuilder { diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 198a519a..35cc84d9 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -1,8 +1,9 @@ use chrono::Timelike; use crate::internal::{ + arrow::{Array, TimeUnit}, error::{Error, Result}, - schema::{GenericField, GenericTimeUnit}, + schema::GenericField, }; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -17,8 +18,13 @@ pub struct TimeBuilder { } impl TimeBuilder { - pub fn new(field: GenericField, nullable: bool, unit: GenericTimeUnit) -> Self { - let (seconds_factor, nanoseconds_factor) = unit.get_factors(); + pub fn new(field: GenericField, nullable: bool, unit: TimeUnit) -> Self { + let (seconds_factor, nanoseconds_factor) = match unit { + TimeUnit::Nanosecond => (1_000_000_000, 1), + TimeUnit::Microsecond => (1_000_000, 1_000), + TimeUnit::Millisecond => (1_000, 1_000_000), + TimeUnit::Second => (1, 1_000_000_000), + }; Self { field, @@ -42,6 +48,10 @@ impl TimeBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for TimeBuilder diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index f3885ad6..564f1fd5 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::Array, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -39,6 +40,10 @@ impl UnionBuilder { pub fn is_nullable(&self) -> bool { false } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl UnionBuilder { diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 4ae7e653..35e895e2 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -1,6 +1,12 @@ use serde::Serialize; -use crate::{internal::error::fail, Result}; +use crate::{ + internal::{ + arrow::{Array, NullArray}, + error::fail, + }, + Result, +}; use super::{utils::SimpleSerializer, ArrayBuilder}; @@ -15,6 +21,10 @@ impl UnknownVariantBuilder { pub fn is_nullable(&self) -> bool { false } + + pub fn into_array(self) -> Array { + Array::Null(NullArray { len: 0 }) + } } impl SimpleSerializer for UnknownVariantBuilder { diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 3940df16..f7eb710a 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,4 +1,5 @@ use crate::internal::{ + arrow::Array, error::{fail, Result}, utils::Offset, }; @@ -34,6 +35,10 @@ impl Utf8Builder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } + + pub fn into_array(self) -> Array { + unimplemented!() + } } impl SimpleSerializer for Utf8Builder { From 3dd78286ba9c8f468fcf81f0b74c4f10bc76f72c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 19:21:48 +0200 Subject: [PATCH 008/178] No longer check all feature combinations per default --- x.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/x.py b/x.py index 9a106aa1..cbcfc488 100644 --- a/x.py +++ b/x.py @@ -169,13 +169,13 @@ def format(): @cmd(help="Run the linters") -@arg("--fast", action="store_true") -def check(fast=False): +@arg("--all", action="store_true") +def check(all=False): check_cargo_toml() _sh(f"cargo check --features {default_features}") _sh(f"cargo clippy --features {default_features}") - if not fast: + if all: for arrow2_feature in (*all_arrow2_features, *all_arrow_features): _sh(f"cargo check --features {arrow2_feature}") From f0fce3a59aa2a56be2b237cab3e1a36e13e4806f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 19:22:01 +0200 Subject: [PATCH 009/178] Address warnings / clippy --- serde_arrow/src/arrow_impl/serialization.rs | 4 ++-- serde_arrow/src/internal/arrow/array.rs | 1 + serde_arrow/src/internal/arrow/mod.rs | 3 +++ serde_arrow/src/internal/serialization/float_builder.rs | 2 +- serde_arrow/src/internal/serialization/int_builder.rs | 2 +- 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 21639855..81125592 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -266,7 +266,7 @@ impl TryFrom for ArrayData { A::Boolean(arr) => Ok(ArrayData::try_new( ArrowT::Boolean, arr.len, - arr.validity.map(|buffer| Buffer::from(buffer)), + arr.validity.map(Buffer::from), 0, vec![ScalarBuffer::from(arr.values).into_inner()], vec![], @@ -294,7 +294,7 @@ fn primitive_into_data( Ok(ArrayData::try_new( data_type, array.values.len(), - array.validity.map(|buffer| Buffer::from(buffer)), + array.validity.map(Buffer::from), 0, vec![ScalarBuffer::from(array.values).into_inner()], vec![], diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 4f979f0e..9439761b 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -86,6 +86,7 @@ impl NullArray { #[derive(Clone, Debug)] pub struct BooleanArray { + // Note: len is required to know how many bits of values are used pub len: usize, pub validity: Option>, pub values: Vec, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 086d3e08..218182c8 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -4,11 +4,14 @@ mod array; mod array_view; mod data_type; +#[allow(unused)] pub use array::{ Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, }; +#[allow(unused)] pub use array_view::{ ArrayView, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, Utf8ArrayView, }; +#[allow(unused)] pub use data_type::{BaseDataTypeDisplay, DataType, Field, TimeUnit}; diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 1cc9007f..efd0bf66 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -1,7 +1,7 @@ use half::f16; use crate::internal::{ - arrow::{Array, DataType, PrimitiveArray}, + arrow::{Array, PrimitiveArray}, error::Result, utils::Mut, }; diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index bafa0aec..89d4d2f4 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -1,5 +1,5 @@ use crate::internal::{ - arrow::{Array, DataType, PrimitiveArray}, + arrow::{Array, PrimitiveArray}, error::{Error, Result}, }; From c5efc8193e6d561c2a527dfa24f633c17076a9e9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 19:52:45 +0200 Subject: [PATCH 010/178] Remove unused View <-> Array conversions, add offset to bitbuffer views --- serde_arrow/src/internal/arrow/array.rs | 108 +---------------- serde_arrow/src/internal/arrow/array_view.rs | 119 ++----------------- serde_arrow/src/internal/arrow/mod.rs | 4 +- 3 files changed, 13 insertions(+), 218 deletions(-) diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 9439761b..af8ca95a 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -1,15 +1,7 @@ //! Owned versions of the different array types use half::f16; -use crate::internal::arrow::{ - array_view::{ - ArrayView, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, - StructArrayView, Utf8ArrayView, - }, - data_type::TimeUnit, -}; - -use super::array_view::TimeArrayView; +use crate::internal::arrow::data_type::TimeUnit; #[derive(Clone, Debug)] #[non_exhaustive] @@ -41,49 +33,11 @@ pub enum Array { LargeList(ListArray), } -impl Array { - pub fn view(&self) -> ArrayView { - match self { - Self::Null(array) => ArrayView::Null(array.view()), - Self::Boolean(array) => ArrayView::Boolean(array.view()), - Self::Int8(array) => ArrayView::Int8(array.view()), - Self::Int16(array) => ArrayView::Int16(array.view()), - Self::Int32(array) => ArrayView::Int32(array.view()), - Self::Int64(array) => ArrayView::Int64(array.view()), - Self::UInt8(array) => ArrayView::UInt8(array.view()), - Self::UInt16(array) => ArrayView::UInt16(array.view()), - Self::UInt32(array) => ArrayView::UInt32(array.view()), - Self::UInt64(array) => ArrayView::UInt64(array.view()), - Self::Float16(array) => ArrayView::Float16(array.view()), - Self::Float32(array) => ArrayView::Float32(array.view()), - Self::Float64(array) => ArrayView::Float64(array.view()), - Self::Date32(array) => ArrayView::Date32(array.view()), - Self::Date64(array) => ArrayView::Date64(array.view()), - Self::Time32(array) => ArrayView::Time32(array.view()), - Self::Time64(array) => ArrayView::Time64(array.view()), - Self::Utf8(array) => ArrayView::Utf8(array.view()), - Self::LargeUtf8(array) => ArrayView::LargeUtf8(array.view()), - Self::Binary(array) => ArrayView::Binary(array.view()), - Self::LargeBinary(array) => ArrayView::LargeBinary(array.view()), - Self::Decimal128(array) => ArrayView::Decimal128(array.view()), - Self::Struct(array) => ArrayView::Struct(array.view()), - Self::List(array) => ArrayView::List(array.view()), - Self::LargeList(array) => ArrayView::LargeList(array.view()), - } - } -} - #[derive(Clone, Debug)] pub struct NullArray { pub len: usize, } -impl NullArray { - pub fn view(&self) -> NullArrayView { - NullArrayView { len: self.len } - } -} - #[derive(Clone, Debug)] pub struct BooleanArray { // Note: len is required to know how many bits of values are used @@ -92,15 +46,6 @@ pub struct BooleanArray { pub values: Vec, } -impl BooleanArray { - pub fn view(&self) -> BooleanArrayView { - BooleanArrayView { - len: self.len, - validity: self.validity.as_deref(), - values: &self.values, - } - } -} #[derive(Clone, Debug)] pub struct PrimitiveArray { @@ -109,14 +54,7 @@ pub struct PrimitiveArray { } impl PrimitiveArray { - pub fn view(&self) -> PrimitiveArrayView { - PrimitiveArrayView { - validity: self.validity.as_deref(), - values: &self.values, - } - } - - pub fn map_values(self, func: impl Fn(T) -> R) -> PrimitiveArray { + pub(crate) fn map_values(self, func: impl Fn(T) -> R) -> PrimitiveArray { PrimitiveArray { validity: self.validity, values: self.values.into_iter().map(func).collect(), @@ -131,16 +69,6 @@ pub struct TimeArray { pub values: Vec, } -impl TimeArray { - pub fn view(&self) -> TimeArrayView { - TimeArrayView { - unit: self.unit, - validity: self.validity.as_deref(), - values: &self.values, - } - } -} - #[derive(Clone, Debug)] pub struct StructArray { pub len: usize, @@ -148,16 +76,6 @@ pub struct StructArray { pub fields: Vec, } -impl StructArray { - pub fn view(&self) -> StructArrayView { - StructArrayView { - len: self.len, - validity: self.validity.as_deref(), - fields: self.fields.iter().map(|f| f.view()).collect(), - } - } -} - #[derive(Clone, Debug)] pub struct ListArray { pub len: usize, @@ -166,17 +84,6 @@ pub struct ListArray { pub element: Box, } -impl ListArray { - pub fn view(&self) -> ListArrayView { - ListArrayView { - len: self.len, - validity: self.validity.as_deref(), - offsets: &self.offsets, - element: Box::new(self.element.view()), - } - } -} - #[derive(Clone, Debug)] pub struct Utf8Array { pub len: usize, @@ -184,14 +91,3 @@ pub struct Utf8Array { pub offsets: Vec, pub data: Vec, } - -impl Utf8Array { - pub fn view(&self) -> Utf8ArrayView { - Utf8ArrayView { - len: self.len, - validity: self.validity.as_deref(), - offsets: &self.offsets, - data: &self.data, - } - } -} diff --git a/serde_arrow/src/internal/arrow/array_view.rs b/serde_arrow/src/internal/arrow/array_view.rs index 40f80bb9..d3161a77 100644 --- a/serde_arrow/src/internal/arrow/array_view.rs +++ b/serde_arrow/src/internal/arrow/array_view.rs @@ -1,12 +1,6 @@ use half::f16; -use crate::internal::arrow::{ - array::{ - Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, - Utf8Array, - }, - data_type::TimeUnit, -}; +use crate::internal::arrow::data_type::TimeUnit; pub enum ArrayView<'a> { Null(NullArrayView), @@ -36,110 +30,39 @@ pub enum ArrayView<'a> { LargeList(ListArrayView<'a, i64>), } -impl<'a> ArrayView<'a> { - pub fn array(&self) -> Array { - match self { - Self::Null(view) => Array::Null(view.array()), - Self::Boolean(view) => Array::Boolean(view.array()), - Self::Int8(view) => Array::Int8(view.array()), - Self::Int16(view) => Array::Int16(view.array()), - Self::Int32(view) => Array::Int32(view.array()), - Self::Int64(view) => Array::Int64(view.array()), - Self::UInt8(view) => Array::UInt8(view.array()), - Self::UInt16(view) => Array::UInt16(view.array()), - Self::UInt32(view) => Array::UInt32(view.array()), - Self::UInt64(view) => Array::UInt64(view.array()), - Self::Float16(view) => Array::Float16(view.array()), - Self::Float32(view) => Array::Float32(view.array()), - Self::Float64(view) => Array::Float64(view.array()), - Self::Date32(view) => Array::Date32(view.array()), - Self::Date64(view) => Array::Date64(view.array()), - Self::Time32(view) => Array::Time32(view.array()), - Self::Time64(view) => Array::Time64(view.array()), - Self::Utf8(view) => Array::Utf8(view.array()), - Self::Binary(view) => Array::Binary(view.array()), - Self::LargeBinary(view) => Array::LargeBinary(view.array()), - Self::LargeUtf8(view) => Array::LargeUtf8(view.array()), - Self::Decimal128(view) => Array::Decimal128(view.array()), - Self::Struct(view) => Array::Struct(view.array()), - Self::List(view) => Array::List(view.array()), - Self::LargeList(view) => Array::LargeList(view.array()), - } - } -} - pub struct NullArrayView { pub len: usize, } -impl NullArrayView { - pub fn array(&self) -> NullArray { - NullArray { len: self.len } - } +#[derive(Debug, Clone, Copy)] +pub struct BitsWithOffset<'a> { + pub offset: usize, + pub data: &'a [u8], } pub struct BooleanArrayView<'a> { pub len: usize, - pub validity: Option<&'a [u8]>, - pub values: &'a [u8], -} - -impl<'a> BooleanArrayView<'a> { - pub fn array(&self) -> BooleanArray { - BooleanArray { - len: self.len, - validity: self.validity.map(<[_]>::to_vec), - values: self.values.to_owned(), - } - } + pub validity: Option>, + pub values: BitsWithOffset<'a>, } pub struct PrimitiveArrayView<'a, T> { - pub validity: Option<&'a [u8]>, + pub validity: Option>, pub values: &'a [T], } -impl<'a, T: Clone> PrimitiveArrayView<'a, T> { - pub fn array(&self) -> PrimitiveArray { - PrimitiveArray { - validity: self.validity.map(<[_]>::to_vec), - values: self.values.to_vec(), - } - } -} - pub struct TimeArrayView<'a, T> { pub unit: TimeUnit, - pub validity: Option<&'a [u8]>, + pub validity: Option>, pub values: &'a [T], } -impl<'a, T: Clone> TimeArrayView<'a, T> { - pub fn array(&self) -> TimeArray { - TimeArray { - unit: self.unit, - validity: self.validity.map(<[_]>::to_vec), - values: self.values.to_vec(), - } - } -} - pub struct StructArrayView<'a> { pub len: usize, - pub validity: Option<&'a [u8]>, + pub validity: Option>, pub fields: Vec>, } -impl<'a> StructArrayView<'a> { - pub fn array(&self) -> StructArray { - StructArray { - len: self.len, - validity: self.validity.map(<[_]>::to_vec), - fields: self.fields.iter().map(|f| f.array()).collect(), - } - } -} - pub struct ListArrayView<'a, O> { pub len: usize, pub validity: Option<&'a [u8]>, @@ -147,31 +70,9 @@ pub struct ListArrayView<'a, O> { pub element: Box>, } -impl<'a, O: Clone> ListArrayView<'a, O> { - pub fn array(&self) -> ListArray { - ListArray { - len: self.len, - validity: self.validity.map(<[_]>::to_vec), - offsets: self.offsets.to_vec(), - element: Box::new(self.element.array()), - } - } -} - pub struct Utf8ArrayView<'a, O> { pub len: usize, pub validity: Option<&'a [u8]>, pub offsets: &'a [O], pub data: &'a [u8], } - -impl<'a, O: Clone> Utf8ArrayView<'a, O> { - pub fn array(&self) -> Utf8Array { - Utf8Array { - len: self.len, - validity: self.validity.map(<[_]>::to_vec), - offsets: self.offsets.to_vec(), - data: self.data.to_vec(), - } - } -} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 218182c8..d358ebd1 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -1,17 +1,15 @@ //! A common arrow abstraction to simplify conversion between different arrow //! implementations +#![allow(dead_code, unused)] mod array; mod array_view; mod data_type; -#[allow(unused)] pub use array::{ Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, }; -#[allow(unused)] pub use array_view::{ ArrayView, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, Utf8ArrayView, }; -#[allow(unused)] pub use data_type::{BaseDataTypeDisplay, DataType, Field, TimeUnit}; From b3bdb91759af67ba9590f7fe4794715f27015b4e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 19:54:10 +0200 Subject: [PATCH 011/178] Export BitsWithOffset --- serde_arrow/src/internal/arrow/array.rs | 1 - serde_arrow/src/internal/arrow/mod.rs | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index af8ca95a..51b074bb 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -46,7 +46,6 @@ pub struct BooleanArray { pub values: Vec, } - #[derive(Clone, Debug)] pub struct PrimitiveArray { pub validity: Option>, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index d358ebd1..48e25b65 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -9,7 +9,7 @@ pub use array::{ Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, }; pub use array_view::{ - ArrayView, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, - Utf8ArrayView, + ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, + StructArrayView, Utf8ArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, Field, TimeUnit}; From 00eac9673b725302e780bf4e9691b5831b0302ad Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 20:19:37 +0200 Subject: [PATCH 012/178] Impl Date32, Date64, Timestmap serialization --- serde_arrow/src/arrow_impl/serialization.rs | 27 ++++++----- serde_arrow/src/internal/arrow/array.rs | 10 ++++ serde_arrow/src/internal/arrow/mod.rs | 3 +- .../internal/serialization/date32_builder.rs | 11 ++++- .../internal/serialization/date64_builder.rs | 16 ++++++- .../internal/serialization/time_builder.rs | 47 ++++++++++++------- 6 files changed, 81 insertions(+), 33 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 81125592..d936e31c 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -83,17 +83,9 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::U64(_) | A::F16(_) | A::F32(_) - | A::F64(_)) => builder.into_array().try_into(), - A::Date32(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), - A::Date64(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), + | A::F64(_) + | A::Date32(_) + | A::Date64(_)) => builder.into_array().try_into(), A::Time32(builder) => build_array_data_primitive( Field::try_from(&builder.field)?.data_type().clone(), builder.buffer, @@ -282,6 +274,19 @@ impl TryFrom for ArrayData { A::Float16(arr) => primitive_into_data(ArrowT::Float16, arr.map_values(f16_to_f16)), A::Float32(arr) => primitive_into_data(ArrowT::Float32, arr), A::Float64(arr) => primitive_into_data(ArrowT::Float64, arr), + A::Date32(arr) => primitive_into_data(ArrowT::Date32, arr), + A::Date64(arr) => primitive_into_data(ArrowT::Date64, arr), + A::Timestamp(arr) => { + let data_type = ArrowT::Timestamp(arr.unit.into(), arr.timezone.map(String::into)); + Ok(ArrayData::try_new( + data_type, + arr.values.len(), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?) + } array => fail!("{:?} not implemented", array), } } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 51b074bb..e1a44cd3 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -23,6 +23,7 @@ pub enum Array { Date64(PrimitiveArray), Time32(TimeArray), Time64(TimeArray), + Timestamp(TimestampArray), Utf8(Utf8Array), LargeUtf8(Utf8Array), Binary(Utf8Array), @@ -68,6 +69,15 @@ pub struct TimeArray { pub values: Vec, } +#[derive(Debug, Clone)] + +pub struct TimestampArray { + pub unit: TimeUnit, + pub timezone: Option, + pub validity: Option>, + pub values: Vec, +} + #[derive(Clone, Debug)] pub struct StructArray { pub len: usize, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 48e25b65..39a8e203 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,7 +6,8 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, + Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, + TimestampArray, Utf8Array, }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index e0e8d9b0..2fad40fb 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -1,6 +1,10 @@ use chrono::{NaiveDate, NaiveDateTime}; -use crate::internal::{arrow::Array, error::Result, schema::GenericField}; +use crate::internal::{ + arrow::{Array, PrimitiveArray}, + error::Result, + schema::GenericField, +}; use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; @@ -33,7 +37,10 @@ impl Date32Builder { } pub fn into_array(self) -> Array { - unimplemented!() + Array::Date32(PrimitiveArray { + validity: self.validity.map(|validity| validity.buffer), + values: self.buffer, + }) } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 9952cd68..c01211a9 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,5 +1,5 @@ use crate::internal::{ - arrow::{Array, TimeUnit}, + arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, error::{Error, Result}, schema::{GenericDataType, GenericField}, }; @@ -38,7 +38,19 @@ impl Date64Builder { } pub fn into_array(self) -> Array { - unimplemented!() + if let GenericDataType::Timestamp(unit, timezone) = self.field.data_type { + Array::Timestamp(TimestampArray { + unit, + timezone, + validity: self.validity.map(|validity| validity.buffer), + values: self.buffer, + }) + } else { + Array::Date64(PrimitiveArray { + validity: self.validity.map(|validity| validity.buffer), + values: self.buffer, + }) + } } } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 35cc84d9..7e28b5da 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -1,7 +1,7 @@ use chrono::Timelike; use crate::internal::{ - arrow::{Array, TimeUnit}, + arrow::{Array, TimeArray, TimeUnit}, error::{Error, Result}, schema::GenericField, }; @@ -13,25 +13,16 @@ pub struct TimeBuilder { pub field: GenericField, pub validity: Option, pub buffer: Vec, - pub seconds_factor: i64, - pub nanoseconds_factor: i64, + pub unit: TimeUnit, } impl TimeBuilder { pub fn new(field: GenericField, nullable: bool, unit: TimeUnit) -> Self { - let (seconds_factor, nanoseconds_factor) = match unit { - TimeUnit::Nanosecond => (1_000_000_000, 1), - TimeUnit::Microsecond => (1_000_000, 1_000), - TimeUnit::Millisecond => (1_000, 1_000_000), - TimeUnit::Second => (1, 1_000_000_000), - }; - Self { field, validity: nullable.then(MutableBitBuffer::default), buffer: Vec::new(), - seconds_factor, - nanoseconds_factor, + unit, } } @@ -40,17 +31,32 @@ impl TimeBuilder { field: self.field.clone(), validity: self.validity.as_mut().map(std::mem::take), buffer: std::mem::take(&mut self.buffer), - seconds_factor: self.seconds_factor, - nanoseconds_factor: self.nanoseconds_factor, + unit: self.unit, } } pub fn is_nullable(&self) -> bool { self.validity.is_some() } +} +impl TimeBuilder { pub fn into_array(self) -> Array { - unimplemented!() + Array::Time32(TimeArray { + unit: self.unit, + validity: self.validity.map(|v| v.buffer), + values: self.buffer, + }) + } +} + +impl TimeBuilder { + pub fn into_array(self) -> Array { + Array::Time64(TimeArray { + unit: self.unit, + validity: self.validity.map(|v| v.buffer), + values: self.buffer, + }) } } @@ -77,10 +83,17 @@ where } fn serialize_str(&mut self, v: &str) -> Result<()> { + let (seconds_factor, nanoseconds_factor) = match self.unit { + TimeUnit::Nanosecond => (1_000_000_000, 1), + TimeUnit::Microsecond => (1_000_000, 1_000), + TimeUnit::Millisecond => (1_000, 1_000_000), + TimeUnit::Second => (1, 1_000_000_000), + }; + use chrono::naive::NaiveTime; let time = v.parse::()?; - let timestamp = time.num_seconds_from_midnight() as i64 * self.seconds_factor - + time.nanosecond() as i64 / self.nanoseconds_factor; + let timestamp = time.num_seconds_from_midnight() as i64 * seconds_factor + + time.nanosecond() as i64 / nanoseconds_factor; push_validity(&mut self.validity, true)?; self.buffer.push(timestamp.try_into()?); From d43e44e6a42e8223b8e1e018d76d42b7bf7369f5 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 20:22:36 +0200 Subject: [PATCH 013/178] Add support for Time32, Time64 serialization --- serde_arrow/src/arrow_impl/serialization.rs | 49 ++++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index d936e31c..f8581645 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -85,17 +85,9 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::F32(_) | A::F64(_) | A::Date32(_) - | A::Date64(_)) => builder.into_array().try_into(), - A::Time32(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), - A::Time64(builder) => build_array_data_primitive( - Field::try_from(&builder.field)?.data_type().clone(), - builder.buffer, - builder.validity, - ), + | A::Date64(_) + | A::Time32(_) + | A::Time64(_)) => builder.into_array().try_into(), A::Duration(builder) => build_array_data_primitive( T::Duration(builder.unit.into()), builder.buffer, @@ -276,17 +268,30 @@ impl TryFrom for ArrayData { A::Float64(arr) => primitive_into_data(ArrowT::Float64, arr), A::Date32(arr) => primitive_into_data(ArrowT::Date32, arr), A::Date64(arr) => primitive_into_data(ArrowT::Date64, arr), - A::Timestamp(arr) => { - let data_type = ArrowT::Timestamp(arr.unit.into(), arr.timezone.map(String::into)); - Ok(ArrayData::try_new( - data_type, - arr.values.len(), - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.values).into_inner()], - vec![], - )?) - } + A::Timestamp(arr) => Ok(ArrayData::try_new( + ArrowT::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), + arr.values.len(), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?), + A::Time32(arr) => Ok(ArrayData::try_new( + ArrowT::Time32(arr.unit.into()), + arr.values.len(), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?), + A::Time64(arr) => Ok(ArrayData::try_new( + ArrowT::Time64(arr.unit.into()), + arr.values.len(), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?), array => fail!("{:?} not implemented", array), } } From d978af7da2253a588fe24155666a0ebb7fb10f71 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 20:35:59 +0200 Subject: [PATCH 014/178] Add Duration serialization --- serde_arrow/src/arrow_impl/serialization.rs | 84 +++++++++---------- serde_arrow/src/internal/arrow/array.rs | 10 +-- .../serialization/duration_builder.rs | 8 +- 3 files changed, 45 insertions(+), 57 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index f8581645..76e516b7 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -87,12 +87,8 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::Date32(_) | A::Date64(_) | A::Time32(_) - | A::Time64(_)) => builder.into_array().try_into(), - A::Duration(builder) => build_array_data_primitive( - T::Duration(builder.unit.into()), - builder.buffer, - builder.validity, - ), + | A::Time64(_) + | A::Duration(_)) => builder.into_array().try_into(), A::Decimal128(builder) => build_array_data_primitive( T::Decimal128(builder.precision, builder.scale), builder.buffer, @@ -249,49 +245,44 @@ impl TryFrom for ArrayData { A::Null(arr) => Ok(NullArray::new(arr.len).into_data()), A::Boolean(arr) => Ok(ArrayData::try_new( ArrowT::Boolean, + // NOTE: use the explicit len arr.len, arr.validity.map(Buffer::from), 0, vec![ScalarBuffer::from(arr.values).into_inner()], vec![], )?), - A::Int8(arr) => primitive_into_data(ArrowT::Int8, arr), - A::Int16(arr) => primitive_into_data(ArrowT::Int16, arr), - A::Int32(arr) => primitive_into_data(ArrowT::Int32, arr), - A::Int64(arr) => primitive_into_data(ArrowT::Int64, arr), - A::UInt8(arr) => primitive_into_data(ArrowT::UInt8, arr), - A::UInt16(arr) => primitive_into_data(ArrowT::UInt16, arr), - A::UInt32(arr) => primitive_into_data(ArrowT::UInt32, arr), - A::UInt64(arr) => primitive_into_data(ArrowT::UInt64, arr), - A::Float16(arr) => primitive_into_data(ArrowT::Float16, arr.map_values(f16_to_f16)), - A::Float32(arr) => primitive_into_data(ArrowT::Float32, arr), - A::Float64(arr) => primitive_into_data(ArrowT::Float64, arr), - A::Date32(arr) => primitive_into_data(ArrowT::Date32, arr), - A::Date64(arr) => primitive_into_data(ArrowT::Date64, arr), - A::Timestamp(arr) => Ok(ArrayData::try_new( + A::Int8(arr) => primitive_into_data(ArrowT::Int8, arr.validity, arr.values), + A::Int16(arr) => primitive_into_data(ArrowT::Int16, arr.validity, arr.values), + A::Int32(arr) => primitive_into_data(ArrowT::Int32, arr.validity, arr.values), + A::Int64(arr) => primitive_into_data(ArrowT::Int64, arr.validity, arr.values), + A::UInt8(arr) => primitive_into_data(ArrowT::UInt8, arr.validity, arr.values), + A::UInt16(arr) => primitive_into_data(ArrowT::UInt16, arr.validity, arr.values), + A::UInt32(arr) => primitive_into_data(ArrowT::UInt32, arr.validity, arr.values), + A::UInt64(arr) => primitive_into_data(ArrowT::UInt64, arr.validity, arr.values), + A::Float16(arr) => primitive_into_data( + ArrowT::Float16, + arr.validity, + arr.values.into_iter().map(f16_to_f16).collect(), + ), + A::Float32(arr) => primitive_into_data(ArrowT::Float32, arr.validity, arr.values), + A::Float64(arr) => primitive_into_data(ArrowT::Float64, arr.validity, arr.values), + A::Date32(arr) => primitive_into_data(ArrowT::Date32, arr.validity, arr.values), + A::Date64(arr) => primitive_into_data(ArrowT::Date64, arr.validity, arr.values), + A::Timestamp(arr) => primitive_into_data( ArrowT::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), - arr.values.len(), - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.values).into_inner()], - vec![], - )?), - A::Time32(arr) => Ok(ArrayData::try_new( - ArrowT::Time32(arr.unit.into()), - arr.values.len(), - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.values).into_inner()], - vec![], - )?), - A::Time64(arr) => Ok(ArrayData::try_new( - ArrowT::Time64(arr.unit.into()), - arr.values.len(), - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.values).into_inner()], - vec![], - )?), + arr.validity, + arr.values, + ), + A::Time32(arr) => { + primitive_into_data(ArrowT::Time32(arr.unit.into()), arr.validity, arr.values) + } + A::Time64(arr) => { + primitive_into_data(ArrowT::Time64(arr.unit.into()), arr.validity, arr.values) + } + A::Duration(arr) => { + primitive_into_data(ArrowT::Duration(arr.unit.into()), arr.validity, arr.values) + } array => fail!("{:?} not implemented", array), } } @@ -299,14 +290,15 @@ impl TryFrom for ArrayData { fn primitive_into_data( data_type: DataType, - array: crate::internal::arrow::PrimitiveArray, + validity: Option>, + values: Vec, ) -> Result { Ok(ArrayData::try_new( data_type, - array.values.len(), - array.validity.map(Buffer::from), + values.len(), + validity.map(Buffer::from), 0, - vec![ScalarBuffer::from(array.values).into_inner()], + vec![ScalarBuffer::from(values).into_inner()], vec![], )?) } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index e1a44cd3..69a8a44a 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -24,6 +24,7 @@ pub enum Array { Time32(TimeArray), Time64(TimeArray), Timestamp(TimestampArray), + Duration(TimeArray), Utf8(Utf8Array), LargeUtf8(Utf8Array), Binary(Utf8Array), @@ -53,15 +54,6 @@ pub struct PrimitiveArray { pub values: Vec, } -impl PrimitiveArray { - pub(crate) fn map_values(self, func: impl Fn(T) -> R) -> PrimitiveArray { - PrimitiveArray { - validity: self.validity, - values: self.values.into_iter().map(func).collect(), - } - } -} - #[derive(Debug, Clone)] pub struct TimeArray { pub unit: TimeUnit, diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 4514fd45..0b6edb8c 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,5 +1,5 @@ use crate::internal::{ - arrow::{Array, TimeUnit}, + arrow::{Array, TimeArray, TimeUnit}, error::Result, }; @@ -34,7 +34,11 @@ impl DurationBuilder { } pub fn into_array(self) -> Array { - unimplemented!() + Array::Duration(TimeArray { + unit: self.unit, + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + }) } } From 421cd1a4c13d8d96362b35ac61cc2ee3913b4df3 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 20:41:17 +0200 Subject: [PATCH 015/178] Add Decimal128 serialization --- serde_arrow/src/arrow_impl/serialization.rs | 38 ++++--------------- serde_arrow/src/internal/arrow/array.rs | 10 ++++- serde_arrow/src/internal/arrow/mod.rs | 4 +- .../internal/serialization/decimal_builder.rs | 9 ++++- 4 files changed, 25 insertions(+), 36 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 76e516b7..e2ecb339 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -88,12 +88,8 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::Date64(_) | A::Time32(_) | A::Time64(_) - | A::Duration(_)) => builder.into_array().try_into(), - A::Decimal128(builder) => build_array_data_primitive( - T::Decimal128(builder.precision, builder.scale), - builder.buffer, - builder.validity, - ), + | A::Duration(_) + | A::Decimal128(_)) => builder.into_array().try_into(), A::Utf8(builder) => build_array_data_utf8( T::Utf8, builder.offsets.offsets, @@ -283,6 +279,11 @@ impl TryFrom for ArrayData { A::Duration(arr) => { primitive_into_data(ArrowT::Duration(arr.unit.into()), arr.validity, arr.values) } + A::Decimal128(arr) => primitive_into_data( + ArrowT::Decimal128(arr.precision, arr.scale), + arr.validity, + arr.values, + ), array => fail!("{:?} not implemented", array), } } @@ -303,31 +304,6 @@ fn primitive_into_data( )?) } -fn build_array_data_primitive( - data_type: DataType, - data: Vec, - validity: Option, -) -> Result { - let len = data.len(); - build_array_data_primitive_with_len(data_type, len, data, validity) -} - -fn build_array_data_primitive_with_len( - data_type: DataType, - len: usize, - data: Vec, - validity: Option, -) -> Result { - Ok(ArrayData::try_new( - data_type, - len, - validity.map(|b| Buffer::from(b.buffer)), - 0, - vec![ScalarBuffer::from(data).into_inner()], - vec![], - )?) -} - fn build_array_data_utf8( data_type: DataType, offsets: Vec, diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 69a8a44a..b504f073 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -29,7 +29,7 @@ pub enum Array { LargeUtf8(Utf8Array), Binary(Utf8Array), LargeBinary(Utf8Array), - Decimal128(PrimitiveArray), + Decimal128(DecimalArray), Struct(StructArray), List(ListArray), LargeList(ListArray), @@ -92,3 +92,11 @@ pub struct Utf8Array { pub offsets: Vec, pub data: Vec, } + +#[derive(Clone, Debug)] +pub struct DecimalArray { + pub precision: u8, + pub scale: i8, + pub validity: Option>, + pub values: Vec, +} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 39a8e203..b524c86b 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,8 +6,8 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, - TimestampArray, Utf8Array, + Array, BooleanArray, DecimalArray, ListArray, NullArray, PrimitiveArray, StructArray, + TimeArray, TimestampArray, Utf8Array, }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index e3444cca..1e8889a9 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,5 +1,5 @@ use crate::internal::{ - arrow::Array, + arrow::{Array, DecimalArray}, error::Result, utils::decimal::{self, DecimalParser}, }; @@ -47,7 +47,12 @@ impl DecimalBuilder { } pub fn into_array(self) -> Array { - unimplemented!() + Array::Decimal128(DecimalArray { + precision: self.precision, + scale: self.scale, + validity: self.validity.map(|b| b.buffer), + values: self.buffer, + }) } } From 40942aceb701d587c5c34f0a9304c5fa9292c417 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 20:51:41 +0200 Subject: [PATCH 016/178] Implement Utf8, LargeUtf8, Binary, LargeBinary serialization --- serde_arrow/src/arrow_impl/serialization.rs | 69 +++++-------------- serde_arrow/src/internal/arrow/array.rs | 10 +-- serde_arrow/src/internal/arrow/mod.rs | 4 +- .../internal/serialization/binary_builder.rs | 22 +++++- .../internal/serialization/utf8_builder.rs | 22 +++++- 5 files changed, 64 insertions(+), 63 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index e2ecb339..bbe79166 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -14,11 +14,7 @@ use crate::{ internal::{ error::{fail, Error, Result}, schema::{GenericField, SerdeArrowSchema}, - serialization::{ - utils::{MutableBitBuffer, MutableOffsetBuffer}, - ArrayBuilder, OuterSequenceBuilder, - }, - utils::Offset, + serialization::{utils::MutableBitBuffer, ArrayBuilder, OuterSequenceBuilder}, }, }; @@ -89,19 +85,11 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::Time32(_) | A::Time64(_) | A::Duration(_) - | A::Decimal128(_)) => builder.into_array().try_into(), - A::Utf8(builder) => build_array_data_utf8( - T::Utf8, - builder.offsets.offsets, - builder.buffer, - builder.validity, - ), - A::LargeUtf8(builder) => build_array_data_utf8( - T::LargeUtf8, - builder.offsets.offsets, - builder.buffer, - builder.validity, - ), + | A::Decimal128(_) + | A::Utf8(_) + | A::LargeUtf8(_) + | A::Binary(_) + | A::LargeBinary(_)) => builder.into_array().try_into(), A::LargeList(builder) => build_array_data_list( T::LargeList(Arc::new(Field::try_from(&builder.field)?)), builder.offsets.offsets.len() - 1, @@ -134,15 +122,6 @@ fn build_array_data(builder: ArrayBuilder) -> Result { .add_child_data(child_data) .build()?) } - A::Binary(builder) => { - build_array_data_binary(T::Binary, builder.offsets, builder.buffer, builder.validity) - } - A::LargeBinary(builder) => build_array_data_binary( - T::LargeBinary, - builder.offsets, - builder.buffer, - builder.validity, - ), A::FixedSizeBinary(builder) => { let data_buffer = ScalarBuffer::from(builder.buffer).into_inner(); let validity = if let Some(validity) = builder.validity { @@ -284,6 +263,14 @@ impl TryFrom for ArrayData { arr.validity, arr.values, ), + A::Utf8(arr) => bytes_into_data(ArrowT::Utf8, arr.offsets, arr.data, arr.validity), + A::LargeUtf8(arr) => { + bytes_into_data(ArrowT::LargeUtf8, arr.offsets, arr.data, arr.validity) + } + A::Binary(arr) => bytes_into_data(ArrowT::Binary, arr.offsets, arr.data, arr.validity), + A::LargeBinary(arr) => { + bytes_into_data(ArrowT::LargeBinary, arr.offsets, arr.data, arr.validity) + } array => fail!("{:?} not implemented", array), } } @@ -304,17 +291,17 @@ fn primitive_into_data( )?) } -fn build_array_data_utf8( +fn bytes_into_data( data_type: DataType, offsets: Vec, data: Vec, - validity: Option, + validity: Option>, ) -> Result { let values_len = offsets.len() - 1; let offsets = ScalarBuffer::from(offsets).into_inner(); let data = ScalarBuffer::from(data).into_inner(); - let validity = validity.map(|b| Buffer::from(b.buffer)); + let validity = validity.map(Buffer::from); Ok(ArrayData::try_new( data_type, @@ -326,28 +313,6 @@ fn build_array_data_utf8( )?) } -fn build_array_data_binary( - data_type: DataType, - offsets: MutableOffsetBuffer, - data: Vec, - validity: Option, -) -> Result { - let len = offsets.len(); - let offset_buffer = ScalarBuffer::from(offsets.offsets).into_inner(); - let data_buffer = ScalarBuffer::from(data).into_inner(); - let validity = if let Some(validity) = validity { - Some(Buffer::from(validity.buffer)) - } else { - None - }; - Ok(ArrayData::builder(data_type) - .len(len) - .null_bit_buffer(validity) - .add_buffer(offset_buffer) - .add_buffer(data_buffer) - .build()?) -} - fn build_array_data_list( data_type: DataType, len: usize, diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index b504f073..e78e0da6 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -25,10 +25,10 @@ pub enum Array { Time64(TimeArray), Timestamp(TimestampArray), Duration(TimeArray), - Utf8(Utf8Array), - LargeUtf8(Utf8Array), - Binary(Utf8Array), - LargeBinary(Utf8Array), + Utf8(BytesArray), + LargeUtf8(BytesArray), + Binary(BytesArray), + LargeBinary(BytesArray), Decimal128(DecimalArray), Struct(StructArray), List(ListArray), @@ -86,7 +86,7 @@ pub struct ListArray { } #[derive(Clone, Debug)] -pub struct Utf8Array { +pub struct BytesArray { pub len: usize, pub validity: Option>, pub offsets: Vec, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index b524c86b..25517533 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,8 +6,8 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, DecimalArray, ListArray, NullArray, PrimitiveArray, StructArray, - TimeArray, TimestampArray, Utf8Array, + Array, BooleanArray, BytesArray, DecimalArray, ListArray, NullArray, PrimitiveArray, + StructArray, TimeArray, TimestampArray, }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 49bdcdff..7f55cbfb 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::internal::{ - arrow::Array, + arrow::{Array, BytesArray}, error::Result, utils::{Mut, Offset}, }; @@ -38,9 +38,27 @@ impl BinaryBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } +} + +impl BinaryBuilder { + pub fn into_array(self) -> Array { + Array::Binary(BytesArray { + len: self.offsets.len(), + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + }) + } +} +impl BinaryBuilder { pub fn into_array(self) -> Array { - unimplemented!() + Array::LargeBinary(BytesArray { + len: self.offsets.len(), + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + }) } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index f7eb710a..ba7a3aa3 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,5 +1,5 @@ use crate::internal::{ - arrow::Array, + arrow::{Array, BytesArray}, error::{fail, Result}, utils::Offset, }; @@ -35,9 +35,27 @@ impl Utf8Builder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } +} + +impl Utf8Builder { + pub fn into_array(self) -> Array { + Array::Utf8(BytesArray { + len: self.offsets.len(), + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + }) + } +} +impl Utf8Builder { pub fn into_array(self) -> Array { - unimplemented!() + Array::LargeUtf8(BytesArray { + len: self.offsets.len(), + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + data: self.buffer, + }) } } From dc2ed0dcf25edaa3ed74f357ea8e57d90976231f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 18 Jul 2024 21:01:44 +0200 Subject: [PATCH 017/178] Refactor arrow list construction --- serde_arrow/src/arrow_impl/serialization.rs | 46 ++++++++++----------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index bbe79166..10cd1d26 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -14,7 +14,7 @@ use crate::{ internal::{ error::{fail, Error, Result}, schema::{GenericField, SerdeArrowSchema}, - serialization::{utils::MutableBitBuffer, ArrayBuilder, OuterSequenceBuilder}, + serialization::{ArrayBuilder, OuterSequenceBuilder}, }, }; @@ -90,19 +90,19 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::LargeUtf8(_) | A::Binary(_) | A::LargeBinary(_)) => builder.into_array().try_into(), - A::LargeList(builder) => build_array_data_list( + A::LargeList(builder) => list_into_data( T::LargeList(Arc::new(Field::try_from(&builder.field)?)), builder.offsets.offsets.len() - 1, builder.offsets.offsets, build_array_data(*builder.element)?, - builder.validity, + builder.validity.map(|v| v.buffer), ), - A::List(builder) => build_array_data_list( + A::List(builder) => list_into_data( T::List(Arc::new(Field::try_from(&builder.field)?)), builder.offsets.offsets.len() - 1, builder.offsets.offsets, build_array_data(*builder.element)?, - builder.validity, + builder.validity.map(|v| v.buffer), ), A::FixedSizedList(builder) => { let data_type = T::FixedSizeList( @@ -297,36 +297,32 @@ fn bytes_into_data( data: Vec, validity: Option>, ) -> Result { - let values_len = offsets.len() - 1; - - let offsets = ScalarBuffer::from(offsets).into_inner(); - let data = ScalarBuffer::from(data).into_inner(); - let validity = validity.map(Buffer::from); - Ok(ArrayData::try_new( data_type, - values_len, - validity, + offsets.len() - 1, + validity.map(Buffer::from), 0, - vec![offsets, data], + vec![ + ScalarBuffer::from(offsets).into_inner(), + ScalarBuffer::from(data).into_inner(), + ], vec![], )?) } -fn build_array_data_list( +fn list_into_data( data_type: DataType, len: usize, offsets: Vec, child_data: ArrayData, - validity: Option, + validity: Option>, ) -> Result { - let offset_buffer = ScalarBuffer::from(offsets).into_inner(); - let validity = validity.map(|b| Buffer::from(b.buffer)); - - Ok(ArrayData::builder(data_type) - .len(len) - .add_buffer(offset_buffer) - .add_child_data(child_data) - .null_bit_buffer(validity) - .build()?) + Ok(ArrayData::try_new( + data_type, + len, + validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(offsets).into_inner()], + vec![child_data], + )?) } From 1f4fdef44ef50412893257bd19116420f301f7c4 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 16:03:12 +0200 Subject: [PATCH 018/178] Implement FixedSizeList, List, LargeList, Struct --- serde_arrow/src/arrow_impl/serialization.rs | 134 ++++++++++-------- serde_arrow/src/internal/arrow/array.rs | 23 ++- serde_arrow/src/internal/arrow/mod.rs | 4 +- .../internal/serialization/array_builder.rs | 2 +- .../internal/serialization/binary_builder.rs | 14 +- .../internal/serialization/bool_builder.rs | 6 +- .../internal/serialization/date32_builder.rs | 6 +- .../internal/serialization/date64_builder.rs | 10 +- .../internal/serialization/decimal_builder.rs | 6 +- .../serialization/dictionary_utf8_builder.rs | 2 +- .../serialization/duration_builder.rs | 6 +- .../fixed_size_binary_builder.rs | 2 +- .../serialization/fixed_size_list_builder.rs | 15 +- .../internal/serialization/float_builder.rs | 6 +- .../src/internal/serialization/int_builder.rs | 6 +- .../internal/serialization/list_builder.rs | 28 +++- .../src/internal/serialization/map_builder.rs | 2 +- .../internal/serialization/null_builder.rs | 4 +- .../internal/serialization/struct_builder.rs | 41 ++++-- .../internal/serialization/time_builder.rs | 12 +- .../internal/serialization/union_builder.rs | 2 +- .../serialization/unknown_variant_builder.rs | 4 +- .../internal/serialization/utf8_builder.rs | 14 +- .../src/internal/serialization/utils.rs | 10 ++ 24 files changed, 215 insertions(+), 144 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 10cd1d26..9a0165b5 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -12,6 +12,7 @@ use crate::{ }, }, internal::{ + arrow::FieldMeta, error::{fail, Error, Result}, schema::{GenericField, SerdeArrowSchema}, serialization::{ArrayBuilder, OuterSequenceBuilder}, @@ -89,39 +90,11 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::Utf8(_) | A::LargeUtf8(_) | A::Binary(_) - | A::LargeBinary(_)) => builder.into_array().try_into(), - A::LargeList(builder) => list_into_data( - T::LargeList(Arc::new(Field::try_from(&builder.field)?)), - builder.offsets.offsets.len() - 1, - builder.offsets.offsets, - build_array_data(*builder.element)?, - builder.validity.map(|v| v.buffer), - ), - A::List(builder) => list_into_data( - T::List(Arc::new(Field::try_from(&builder.field)?)), - builder.offsets.offsets.len() - 1, - builder.offsets.offsets, - build_array_data(*builder.element)?, - builder.validity.map(|v| v.buffer), - ), - A::FixedSizedList(builder) => { - let data_type = T::FixedSizeList( - Arc::new(Field::try_from(&builder.field)?), - builder.n.try_into()?, - ); - let child_data = build_array_data(*builder.element)?; - let validity = if let Some(validity) = builder.validity { - Some(Buffer::from(validity.buffer)) - } else { - None - }; - - Ok(ArrayData::builder(data_type) - .len(builder.num_elements) - .null_bit_buffer(validity) - .add_child_data(child_data) - .build()?) - } + | A::LargeBinary(_) + | A::Struct(_) + | A::LargeList(_) + | A::List(_) + | A::FixedSizedList(_)) => builder.into_array()?.try_into(), A::FixedSizeBinary(builder) => { let data_buffer = ScalarBuffer::from(builder.buffer).into_inner(); let validity = if let Some(validity) = builder.validity { @@ -138,34 +111,6 @@ fn build_array_data(builder: ArrayBuilder) -> Result { .build()?, ) } - A::Struct(builder) => { - let mut data = Vec::new(); - for (_, field) in builder.named_fields { - data.push(build_array_data(field)?); - } - - let (validity, len) = if let Some(validity) = builder.validity { - (Some(Buffer::from(validity.buffer)), validity.len) - } else { - if data.is_empty() { - fail!("cannot built non-nullable structs without fields"); - } - (None, data[0].len()) - }; - - let fields = builder - .fields - .iter() - .map(Field::try_from) - .collect::>>()?; - let data_type = T::Struct(fields.into()); - - Ok(ArrayData::builder(data_type) - .len(len) - .null_bit_buffer(validity) - .child_data(data) - .build()?) - } A::Map(builder) => Ok(ArrayData::builder(T::Map( Arc::new(Field::try_from(&builder.entry_field)?), false, @@ -271,11 +216,74 @@ impl TryFrom for ArrayData { A::LargeBinary(arr) => { bytes_into_data(ArrowT::LargeBinary, arr.offsets, arr.data, arr.validity) } - array => fail!("{:?} not implemented", array), + A::Struct(arr) => { + let mut fields = Vec::new(); + let mut data = Vec::new(); + + for (field, meta) in arr.fields { + let child: ArrayData = field.try_into()?; + let field = Field::new(meta.name, child.data_type().clone(), meta.nullable) + .with_metadata(meta.metadata); + fields.push(Arc::new(field)); + data.push(child); + } + let data_type = ArrowT::Struct(fields.into()); + + Ok(ArrayData::builder(data_type) + .len(arr.len) + .null_bit_buffer(arr.validity.map(Buffer::from)) + .child_data(data) + .build()?) + } + A::List(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + list_into_data( + ArrowT::List(Arc::new(field)), + arr.offsets.len().saturating_sub(1), + arr.offsets, + child, + arr.validity, + ) + } + A::LargeList(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + list_into_data( + ArrowT::LargeList(Arc::new(field)), + arr.offsets.len().saturating_sub(1), + arr.offsets, + child, + arr.validity, + ) + } + A::FixedSizeList(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + + if (child.len() % usize::try_from(arr.n)?) != 0 { + fail!( + "Invalid FixedSizeList: number of child elements ({}) not divisible by n ({})", + child.len(), + arr.n, + ); + } + let field = field_from_data_and_meta(&child, arr.meta); + Ok( + ArrayData::builder(ArrowT::FixedSizeList(Arc::new(field), arr.n)) + .len(child.len() / usize::try_from(arr.n)?) + .null_bit_buffer(arr.validity.map(Buffer::from)) + .add_child_data(child) + .build()?, + ) + } } } } +fn field_from_data_and_meta(data: &ArrayData, meta: FieldMeta) -> Field { + Field::new(meta.name, data.data_type().clone(), meta.nullable).with_metadata(meta.metadata) +} + fn primitive_into_data( data_type: DataType, validity: Option>, @@ -299,7 +307,7 @@ fn bytes_into_data( ) -> Result { Ok(ArrayData::try_new( data_type, - offsets.len() - 1, + offsets.len().saturating_sub(1), validity.map(Buffer::from), 0, vec![ diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index e78e0da6..e43b1911 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -1,4 +1,6 @@ //! Owned versions of the different array types +use std::collections::HashMap; + use half::f16; use crate::internal::arrow::data_type::TimeUnit; @@ -33,6 +35,7 @@ pub enum Array { Struct(StructArray), List(ListArray), LargeList(ListArray), + FixedSizeList(FixedSizeListArray), } #[derive(Clone, Debug)] @@ -74,20 +77,34 @@ pub struct TimestampArray { pub struct StructArray { pub len: usize, pub validity: Option>, - pub fields: Vec, + pub fields: Vec<(Array, FieldMeta)>, +} + +#[derive(Clone, Debug)] +pub struct FieldMeta { + pub name: String, + pub nullable: bool, + pub metadata: HashMap, } #[derive(Clone, Debug)] pub struct ListArray { - pub len: usize, pub validity: Option>, pub offsets: Vec, + pub meta: FieldMeta, + pub element: Box, +} + +#[derive(Clone, Debug)] +pub struct FixedSizeListArray { + pub n: i32, + pub validity: Option>, + pub meta: FieldMeta, pub element: Box, } #[derive(Clone, Debug)] pub struct BytesArray { - pub len: usize, pub validity: Option>, pub offsets: Vec, pub data: Vec, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 25517533..d0cf9675 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,8 +6,8 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, BytesArray, DecimalArray, ListArray, NullArray, PrimitiveArray, - StructArray, TimeArray, TimestampArray, + Array, BooleanArray, BytesArray, DecimalArray, FieldMeta, FixedSizeListArray, ListArray, + NullArray, PrimitiveArray, StructArray, TimeArray, TimestampArray, }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index faa3a6a3..fc73c70d 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -132,7 +132,7 @@ impl ArrayBuilder { dispatch!(self, Self(builder) => builder.is_nullable()) } - pub fn into_array(self) -> Array { + pub fn into_array(self) -> Result { dispatch!(self, Self(builder) => builder.into_array()) } } diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 7f55cbfb..192195ac 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -41,24 +41,22 @@ impl BinaryBuilder { } impl BinaryBuilder { - pub fn into_array(self) -> Array { - Array::Binary(BytesArray { - len: self.offsets.len(), + pub fn into_array(self) -> Result { + Ok(Array::Binary(BytesArray { validity: self.validity.map(|b| b.buffer), offsets: self.offsets.offsets, data: self.buffer, - }) + })) } } impl BinaryBuilder { - pub fn into_array(self) -> Array { - Array::LargeBinary(BytesArray { - len: self.offsets.len(), + pub fn into_array(self) -> Result { + Ok(Array::LargeBinary(BytesArray { validity: self.validity.map(|b| b.buffer), offsets: self.offsets.offsets, data: self.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 0670a637..a884e341 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -30,12 +30,12 @@ impl BoolBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { - Array::Boolean(BooleanArray { + pub fn into_array(self) -> Result { + Ok(Array::Boolean(BooleanArray { len: self.buffer.len, validity: self.validity.map(|v| v.buffer), values: self.buffer.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 2fad40fb..8537c5df 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -36,11 +36,11 @@ impl Date32Builder { self.validity.is_some() } - pub fn into_array(self) -> Array { - Array::Date32(PrimitiveArray { + pub fn into_array(self) -> Result { + Ok(Array::Date32(PrimitiveArray { validity: self.validity.map(|validity| validity.buffer), values: self.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index c01211a9..9668a199 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -37,19 +37,19 @@ impl Date64Builder { self.validity.is_some() } - pub fn into_array(self) -> Array { + pub fn into_array(self) -> Result { if let GenericDataType::Timestamp(unit, timezone) = self.field.data_type { - Array::Timestamp(TimestampArray { + Ok(Array::Timestamp(TimestampArray { unit, timezone, validity: self.validity.map(|validity| validity.buffer), values: self.buffer, - }) + })) } else { - Array::Date64(PrimitiveArray { + Ok(Array::Date64(PrimitiveArray { validity: self.validity.map(|validity| validity.buffer), values: self.buffer, - }) + })) } } } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 1e8889a9..5bb48de3 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -46,13 +46,13 @@ impl DecimalBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { - Array::Decimal128(DecimalArray { + pub fn into_array(self) -> Result { + Ok(Array::Decimal128(DecimalArray { precision: self.precision, scale: self.scale, validity: self.validity.map(|b| b.buffer), values: self.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index b7f068d3..004d7d62 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -42,7 +42,7 @@ impl DictionaryUtf8Builder { self.indices.is_nullable() } - pub fn into_array(self) -> Array { + pub fn into_array(self) -> Result { unimplemented!() } } diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 0b6edb8c..8b43845a 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -33,12 +33,12 @@ impl DurationBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { - Array::Duration(TimeArray { + pub fn into_array(self) -> Result { + Ok(Array::Duration(TimeArray { unit: self.unit, validity: self.validity.map(|b| b.buffer), values: self.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index f8960459..d3d4d78a 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -43,7 +43,7 @@ impl FixedSizeBinaryBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { + pub fn into_array(self) -> Result { unimplemented!() } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 157a8ba6..06754e27 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::internal::{ - arrow::Array, + arrow::{Array, FixedSizeListArray}, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -9,7 +9,9 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}, + utils::{ + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer, + }, }; #[derive(Debug, Clone)] @@ -50,8 +52,13 @@ impl FixedSizeListBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { - unimplemented!() + pub fn into_array(self) -> Result { + Ok(Array::FixedSizeList(FixedSizeListArray { + n: self.n.try_into()?, + meta: meta_from_field(self.field)?, + validity: self.validity.map(|v| v.buffer), + element: Box::new((*self.element).into_array()?), + })) } } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index efd0bf66..6381f905 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -43,11 +43,11 @@ impl FloatBuilder { macro_rules! impl_into_array { ($ty:ty, $var:ident) => { impl FloatBuilder<$ty> { - pub fn into_array(self) -> Array { - Array::$var(PrimitiveArray { + pub fn into_array(self) -> Result { + Ok(Array::$var(PrimitiveArray { validity: self.validity.map(|b| b.buffer), values: self.buffer, - }) + })) } } }; diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 89d4d2f4..391a1f41 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -34,11 +34,11 @@ impl IntBuilder { macro_rules! impl_into_array { ($ty:ty, $var:ident) => { impl IntBuilder<$ty> { - pub fn into_array(self) -> Array { - Array::$var(PrimitiveArray { + pub fn into_array(self) -> Result { + Ok(Array::$var(PrimitiveArray { validity: self.validity.map(|b| b.buffer), values: self.buffer, - }) + })) } } }; diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index c981947b..784e2f22 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::internal::{ - arrow::Array, + arrow::{Array, ListArray}, error::Result, schema::GenericField, utils::{Mut, Offset}, @@ -10,8 +10,8 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, utils::{ - push_validity, push_validity_default, MutableBitBuffer, MutableOffsetBuffer, - SimpleSerializer, + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, + MutableOffsetBuffer, SimpleSerializer, }, }; @@ -46,9 +46,27 @@ impl ListBuilder { pub fn is_nullable(&self) -> bool { self.validity.is_some() } +} + +impl ListBuilder { + pub fn into_array(self) -> Result { + Ok(Array::List(ListArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + element: Box::new(self.element.into_array()?), + meta: meta_from_field(self.field)?, + })) + } +} - pub fn into_array(self) -> Array { - unimplemented!() +impl ListBuilder { + pub fn into_array(self) -> Result { + Ok(Array::LargeList(ListArray { + validity: self.validity.map(|b| b.buffer), + offsets: self.offsets.offsets, + element: Box::new(self.element.into_array()?), + meta: meta_from_field(self.field)?, + })) } } diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index e950553a..c4ba5bbc 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -41,7 +41,7 @@ impl MapBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { + pub fn into_array(self) -> Result { unimplemented!() } } diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index f574250d..eb02acdb 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -25,8 +25,8 @@ impl NullBuilder { true } - pub fn into_array(self) -> Array { - Array::Null(NullArray { len: self.count }) + pub fn into_array(self) -> Result { + Ok(Array::Null(NullArray { len: self.count })) } } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 8b12b1b2..2753c03e 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use serde::Serialize; use crate::internal::{ - arrow::Array, + arrow::{Array, StructArray}, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -11,13 +11,16 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}, + utils::{ + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer, + }, }; const UNKNOWN_KEY: usize = usize::MAX; #[derive(Debug, Clone)] pub struct StructBuilder { + // TODO: clean this up pub fields: Vec, pub validity: Option, pub named_fields: Vec<(String, ArrayBuilder)>, @@ -25,6 +28,7 @@ pub struct StructBuilder { pub seen: Vec, pub next: usize, pub index: BTreeMap, + pub len: usize, } impl StructBuilder { @@ -33,15 +37,11 @@ impl StructBuilder { named_fields: Vec<(String, ArrayBuilder)>, is_nullable: bool, ) -> Result { - let mut index = BTreeMap::new(); - let cached_names = vec![None; named_fields.len()]; - let seen = vec![false; named_fields.len()]; - let next = 0; - if fields.len() != named_fields.len() { fail!("mismatched number of fields and builders"); } + let mut index = BTreeMap::new(); for (idx, (name, _)) in named_fields.iter().enumerate() { if index.contains_key(name) { fail!("Duplicate field {name}"); @@ -51,12 +51,13 @@ impl StructBuilder { Ok(Self { fields, + seen: vec![false; named_fields.len()], + cached_names: vec![None; named_fields.len()], validity: is_nullable.then(MutableBitBuffer::default), named_fields, - cached_names, - seen, - next, + next: 0, index, + len: 0, }) } @@ -75,6 +76,7 @@ impl StructBuilder { ), seen: std::mem::replace(&mut self.seen, vec![false; self.named_fields.len()]), next: std::mem::take(&mut self.next), + len: std::mem::take(&mut self.len), index: self.index.clone(), } } @@ -83,14 +85,26 @@ impl StructBuilder { self.validity.is_some() } - pub fn into_array(self) -> Array { - unimplemented!() + pub fn into_array(self) -> Result { + let mut fields = Vec::new(); + for (field, (_, builder)) in self.fields.into_iter().zip(self.named_fields) { + let meta = meta_from_field(field)?; + let array = builder.into_array()?; + fields.push((array, meta)); + } + + Ok(Array::Struct(StructArray { + len: self.len, + validity: self.validity.map(|b| b.buffer), + fields, + })) } } impl StructBuilder { fn start(&mut self) -> Result<()> { push_validity(&mut self.validity, true)?; + self.len += 1; self.reset(); Ok(()) } @@ -135,6 +149,7 @@ impl SimpleSerializer for StructBuilder { fn serialize_default(&mut self) -> Result<()> { push_validity_default(&mut self.validity); + self.len += 1; for (_, field) in &mut self.named_fields { field.serialize_default()?; } @@ -144,11 +159,11 @@ impl SimpleSerializer for StructBuilder { fn serialize_none(&mut self) -> Result<()> { push_validity(&mut self.validity, false)?; + self.len += 1; for (_, field) in &mut self.named_fields { field.serialize_default()?; } - Ok(()) } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 7e28b5da..71ea9fa1 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -41,22 +41,22 @@ impl TimeBuilder { } impl TimeBuilder { - pub fn into_array(self) -> Array { - Array::Time32(TimeArray { + pub fn into_array(self) -> Result { + Ok(Array::Time32(TimeArray { unit: self.unit, validity: self.validity.map(|v| v.buffer), values: self.buffer, - }) + })) } } impl TimeBuilder { - pub fn into_array(self) -> Array { - Array::Time64(TimeArray { + pub fn into_array(self) -> Result { + Ok(Array::Time64(TimeArray { unit: self.unit, validity: self.validity.map(|v| v.buffer), values: self.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 564f1fd5..e127dab6 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -41,7 +41,7 @@ impl UnionBuilder { false } - pub fn into_array(self) -> Array { + pub fn into_array(self) -> Result { unimplemented!() } } diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 35e895e2..845574de 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -22,8 +22,8 @@ impl UnknownVariantBuilder { false } - pub fn into_array(self) -> Array { - Array::Null(NullArray { len: 0 }) + pub fn into_array(self) -> Result { + Ok(Array::Null(NullArray { len: 0 })) } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index ba7a3aa3..0ebd6051 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -38,24 +38,22 @@ impl Utf8Builder { } impl Utf8Builder { - pub fn into_array(self) -> Array { - Array::Utf8(BytesArray { - len: self.offsets.len(), + pub fn into_array(self) -> Result { + Ok(Array::Utf8(BytesArray { validity: self.validity.map(|b| b.buffer), offsets: self.offsets.offsets, data: self.buffer, - }) + })) } } impl Utf8Builder { - pub fn into_array(self) -> Array { - Array::LargeUtf8(BytesArray { - len: self.offsets.len(), + pub fn into_array(self) -> Result { + Ok(Array::LargeUtf8(BytesArray { validity: self.validity.map(|b| b.buffer), offsets: self.offsets.offsets, data: self.buffer, - }) + })) } } diff --git a/serde_arrow/src/internal/serialization/utils.rs b/serde_arrow/src/internal/serialization/utils.rs index 398ad714..4a65d9ac 100644 --- a/serde_arrow/src/internal/serialization/utils.rs +++ b/serde_arrow/src/internal/serialization/utils.rs @@ -7,12 +7,22 @@ use serde::{ }; use crate::internal::{ + arrow::FieldMeta, error::{fail, Error, Result}, + schema::{merge_strategy_with_metadata, GenericField}, utils::{Mut, Offset}, }; use super::ArrayBuilder; +pub fn meta_from_field(field: GenericField) -> Result { + Ok(FieldMeta { + name: field.name, + nullable: field.nullable, + metadata: merge_strategy_with_metadata(field.metadata, field.strategy)?, + }) +} + #[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct MutableBitBuffer { pub(crate) buffer: Vec, From f3a5a9df40d1f287f587c8dfcee5b713e312766e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 16:12:11 +0200 Subject: [PATCH 019/178] Implement FixedSizeBinary --- serde_arrow/src/arrow_impl/serialization.rs | 52 ++++++++++--------- serde_arrow/src/internal/arrow/array.rs | 8 +++ serde_arrow/src/internal/arrow/mod.rs | 5 +- .../fixed_size_binary_builder.rs | 8 ++- 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 9a0165b5..4d3aabdf 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -94,23 +94,8 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::Struct(_) | A::LargeList(_) | A::List(_) - | A::FixedSizedList(_)) => builder.into_array()?.try_into(), - A::FixedSizeBinary(builder) => { - let data_buffer = ScalarBuffer::from(builder.buffer).into_inner(); - let validity = if let Some(validity) = builder.validity { - Some(Buffer::from(validity.buffer)) - } else { - None - }; - - Ok( - ArrayData::builder(T::FixedSizeBinary(builder.n.try_into()?)) - .len(builder.len) - .null_bit_buffer(validity) - .add_buffer(data_buffer) - .build()?, - ) - } + | A::FixedSizedList(_) + | A::FixedSizeBinary(_)) => builder.into_array()?.try_into(), A::Map(builder) => Ok(ArrayData::builder(T::Map( Arc::new(Field::try_from(&builder.entry_field)?), false, @@ -259,7 +244,6 @@ impl TryFrom for ArrayData { } A::FixedSizeList(arr) => { let child: ArrayData = (*arr.element).try_into()?; - if (child.len() % usize::try_from(arr.n)?) != 0 { fail!( "Invalid FixedSizeList: number of child elements ({}) not divisible by n ({})", @@ -268,13 +252,31 @@ impl TryFrom for ArrayData { ); } let field = field_from_data_and_meta(&child, arr.meta); - Ok( - ArrayData::builder(ArrowT::FixedSizeList(Arc::new(field), arr.n)) - .len(child.len() / usize::try_from(arr.n)?) - .null_bit_buffer(arr.validity.map(Buffer::from)) - .add_child_data(child) - .build()?, - ) + Ok(ArrayData::try_new( + ArrowT::FixedSizeList(Arc::new(field), arr.n), + child.len() / usize::try_from(arr.n)?, + arr.validity.map(Buffer::from), + 0, + vec![], + vec![child], + )?) + } + A::FixedSizeBinary(arr) => { + if (arr.data.len() % usize::try_from(arr.n)?) != 0 { + fail!( + "Invalid FixedSizeBinary: number of child elements ({}) not divisible by n ({})", + arr.data.len(), + arr.n, + ); + } + Ok(ArrayData::try_new( + ArrowT::FixedSizeBinary(arr.n), + arr.data.len() / usize::try_from(arr.n)?, + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.data).into_inner()], + vec![], + )?) } } } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index e43b1911..f5fd9404 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -31,6 +31,7 @@ pub enum Array { LargeUtf8(BytesArray), Binary(BytesArray), LargeBinary(BytesArray), + FixedSizeBinary(FixedSizeBinaryArray), Decimal128(DecimalArray), Struct(StructArray), List(ListArray), @@ -110,6 +111,13 @@ pub struct BytesArray { pub data: Vec, } +#[derive(Clone, Debug)] +pub struct FixedSizeBinaryArray { + pub n: i32, + pub validity: Option>, + pub data: Vec, +} + #[derive(Clone, Debug)] pub struct DecimalArray { pub precision: u8, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index d0cf9675..a1896667 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,8 +6,9 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, BytesArray, DecimalArray, FieldMeta, FixedSizeListArray, ListArray, - NullArray, PrimitiveArray, StructArray, TimeArray, TimestampArray, + Array, BooleanArray, BytesArray, DecimalArray, FieldMeta, FixedSizeBinaryArray, + FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, + TimestampArray, }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index d3d4d78a..06fc3193 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::internal::{ - arrow::Array, + arrow::{Array, FixedSizeBinaryArray}, error::{fail, Result}, utils::Mut, }; @@ -44,7 +44,11 @@ impl FixedSizeBinaryBuilder { } pub fn into_array(self) -> Result { - unimplemented!() + Ok(Array::FixedSizeBinary(FixedSizeBinaryArray { + n: self.n.try_into()?, + validity: self.validity.map(|v| v.buffer), + data: self.buffer, + })) } } From 063c5db904d07f5e5abde2084c7fe1b715d9a03a Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 16:52:40 +0200 Subject: [PATCH 020/178] Implement Dictionary --- serde_arrow/src/arrow_impl/serialization.rs | 28 +++++++++++-------- serde_arrow/src/internal/arrow/array.rs | 7 +++++ serde_arrow/src/internal/arrow/mod.rs | 6 ++-- .../serialization/dictionary_utf8_builder.rs | 7 +++-- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 4d3aabdf..b5d066ac 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -95,7 +95,8 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::LargeList(_) | A::List(_) | A::FixedSizedList(_) - | A::FixedSizeBinary(_)) => builder.into_array()?.try_into(), + | A::FixedSizeBinary(_) + | A::DictionaryUtf8(_)) => builder.into_array()?.try_into(), A::Map(builder) => Ok(ArrayData::builder(T::Map( Arc::new(Field::try_from(&builder.entry_field)?), false, @@ -105,17 +106,6 @@ fn build_array_data(builder: ArrayBuilder) -> Result { .add_child_data(build_array_data(*builder.entry)?) .null_bit_buffer(builder.validity.map(|b| Buffer::from(b.buffer))) .build()?), - A::DictionaryUtf8(builder) => { - let indices = build_array_data(*builder.indices)?; - let values = build_array_data(*builder.values)?; - let data_type = Field::try_from(&builder.field)?.data_type().clone(); - - Ok(indices - .into_builder() - .data_type(data_type) - .child_data(vec![values]) - .build()?) - } A::Union(builder) => { let data_type = Field::try_from(&builder.field)?.data_type().clone(); let children = builder @@ -278,6 +268,20 @@ impl TryFrom for ArrayData { vec![], )?) } + A::Dictionary(arr) => { + let indices: ArrayData = (*arr.indices).try_into()?; + let values: ArrayData = (*arr.values).try_into()?; + let data_type = ArrowT::Dictionary( + Box::new(indices.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + Ok(indices + .into_builder() + .data_type(data_type) + .child_data(vec![values]) + .build()?) + } } } } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index f5fd9404..87bb2a34 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -37,6 +37,7 @@ pub enum Array { List(ListArray), LargeList(ListArray), FixedSizeList(FixedSizeListArray), + Dictionary(DictionaryArray), } #[derive(Clone, Debug)] @@ -125,3 +126,9 @@ pub struct DecimalArray { pub validity: Option>, pub values: Vec, } + +#[derive(Clone, Debug)] +pub struct DictionaryArray { + pub indices: Box, + pub values: Box, +} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index a1896667..195c384c 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,9 +6,9 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, BytesArray, DecimalArray, FieldMeta, FixedSizeBinaryArray, - FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, - TimestampArray, + Array, BooleanArray, BytesArray, DecimalArray, DictionaryArray, FieldMeta, + FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, + TimeArray, TimestampArray, }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index 004d7d62..cbbbd028 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use serde::Serialize; use crate::internal::{ - arrow::Array, + arrow::{Array, DictionaryArray}, error::{fail, Result}, schema::GenericField, utils::Mut, @@ -43,7 +43,10 @@ impl DictionaryUtf8Builder { } pub fn into_array(self) -> Result { - unimplemented!() + Ok(Array::Dictionary(DictionaryArray { + indices: Box::new((*self.indices).into_array()?), + values: Box::new((*self.values).into_array()?), + })) } } From 7461c8bf4d325b2979304ec9f3b97b52b278f4a4 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 17:59:11 +0200 Subject: [PATCH 021/178] Implement Map --- serde_arrow/src/arrow_impl/serialization.rs | 26 +++++++++++-------- serde_arrow/src/internal/arrow/array.rs | 1 + .../src/internal/serialization/map_builder.rs | 17 +++++++++--- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index b5d066ac..3ad07f7d 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -64,7 +64,7 @@ fn build_array(builder: ArrayBuilder) -> Result { } fn build_array_data(builder: ArrayBuilder) -> Result { - use {ArrayBuilder as A, DataType as T}; + use ArrayBuilder as A; match builder { builder @ (A::UnknownVariant(_) @@ -96,16 +96,8 @@ fn build_array_data(builder: ArrayBuilder) -> Result { | A::List(_) | A::FixedSizedList(_) | A::FixedSizeBinary(_) - | A::DictionaryUtf8(_)) => builder.into_array()?.try_into(), - A::Map(builder) => Ok(ArrayData::builder(T::Map( - Arc::new(Field::try_from(&builder.entry_field)?), - false, - )) - .len(builder.offsets.offsets.len() - 1) - .add_buffer(ScalarBuffer::from(builder.offsets.offsets).into_inner()) - .add_child_data(build_array_data(*builder.entry)?) - .null_bit_buffer(builder.validity.map(|b| Buffer::from(b.buffer))) - .build()?), + | A::DictionaryUtf8(_) + | A::Map(_)) => builder.into_array()?.try_into(), A::Union(builder) => { let data_type = Field::try_from(&builder.field)?.data_type().clone(); let children = builder @@ -282,6 +274,18 @@ impl TryFrom for ArrayData { .child_data(vec![values]) .build()?) } + A::Map(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + Ok(ArrayData::try_new( + ArrowT::Map(Arc::new(field), false), + arr.offsets.len().saturating_sub(1), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.offsets).into_inner()], + vec![child], + )?) + } } } } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 87bb2a34..8bec1ebc 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -38,6 +38,7 @@ pub enum Array { LargeList(ListArray), FixedSizeList(FixedSizeListArray), Dictionary(DictionaryArray), + Map(ListArray), } #[derive(Clone, Debug)] diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index c4ba5bbc..6679219d 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -1,12 +1,16 @@ use serde::Serialize; -use crate::internal::{arrow::Array, error::Result, schema::GenericField}; +use crate::internal::{ + arrow::{Array, ListArray}, + error::Result, + schema::GenericField, +}; use super::{ array_builder::ArrayBuilder, utils::{ - push_validity, push_validity_default, MutableBitBuffer, MutableOffsetBuffer, - SimpleSerializer, + meta_from_field, push_validity, push_validity_default, MutableBitBuffer, + MutableOffsetBuffer, SimpleSerializer, }, }; @@ -42,7 +46,12 @@ impl MapBuilder { } pub fn into_array(self) -> Result { - unimplemented!() + Ok(Array::Map(ListArray { + meta: meta_from_field(self.entry_field)?, + element: Box::new((*self.entry).into_array()?), + validity: self.validity.map(|v| v.buffer), + offsets: self.offsets.offsets, + })) } } From c70c39c0585cfa37410ddd93d5e5f551d129504f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 18:13:33 +0200 Subject: [PATCH 022/178] Implement Union --- serde_arrow/src/arrow_impl/serialization.rs | 145 +++++++----------- serde_arrow/src/internal/arrow/array.rs | 8 + serde_arrow/src/internal/arrow/mod.rs | 2 +- .../internal/serialization/union_builder.rs | 20 ++- 4 files changed, 83 insertions(+), 92 deletions(-) diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index 3ad07f7d..d47c823c 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -9,6 +9,7 @@ use crate::{ buffer::{Buffer, ScalarBuffer}, datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Field, FieldRef, Float16Type, Schema, + UnionMode, }, }, internal::{ @@ -59,69 +60,15 @@ impl OuterSequenceBuilder { } fn build_array(builder: ArrayBuilder) -> Result { - let data = build_array_data(builder)?; + let data = builder.into_array()?.try_into()?; Ok(make_array(data)) } -fn build_array_data(builder: ArrayBuilder) -> Result { - use ArrayBuilder as A; - - match builder { - builder @ (A::UnknownVariant(_) - | A::Null(_) - | A::Bool(_) - | A::I8(_) - | A::I16(_) - | A::I32(_) - | A::I64(_) - | A::U8(_) - | A::U16(_) - | A::U32(_) - | A::U64(_) - | A::F16(_) - | A::F32(_) - | A::F64(_) - | A::Date32(_) - | A::Date64(_) - | A::Time32(_) - | A::Time64(_) - | A::Duration(_) - | A::Decimal128(_) - | A::Utf8(_) - | A::LargeUtf8(_) - | A::Binary(_) - | A::LargeBinary(_) - | A::Struct(_) - | A::LargeList(_) - | A::List(_) - | A::FixedSizedList(_) - | A::FixedSizeBinary(_) - | A::DictionaryUtf8(_) - | A::Map(_)) => builder.into_array()?.try_into(), - A::Union(builder) => { - let data_type = Field::try_from(&builder.field)?.data_type().clone(); - let children = builder - .fields - .into_iter() - .map(build_array_data) - .collect::>>()?; - let len = builder.types.len(); - - Ok(ArrayData::builder(data_type) - .len(len) - .add_buffer(Buffer::from_vec(builder.types)) - .add_buffer(Buffer::from_vec(builder.offsets)) - .child_data(children) - .build()?) - } - } -} - impl TryFrom for ArrayData { type Error = Error; fn try_from(value: crate::internal::arrow::Array) -> Result { - use {crate::internal::arrow::Array as A, DataType as ArrowT}; + use {crate::internal::arrow::Array as A, DataType as T}; type ArrowF16 = ::Native; fn f16_to_f16(v: f16) -> ArrowF16 { @@ -131,7 +78,7 @@ impl TryFrom for ArrayData { match value { A::Null(arr) => Ok(NullArray::new(arr.len).into_data()), A::Boolean(arr) => Ok(ArrayData::try_new( - ArrowT::Boolean, + T::Boolean, // NOTE: use the explicit len arr.len, arr.validity.map(Buffer::from), @@ -139,49 +86,47 @@ impl TryFrom for ArrayData { vec![ScalarBuffer::from(arr.values).into_inner()], vec![], )?), - A::Int8(arr) => primitive_into_data(ArrowT::Int8, arr.validity, arr.values), - A::Int16(arr) => primitive_into_data(ArrowT::Int16, arr.validity, arr.values), - A::Int32(arr) => primitive_into_data(ArrowT::Int32, arr.validity, arr.values), - A::Int64(arr) => primitive_into_data(ArrowT::Int64, arr.validity, arr.values), - A::UInt8(arr) => primitive_into_data(ArrowT::UInt8, arr.validity, arr.values), - A::UInt16(arr) => primitive_into_data(ArrowT::UInt16, arr.validity, arr.values), - A::UInt32(arr) => primitive_into_data(ArrowT::UInt32, arr.validity, arr.values), - A::UInt64(arr) => primitive_into_data(ArrowT::UInt64, arr.validity, arr.values), + A::Int8(arr) => primitive_into_data(T::Int8, arr.validity, arr.values), + A::Int16(arr) => primitive_into_data(T::Int16, arr.validity, arr.values), + A::Int32(arr) => primitive_into_data(T::Int32, arr.validity, arr.values), + A::Int64(arr) => primitive_into_data(T::Int64, arr.validity, arr.values), + A::UInt8(arr) => primitive_into_data(T::UInt8, arr.validity, arr.values), + A::UInt16(arr) => primitive_into_data(T::UInt16, arr.validity, arr.values), + A::UInt32(arr) => primitive_into_data(T::UInt32, arr.validity, arr.values), + A::UInt64(arr) => primitive_into_data(T::UInt64, arr.validity, arr.values), A::Float16(arr) => primitive_into_data( - ArrowT::Float16, + T::Float16, arr.validity, arr.values.into_iter().map(f16_to_f16).collect(), ), - A::Float32(arr) => primitive_into_data(ArrowT::Float32, arr.validity, arr.values), - A::Float64(arr) => primitive_into_data(ArrowT::Float64, arr.validity, arr.values), - A::Date32(arr) => primitive_into_data(ArrowT::Date32, arr.validity, arr.values), - A::Date64(arr) => primitive_into_data(ArrowT::Date64, arr.validity, arr.values), + A::Float32(arr) => primitive_into_data(T::Float32, arr.validity, arr.values), + A::Float64(arr) => primitive_into_data(T::Float64, arr.validity, arr.values), + A::Date32(arr) => primitive_into_data(T::Date32, arr.validity, arr.values), + A::Date64(arr) => primitive_into_data(T::Date64, arr.validity, arr.values), A::Timestamp(arr) => primitive_into_data( - ArrowT::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), + T::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), arr.validity, arr.values, ), A::Time32(arr) => { - primitive_into_data(ArrowT::Time32(arr.unit.into()), arr.validity, arr.values) + primitive_into_data(T::Time32(arr.unit.into()), arr.validity, arr.values) } A::Time64(arr) => { - primitive_into_data(ArrowT::Time64(arr.unit.into()), arr.validity, arr.values) + primitive_into_data(T::Time64(arr.unit.into()), arr.validity, arr.values) } A::Duration(arr) => { - primitive_into_data(ArrowT::Duration(arr.unit.into()), arr.validity, arr.values) + primitive_into_data(T::Duration(arr.unit.into()), arr.validity, arr.values) } A::Decimal128(arr) => primitive_into_data( - ArrowT::Decimal128(arr.precision, arr.scale), + T::Decimal128(arr.precision, arr.scale), arr.validity, arr.values, ), - A::Utf8(arr) => bytes_into_data(ArrowT::Utf8, arr.offsets, arr.data, arr.validity), - A::LargeUtf8(arr) => { - bytes_into_data(ArrowT::LargeUtf8, arr.offsets, arr.data, arr.validity) - } - A::Binary(arr) => bytes_into_data(ArrowT::Binary, arr.offsets, arr.data, arr.validity), + A::Utf8(arr) => bytes_into_data(T::Utf8, arr.offsets, arr.data, arr.validity), + A::LargeUtf8(arr) => bytes_into_data(T::LargeUtf8, arr.offsets, arr.data, arr.validity), + A::Binary(arr) => bytes_into_data(T::Binary, arr.offsets, arr.data, arr.validity), A::LargeBinary(arr) => { - bytes_into_data(ArrowT::LargeBinary, arr.offsets, arr.data, arr.validity) + bytes_into_data(T::LargeBinary, arr.offsets, arr.data, arr.validity) } A::Struct(arr) => { let mut fields = Vec::new(); @@ -194,7 +139,7 @@ impl TryFrom for ArrayData { fields.push(Arc::new(field)); data.push(child); } - let data_type = ArrowT::Struct(fields.into()); + let data_type = T::Struct(fields.into()); Ok(ArrayData::builder(data_type) .len(arr.len) @@ -206,7 +151,7 @@ impl TryFrom for ArrayData { let child: ArrayData = (*arr.element).try_into()?; let field = field_from_data_and_meta(&child, arr.meta); list_into_data( - ArrowT::List(Arc::new(field)), + T::List(Arc::new(field)), arr.offsets.len().saturating_sub(1), arr.offsets, child, @@ -217,7 +162,7 @@ impl TryFrom for ArrayData { let child: ArrayData = (*arr.element).try_into()?; let field = field_from_data_and_meta(&child, arr.meta); list_into_data( - ArrowT::LargeList(Arc::new(field)), + T::LargeList(Arc::new(field)), arr.offsets.len().saturating_sub(1), arr.offsets, child, @@ -235,7 +180,7 @@ impl TryFrom for ArrayData { } let field = field_from_data_and_meta(&child, arr.meta); Ok(ArrayData::try_new( - ArrowT::FixedSizeList(Arc::new(field), arr.n), + T::FixedSizeList(Arc::new(field), arr.n), child.len() / usize::try_from(arr.n)?, arr.validity.map(Buffer::from), 0, @@ -252,7 +197,7 @@ impl TryFrom for ArrayData { ); } Ok(ArrayData::try_new( - ArrowT::FixedSizeBinary(arr.n), + T::FixedSizeBinary(arr.n), arr.data.len() / usize::try_from(arr.n)?, arr.validity.map(Buffer::from), 0, @@ -263,7 +208,7 @@ impl TryFrom for ArrayData { A::Dictionary(arr) => { let indices: ArrayData = (*arr.indices).try_into()?; let values: ArrayData = (*arr.values).try_into()?; - let data_type = ArrowT::Dictionary( + let data_type = T::Dictionary( Box::new(indices.data_type().clone()), Box::new(values.data_type().clone()), ); @@ -278,7 +223,7 @@ impl TryFrom for ArrayData { let child: ArrayData = (*arr.element).try_into()?; let field = field_from_data_and_meta(&child, arr.meta); Ok(ArrayData::try_new( - ArrowT::Map(Arc::new(field), false), + T::Map(Arc::new(field), false), arr.offsets.len().saturating_sub(1), arr.validity.map(Buffer::from), 0, @@ -286,6 +231,30 @@ impl TryFrom for ArrayData { vec![child], )?) } + A::DenseUnion(arr) => { + let mut fields = Vec::new(); + let mut child_data = Vec::new(); + + for (idx, (array, meta)) in arr.fields.into_iter().enumerate() { + let child: ArrayData = array.try_into()?; + let field = field_from_data_and_meta(&child, meta); + + fields.push((idx as i8, Arc::new(field))); + child_data.push(child); + } + + Ok(ArrayData::try_new( + DataType::Union(fields.into_iter().collect(), UnionMode::Dense), + arr.types.len(), + None, + 0, + vec![ + ScalarBuffer::from(arr.types).into_inner(), + ScalarBuffer::from(arr.offsets).into_inner(), + ], + child_data, + )?) + } } } } diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 8bec1ebc..6c349f96 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -39,6 +39,7 @@ pub enum Array { FixedSizeList(FixedSizeListArray), Dictionary(DictionaryArray), Map(ListArray), + DenseUnion(DenseUnionArray), } #[derive(Clone, Debug)] @@ -133,3 +134,10 @@ pub struct DictionaryArray { pub indices: Box, pub values: Box, } + +#[derive(Clone, Debug)] +pub struct DenseUnionArray { + pub types: Vec, + pub offsets: Vec, + pub fields: Vec<(Array, FieldMeta)>, +} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 195c384c..671dcb0f 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,7 +6,7 @@ mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, BytesArray, DecimalArray, DictionaryArray, FieldMeta, + Array, BooleanArray, BytesArray, DecimalArray, DenseUnionArray, DictionaryArray, FieldMeta, FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, TimestampArray, }; diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index e127dab6..2aebac47 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -1,11 +1,14 @@ use crate::internal::{ - arrow::Array, + arrow::{Array, DenseUnionArray}, error::{fail, Result}, schema::GenericField, utils::Mut, }; -use super::{utils::SimpleSerializer, ArrayBuilder}; +use super::{ + utils::{meta_from_field, SimpleSerializer}, + ArrayBuilder, +}; #[derive(Debug, Clone)] pub struct UnionBuilder { @@ -42,7 +45,18 @@ impl UnionBuilder { } pub fn into_array(self) -> Result { - unimplemented!() + let mut fields = Vec::new(); + for (field, builder) in self.field.children.into_iter().zip(self.fields) { + let meta = meta_from_field(field)?; + let array = builder.into_array()?; + fields.push((array, meta)); + } + + Ok(Array::DenseUnion(DenseUnionArray { + types: self.types, + offsets: self.offsets, + fields: fields, + })) } } From d738eecc0571d9917220476942a94698295d244d Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 18:15:06 +0200 Subject: [PATCH 023/178] Address clippy --- serde_arrow/src/internal/serialization/union_builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 2aebac47..720e3dc8 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -55,7 +55,7 @@ impl UnionBuilder { Ok(Array::DenseUnion(DenseUnionArray { types: self.types, offsets: self.offsets, - fields: fields, + fields, })) } } From 73ecc2d24e08f665670721f4ff1a963358747eed Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 18:20:22 +0200 Subject: [PATCH 024/178] Move unused markers into view / types submodule to fix PR --- serde_arrow/src/internal/arrow/array_view.rs | 1 + serde_arrow/src/internal/arrow/mod.rs | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/serde_arrow/src/internal/arrow/array_view.rs b/serde_arrow/src/internal/arrow/array_view.rs index d3161a77..81cc7eb4 100644 --- a/serde_arrow/src/internal/arrow/array_view.rs +++ b/serde_arrow/src/internal/arrow/array_view.rs @@ -1,3 +1,4 @@ +#![allow(dead_code, unused)] use half::f16; use crate::internal::arrow::data_type::TimeUnit; diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 671dcb0f..17048a17 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -1,6 +1,5 @@ //! A common arrow abstraction to simplify conversion between different arrow //! implementations -#![allow(dead_code, unused)] mod array; mod array_view; mod data_type; @@ -10,8 +9,4 @@ pub use array::{ FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, TimestampArray, }; -pub use array_view::{ - ArrayView, BitsWithOffset, BooleanArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, - StructArrayView, Utf8ArrayView, -}; -pub use data_type::{BaseDataTypeDisplay, DataType, Field, TimeUnit}; +pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; From 7d48bef1afd6ef2ba3bf9603e2b6128324c9dabf Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 21:30:30 +0200 Subject: [PATCH 025/178] Implement ArrayView based Deserializers for Null, Bool, {UI,I}nt{8,16,32,64}, Float{16,32,64}, Decimal128 --- .../src/arrow2_impl/deserialization.rs | 6 +- serde_arrow/src/arrow_impl/deserialization.rs | 186 ++++++++++-------- serde_arrow/src/internal/arrow/array_view.rs | 9 +- serde_arrow/src/internal/arrow/mod.rs | 4 + .../deserialization/array_deserializer.rs | 78 +++++++- 5 files changed, 195 insertions(+), 88 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 5e564447..761ef111 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -88,7 +88,7 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::Null => Ok(NullDeserializer.into()), + T::Null => Ok(ArrayDeserializer::Null(NullDeserializer)), T::Bool => build_bool_deserializer(field, array), T::U8 => build_integer_deserializer::(field, array), T::U16 => build_integer_deserializer::(field, array), @@ -151,7 +151,9 @@ pub fn build_bool_deserializer<'a>( }; let validity = get_validity(array); - Ok(BoolDeserializer::new(buffer, validity).into()) + Ok(ArrayDeserializer::Bool(BoolDeserializer::new( + buffer, validity, + ))) } pub fn build_integer_deserializer<'a, T>( diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 2ff608bd..a36b1ffb 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,28 +1,27 @@ use crate::internal::{ - arrow::TimeUnit, + arrow::{ + ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, + PrimitiveArrayView, TimeUnit, + }, deserialization::{ array_deserializer::ArrayDeserializer, binary_deserializer::BinaryDeserializer, - bool_deserializer::BoolDeserializer, construction, date32_deserializer::Date32Deserializer, date64_deserializer::Date64Deserializer, - decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, fixed_size_list_deserializer::FixedSizeListDeserializer, - float_deserializer::{Float, FloatDeserializer}, integer_deserializer::{Integer, IntegerDeserializer}, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, - null_deserializer::NullDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, deserializer::Deserializer, - error::{fail, Result}, + error::{fail, Error, Result}, schema::{GenericDataType, GenericField}, utils::Offset, }; @@ -30,7 +29,7 @@ use crate::internal::{ use crate::_impl::arrow::{ array::{ Array, BooleanArray, DictionaryArray, FixedSizeListArray, GenericBinaryArray, - GenericListArray, GenericStringArray, MapArray, OffsetSizeTrait, PrimitiveArray, + GenericListArray, GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UnionArray, }, datatypes::{ @@ -127,20 +126,6 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use {GenericDataType as T, TimeUnit as U}; match &field.data_type { - T::Null => Ok(NullDeserializer.into()), - T::Bool => build_bool_deserializer(field, array), - T::U8 => build_integer_deserializer::(field, array), - T::U16 => build_integer_deserializer::(field, array), - T::U32 => build_integer_deserializer::(field, array), - T::U64 => build_integer_deserializer::(field, array), - T::I8 => build_integer_deserializer::(field, array), - T::I16 => build_integer_deserializer::(field, array), - T::I32 => build_integer_deserializer::(field, array), - T::I64 => build_integer_deserializer::(field, array), - T::F16 => build_float16_deserializer(field, array), - T::F32 => build_float_deserializer::(field, array), - T::F64 => build_float_deserializer::(field, array), - T::Decimal128(_, _) => build_decimal128_deserializer(field, array), T::Date32 => build_date32_deserializer(field, array), T::Date64 => build_date64_deserializer(field, array), T::Time32(unit) => construction::build_time32_deserializer( @@ -195,27 +180,10 @@ pub fn build_array_deserializer<'a>( T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), T::Dictionary => build_dictionary_deserializer(field, array), + _ => ArrayDeserializer::new(field, array.try_into()?), } } -pub fn build_bool_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!("cannot convert {} array into bool", array.data_type()); - }; - - let buffer = BitBuffer { - data: array.values().values(), - offset: array.values().offset(), - number_of_bits: array.values().len(), - }; - let validity = get_validity(array); - - Ok(BoolDeserializer::new(buffer, validity).into()) -} - pub fn build_integer_deserializer<'a, T>( _field: &GenericField, array: &'a dyn Array, @@ -228,45 +196,6 @@ where Ok(IntegerDeserializer::new(as_primitive_values::(array)?, get_validity(array)).into()) } -pub fn build_float16_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - Ok(FloatDeserializer::new( - as_primitive_values::(array)?, - get_validity(array), - ) - .into()) -} - -pub fn build_float_deserializer<'a, T>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - T: ArrowPrimitiveType, - T::Native: Float, - ArrayDeserializer<'a>: From>, -{ - Ok(FloatDeserializer::new(as_primitive_values::(array)?, get_validity(array)).into()) -} - -pub fn build_decimal128_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let GenericDataType::Decimal128(_, scale) = field.data_type else { - fail!("Invalid data type for Decimal128Deserializer"); - }; - - Ok(DecimalDeserializer::new( - as_primitive_values::(array)?, - get_validity(array), - scale, - ) - .into()) -} - pub fn build_date32_deserializer<'a>( _field: &GenericField, array: &'a dyn Array, @@ -592,3 +521,104 @@ fn get_validity(arr: &dyn Array) -> Option> { number_of_bits, }) } + +impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { + type Error = Error; + + fn try_from(array: &'a dyn Array) -> Result { + let any = array.as_any(); + if let Some(array) = any.downcast_ref::() { + Ok(ArrayView::Null(NullArrayView { len: array.len() })) + } else if let Some(array) = any.downcast_ref::() { + Ok(ArrayView::Boolean(BooleanArrayView { + len: array.len(), + validity: get_bits_with_offset(array), + values: BitsWithOffset { + offset: array.values().offset(), + data: array.values().values(), + }, + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Int8(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Int16(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Int32(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Int64(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::UInt8(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::UInt16(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::UInt32(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::UInt64(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Float16(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Float32(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Float64(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + let &DataType::Decimal128(precision, scale) = array.data_type() else { + fail!( + "Invalid data type for Decimal128 array: {}", + array.data_type() + ); + }; + Ok(ArrayView::Decimal128(DecimalArrayView { + precision, + scale, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else { + fail!( + "Cannot build an array view for {dt}", + dt = array.data_type() + ); + } + } +} + +fn get_bits_with_offset(array: &dyn Array) -> Option> { + let validity = array.nulls()?; + Some(BitsWithOffset { + offset: validity.offset(), + data: validity.validity(), + }) +} diff --git a/serde_arrow/src/internal/arrow/array_view.rs b/serde_arrow/src/internal/arrow/array_view.rs index 81cc7eb4..81d9199b 100644 --- a/serde_arrow/src/internal/arrow/array_view.rs +++ b/serde_arrow/src/internal/arrow/array_view.rs @@ -25,7 +25,7 @@ pub enum ArrayView<'a> { LargeUtf8(Utf8ArrayView<'a, i64>), Binary(Utf8ArrayView<'a, i32>), LargeBinary(Utf8ArrayView<'a, i64>), - Decimal128(PrimitiveArrayView<'a, i128>), + Decimal128(DecimalArrayView<'a, i128>), Struct(StructArrayView<'a>), List(ListArrayView<'a, i32>), LargeList(ListArrayView<'a, i64>), @@ -52,6 +52,13 @@ pub struct PrimitiveArrayView<'a, T> { pub values: &'a [T], } +pub struct DecimalArrayView<'a, T> { + pub precision: u8, + pub scale: i8, + pub validity: Option>, + pub values: &'a [T], +} + pub struct TimeArrayView<'a, T> { pub unit: TimeUnit, pub validity: Option>, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 17048a17..1cef413e 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -9,4 +9,8 @@ pub use array::{ FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, TimeArray, TimestampArray, }; +pub use array_view::{ + ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, + PrimitiveArrayView, +}; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index afd65ea9..0dd9daf9 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -2,7 +2,9 @@ use half::f16; use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ + arrow::{ArrayView, BitsWithOffset}, error::{Error, Result}, + schema::GenericField, utils::Mut, }; @@ -16,7 +18,7 @@ use super::{ integer_deserializer::IntegerDeserializer, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, simple_deserializer::SimpleDeserializer, string_deserializer::StringDeserializer, - struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, + struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::BitBuffer, }; pub enum ArrayDeserializer<'a> { @@ -67,18 +69,80 @@ pub enum ArrayDeserializer<'a> { Enum(EnumDeserializer<'a>), } -impl<'a> From for ArrayDeserializer<'a> { - fn from(value: NullDeserializer) -> Self { - Self::Null(value) +impl<'a> ArrayDeserializer<'a> { + pub fn new(_field: &GenericField, array: ArrayView<'a>) -> Result { + match array { + ArrayView::Null(_) => Ok(Self::Null(NullDeserializer {})), + ArrayView::Boolean(view) => Ok(Self::Bool(BoolDeserializer::new( + buffer_from_bits_with_offset(view.values, view.len), + buffer_from_bits_with_offset_opt(view.validity, view.len), + ))), + ArrayView::Int8(view) => Ok(Self::I8(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Int16(view) => Ok(Self::I16(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Int32(view) => Ok(Self::I32(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Int64(view) => Ok(Self::I64(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::UInt8(view) => Ok(Self::U8(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::UInt16(view) => Ok(Self::U16(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::UInt32(view) => Ok(Self::U32(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::UInt64(view) => Ok(Self::U64(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Float16(view) => Ok(Self::F16(FloatDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Float32(view) => Ok(Self::F32(FloatDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Float64(view) => Ok(Self::F64(FloatDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Decimal128(view) => Ok(Self::Decimal128(DecimalDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.scale, + ))), + _ => unimplemented!(), + } } } -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: BoolDeserializer<'a>) -> Self { - Self::Bool(value) +fn buffer_from_bits_with_offset(bits: BitsWithOffset, len: usize) -> BitBuffer { + BitBuffer { + data: bits.data, + offset: bits.offset, + number_of_bits: len, } } +fn buffer_from_bits_with_offset_opt(bits: Option, len: usize) -> Option { + Some(buffer_from_bits_with_offset(bits?, len)) +} + impl<'a> From> for ArrayDeserializer<'a> { fn from(value: IntegerDeserializer<'a, i8>) -> Self { Self::I8(value) From 0938c22eb352ce65234ddd14bbf3365371c5488a Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 21:44:07 +0200 Subject: [PATCH 026/178] Implement Date32, Date64 --- serde_arrow/src/arrow_impl/deserialization.rs | 38 +++++-------------- .../deserialization/array_deserializer.rs | 14 ++++++- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index a36b1ffb..f3b9a60b 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -7,8 +7,6 @@ use crate::internal::{ array_deserializer::ArrayDeserializer, binary_deserializer::BinaryDeserializer, construction, - date32_deserializer::Date32Deserializer, - date64_deserializer::Date64Deserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, fixed_size_list_deserializer::FixedSizeListDeserializer, @@ -126,8 +124,6 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use {GenericDataType as T, TimeUnit as U}; match &field.data_type { - T::Date32 => build_date32_deserializer(field, array), - T::Date64 => build_date64_deserializer(field, array), T::Time32(unit) => construction::build_time32_deserializer( field, match unit { @@ -196,30 +192,6 @@ where Ok(IntegerDeserializer::new(as_primitive_values::(array)?, get_validity(array)).into()) } -pub fn build_date32_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - Ok(Date32Deserializer::new( - as_primitive_values::(array)?, - get_validity(array), - ) - .into()) -} - -pub fn build_date64_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - Ok(Date64Deserializer::new( - as_primitive_values::(array)?, - get_validity(array), - TimeUnit::Millisecond, - field.is_utc()?, - ) - .into()) -} - pub fn build_string_deserializer<'a, O>( _field: &GenericField, array: &'a dyn Array, @@ -606,6 +578,16 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { validity: get_bits_with_offset(array), values: array.values(), })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Date32(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Date64(PrimitiveArrayView { + validity: get_bits_with_offset(array), + values: array.values(), + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 0dd9daf9..3770a12d 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -2,7 +2,7 @@ use half::f16; use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ - arrow::{ArrayView, BitsWithOffset}, + arrow::{ArrayView, BitsWithOffset, TimeUnit}, error::{Error, Result}, schema::GenericField, utils::Mut, @@ -70,7 +70,7 @@ pub enum ArrayDeserializer<'a> { } impl<'a> ArrayDeserializer<'a> { - pub fn new(_field: &GenericField, array: ArrayView<'a>) -> Result { + pub fn new(field: &GenericField, array: ArrayView<'a>) -> Result { match array { ArrayView::Null(_) => Ok(Self::Null(NullDeserializer {})), ArrayView::Boolean(view) => Ok(Self::Bool(BoolDeserializer::new( @@ -126,6 +126,16 @@ impl<'a> ArrayDeserializer<'a> { buffer_from_bits_with_offset_opt(view.validity, view.values.len()), view.scale, ))), + ArrayView::Date32(view) => Ok(Self::Date32(Date32Deserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), + ArrayView::Date64(view) => Ok(Self::Date64(Date64Deserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + TimeUnit::Millisecond, + field.is_utc()?, + ))), _ => unimplemented!(), } } From 6e6298625c163cd08a96c49aefa5b000c3f1c4fe Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 21:54:42 +0200 Subject: [PATCH 027/178] Implement Time32, Time64 --- .../src/arrow2_impl/deserialization.rs | 30 +++--------- serde_arrow/src/arrow_impl/deserialization.rs | 46 ++++++++++--------- serde_arrow/src/internal/arrow/mod.rs | 2 +- .../deserialization/array_deserializer.rs | 10 ++++ .../internal/deserialization/construction.rs | 25 +--------- 5 files changed, 44 insertions(+), 69 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 761ef111..297d3a38 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,23 +1,7 @@ use crate::internal::{ arrow::TimeUnit, deserialization::{ - array_deserializer::ArrayDeserializer, - bool_deserializer::BoolDeserializer, - construction, - date32_deserializer::Date32Deserializer, - date64_deserializer::Date64Deserializer, - decimal_deserializer::DecimalDeserializer, - dictionary_deserializer::DictionaryDeserializer, - enum_deserializer::EnumDeserializer, - float_deserializer::{Float, FloatDeserializer}, - integer_deserializer::{Integer, IntegerDeserializer}, - list_deserializer::ListDeserializer, - map_deserializer::MapDeserializer, - null_deserializer::NullDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - string_deserializer::StringDeserializer, - struct_deserializer::StructDeserializer, - utils::{check_supported_list_layout, BitBuffer}, + array_deserializer::ArrayDeserializer, bool_deserializer::BoolDeserializer, construction, date32_deserializer::Date32Deserializer, date64_deserializer::Date64Deserializer, decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, float_deserializer::{Float, FloatDeserializer}, integer_deserializer::{Integer, IntegerDeserializer}, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::{check_supported_list_layout, BitBuffer} }, deserializer::Deserializer, error::{fail, Result}, @@ -104,16 +88,16 @@ pub fn build_array_deserializer<'a>( T::Decimal128(_, _) => build_decimal128_deserializer(field, array), T::Date32 => build_date32_deserializer(field, array), T::Date64 => build_date64_deserializer(field, array), - T::Time32(_) => construction::build_time32_deserializer( - field, + T::Time32(unit) => Ok(ArrayDeserializer::Time32(TimeDeserializer::new( as_primitive_values::(array)?, get_validity(array), - ), - T::Time64(_) => construction::build_time64_deserializer( - field, + *unit, + ))), + T::Time64(unit) => Ok(ArrayDeserializer::Time64(TimeDeserializer::new( as_primitive_values::(array)?, get_validity(array), - ), + *unit, + ))), T::Timestamp(_, _) => construction::build_timestamp_deserializer( field, as_primitive_values::(array)?, diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index f3b9a60b..a16fc219 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,7 +1,7 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, - PrimitiveArrayView, TimeUnit, + PrimitiveArrayView, TimeArrayView, TimeUnit, }, deserialization::{ array_deserializer::ArrayDeserializer, @@ -124,26 +124,6 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use {GenericDataType as T, TimeUnit as U}; match &field.data_type { - T::Time32(unit) => construction::build_time32_deserializer( - field, - match unit { - U::Second => as_primitive_values::(array)?, - U::Millisecond => as_primitive_values::(array)?, - // Not supported according to the arrow docs - unit => fail!("cannot build deserializer for Time64({unit})"), - }, - get_validity(array), - ), - T::Time64(unit) => construction::build_time64_deserializer( - field, - match unit { - U::Microsecond => as_primitive_values::(array)?, - U::Nanosecond => as_primitive_values::(array)?, - // Not supported according to the arrow docs - unit => fail!("cannot build deserializer for Time64({unit})"), - }, - get_validity(array), - ), T::Timestamp(unit, _) => construction::build_timestamp_deserializer( field, match unit { @@ -588,6 +568,30 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { validity: get_bits_with_offset(array), values: array.values(), })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Time32(TimeArrayView { + unit: TimeUnit::Millisecond, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Time32(TimeArrayView { + unit: TimeUnit::Second, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Time64(TimeArrayView { + unit: TimeUnit::Nanosecond, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Time64(TimeArrayView { + unit: TimeUnit::Microsecond, + validity: get_bits_with_offset(array), + values: array.values(), + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 1cef413e..44a67759 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -11,6 +11,6 @@ pub use array::{ }; pub use array_view::{ ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, - PrimitiveArrayView, + PrimitiveArrayView, TimeArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 3770a12d..45a90daa 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -136,6 +136,16 @@ impl<'a> ArrayDeserializer<'a> { TimeUnit::Millisecond, field.is_utc()?, ))), + ArrayView::Time32(view) => Ok(Self::Time32(TimeDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.unit, + ))), + ArrayView::Time64(view) => Ok(Self::Time64(TimeDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.unit, + ))), _ => unimplemented!(), } } diff --git a/serde_arrow/src/internal/deserialization/construction.rs b/serde_arrow/src/internal/deserialization/construction.rs index c03a3091..79f8e554 100644 --- a/serde_arrow/src/internal/deserialization/construction.rs +++ b/serde_arrow/src/internal/deserialization/construction.rs @@ -5,7 +5,7 @@ use crate::internal::{ }; use super::{ - array_deserializer::ArrayDeserializer, time_deserializer::TimeDeserializer, utils::BitBuffer, + array_deserializer::ArrayDeserializer, utils::BitBuffer, }; pub fn build_timestamp_deserializer<'a>( @@ -35,26 +35,3 @@ pub fn build_timestamp_deserializer<'a>( Ok(Date64Deserializer::new(values, validity, *unit, field.is_utc()?).into()) } -pub fn build_time32_deserializer<'a>( - field: &GenericField, - values: &'a [i32], - validity: Option>, -) -> Result> { - let GenericDataType::Time32(unit) = &field.data_type else { - fail!("invalid data type for time64"); - }; - - Ok(TimeDeserializer::::new(values, validity, *unit).into()) -} - -pub fn build_time64_deserializer<'a>( - field: &GenericField, - values: &'a [i64], - validity: Option>, -) -> Result> { - let GenericDataType::Time64(unit) = &field.data_type else { - fail!("invalid data type for time64"); - }; - - Ok(TimeDeserializer::::new(values, validity, *unit).into()) -} From d8a388c921f38bf1536f7935f65944548985c939 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 21 Jul 2024 22:30:24 +0200 Subject: [PATCH 028/178] Reformat code --- .../src/arrow2_impl/deserialization.rs | 19 ++++++++++++++++++- .../internal/deserialization/construction.rs | 5 +---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 297d3a38..68fa0769 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,7 +1,24 @@ use crate::internal::{ arrow::TimeUnit, deserialization::{ - array_deserializer::ArrayDeserializer, bool_deserializer::BoolDeserializer, construction, date32_deserializer::Date32Deserializer, date64_deserializer::Date64Deserializer, decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, float_deserializer::{Float, FloatDeserializer}, integer_deserializer::{Integer, IntegerDeserializer}, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::{check_supported_list_layout, BitBuffer} + array_deserializer::ArrayDeserializer, + bool_deserializer::BoolDeserializer, + construction, + date32_deserializer::Date32Deserializer, + date64_deserializer::Date64Deserializer, + decimal_deserializer::DecimalDeserializer, + dictionary_deserializer::DictionaryDeserializer, + enum_deserializer::EnumDeserializer, + float_deserializer::{Float, FloatDeserializer}, + integer_deserializer::{Integer, IntegerDeserializer}, + list_deserializer::ListDeserializer, + map_deserializer::MapDeserializer, + null_deserializer::NullDeserializer, + outer_sequence_deserializer::OuterSequenceDeserializer, + string_deserializer::StringDeserializer, + struct_deserializer::StructDeserializer, + time_deserializer::TimeDeserializer, + utils::{check_supported_list_layout, BitBuffer}, }, deserializer::Deserializer, error::{fail, Result}, diff --git a/serde_arrow/src/internal/deserialization/construction.rs b/serde_arrow/src/internal/deserialization/construction.rs index 79f8e554..a6d26c4c 100644 --- a/serde_arrow/src/internal/deserialization/construction.rs +++ b/serde_arrow/src/internal/deserialization/construction.rs @@ -4,9 +4,7 @@ use crate::internal::{ schema::{GenericDataType, GenericField, Strategy}, }; -use super::{ - array_deserializer::ArrayDeserializer, utils::BitBuffer, -}; +use super::{array_deserializer::ArrayDeserializer, utils::BitBuffer}; pub fn build_timestamp_deserializer<'a>( field: &GenericField, @@ -34,4 +32,3 @@ pub fn build_timestamp_deserializer<'a>( Ok(Date64Deserializer::new(values, validity, *unit, field.is_utc()?).into()) } - From a82a839b725bf24d2cbeca3fe07c4763ea1fd3d7 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 23 Jul 2024 19:35:10 +0200 Subject: [PATCH 029/178] Implement Duration, Timestamp --- serde_arrow/src/arrow_impl/deserialization.rs | 106 +++++++------ serde_arrow/src/internal/arrow/array.rs | 143 +++++++++++++++++- serde_arrow/src/internal/arrow/array_view.rs | 86 ----------- serde_arrow/src/internal/arrow/mod.rs | 12 +- .../deserialization/array_deserializer.rs | 26 +++- .../internal/deserialization/construction.rs | 1 + 6 files changed, 225 insertions(+), 149 deletions(-) delete mode 100644 serde_arrow/src/internal/arrow/array_view.rs diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index a16fc219..638cb868 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,16 +1,15 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, - PrimitiveArrayView, TimeArrayView, TimeUnit, + PrimitiveArrayView, TimeArrayView, TimeUnit, TimestampArrayView, }, deserialization::{ array_deserializer::ArrayDeserializer, binary_deserializer::BinaryDeserializer, - construction, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, fixed_size_list_deserializer::FixedSizeListDeserializer, - integer_deserializer::{Integer, IntegerDeserializer}, + integer_deserializer::Integer, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, @@ -31,8 +30,8 @@ use crate::_impl::arrow::{ RecordBatch, StructArray, UnionArray, }, datatypes::{ - ArrowDictionaryKeyType, ArrowPrimitiveType, DataType, Date32Type, Date64Type, - Decimal128Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + ArrowDictionaryKeyType, DataType, Date32Type, Date64Type, Decimal128Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, @@ -122,28 +121,8 @@ pub fn build_array_deserializer<'a>( field: &GenericField, array: &'a dyn Array, ) -> Result> { - use {GenericDataType as T, TimeUnit as U}; + use GenericDataType as T; match &field.data_type { - T::Timestamp(unit, _) => construction::build_timestamp_deserializer( - field, - match unit { - U::Second => as_primitive_values::(array)?, - U::Millisecond => as_primitive_values::(array)?, - U::Microsecond => as_primitive_values::(array)?, - U::Nanosecond => as_primitive_values::(array)?, - }, - get_validity(array), - ), - T::Duration(U::Second) => build_integer_deserializer::(field, array), - T::Duration(U::Millisecond) => { - build_integer_deserializer::(field, array) - } - T::Duration(U::Microsecond) => { - build_integer_deserializer::(field, array) - } - T::Duration(U::Nanosecond) => { - build_integer_deserializer::(field, array) - } T::Utf8 => build_string_deserializer::(field, array), T::LargeUtf8 => build_string_deserializer::(field, array), T::Struct => build_struct_deserializer(field, array), @@ -160,18 +139,6 @@ pub fn build_array_deserializer<'a>( } } -pub fn build_integer_deserializer<'a, T>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - T: ArrowPrimitiveType, - T::Native: Integer, - ArrayDeserializer<'a>: From>, -{ - Ok(IntegerDeserializer::new(as_primitive_values::(array)?, get_validity(array)).into()) -} - pub fn build_string_deserializer<'a, O>( _field: &GenericField, array: &'a dyn Array, @@ -451,17 +418,6 @@ pub fn build_union_deserializer<'a>( Ok(EnumDeserializer::new(type_ids, variants).into()) } -fn as_primitive_values(array: &dyn Array) -> Result<&[T::Native]> { - let Some(array) = array.as_any().downcast_ref::>() else { - fail!( - "cannot convert {} array into {}", - array.data_type(), - T::DATA_TYPE, - ); - }; - Ok(array.values()) -} - fn get_validity(arr: &dyn Array) -> Option> { let validity = arr.nulls()?; let data = validity.validity(); @@ -592,6 +548,58 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { validity: get_bits_with_offset(array), values: array.values(), })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Timestamp(TimestampArrayView { + unit: TimeUnit::Nanosecond, + timezone: array.timezone().map(str::to_owned), + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Timestamp(TimestampArrayView { + unit: TimeUnit::Microsecond, + timezone: array.timezone().map(str::to_owned), + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Timestamp(TimestampArrayView { + unit: TimeUnit::Millisecond, + timezone: array.timezone().map(str::to_owned), + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Timestamp(TimestampArrayView { + unit: TimeUnit::Second, + timezone: array.timezone().map(str::to_owned), + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Duration(TimeArrayView { + unit: TimeUnit::Nanosecond, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Duration(TimeArrayView { + unit: TimeUnit::Microsecond, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Duration(TimeArrayView { + unit: TimeUnit::Millisecond, + validity: get_bits_with_offset(array), + values: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Duration(TimeArrayView { + unit: TimeUnit::Second, + validity: get_bits_with_offset(array), + values: array.values(), + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 6c349f96..5a528c4a 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -42,11 +42,66 @@ pub enum Array { DenseUnion(DenseUnionArray), } +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum ArrayView<'a> { + Null(NullArrayView), + Boolean(BooleanArrayView<'a>), + Int8(PrimitiveArrayView<'a, i8>), + Int16(PrimitiveArrayView<'a, i16>), + Int32(PrimitiveArrayView<'a, i32>), + Int64(PrimitiveArrayView<'a, i64>), + UInt8(PrimitiveArrayView<'a, u8>), + UInt16(PrimitiveArrayView<'a, u16>), + UInt32(PrimitiveArrayView<'a, u32>), + UInt64(PrimitiveArrayView<'a, u64>), + Float16(PrimitiveArrayView<'a, f16>), + Float32(PrimitiveArrayView<'a, f32>), + Float64(PrimitiveArrayView<'a, f64>), + Date32(PrimitiveArrayView<'a, i32>), + Date64(PrimitiveArrayView<'a, i64>), + Time32(TimeArrayView<'a, i32>), + Time64(TimeArrayView<'a, i64>), + Timestamp(TimestampArrayView<'a>), + Duration(TimeArrayView<'a, i64>), + Utf8(BytesArrayView<'a, i32>), + LargeUtf8(BytesArrayView<'a, i64>), + Binary(BytesArrayView<'a, i32>), + LargeBinary(BytesArrayView<'a, i64>), + FixedSizeBinary(FixedSizeBinaryArrayView<'a>), + Decimal128(DecimalArrayView<'a, i128>), + Struct(StructArrayView<'a>), + List(ListArrayView<'a, i32>), + LargeList(ListArrayView<'a, i64>), + FixedSizeList(FixedSizeListArrayView<'a>), + Dictionary(DictionaryArrayView<'a>), + Map(ListArrayView<'a, i32>), + DenseUnion(DenseUnionArrayView<'a>), +} + +#[derive(Debug, Clone, Copy)] +pub struct BitsWithOffset<'a> { + pub offset: usize, + pub data: &'a [u8], +} + +#[derive(Clone, Debug)] +pub struct FieldMeta { + pub name: String, + pub nullable: bool, + pub metadata: HashMap, +} + #[derive(Clone, Debug)] pub struct NullArray { pub len: usize, } +#[derive(Clone, Debug)] +pub struct NullArrayView { + pub len: usize, +} + #[derive(Clone, Debug)] pub struct BooleanArray { // Note: len is required to know how many bits of values are used @@ -55,12 +110,25 @@ pub struct BooleanArray { pub values: Vec, } +#[derive(Clone, Debug)] +pub struct BooleanArrayView<'a> { + pub len: usize, + pub validity: Option>, + pub values: BitsWithOffset<'a>, +} + #[derive(Clone, Debug)] pub struct PrimitiveArray { pub validity: Option>, pub values: Vec, } +#[derive(Clone, Debug)] +pub struct PrimitiveArrayView<'a, T> { + pub validity: Option>, + pub values: &'a [T], +} + #[derive(Debug, Clone)] pub struct TimeArray { pub unit: TimeUnit, @@ -68,6 +136,13 @@ pub struct TimeArray { pub values: Vec, } +#[derive(Debug, Clone)] +pub struct TimeArrayView<'a, T> { + pub unit: TimeUnit, + pub validity: Option>, + pub values: &'a [T], +} + #[derive(Debug, Clone)] pub struct TimestampArray { @@ -77,6 +152,15 @@ pub struct TimestampArray { pub values: Vec, } +#[derive(Debug, Clone)] + +pub struct TimestampArrayView<'a> { + pub unit: TimeUnit, + pub timezone: Option, + pub validity: Option>, + pub values: &'a [i64], +} + #[derive(Clone, Debug)] pub struct StructArray { pub len: usize, @@ -85,10 +169,10 @@ pub struct StructArray { } #[derive(Clone, Debug)] -pub struct FieldMeta { - pub name: String, - pub nullable: bool, - pub metadata: HashMap, +pub struct StructArrayView<'a> { + pub len: usize, + pub validity: Option>, + pub fields: Vec<(ArrayView<'a>, FieldMeta)>, } #[derive(Clone, Debug)] @@ -99,6 +183,14 @@ pub struct ListArray { pub element: Box, } +#[derive(Clone, Debug)] +pub struct ListArrayView<'a, O> { + pub validity: Option>, + pub offsets: &'a [O], + pub meta: FieldMeta, + pub element: Box>, +} + #[derive(Clone, Debug)] pub struct FixedSizeListArray { pub n: i32, @@ -107,6 +199,14 @@ pub struct FixedSizeListArray { pub element: Box, } +#[derive(Clone, Debug)] +pub struct FixedSizeListArrayView<'a> { + pub n: i32, + pub validity: Option>, + pub meta: FieldMeta, + pub element: Box>, +} + #[derive(Clone, Debug)] pub struct BytesArray { pub validity: Option>, @@ -114,6 +214,13 @@ pub struct BytesArray { pub data: Vec, } +#[derive(Clone, Debug)] +pub struct BytesArrayView<'a, O> { + pub validity: Option>, + pub offsets: &'a [O], + pub data: &'a [u8], +} + #[derive(Clone, Debug)] pub struct FixedSizeBinaryArray { pub n: i32, @@ -121,6 +228,13 @@ pub struct FixedSizeBinaryArray { pub data: Vec, } +#[derive(Clone, Debug)] +pub struct FixedSizeBinaryArrayView<'a> { + pub n: i32, + pub validity: Option>, + pub data: &'a [u8], +} + #[derive(Clone, Debug)] pub struct DecimalArray { pub precision: u8, @@ -129,15 +243,36 @@ pub struct DecimalArray { pub values: Vec, } +#[derive(Clone, Debug)] +pub struct DecimalArrayView<'a, T> { + pub precision: u8, + pub scale: i8, + pub validity: Option>, + pub values: &'a [T], +} + #[derive(Clone, Debug)] pub struct DictionaryArray { pub indices: Box, pub values: Box, } +#[derive(Clone, Debug)] +pub struct DictionaryArrayView<'a> { + pub indices: Box>, + pub values: Box>, +} + #[derive(Clone, Debug)] pub struct DenseUnionArray { pub types: Vec, pub offsets: Vec, pub fields: Vec<(Array, FieldMeta)>, } + +#[derive(Clone, Debug)] +pub struct DenseUnionArrayView<'a> { + pub types: &'a [i8], + pub offsets: &'a [i32], + pub fields: Vec<(ArrayView<'a>, FieldMeta)>, +} diff --git a/serde_arrow/src/internal/arrow/array_view.rs b/serde_arrow/src/internal/arrow/array_view.rs deleted file mode 100644 index 81d9199b..00000000 --- a/serde_arrow/src/internal/arrow/array_view.rs +++ /dev/null @@ -1,86 +0,0 @@ -#![allow(dead_code, unused)] -use half::f16; - -use crate::internal::arrow::data_type::TimeUnit; - -pub enum ArrayView<'a> { - Null(NullArrayView), - Boolean(BooleanArrayView<'a>), - Int8(PrimitiveArrayView<'a, i8>), - Int16(PrimitiveArrayView<'a, i16>), - Int32(PrimitiveArrayView<'a, i32>), - Int64(PrimitiveArrayView<'a, i64>), - UInt8(PrimitiveArrayView<'a, u8>), - UInt16(PrimitiveArrayView<'a, u16>), - UInt32(PrimitiveArrayView<'a, u32>), - UInt64(PrimitiveArrayView<'a, u64>), - Float16(PrimitiveArrayView<'a, f16>), - Float32(PrimitiveArrayView<'a, f32>), - Float64(PrimitiveArrayView<'a, f64>), - Date32(PrimitiveArrayView<'a, i32>), - Date64(PrimitiveArrayView<'a, i64>), - Time32(TimeArrayView<'a, i32>), - Time64(TimeArrayView<'a, i64>), - Utf8(Utf8ArrayView<'a, i32>), - LargeUtf8(Utf8ArrayView<'a, i64>), - Binary(Utf8ArrayView<'a, i32>), - LargeBinary(Utf8ArrayView<'a, i64>), - Decimal128(DecimalArrayView<'a, i128>), - Struct(StructArrayView<'a>), - List(ListArrayView<'a, i32>), - LargeList(ListArrayView<'a, i64>), -} - -pub struct NullArrayView { - pub len: usize, -} - -#[derive(Debug, Clone, Copy)] -pub struct BitsWithOffset<'a> { - pub offset: usize, - pub data: &'a [u8], -} - -pub struct BooleanArrayView<'a> { - pub len: usize, - pub validity: Option>, - pub values: BitsWithOffset<'a>, -} - -pub struct PrimitiveArrayView<'a, T> { - pub validity: Option>, - pub values: &'a [T], -} - -pub struct DecimalArrayView<'a, T> { - pub precision: u8, - pub scale: i8, - pub validity: Option>, - pub values: &'a [T], -} - -pub struct TimeArrayView<'a, T> { - pub unit: TimeUnit, - pub validity: Option>, - pub values: &'a [T], -} - -pub struct StructArrayView<'a> { - pub len: usize, - pub validity: Option>, - pub fields: Vec>, -} - -pub struct ListArrayView<'a, O> { - pub len: usize, - pub validity: Option<&'a [u8]>, - pub offsets: &'a [O], - pub element: Box>, -} - -pub struct Utf8ArrayView<'a, O> { - pub len: usize, - pub validity: Option<&'a [u8]>, - pub offsets: &'a [O], - pub data: &'a [u8], -} diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 44a67759..0479e577 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -1,16 +1,12 @@ //! A common arrow abstraction to simplify conversion between different arrow //! implementations mod array; -mod array_view; mod data_type; pub use array::{ - Array, BooleanArray, BytesArray, DecimalArray, DenseUnionArray, DictionaryArray, FieldMeta, - FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, - TimeArray, TimestampArray, -}; -pub use array_view::{ - ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, - PrimitiveArrayView, TimeArrayView, + Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, + DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, FieldMeta, + FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, NullArrayView, PrimitiveArray, + PrimitiveArrayView, StructArray, TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 45a90daa..1ee422be 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -3,8 +3,8 @@ use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ arrow::{ArrayView, BitsWithOffset, TimeUnit}, - error::{Error, Result}, - schema::GenericField, + error::{fail, Error, Result}, + schema::{GenericField, Strategy}, utils::Mut, }; @@ -146,6 +146,28 @@ impl<'a> ArrayDeserializer<'a> { buffer_from_bits_with_offset_opt(view.validity, view.values.len()), view.unit, ))), + ArrayView::Timestamp(view) => match field.strategy.as_ref() { + Some(Strategy::NaiveStrAsDate64 | Strategy::UtcStrAsDate64) => { + Ok(Self::Date64(Date64Deserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.unit, + field.is_utc()?, + ))) + } + Some(strategy) => fail!("invalid strategy {strategy} for timestamp field"), + None => Ok(Date64Deserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.unit, + field.is_utc()?, + ) + .into()), + }, + ArrayView::Duration(view) => Ok(Self::I64(IntegerDeserializer::new( + view.values, + buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + ))), _ => unimplemented!(), } } diff --git a/serde_arrow/src/internal/deserialization/construction.rs b/serde_arrow/src/internal/deserialization/construction.rs index a6d26c4c..db7e8ab1 100644 --- a/serde_arrow/src/internal/deserialization/construction.rs +++ b/serde_arrow/src/internal/deserialization/construction.rs @@ -6,6 +6,7 @@ use crate::internal::{ use super::{array_deserializer::ArrayDeserializer, utils::BitBuffer}; +#[allow(unused)] pub fn build_timestamp_deserializer<'a>( field: &GenericField, values: &'a [i64], From f1b724f3cf8cd5d279654eef06533f61c5d919b1 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 23 Jul 2024 19:46:34 +0200 Subject: [PATCH 030/178] Implement Utf8, LargeUtf8, Binary, LargeBinaray --- serde_arrow/src/arrow_impl/deserialization.rs | 72 +++++++------------ .../deserialization/array_deserializer.rs | 32 +++++++++ 2 files changed, 58 insertions(+), 46 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 638cb868..54caa546 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,11 +1,10 @@ use crate::internal::{ arrow::{ - ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, NullArrayView, - PrimitiveArrayView, TimeArrayView, TimeUnit, TimestampArrayView, + ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, + NullArrayView, PrimitiveArrayView, TimeArrayView, TimeUnit, TimestampArrayView, }, deserialization::{ array_deserializer::ArrayDeserializer, - binary_deserializer::BinaryDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, fixed_size_list_deserializer::FixedSizeListDeserializer, @@ -13,7 +12,6 @@ use crate::internal::{ list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, - string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, @@ -123,14 +121,10 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::Utf8 => build_string_deserializer::(field, array), - T::LargeUtf8 => build_string_deserializer::(field, array), T::Struct => build_struct_deserializer(field, array), T::List => build_list_deserializer::(field, array), T::LargeList => build_list_deserializer::(field, array), T::FixedSizeList(n) => build_fixed_size_list_deserializer(field, array, *n), - T::Binary => build_binary_deserializer::(field, array), - T::LargeBinary => build_binary_deserializer::(field, array), T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), @@ -139,25 +133,6 @@ pub fn build_array_deserializer<'a>( } } -pub fn build_string_deserializer<'a, O>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - O: OffsetSizeTrait + Offset, - ArrayDeserializer<'a>: From>, -{ - let Some(array) = array.as_any().downcast_ref::>() else { - fail!("cannot convert {} array into string", array.data_type()); - }; - - let buffer = array.value_data(); - let offsets = array.value_offsets(); - let validity = get_validity(array); - - Ok(StringDeserializer::new(buffer, offsets, validity).into()) -} - pub fn build_dictionary_deserializer<'a>( field: &GenericField, array: &'a dyn Array, @@ -295,25 +270,6 @@ where Ok(ListDeserializer::new(item, offsets, validity).into()) } -pub fn build_binary_deserializer<'a, O>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - O: OffsetSizeTrait + Offset, - ArrayDeserializer<'a>: From>, -{ - let Some(array) = array.as_any().downcast_ref::>() else { - fail!("cannot convert {} array into string", array.data_type()); - }; - - let buffer = array.value_data(); - let offsets = array.value_offsets(); - let validity = get_validity(array); - - Ok(BinaryDeserializer::new(buffer, offsets, validity).into()) -} - #[cfg(has_arrow_fixed_binary_support)] pub fn build_fixed_size_binary_deserializer<'a>( _field: &GenericField, @@ -600,6 +556,30 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { validity: get_bits_with_offset(array), values: array.values(), })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Utf8(BytesArrayView { + validity: get_bits_with_offset(array), + offsets: array.offsets(), + data: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::LargeUtf8(BytesArrayView { + validity: get_bits_with_offset(array), + offsets: array.offsets(), + data: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::Binary(BytesArrayView { + validity: get_bits_with_offset(array), + offsets: array.offsets(), + data: array.values(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(ArrayView::LargeBinary(BytesArrayView { + validity: get_bits_with_offset(array), + offsets: array.offsets(), + data: array.values(), + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 1ee422be..65ab12b6 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -168,6 +168,38 @@ impl<'a> ArrayDeserializer<'a> { view.values, buffer_from_bits_with_offset_opt(view.validity, view.values.len()), ))), + ArrayView::Utf8(view) => Ok(Self::Utf8(StringDeserializer::new( + view.data, + view.offsets, + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), + ))), + ArrayView::LargeUtf8(view) => Ok(Self::LargeUtf8(StringDeserializer::new( + view.data, + view.offsets, + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), + ))), + ArrayView::Binary(view) => Ok(Self::Binary(BinaryDeserializer::new( + view.data, + view.offsets, + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), + ))), + ArrayView::LargeBinary(view) => Ok(Self::LargeBinary(BinaryDeserializer::new( + view.data, + view.offsets, + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), + ))), _ => unimplemented!(), } } From 32eebfc23b5fb3309b4f87a1072896ac502469aa Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 23 Jul 2024 20:04:39 +0200 Subject: [PATCH 031/178] Fix arrow=37 support, start to implement ListArray support --- serde_arrow/src/arrow_impl/deserialization.rs | 32 ++++++++++++++++--- serde_arrow/src/internal/arrow/mod.rs | 5 +-- .../deserialization/array_deserializer.rs | 17 ++++++++-- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 54caa546..fb28cb79 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,7 +1,8 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - NullArrayView, PrimitiveArrayView, TimeArrayView, TimeUnit, TimestampArrayView, + ListArrayView, NullArrayView, PrimitiveArrayView, TimeArrayView, TimeUnit, + TimestampArrayView, }, deserialization::{ array_deserializer::ArrayDeserializer, @@ -18,6 +19,7 @@ use crate::internal::{ deserializer::Deserializer, error::{fail, Error, Result}, schema::{GenericDataType, GenericField}, + serialization::utils::meta_from_field, utils::Offset, }; @@ -559,27 +561,47 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::Utf8(BytesArrayView { validity: get_bits_with_offset(array), - offsets: array.offsets(), + offsets: array.value_offsets(), data: array.values(), })) } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::LargeUtf8(BytesArrayView { validity: get_bits_with_offset(array), - offsets: array.offsets(), + offsets: array.value_offsets(), data: array.values(), })) } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::Binary(BytesArrayView { validity: get_bits_with_offset(array), - offsets: array.offsets(), + offsets: array.value_offsets(), data: array.values(), })) } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::LargeBinary(BytesArrayView { validity: get_bits_with_offset(array), - offsets: array.offsets(), + offsets: array.value_offsets(), data: array.values(), })) + } else if let Some(array) = any.downcast_ref::>() { + let DataType::List(field) = array.data_type() else { + fail!("invalid data type for list array: {}", array.data_type()); + }; + Ok(ArrayView::List(ListArrayView { + validity: get_bits_with_offset(array), + offsets: array.value_offsets(), + meta: meta_from_field(field.as_ref().try_into()?)?, + element: Box::new(array.values().as_ref().try_into()?), + })) + } else if let Some(array) = any.downcast_ref::>() { + let DataType::LargeList(field) = array.data_type() else { + fail!("invalid data type for list array: {}", array.data_type()); + }; + Ok(ArrayView::LargeList(ListArrayView { + validity: get_bits_with_offset(array), + offsets: array.value_offsets(), + meta: meta_from_field(field.as_ref().try_into()?)?, + element: Box::new(array.values().as_ref().try_into()?), + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 0479e577..c8fa384b 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,7 +6,8 @@ mod data_type; pub use array::{ Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, FieldMeta, - FixedSizeBinaryArray, FixedSizeListArray, ListArray, NullArray, NullArrayView, PrimitiveArray, - PrimitiveArrayView, StructArray, TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, + FixedSizeBinaryArray, FixedSizeListArray, ListArray, ListArrayView, NullArray, NullArrayView, + PrimitiveArray, PrimitiveArrayView, StructArray, TimeArray, TimeArrayView, TimestampArray, + TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 65ab12b6..edf78c8f 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -152,7 +152,7 @@ impl<'a> ArrayDeserializer<'a> { view.values, buffer_from_bits_with_offset_opt(view.validity, view.values.len()), view.unit, - field.is_utc()?, + is_utc_timestamp(view.timezone.as_deref())?, ))) } Some(strategy) => fail!("invalid strategy {strategy} for timestamp field"), @@ -160,7 +160,7 @@ impl<'a> ArrayDeserializer<'a> { view.values, buffer_from_bits_with_offset_opt(view.validity, view.values.len()), view.unit, - field.is_utc()?, + is_utc_timestamp(view.timezone.as_deref())?, ) .into()), }, @@ -200,11 +200,24 @@ impl<'a> ArrayDeserializer<'a> { view.offsets.len().saturating_sub(1), ), ))), + ArrayView::List(view) => Ok(Self::List(ListDeserializer::new( + todo!(), + view.offsets, + buffer_from_bits_with_offset_opt(view.validity, view.offsets.len()), + ))), _ => unimplemented!(), } } } +fn is_utc_timestamp(timezone: Option<&str>) -> Result { + match timezone { + Some(tz) if tz.to_lowercase() == "utc" => Ok(true), + Some(tz) => fail!("unsupported timezone {}", tz), + None => Ok(false), + } +} + fn buffer_from_bits_with_offset(bits: BitsWithOffset, len: usize) -> BitBuffer { BitBuffer { data: bits.data, From fb66cc733b571f35afdc1b2c924f4c70de31c72a Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 23 Jul 2024 20:10:20 +0200 Subject: [PATCH 032/178] Fix arrow==37 support --- serde_arrow/src/arrow_impl/deserialization.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index fb28cb79..ab094916 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -562,25 +562,25 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { Ok(ArrayView::Utf8(BytesArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - data: array.values(), + data: array.value_data(), })) } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::LargeUtf8(BytesArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - data: array.values(), + data: array.value_data(), })) } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::Binary(BytesArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - data: array.values(), + data: array.value_data(), })) } else if let Some(array) = any.downcast_ref::>() { Ok(ArrayView::LargeBinary(BytesArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - data: array.values(), + data: array.value_data(), })) } else if let Some(array) = any.downcast_ref::>() { let DataType::List(field) = array.data_type() else { From a6b6d7a00c542c4a999b782cea70ea22e3a668db Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 23 Jul 2024 20:45:29 +0200 Subject: [PATCH 033/178] Implement List, Struct --- serde_arrow/src/arrow_impl/deserialization.rs | 46 ++++++++---------- serde_arrow/src/internal/arrow/mod.rs | 4 +- .../deserialization/array_deserializer.rs | 48 ++++++++++++++++--- 3 files changed, 63 insertions(+), 35 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index ab094916..5b89941f 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,7 +1,7 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - ListArrayView, NullArrayView, PrimitiveArrayView, TimeArrayView, TimeUnit, + ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, TimestampArrayView, }, deserialization::{ @@ -123,15 +123,12 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::Struct => build_struct_deserializer(field, array), - T::List => build_list_deserializer::(field, array), - T::LargeList => build_list_deserializer::(field, array), T::FixedSizeList(n) => build_fixed_size_list_deserializer(field, array, *n), T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), T::Dictionary => build_dictionary_deserializer(field, array), - _ => ArrayDeserializer::new(field, array.try_into()?), + _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), } } @@ -250,28 +247,6 @@ pub fn build_struct_fields<'a>( Ok((deserializers, len)) } -pub fn build_list_deserializer<'a, O>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - O: OffsetSizeTrait + Offset, - ArrayDeserializer<'a>: From>, -{ - let Some(array) = array.as_any().downcast_ref::>() else { - fail!( - "Cannot interpret {} array as GenericListArray", - array.data_type() - ); - }; - - let item = build_array_deserializer(&field.children[0], array.values())?; - let offsets = array.value_offsets(); - let validity = get_validity(array); - - Ok(ListDeserializer::new(item, offsets, validity).into()) -} - #[cfg(has_arrow_fixed_binary_support)] pub fn build_fixed_size_binary_deserializer<'a>( _field: &GenericField, @@ -602,6 +577,23 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { meta: meta_from_field(field.as_ref().try_into()?)?, element: Box::new(array.values().as_ref().try_into()?), })) + } else if let Some(array) = any.downcast_ref::() { + let DataType::Struct(column_fields) = array.data_type() else { + fail!("invalid data type for struct array: {}", array.data_type()); + }; + + let mut fields = Vec::new(); + for (field, array) in std::iter::zip(column_fields, array.columns()) { + let view = ArrayView::try_from(array.as_ref())?; + let meta = meta_from_field(GenericField::try_from(field.as_ref())?)?; + fields.push((view, meta)); + } + + Ok(ArrayView::Struct(StructArrayView { + len: array.len(), + validity: get_bits_with_offset(array), + fields, + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index c8fa384b..d9064258 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -7,7 +7,7 @@ pub use array::{ Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, FieldMeta, FixedSizeBinaryArray, FixedSizeListArray, ListArray, ListArrayView, NullArray, NullArrayView, - PrimitiveArray, PrimitiveArrayView, StructArray, TimeArray, TimeArrayView, TimestampArray, - TimestampArrayView, + PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, + TimestampArray, TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index edf78c8f..cff90fce 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -2,9 +2,9 @@ use half::f16; use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ - arrow::{ArrayView, BitsWithOffset, TimeUnit}, + arrow::{ArrayView, BitsWithOffset, FieldMeta, TimeUnit}, error::{fail, Error, Result}, - schema::{GenericField, Strategy}, + schema::{Strategy, STRATEGY_KEY}, utils::Mut, }; @@ -70,7 +70,7 @@ pub enum ArrayDeserializer<'a> { } impl<'a> ArrayDeserializer<'a> { - pub fn new(field: &GenericField, array: ArrayView<'a>) -> Result { + pub fn new(strategy: Option<&Strategy>, array: ArrayView<'a>) -> Result { match array { ArrayView::Null(_) => Ok(Self::Null(NullDeserializer {})), ArrayView::Boolean(view) => Ok(Self::Bool(BoolDeserializer::new( @@ -134,7 +134,7 @@ impl<'a> ArrayDeserializer<'a> { view.values, buffer_from_bits_with_offset_opt(view.validity, view.values.len()), TimeUnit::Millisecond, - field.is_utc()?, + is_utc_date64(strategy)?, ))), ArrayView::Time32(view) => Ok(Self::Time32(TimeDeserializer::new( view.values, @@ -146,7 +146,7 @@ impl<'a> ArrayDeserializer<'a> { buffer_from_bits_with_offset_opt(view.validity, view.values.len()), view.unit, ))), - ArrayView::Timestamp(view) => match field.strategy.as_ref() { + ArrayView::Timestamp(view) => match strategy { Some(Strategy::NaiveStrAsDate64 | Strategy::UtcStrAsDate64) => { Ok(Self::Date64(Date64Deserializer::new( view.values, @@ -201,10 +201,31 @@ impl<'a> ArrayDeserializer<'a> { ), ))), ArrayView::List(view) => Ok(Self::List(ListDeserializer::new( - todo!(), + ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, buffer_from_bits_with_offset_opt(view.validity, view.offsets.len()), ))), + ArrayView::LargeList(view) => Ok(Self::LargeList(ListDeserializer::new( + ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, + view.offsets, + buffer_from_bits_with_offset_opt(view.validity, view.offsets.len()), + ))), + ArrayView::Struct(view) => { + let mut fields = Vec::new(); + for (field_view, field_meta) in view.fields { + let field_deserializer = + ArrayDeserializer::new(get_strategy(&field_meta)?.as_ref(), field_view)?; + let field_name = field_meta.name; + + fields.push((field_name, field_deserializer)); + } + + Ok(Self::Struct(StructDeserializer::new( + fields, + buffer_from_bits_with_offset_opt(view.validity, view.len), + view.len, + ))) + } _ => unimplemented!(), } } @@ -218,6 +239,21 @@ fn is_utc_timestamp(timezone: Option<&str>) -> Result { } } +fn is_utc_date64(strategy: Option<&Strategy>) -> Result { + match strategy { + None | Some(Strategy::UtcStrAsDate64) => Ok(true), + Some(Strategy::NaiveStrAsDate64) => Ok(false), + Some(strategy) => fail!("invalid strategy for date64 deserializer: {strategy}"), + } +} + +fn get_strategy(meta: &FieldMeta) -> Result> { + let Some(strategy) = meta.metadata.get(STRATEGY_KEY) else { + return Ok(None); + }; + Ok(Some(strategy.parse()?)) +} + fn buffer_from_bits_with_offset(bits: BitsWithOffset, len: usize) -> BitBuffer { BitBuffer { data: bits.data, From a76a66af9ca7003837bde45466b3c9beabbadc32 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 23 Jul 2024 21:12:38 +0200 Subject: [PATCH 034/178] Implement FixedSizeList --- serde_arrow/src/arrow_impl/deserialization.rs | 59 ++++--------------- serde_arrow/src/internal/arrow/array.rs | 2 + serde_arrow/src/internal/arrow/mod.rs | 6 +- .../deserialization/array_deserializer.rs | 18 +++++- .../serialization/fixed_size_list_builder.rs | 1 + 5 files changed, 35 insertions(+), 51 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 5b89941f..3187d834 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,19 +1,16 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, - TimestampArrayView, + FixedSizeListArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, + TimeArrayView, TimeUnit, TimestampArrayView, }, deserialization::{ array_deserializer::ArrayDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, - fixed_size_list_deserializer::FixedSizeListDeserializer, integer_deserializer::Integer, - list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, - struct_deserializer::StructDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, deserializer::Deserializer, @@ -123,7 +120,6 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::FixedSizeList(n) => build_fixed_size_list_deserializer(field, array, *n), T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), @@ -202,26 +198,6 @@ pub fn build_dictionary_deserializer<'a>( } } -pub fn build_struct_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!("Cannot convert {} array into struct", array.data_type()); - }; - - let fields = &field.children; - let arrays = array - .columns() - .iter() - .map(|array| array.as_ref()) - .collect::>(); - let validity = get_validity(array); - - let (deserializers, len) = build_struct_fields(fields, &arrays)?; - Ok(StructDeserializer::new(deserializers, validity, len).into()) -} - pub fn build_struct_fields<'a>( fields: &[GenericField], arrays: &[&'a dyn Array], @@ -274,26 +250,6 @@ pub fn build_fixed_size_binary_deserializer<'a>( fail!("FixedSizeBinary arrays are not supported for arrow<=46"); } -pub fn build_fixed_size_list_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, - n: i32, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!( - "Cannot interpret {} array as GenericListArray", - array.data_type() - ); - }; - - let n = n.try_into()?; - let len = array.len(); - let item = build_array_deserializer(&field.children[0], array.values())?; - let validity = get_validity(array); - - Ok(FixedSizeListDeserializer::new(item, validity, n, len).into()) -} - pub fn build_map_deserializer<'a>( field: &GenericField, array: &'a dyn Array, @@ -577,6 +533,17 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { meta: meta_from_field(field.as_ref().try_into()?)?, element: Box::new(array.values().as_ref().try_into()?), })) + } else if let Some(array) = any.downcast_ref::() { + let DataType::FixedSizeList(field, n) = array.data_type() else { + fail!("invalid data type for list array: {}", array.data_type()); + }; + Ok(ArrayView::FixedSizeList(FixedSizeListArrayView { + len: array.len(), + n: *n, + validity: get_bits_with_offset(array), + meta: meta_from_field(field.as_ref().try_into()?)?, + element: Box::new(array.values().as_ref().try_into()?), + })) } else if let Some(array) = any.downcast_ref::() { let DataType::Struct(column_fields) = array.data_type() else { fail!("invalid data type for struct array: {}", array.data_type()); diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 5a528c4a..b3db7209 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -193,6 +193,7 @@ pub struct ListArrayView<'a, O> { #[derive(Clone, Debug)] pub struct FixedSizeListArray { + pub len: usize, pub n: i32, pub validity: Option>, pub meta: FieldMeta, @@ -201,6 +202,7 @@ pub struct FixedSizeListArray { #[derive(Clone, Debug)] pub struct FixedSizeListArrayView<'a> { + pub len: usize, pub n: i32, pub validity: Option>, pub meta: FieldMeta, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index d9064258..d36c3d19 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,8 +6,8 @@ mod data_type; pub use array::{ Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, FieldMeta, - FixedSizeBinaryArray, FixedSizeListArray, ListArray, ListArrayView, NullArray, NullArrayView, - PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, - TimestampArray, TimestampArrayView, + FixedSizeBinaryArray, FixedSizeListArray, FixedSizeListArrayView, ListArray, ListArrayView, + NullArray, NullArrayView, PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, + TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index cff90fce..51ed2ff0 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -203,13 +203,27 @@ impl<'a> ArrayDeserializer<'a> { ArrayView::List(view) => Ok(Self::List(ListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, - buffer_from_bits_with_offset_opt(view.validity, view.offsets.len()), + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), ))), ArrayView::LargeList(view) => Ok(Self::LargeList(ListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, - buffer_from_bits_with_offset_opt(view.validity, view.offsets.len()), + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), ))), + ArrayView::FixedSizeList(view) => { + Ok(Self::FixedSizeList(FixedSizeListDeserializer::new( + ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, + buffer_from_bits_with_offset_opt(view.validity, view.len), + view.n.try_into()?, + view.len, + ))) + } ArrayView::Struct(view) => { let mut fields = Vec::new(); for (field_view, field_meta) in view.fields { diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 06754e27..b977d648 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -54,6 +54,7 @@ impl FixedSizeListBuilder { pub fn into_array(self) -> Result { Ok(Array::FixedSizeList(FixedSizeListArray { + len: self.num_elements, n: self.n.try_into()?, meta: meta_from_field(self.field)?, validity: self.validity.map(|v| v.buffer), From 8d7d76e1d76b89021907bd247db5acdd18c93dea Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 24 Jul 2024 20:46:53 +0200 Subject: [PATCH 035/178] Implement Map --- .../src/arrow2_impl/deserialization.rs | 4 +- serde_arrow/src/arrow_impl/deserialization.rs | 51 ++++++------------- .../deserialization/array_deserializer.rs | 26 +++++++++- .../deserialization/list_deserializer.rs | 13 +++-- .../deserialization/map_deserializer.rs | 13 +++-- .../src/internal/deserialization/utils.rs | 16 +++--- 6 files changed, 66 insertions(+), 57 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 68fa0769..63d2aede 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -370,7 +370,7 @@ where }; let item = build_array_deserializer(item_field, array.values().as_ref())?; - Ok(ListDeserializer::new(item, offsets, validity).into()) + Ok(ListDeserializer::new(item, offsets, validity)?.into()) } pub fn build_map_deserializer<'a>( @@ -407,7 +407,7 @@ pub fn build_map_deserializer<'a>( let keys = build_array_deserializer(keys_field, keys.as_ref())?; let values = build_array_deserializer(values_field, values.as_ref())?; - Ok(MapDeserializer::new(keys, values, offsets, validity).into()) + Ok(MapDeserializer::new(keys, values, offsets, validity)?.into()) } pub fn build_union_deserializer<'a>( diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 3187d834..714bb4c5 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -5,13 +5,9 @@ use crate::internal::{ TimeArrayView, TimeUnit, TimestampArrayView, }, deserialization::{ - array_deserializer::ArrayDeserializer, - dictionary_deserializer::DictionaryDeserializer, - enum_deserializer::EnumDeserializer, - integer_deserializer::Integer, - map_deserializer::MapDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - utils::{check_supported_list_layout, BitBuffer}, + array_deserializer::ArrayDeserializer, dictionary_deserializer::DictionaryDeserializer, + enum_deserializer::EnumDeserializer, integer_deserializer::Integer, + outer_sequence_deserializer::OuterSequenceDeserializer, utils::BitBuffer, }, deserializer::Deserializer, error::{fail, Error, Result}, @@ -121,7 +117,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), - T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), T::Dictionary => build_dictionary_deserializer(field, array), _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), @@ -250,34 +245,6 @@ pub fn build_fixed_size_binary_deserializer<'a>( fail!("FixedSizeBinary arrays are not supported for arrow<=46"); } -pub fn build_map_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(entries_field) = field.children.first() else { - fail!("cannot get children of map"); - }; - let Some(keys_field) = entries_field.children.first() else { - fail!("cannot get keys field"); - }; - let Some(values_field) = entries_field.children.get(1) else { - fail!("cannot get values field"); - }; - let Some(array) = array.as_any().downcast_ref::() else { - fail!("cannot convert {} array into map array", array.data_type()); - }; - - let offsets = array.value_offsets(); - let validity = get_validity(array); - - check_supported_list_layout(validity, offsets)?; - - let key = build_array_deserializer(keys_field, array.keys())?; - let value = build_array_deserializer(values_field, array.values())?; - - Ok(MapDeserializer::new(key, value, offsets, validity).into()) -} - pub fn build_union_deserializer<'a>( field: &GenericField, array: &'a dyn Array, @@ -561,6 +528,18 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { validity: get_bits_with_offset(array), fields, })) + } else if let Some(array) = any.downcast_ref::() { + let DataType::Map(entries_field, _) = array.data_type() else { + fail!("invalid data type for map array: {}", array.data_type()); + }; + let entries_array: &dyn Array = array.entries(); + + Ok(ArrayView::Map(ListArrayView { + validity: get_bits_with_offset(array), + offsets: array.value_offsets(), + meta: meta_from_field(GenericField::try_from(entries_field.as_ref())?)?, + element: Box::new(entries_array.try_into()?), + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 51ed2ff0..0bfde2ea 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -207,7 +207,7 @@ impl<'a> ArrayDeserializer<'a> { view.validity, view.offsets.len().saturating_sub(1), ), - ))), + )?)), ArrayView::LargeList(view) => Ok(Self::LargeList(ListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, @@ -215,7 +215,7 @@ impl<'a> ArrayDeserializer<'a> { view.validity, view.offsets.len().saturating_sub(1), ), - ))), + )?)), ArrayView::FixedSizeList(view) => { Ok(Self::FixedSizeList(FixedSizeListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, @@ -240,6 +240,28 @@ impl<'a> ArrayDeserializer<'a> { view.len, ))) } + ArrayView::Map(view) => { + let ArrayView::Struct(entries_view) = *view.element else { + fail!("invalid entries field in map array"); + }; + let Ok(entries_fields) = <[_; 2]>::try_from(entries_view.fields) else { + fail!("invalid entries field in map array") + }; + let [(keys_view, keys_meta), (values_view, values_meta)] = entries_fields; + let keys = ArrayDeserializer::new(get_strategy(&keys_meta)?.as_ref(), keys_view)?; + let values = + ArrayDeserializer::new(get_strategy(&values_meta)?.as_ref(), values_view)?; + + Ok(Self::Map(MapDeserializer::new( + keys, + values, + view.offsets, + buffer_from_bits_with_offset_opt( + view.validity, + view.offsets.len().saturating_sub(1), + ), + )?)) + } _ => unimplemented!(), } } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 0272b35d..ce4fccb7 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -6,8 +6,9 @@ use crate::internal::{ }; use super::{ - array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer, - utils::BitBuffer, + array_deserializer::ArrayDeserializer, + simple_deserializer::SimpleDeserializer, + utils::{check_supported_list_layout, BitBuffer}, }; pub struct ListDeserializer<'a, O: Offset> { @@ -22,13 +23,15 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { item: ArrayDeserializer<'a>, offsets: &'a [O], validity: Option>, - ) -> Self { - Self { + ) -> Result { + check_supported_list_layout(validity, offsets)?; + + Ok(Self { item: Box::new(item), offsets, validity, next: (0, 0), - } + }) } pub fn peek_next(&self) -> Result { diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 33ad6263..10f18c8c 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -6,8 +6,9 @@ use crate::internal::{ }; use super::{ - array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer, - utils::BitBuffer, + array_deserializer::ArrayDeserializer, + simple_deserializer::SimpleDeserializer, + utils::{check_supported_list_layout, BitBuffer}, }; pub struct MapDeserializer<'a> { @@ -24,14 +25,16 @@ impl<'a> MapDeserializer<'a> { value: ArrayDeserializer<'a>, offsets: &'a [i32], validity: Option>, - ) -> Self { - Self { + ) -> Result { + check_supported_list_layout(validity, offsets)?; + + Ok(Self { key: Box::new(key), value: Box::new(value), offsets, validity, next: (0, 0), - } + }) } pub fn peek_next(&self) -> Result { diff --git a/serde_arrow/src/internal/deserialization/utils.rs b/serde_arrow/src/internal/deserialization/utils.rs index 213dca3a..cdf8b32b 100644 --- a/serde_arrow/src/internal/deserialization/utils.rs +++ b/serde_arrow/src/internal/deserialization/utils.rs @@ -1,4 +1,7 @@ -use crate::internal::error::{error, fail, Result}; +use crate::internal::{ + error::{error, fail, Result}, + utils::Offset, +}; #[derive(Debug, PartialEq, Clone, Copy)] pub struct BitBuffer<'a> { @@ -81,13 +84,10 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { /// **non-empty** segment in the child array." /// /// [arrow format spec]: https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout -pub fn check_supported_list_layout<'a, O>( +pub fn check_supported_list_layout<'a, O: Offset>( validity: Option>, offsets: &'a [O], -) -> Result<()> -where - O: std::ops::Sub + std::cmp::PartialEq + From + Copy, -{ +) -> Result<()> { let Some(validity) = validity else { return Ok(()); }; @@ -101,7 +101,9 @@ where ); } for i in 0..validity.len() { - if !validity.is_set(i) && (offsets[i + 1] - offsets[i]) != O::from(0) { + let curr = offsets[i].try_into_usize()?; + let next = offsets[i + 1].try_into_usize()?; + if !validity.is_set(i) && (next - curr) != 0 { fail!("lists with data in null values are currently not supported in deserialization"); } } From 5da2ea9c68a3c26b83e7ef8596924e2ddfe983f6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 24 Jul 2024 21:18:07 +0200 Subject: [PATCH 036/178] Implement Dictionary --- serde_arrow/src/arrow_impl/deserialization.rs | 108 +++++------------- serde_arrow/src/internal/arrow/mod.rs | 8 +- .../deserialization/array_deserializer.rs | 96 ++++++++++++++-- 3 files changed, 121 insertions(+), 91 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index 714bb4c5..e58a4216 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,25 +1,23 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - FixedSizeListArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, - TimeArrayView, TimeUnit, TimestampArrayView, + DictionaryArrayView, FixedSizeListArrayView, ListArrayView, NullArrayView, + PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, TimestampArrayView, }, deserialization::{ - array_deserializer::ArrayDeserializer, dictionary_deserializer::DictionaryDeserializer, - enum_deserializer::EnumDeserializer, integer_deserializer::Integer, + array_deserializer::ArrayDeserializer, enum_deserializer::EnumDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, utils::BitBuffer, }, deserializer::Deserializer, error::{fail, Error, Result}, schema::{GenericDataType, GenericField}, serialization::utils::meta_from_field, - utils::Offset, }; use crate::_impl::arrow::{ array::{ Array, BooleanArray, DictionaryArray, FixedSizeListArray, GenericBinaryArray, - GenericListArray, GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, + GenericListArray, GenericStringArray, MapArray, NullArray, PrimitiveArray, RecordBatch, StructArray, UnionArray, }, datatypes::{ @@ -118,81 +116,10 @@ pub fn build_array_deserializer<'a>( match &field.data_type { T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), T::Union => build_union_deserializer(field, array), - T::Dictionary => build_dictionary_deserializer(field, array), _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), } } -pub fn build_dictionary_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - use GenericDataType as T; - - let Some(key_field) = field.children.first() else { - fail!("Missing key field"); - }; - let Some(value_field) = field.children.get(1) else { - fail!("Missing key field"); - }; - - return match (&key_field.data_type, &value_field.data_type) { - (T::U8, T::Utf8) => typed::(field, array), - (T::U16, T::Utf8) => typed::(field, array), - (T::U32, T::Utf8) => typed::(field, array), - (T::U64, T::Utf8) => typed::(field, array), - (T::I8, T::Utf8) => typed::(field, array), - (T::I16, T::Utf8) => typed::(field, array), - (T::I32, T::Utf8) => typed::(field, array), - (T::I64, T::Utf8) => typed::(field, array), - (T::U8, T::LargeUtf8) => typed::(field, array), - (T::U16, T::LargeUtf8) => typed::(field, array), - (T::U32, T::LargeUtf8) => typed::(field, array), - (T::U64, T::LargeUtf8) => typed::(field, array), - (T::I8, T::LargeUtf8) => typed::(field, array), - (T::I16, T::LargeUtf8) => typed::(field, array), - (T::I32, T::LargeUtf8) => typed::(field, array), - (T::I64, T::LargeUtf8) => typed::(field, array), - _ => fail!("invalid dicitonary key / value data type"), - }; - - pub fn typed<'a, K, V>( - _field: &GenericField, - array: &'a dyn Array, - ) -> Result> - where - K: ArrowDictionaryKeyType, - K::Native: Integer, - V: OffsetSizeTrait + Offset, - DictionaryDeserializer<'a, K::Native, V>: Into>, - { - let Some(array) = array.as_any().downcast_ref::>() else { - fail!( - "cannot convert {} array into dictionary array", - array.data_type() - ); - }; - let Some(values) = array - .values() - .as_any() - .downcast_ref::>() - else { - fail!("invalid values"); - }; - - let keys_buffer = array.keys().values(); - let keys_validity = get_validity(array); - - let values_data = values.value_data(); - let values_offsets = values.value_offsets(); - - Ok( - DictionaryDeserializer::new(keys_buffer, keys_validity, values_data, values_offsets) - .into(), - ) - } -} - pub fn build_struct_fields<'a>( fields: &[GenericField], arrays: &[&'a dyn Array], @@ -540,6 +467,22 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { meta: meta_from_field(GenericField::try_from(entries_field.as_ref())?)?, element: Box::new(entries_array.try_into()?), })) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::>() { + wrap_dictionary_array::(array) } else { fail!( "Cannot build an array view for {dt}", @@ -549,6 +492,17 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { } } +fn wrap_dictionary_array( + array: &DictionaryArray, +) -> Result> { + let keys: &dyn Array = array.keys(); + + Ok(ArrayView::Dictionary(DictionaryArrayView { + indices: Box::new(keys.try_into()?), + values: Box::new(array.values().as_ref().try_into()?), + })) +} + fn get_bits_with_offset(array: &dyn Array) -> Option> { let validity = array.nulls()?; Some(BitsWithOffset { diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index d36c3d19..6c582d91 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -5,9 +5,9 @@ mod data_type; pub use array::{ Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, - DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, FieldMeta, - FixedSizeBinaryArray, FixedSizeListArray, FixedSizeListArrayView, ListArray, ListArrayView, - NullArray, NullArrayView, PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, - TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, + DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, DictionaryArrayView, + FieldMeta, FixedSizeBinaryArray, FixedSizeListArray, FixedSizeListArrayView, ListArray, + ListArrayView, NullArray, NullArrayView, PrimitiveArray, PrimitiveArrayView, StructArray, + StructArrayView, TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 0bfde2ea..c0d1eea9 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -2,23 +2,32 @@ use half::f16; use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ - arrow::{ArrayView, BitsWithOffset, FieldMeta, TimeUnit}, + arrow::{ArrayView, BitsWithOffset, BytesArrayView, FieldMeta, PrimitiveArrayView, TimeUnit}, error::{fail, Error, Result}, schema::{Strategy, STRATEGY_KEY}, - utils::Mut, + utils::{Mut, Offset}, }; use super::{ - binary_deserializer::BinaryDeserializer, bool_deserializer::BoolDeserializer, - date32_deserializer::Date32Deserializer, date64_deserializer::Date64Deserializer, - decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, + binary_deserializer::BinaryDeserializer, + bool_deserializer::BoolDeserializer, + date32_deserializer::Date32Deserializer, + date64_deserializer::Date64Deserializer, + decimal_deserializer::DecimalDeserializer, + dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, fixed_size_binary_deserializer::FixedSizeBinaryDeserializer, - fixed_size_list_deserializer::FixedSizeListDeserializer, float_deserializer::FloatDeserializer, - integer_deserializer::IntegerDeserializer, list_deserializer::ListDeserializer, - map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, - simple_deserializer::SimpleDeserializer, string_deserializer::StringDeserializer, - struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::BitBuffer, + fixed_size_list_deserializer::FixedSizeListDeserializer, + float_deserializer::FloatDeserializer, + integer_deserializer::{Integer, IntegerDeserializer}, + list_deserializer::ListDeserializer, + map_deserializer::MapDeserializer, + null_deserializer::NullDeserializer, + simple_deserializer::SimpleDeserializer, + string_deserializer::StringDeserializer, + struct_deserializer::StructDeserializer, + time_deserializer::TimeDeserializer, + utils::BitBuffer, }; pub enum ArrayDeserializer<'a> { @@ -262,11 +271,78 @@ impl<'a> ArrayDeserializer<'a> { ), )?)) } + ArrayView::Dictionary(view) => match (*view.indices, *view.values) { + (ArrayView::Int8(keys), ArrayView::Utf8(values)) => { + Ok(Self::DictionaryI8I32(build_dictionary_array(keys, values)?)) + } + (ArrayView::Int16(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryI16I32( + build_dictionary_array(keys, values)?, + )), + (ArrayView::Int32(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryI32I32( + build_dictionary_array(keys, values)?, + )), + (ArrayView::Int64(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryI64I32( + build_dictionary_array(keys, values)?, + )), + (ArrayView::UInt8(keys), ArrayView::Utf8(values)) => { + Ok(Self::DictionaryU8I32(build_dictionary_array(keys, values)?)) + } + (ArrayView::UInt16(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryU16I32( + build_dictionary_array(keys, values)?, + )), + (ArrayView::UInt32(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryU32I32( + build_dictionary_array(keys, values)?, + )), + (ArrayView::UInt64(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryU64I32( + build_dictionary_array(keys, values)?, + )), + (ArrayView::Int8(keys), ArrayView::LargeUtf8(values)) => { + Ok(Self::DictionaryI8I64(build_dictionary_array(keys, values)?)) + } + (ArrayView::Int16(keys), ArrayView::LargeUtf8(values)) => Ok( + Self::DictionaryI16I64(build_dictionary_array(keys, values)?), + ), + (ArrayView::Int32(keys), ArrayView::LargeUtf8(values)) => Ok( + Self::DictionaryI32I64(build_dictionary_array(keys, values)?), + ), + (ArrayView::Int64(keys), ArrayView::LargeUtf8(values)) => Ok( + Self::DictionaryI64I64(build_dictionary_array(keys, values)?), + ), + (ArrayView::UInt8(keys), ArrayView::LargeUtf8(values)) => { + Ok(Self::DictionaryU8I64(build_dictionary_array(keys, values)?)) + } + (ArrayView::UInt16(keys), ArrayView::LargeUtf8(values)) => Ok( + Self::DictionaryU16I64(build_dictionary_array(keys, values)?), + ), + (ArrayView::UInt32(keys), ArrayView::LargeUtf8(values)) => Ok( + Self::DictionaryU32I64(build_dictionary_array(keys, values)?), + ), + (ArrayView::UInt64(keys), ArrayView::LargeUtf8(values)) => Ok( + Self::DictionaryU64I64(build_dictionary_array(keys, values)?), + ), + _ => fail!("unsupported dictionary array"), + }, _ => unimplemented!(), } } } +fn build_dictionary_array<'a, K: Integer, V: Offset>( + keys: PrimitiveArrayView<'a, K>, + values: BytesArrayView<'a, V>, +) -> Result> { + if values.validity.is_some() { + // TODO: check whether all values are defined? + fail!("dictionaries with nullable values are not supported"); + } + Ok(DictionaryDeserializer::new( + keys.values, + buffer_from_bits_with_offset_opt(keys.validity, keys.values.len()), + values.data, + values.offsets, + )) +} + fn is_utc_timestamp(timezone: Option<&str>) -> Result { match timezone { Some(tz) if tz.to_lowercase() == "utc" => Ok(true), From 2475512f904d6fa9d35ef06955b4f34fc2e91db5 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 24 Jul 2024 21:43:05 +0200 Subject: [PATCH 037/178] Implement Union --- serde_arrow/src/arrow_impl/deserialization.rs | 66 +++++++++---------- serde_arrow/src/internal/arrow/mod.rs | 9 +-- .../deserialization/array_deserializer.rs | 10 +++ 3 files changed, 46 insertions(+), 39 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index e58a4216..c7128f08 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -1,11 +1,12 @@ use crate::internal::{ arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - DictionaryArrayView, FixedSizeListArrayView, ListArrayView, NullArrayView, - PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, TimestampArrayView, + DenseUnionArrayView, DictionaryArrayView, FixedSizeListArrayView, ListArrayView, + NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, + TimestampArrayView, }, deserialization::{ - array_deserializer::ArrayDeserializer, enum_deserializer::EnumDeserializer, + array_deserializer::ArrayDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, utils::BitBuffer, }, deserializer::Deserializer, @@ -17,8 +18,8 @@ use crate::internal::{ use crate::_impl::arrow::{ array::{ Array, BooleanArray, DictionaryArray, FixedSizeListArray, GenericBinaryArray, - GenericListArray, GenericStringArray, MapArray, NullArray, PrimitiveArray, - RecordBatch, StructArray, UnionArray, + GenericListArray, GenericStringArray, MapArray, NullArray, PrimitiveArray, RecordBatch, + StructArray, UnionArray, }, datatypes::{ ArrowDictionaryKeyType, DataType, Date32Type, Date64Type, Decimal128Type, @@ -115,7 +116,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), - T::Union => build_union_deserializer(field, array), _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), } } @@ -172,35 +172,6 @@ pub fn build_fixed_size_binary_deserializer<'a>( fail!("FixedSizeBinary arrays are not supported for arrow<=46"); } -pub fn build_union_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!( - "Cannot interpret {} array as a union array", - array.data_type() - ); - }; - - if !matches!(array.data_type(), DataType::Union(_, UnionMode::Dense)) { - fail!("Invalid data type: only dense unions are supported"); - } - - let type_ids = array.type_ids(); - - let mut variants = Vec::new(); - for (type_id, field) in field.children.iter().enumerate() { - // TODO: how to prevent a panic? + validate the order / type_ids - let name = field.name.to_owned(); - let deser = build_array_deserializer(field, array.child(type_id.try_into()?).as_ref())?; - - variants.push((name, deser)); - } - - Ok(EnumDeserializer::new(type_ids, variants).into()) -} - fn get_validity(arr: &dyn Array) -> Option> { let validity = arr.nulls()?; let data = validity.validity(); @@ -483,6 +454,31 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { wrap_dictionary_array::(array) } else if let Some(array) = any.downcast_ref::>() { wrap_dictionary_array::(array) + } else if let Some(array) = any.downcast_ref::() { + let DataType::Union(union_fields, UnionMode::Dense) = array.data_type() else { + fail!("Invalid data type: only dense unions are supported"); + }; + + let mut fields = Vec::new(); + for (type_idx, (type_id, field)) in union_fields.iter().enumerate() { + if type_id < 0 || usize::try_from(type_id)? != type_idx { + fail!("invalid union, only unions with consecutive variants are supported"); + } + + let meta = meta_from_field(GenericField::try_from(field.as_ref())?)?; + let view: ArrayView = array.child(type_id).as_ref().try_into()?; + fields.push((view, meta)); + } + let Some(offsets) = array.offsets() else { + fail!("Dense unions must have an offset array"); + }; + + // array.type_ids() + Ok(ArrayView::DenseUnion(DenseUnionArrayView { + types: array.type_ids(), + offsets, + fields, + })) } else { fail!( "Cannot build an array view for {dt}", diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 6c582d91..90c965c2 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -5,9 +5,10 @@ mod data_type; pub use array::{ Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, - DecimalArray, DecimalArrayView, DenseUnionArray, DictionaryArray, DictionaryArrayView, - FieldMeta, FixedSizeBinaryArray, FixedSizeListArray, FixedSizeListArrayView, ListArray, - ListArrayView, NullArray, NullArrayView, PrimitiveArray, PrimitiveArrayView, StructArray, - StructArrayView, TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, + DecimalArray, DecimalArrayView, DenseUnionArray, DenseUnionArrayView, DictionaryArray, + DictionaryArrayView, FieldMeta, FixedSizeBinaryArray, FixedSizeListArray, + FixedSizeListArrayView, ListArray, ListArrayView, NullArray, NullArrayView, PrimitiveArray, + PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, TimestampArray, + TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index c0d1eea9..9ed3f9ca 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -322,6 +322,16 @@ impl<'a> ArrayDeserializer<'a> { ), _ => fail!("unsupported dictionary array"), }, + ArrayView::DenseUnion(view) => { + let mut fields = Vec::new(); + for (field_view, field_meta) in view.fields { + let field_deserializer = + ArrayDeserializer::new(get_strategy(&field_meta)?.as_ref(), field_view)?; + fields.push((field_meta.name, field_deserializer)) + } + + Ok(Self::Enum(EnumDeserializer::new(view.types, fields))) + } _ => unimplemented!(), } } From 7424407a7cda81abcecb3f650fb60a3d2c4bdb67 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 24 Jul 2024 21:55:16 +0200 Subject: [PATCH 038/178] Implement FixedSizeBinary --- serde_arrow/src/arrow_impl/deserialization.rs | 122 ++++++------------ serde_arrow/src/internal/arrow/mod.rs | 8 +- .../deserialization/array_deserializer.rs | 14 +- 3 files changed, 58 insertions(+), 86 deletions(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index c7128f08..e38aa496 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -7,19 +7,19 @@ use crate::internal::{ }, deserialization::{ array_deserializer::ArrayDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, utils::BitBuffer, + outer_sequence_deserializer::OuterSequenceDeserializer, }, deserializer::Deserializer, error::{fail, Error, Result}, - schema::{GenericDataType, GenericField}, + schema::GenericField, serialization::utils::meta_from_field, }; use crate::_impl::arrow::{ array::{ - Array, BooleanArray, DictionaryArray, FixedSizeListArray, GenericBinaryArray, - GenericListArray, GenericStringArray, MapArray, NullArray, PrimitiveArray, RecordBatch, - StructArray, UnionArray, + Array, BooleanArray, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, + GenericBinaryArray, GenericListArray, GenericStringArray, MapArray, NullArray, + PrimitiveArray, RecordBatch, StructArray, UnionArray, }, datatypes::{ ArrowDictionaryKeyType, DataType, Date32Type, Date64Type, Decimal128Type, @@ -71,7 +71,24 @@ impl<'de> Deserializer<'de> { .map(|array| array.as_ref()) .collect::>(); - let (deserializers, len) = build_struct_fields(&fields, &arrays)?; + if fields.len() != arrays.len() { + fail!( + "different number of fields ({}) and arrays ({})", + fields.len(), + arrays.len() + ); + } + let len = arrays.first().map(|array| array.len()).unwrap_or_default(); + + let mut deserializers = Vec::new(); + for (field, array) in std::iter::zip(&fields, arrays) { + if array.len() != len { + fail!("arrays of different lengths are not supported"); + } + + let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; + deserializers.push((field.name.clone(), deserializer)); + } let deserializer = OuterSequenceDeserializer::new(deserializers, len); let deserializer = Deserializer(deserializer); @@ -109,81 +126,6 @@ impl<'de> Deserializer<'de> { } } -pub fn build_array_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - use GenericDataType as T; - match &field.data_type { - T::FixedSizeBinary(_) => build_fixed_size_binary_deserializer(field, array), - _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), - } -} - -pub fn build_struct_fields<'a>( - fields: &[GenericField], - arrays: &[&'a dyn Array], -) -> Result<(Vec<(String, ArrayDeserializer<'a>)>, usize)> { - if fields.len() != arrays.len() { - fail!( - "different number of fields ({}) and arrays ({})", - fields.len(), - arrays.len() - ); - } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - - let mut deserializers = Vec::new(); - for (field, &array) in std::iter::zip(fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } - - deserializers.push((field.name.clone(), build_array_deserializer(field, array)?)); - } - - Ok((deserializers, len)) -} - -#[cfg(has_arrow_fixed_binary_support)] -pub fn build_fixed_size_binary_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - use crate::_impl::arrow::array::FixedSizeBinaryArray; - use crate::internal::deserialization::fixed_size_binary_deserializer::FixedSizeBinaryDeserializer; - - let Some(array) = array.as_any().downcast_ref::() else { - fail!("cannot convert {} array into string", array.data_type()); - }; - - let shape = (array.len(), array.value_length().try_into()?); - let buffer = array.value_data(); - let validity = get_validity(array); - - Ok(FixedSizeBinaryDeserializer::new(shape, buffer, validity).into()) -} - -#[cfg(not(has_arrow_fixed_binary_support))] -pub fn build_fixed_size_binary_deserializer<'a>( - _field: &GenericField, - _array: &'a dyn Array, -) -> Result> { - fail!("FixedSizeBinary arrays are not supported for arrow<=46"); -} - -fn get_validity(arr: &dyn Array) -> Option> { - let validity = arr.nulls()?; - let data = validity.validity(); - let offset = validity.offset(); - let number_of_bits = validity.len(); - Some(BitBuffer { - data, - offset, - number_of_bits, - }) -} - impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { type Error = Error; @@ -378,6 +320,8 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { offsets: array.value_offsets(), data: array.value_data(), })) + } else if let Some(array) = any.downcast_ref::() { + wrap_fixed_size_binary_array(array) } else if let Some(array) = any.downcast_ref::>() { let DataType::List(field) = array.data_type() else { fail!("invalid data type for list array: {}", array.data_type()); @@ -499,6 +443,22 @@ fn wrap_dictionary_array( })) } +#[cfg(has_arrow_fixed_binary_support)] +pub fn wrap_fixed_size_binary_array<'a>(array: &'a FixedSizeBinaryArray) -> Result> { + use crate::internal::arrow::FixedSizeBinaryArrayView; + + Ok(ArrayView::FixedSizeBinary(FixedSizeBinaryArrayView { + n: array.value_length(), + validity: get_bits_with_offset(array), + data: array.value_data(), + })) +} + +#[cfg(not(has_arrow_fixed_binary_support))] +pub fn wrap_fixed_size_binary_array<'a>(_array: &'a FixedSizeBinaryArray) -> Result> { + fail!("FixedSizeBinary arrays are not supported for arrow<=46"); +} + fn get_bits_with_offset(array: &dyn Array) -> Option> { let validity = array.nulls()?; Some(BitsWithOffset { diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 90c965c2..d439e73e 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -6,9 +6,9 @@ mod data_type; pub use array::{ Array, ArrayView, BitsWithOffset, BooleanArray, BooleanArrayView, BytesArray, BytesArrayView, DecimalArray, DecimalArrayView, DenseUnionArray, DenseUnionArrayView, DictionaryArray, - DictionaryArrayView, FieldMeta, FixedSizeBinaryArray, FixedSizeListArray, - FixedSizeListArrayView, ListArray, ListArrayView, NullArray, NullArrayView, PrimitiveArray, - PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, TimestampArray, - TimestampArrayView, + DictionaryArrayView, FieldMeta, FixedSizeBinaryArray, FixedSizeBinaryArrayView, + FixedSizeListArray, FixedSizeListArrayView, ListArray, ListArrayView, NullArray, NullArrayView, + PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, + TimestampArray, TimestampArrayView, }; pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 9ed3f9ca..8f52798e 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -209,6 +209,19 @@ impl<'a> ArrayDeserializer<'a> { view.offsets.len().saturating_sub(1), ), ))), + ArrayView::FixedSizeBinary(view) => { + let value_length: usize = view.n.try_into()?; + if view.data.len() % value_length != 0 { + fail!("Invalid FixedSizeBinary array: Data is not evenly divisible into chunks of size {value_length}"); + } + let len = view.data.len() / value_length; + + Ok(Self::FixedSizeBinary(FixedSizeBinaryDeserializer::new( + (len, value_length), + view.data, + buffer_from_bits_with_offset_opt(view.validity, len), + ))) + } ArrayView::List(view) => Ok(Self::List(ListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, @@ -332,7 +345,6 @@ impl<'a> ArrayDeserializer<'a> { Ok(Self::Enum(EnumDeserializer::new(view.types, fields))) } - _ => unimplemented!(), } } } From 76b29a8e19235f7accd1cd43be585c7fcf0b0138 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 11:02:01 +0200 Subject: [PATCH 039/178] Add tests to confirm that from_type does not work DateTime and Uuid --- Cargo.lock | 9 +++-- serde_arrow/Cargo.toml | 1 + .../src/test_with_arrow/impls/chrono.rs | 6 ++++ .../test_with_arrow/impls/issue_203_uuid.rs | 34 +++++++++++++++++++ serde_arrow/src/test_with_arrow/impls/mod.rs | 1 + 5 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs diff --git a/Cargo.lock b/Cargo.lock index 73d47c08..45605717 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2312,6 +2312,7 @@ dependencies = [ "serde_bytes", "serde_json", "simd-json", + "uuid", ] [[package]] @@ -2505,9 +2506,13 @@ checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" [[package]] name = "uuid" -version = "1.6.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", + "serde", +] [[package]] name = "value-trait" diff --git a/serde_arrow/Cargo.toml b/serde_arrow/Cargo.toml index ee35bba1..780fc528 100644 --- a/serde_arrow/Cargo.toml +++ b/serde_arrow/Cargo.toml @@ -132,6 +132,7 @@ serde_json = "1" serde_bytes = "0.11" rand = "0.8" bigdecimal = {version = "0.4", features = ["serde"] } +uuid = { version = "1.10.0", features = ["serde", "v4"] } # for benchmarks # arrow-version:replace: arrow-json-{version} = {{ package = "arrow-json", version = "{version}" }} diff --git a/serde_arrow/src/test_with_arrow/impls/chrono.rs b/serde_arrow/src/test_with_arrow/impls/chrono.rs index 621e0e13..dcfe1334 100644 --- a/serde_arrow/src/test_with_arrow/impls/chrono.rs +++ b/serde_arrow/src/test_with_arrow/impls/chrono.rs @@ -9,6 +9,12 @@ use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use serde::{Deserialize, Serialize}; use serde_json::json; +#[test] +fn trace_from_type_does_not_work() { + let res = SerdeArrowSchema::from_type::>>(TracingOptions::default()); + assert_error(&res, "premature end of input"); +} + #[test] fn utc_as_str() { let items = [ diff --git a/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs b/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs new file mode 100644 index 00000000..ed4d4a81 --- /dev/null +++ b/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs @@ -0,0 +1,34 @@ +use serde_json::json; +use uuid::Uuid; + +use crate::{ + internal::testing::assert_error, + schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, + utils::Item, +}; + +use super::utils::Test; + +#[test] +fn example_as_list() { + let items = [ + Item(Uuid::new_v4()), + Item(Uuid::new_v4()), + Item(Uuid::new_v4()), + ]; + + Test::new() + .with_schema(json!([{ + "name": "item", + "data_type": "LargeUtf8", + }])) + .trace_schema_from_samples(&items, TracingOptions::default()) + .serialize(&items) + .deserialize(&items); +} + +#[test] +fn trace_from_type_does_not_work() { + let res = SerdeArrowSchema::from_type::>(TracingOptions::default()); + assert_error(&res, "UUID parsing failed"); +} diff --git a/serde_arrow/src/test_with_arrow/impls/mod.rs b/serde_arrow/src/test_with_arrow/impls/mod.rs index ca900144..90695d09 100644 --- a/serde_arrow/src/test_with_arrow/impls/mod.rs +++ b/serde_arrow/src/test_with_arrow/impls/mod.rs @@ -14,6 +14,7 @@ mod tuple; mod r#union; mod wrappers; +mod issue_203_uuid; mod issue_59_decimals; mod issue_74_unknown_fields; mod issue_79_declared_but_missing_fields; From 35c6646b227da19a5b95cbabd4f9a9ffdecd2b36 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 11:02:31 +0200 Subject: [PATCH 040/178] Document issues with from_type for Uuid and DateTime --- serde_arrow/src/internal/schema/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 3af0ea09..13b33ceb 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -150,6 +150,8 @@ pub trait SchemaLike: Sized + Sealed { /// - auto detection of date time strings /// - non self-describing types such as `serde_json::Value` /// - flattened structure (`#[serde(flatten)]`) + /// - types that require specific data to be deserialized, such as the + /// `DateTime` type of `chrono` or the `Uuid` type of the `uuid` package /// /// Consider using [`from_samples`][SchemaLike::from_samples] in these /// cases. From 1e839aa9be6f06f2609701f948436f23d06d8442 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 11:08:20 +0200 Subject: [PATCH 041/178] Remove stray comment --- serde_arrow/src/arrow_impl/deserialization.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index e38aa496..cd688740 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -417,7 +417,6 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { fail!("Dense unions must have an offset array"); }; - // array.type_ids() Ok(ArrayView::DenseUnion(DenseUnionArrayView { types: array.type_ids(), offsets, From 0c77a6f713780463760c874a11cd2ff754dd70c9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 11:15:31 +0200 Subject: [PATCH 042/178] Refactor code, remove unnecessary abstractions --- serde_arrow/src/arrow_impl/api.rs | 46 ++++++++++++++++-- serde_arrow/src/arrow_impl/serialization.rs | 53 +-------------------- 2 files changed, 44 insertions(+), 55 deletions(-) diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index c73e906c..2b9d8a0a 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -1,14 +1,19 @@ #![deny(missing_docs)] +use std::sync::Arc; + use serde::{Deserialize, Serialize}; use crate::{ _impl::arrow::{ - array::{Array, ArrayRef, RecordBatch}, - datatypes::FieldRef, + array::{make_array, Array, ArrayRef, RecordBatch}, + datatypes::{FieldRef, Schema}, }, internal::{ - array_builder::ArrayBuilder, deserializer::Deserializer, error::Result, - schema::SerdeArrowSchema, serializer::Serializer, + array_builder::ArrayBuilder, + deserializer::Deserializer, + error::Result, + schema::{GenericField, SerdeArrowSchema}, + serializer::Serializer, }, }; @@ -169,3 +174,36 @@ pub fn to_record_batch( pub fn from_record_batch<'de, T: Deserialize<'de>>(record_batch: &'de RecordBatch) -> Result { T::deserialize(Deserializer::from_record_batch(record_batch)?) } + +/// Support `arrow` (*requires one of the `arrow-*` features*) +impl crate::internal::array_builder::ArrayBuilder { + /// Build an ArrayBuilder from `arrow` fields (*requires one of the + /// `arrow-*` features*) + pub fn from_arrow(fields: &[FieldRef]) -> Result { + let fields = fields + .iter() + .map(|f| GenericField::try_from(f.as_ref())) + .collect::>>()?; + Self::new(SerdeArrowSchema { fields }) + } + + /// Construct `arrow` arrays and reset the builder (*requires one of the + /// `arrow-*` features*) + pub fn to_arrow(&mut self) -> Result> { + let mut arrays = Vec::new(); + for field in self.builder.take_records()? { + let data = field.into_array()?.try_into()?; + arrays.push(make_array(data)); + } + Ok(arrays) + } + + /// Construct a [`RecordBatch`] and reset the builder (*requires one of the + /// `arrow-*` features*) + pub fn to_record_batch(&mut self) -> Result { + let arrays = self.to_arrow()?; + let fields = Vec::::try_from(&self.schema)?; + let schema = Schema::new(fields); + Ok(RecordBatch::try_new(Arc::new(schema), arrays)?) + } +} diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs index d47c823c..ada887ec 100644 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ b/serde_arrow/src/arrow_impl/serialization.rs @@ -5,65 +5,16 @@ use half::f16; use crate::{ _impl::arrow::{ - array::{make_array, Array, ArrayData, ArrayRef, NullArray, RecordBatch}, + array::{Array, ArrayData, NullArray}, buffer::{Buffer, ScalarBuffer}, - datatypes::{ - ArrowNativeType, ArrowPrimitiveType, DataType, Field, FieldRef, Float16Type, Schema, - UnionMode, - }, + datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field, Float16Type, UnionMode}, }, internal::{ arrow::FieldMeta, error::{fail, Error, Result}, - schema::{GenericField, SerdeArrowSchema}, - serialization::{ArrayBuilder, OuterSequenceBuilder}, }, }; -/// Support `arrow` (*requires one of the `arrow-*` features*) -impl crate::internal::array_builder::ArrayBuilder { - /// Build an ArrayBuilder from `arrow` fields (*requires one of the - /// `arrow-*` features*) - pub fn from_arrow(fields: &[FieldRef]) -> Result { - let fields = fields - .iter() - .map(|f| GenericField::try_from(f.as_ref())) - .collect::>>()?; - Self::new(SerdeArrowSchema { fields }) - } - - /// Construct `arrow` arrays and reset the builder (*requires one of the - /// `arrow-*` features*) - pub fn to_arrow(&mut self) -> Result> { - self.builder.build_arrow() - } - - /// Construct a [`RecordBatch`] and reset the builder (*requires one of the - /// `arrow-*` features*) - pub fn to_record_batch(&mut self) -> Result { - let arrays = self.builder.build_arrow()?; - let fields = Vec::::try_from(&self.schema)?; - let schema = Schema::new(fields); - Ok(RecordBatch::try_new(Arc::new(schema), arrays)?) - } -} - -impl OuterSequenceBuilder { - pub fn build_arrow(&mut self) -> Result> { - let fields = self.take_records()?; - let arrays = fields - .into_iter() - .map(build_array) - .collect::>>()?; - Ok(arrays) - } -} - -fn build_array(builder: ArrayBuilder) -> Result { - let data = builder.into_array()?.try_into()?; - Ok(make_array(data)) -} - impl TryFrom for ArrayData { type Error = Error; From e0a5bd5e1ae5f9849ec71d734296662df264fdef Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 11:33:44 +0200 Subject: [PATCH 043/178] More refactoring --- serde_arrow/src/arrow_impl/api.rs | 106 +++++++++++++++++- serde_arrow/src/arrow_impl/deserialization.rs | 105 +---------------- serde_arrow/src/arrow_impl/type_support.rs | 14 ++- 3 files changed, 115 insertions(+), 110 deletions(-) diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index 2b9d8a0a..f683f410 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -10,13 +10,19 @@ use crate::{ }, internal::{ array_builder::ArrayBuilder, + deserialization::{ + array_deserializer::ArrayDeserializer, + outer_sequence_deserializer::OuterSequenceDeserializer, + }, deserializer::Deserializer, - error::Result, - schema::{GenericField, SerdeArrowSchema}, + error::{fail, Result}, + schema::SerdeArrowSchema, serializer::Serializer, }, }; +use super::type_support::fields_from_field_refs; + /// Build arrow arrays from the given items (*requires one of the `arrow-*` /// features*) /// @@ -180,10 +186,7 @@ impl crate::internal::array_builder::ArrayBuilder { /// Build an ArrayBuilder from `arrow` fields (*requires one of the /// `arrow-*` features*) pub fn from_arrow(fields: &[FieldRef]) -> Result { - let fields = fields - .iter() - .map(|f| GenericField::try_from(f.as_ref())) - .collect::>>()?; + let fields = fields_from_field_refs(fields)?; Self::new(SerdeArrowSchema { fields }) } @@ -207,3 +210,94 @@ impl crate::internal::array_builder::ArrayBuilder { Ok(RecordBatch::try_new(Arc::new(schema), arrays)?) } } + +impl<'de> Deserializer<'de> { + /// Construct a new deserializer from `arrow` arrays (*requires one of the + /// `arrow-*` features*) + /// + /// Usage + /// ```rust + /// # fn main() -> serde_arrow::Result<()> { + /// # let (_, arrays) = serde_arrow::_impl::docs::defs::example_arrow_arrays(); + /// # use serde_arrow::_impl::arrow; + /// use arrow::datatypes::FieldRef; + /// use serde::{Deserialize, Serialize}; + /// use serde_arrow::{Deserializer, schema::{SchemaLike, TracingOptions}}; + /// + /// ##[derive(Deserialize, Serialize)] + /// struct Record { + /// a: Option, + /// b: u64, + /// } + /// + /// let fields = Vec::::from_type::(TracingOptions::default())?; + /// + /// let deserializer = Deserializer::from_arrow(&fields, &arrays)?; + /// let items = Vec::::deserialize(deserializer)?; + /// # Ok(()) + /// # } + /// ``` + pub fn from_arrow(fields: &[FieldRef], arrays: &'de [A]) -> Result + where + A: AsRef, + { + let fields = fields_from_field_refs(fields)?; + let arrays = arrays + .iter() + .map(|array| array.as_ref()) + .collect::>(); + + if fields.len() != arrays.len() { + fail!( + "different number of fields ({}) and arrays ({})", + fields.len(), + arrays.len() + ); + } + let len = arrays.first().map(|array| array.len()).unwrap_or_default(); + + let mut deserializers = Vec::new(); + for (field, array) in std::iter::zip(&fields, arrays) { + if array.len() != len { + fail!("arrays of different lengths are not supported"); + } + + let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; + deserializers.push((field.name.clone(), deserializer)); + } + + let deserializer = OuterSequenceDeserializer::new(deserializers, len); + let deserializer = Deserializer(deserializer); + + Ok(deserializer) + } + + /// Construct a new deserializer from a record batch (*requires one of the + /// `arrow-*` features*) + /// + /// Usage: + /// + /// ```rust + /// # fn main() -> serde_arrow::Result<()> { + /// # let record_batch = serde_arrow::_impl::docs::defs::example_record_batch(); + /// # + /// use serde::Deserialize; + /// use serde_arrow::Deserializer; + /// + /// ##[derive(Deserialize)] + /// struct Record { + /// a: Option, + /// b: u64, + /// } + /// + /// let deserializer = Deserializer::from_record_batch(&record_batch)?; + /// let items = Vec::::deserialize(deserializer)?; + /// # Ok(()) + /// # } + /// ``` + /// + pub fn from_record_batch(record_batch: &'de RecordBatch) -> Result { + let schema = record_batch.schema(); + Deserializer::from_arrow(schema.fields(), record_batch.columns()) + } +} diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/deserialization.rs index cd688740..d9a3fbd2 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/deserialization.rs @@ -5,11 +5,6 @@ use crate::internal::{ NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, TimestampArrayView, }, - deserialization::{ - array_deserializer::ArrayDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - }, - deserializer::Deserializer, error::{fail, Error, Result}, schema::GenericField, serialization::utils::meta_from_field, @@ -19,113 +14,19 @@ use crate::_impl::arrow::{ array::{ Array, BooleanArray, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, MapArray, NullArray, - PrimitiveArray, RecordBatch, StructArray, UnionArray, + PrimitiveArray, StructArray, UnionArray, }, datatypes::{ ArrowDictionaryKeyType, DataType, Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionMode, }, }; -impl<'de> Deserializer<'de> { - /// Construct a new deserializer from `arrow` arrays (*requires one of the - /// `arrow-*` features*) - /// - /// Usage - /// ```rust - /// # fn main() -> serde_arrow::Result<()> { - /// # let (_, arrays) = serde_arrow::_impl::docs::defs::example_arrow_arrays(); - /// # use serde_arrow::_impl::arrow; - /// use arrow::datatypes::FieldRef; - /// use serde::{Deserialize, Serialize}; - /// use serde_arrow::{Deserializer, schema::{SchemaLike, TracingOptions}}; - /// - /// ##[derive(Deserialize, Serialize)] - /// struct Record { - /// a: Option, - /// b: u64, - /// } - /// - /// let fields = Vec::::from_type::(TracingOptions::default())?; - /// - /// let deserializer = Deserializer::from_arrow(&fields, &arrays)?; - /// let items = Vec::::deserialize(deserializer)?; - /// # Ok(()) - /// # } - /// ``` - pub fn from_arrow(fields: &[FieldRef], arrays: &'de [A]) -> Result - where - A: AsRef, - { - let fields = fields - .iter() - .map(|field| GenericField::try_from(field.as_ref())) - .collect::>>()?; - let arrays = arrays - .iter() - .map(|array| array.as_ref()) - .collect::>(); - - if fields.len() != arrays.len() { - fail!( - "different number of fields ({}) and arrays ({})", - fields.len(), - arrays.len() - ); - } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - - let mut deserializers = Vec::new(); - for (field, array) in std::iter::zip(&fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } - - let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; - deserializers.push((field.name.clone(), deserializer)); - } - - let deserializer = OuterSequenceDeserializer::new(deserializers, len); - let deserializer = Deserializer(deserializer); - - Ok(deserializer) - } - - /// Construct a new deserializer from a record batch (*requires one of the - /// `arrow-*` features*) - /// - /// Usage: - /// - /// ```rust - /// # fn main() -> serde_arrow::Result<()> { - /// # let record_batch = serde_arrow::_impl::docs::defs::example_record_batch(); - /// # - /// use serde::Deserialize; - /// use serde_arrow::Deserializer; - /// - /// ##[derive(Deserialize)] - /// struct Record { - /// a: Option, - /// b: u64, - /// } - /// - /// let deserializer = Deserializer::from_record_batch(&record_batch)?; - /// let items = Vec::::deserialize(deserializer)?; - /// # Ok(()) - /// # } - /// ``` - /// - pub fn from_record_batch(record_batch: &'de RecordBatch) -> Result { - let schema = record_batch.schema(); - Deserializer::from_arrow(schema.fields(), record_batch.columns()) - } -} - impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { type Error = Error; diff --git a/serde_arrow/src/arrow_impl/type_support.rs b/serde_arrow/src/arrow_impl/type_support.rs index 9c4edf64..8fdbddf0 100644 --- a/serde_arrow/src/arrow_impl/type_support.rs +++ b/serde_arrow/src/arrow_impl/type_support.rs @@ -1,7 +1,10 @@ -use crate::_impl::arrow::{datatypes::Field, error::ArrowError}; +use crate::_impl::arrow::{ + datatypes::{Field, FieldRef}, + error::ArrowError, +}; use crate::internal::{ - error::Error, + error::{Error, Result}, schema::{extensions::FixedShapeTensorField, GenericField}, }; @@ -26,3 +29,10 @@ impl TryFrom for Field { Self::try_from(&value) } } + +pub fn fields_from_field_refs(fields: &[FieldRef]) -> Result> { + fields + .iter() + .map(|field| GenericField::try_from(field.as_ref())) + .collect() +} From 7f572a94c4e4b1e6bf231c73907ff15ae6646e5c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 11:39:51 +0200 Subject: [PATCH 044/178] More refactoring --- .../{deserialization.rs => array.rs} | 306 ++++++++++++++++-- serde_arrow/src/arrow_impl/mod.rs | 3 +- serde_arrow/src/arrow_impl/serialization.rs | 266 --------------- 3 files changed, 283 insertions(+), 292 deletions(-) rename serde_arrow/src/arrow_impl/{deserialization.rs => array.rs} (58%) delete mode 100644 serde_arrow/src/arrow_impl/serialization.rs diff --git a/serde_arrow/src/arrow_impl/deserialization.rs b/serde_arrow/src/arrow_impl/array.rs similarity index 58% rename from serde_arrow/src/arrow_impl/deserialization.rs rename to serde_arrow/src/arrow_impl/array.rs index d9a3fbd2..fadcd94a 100644 --- a/serde_arrow/src/arrow_impl/deserialization.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -1,32 +1,235 @@ -use crate::internal::{ - arrow::{ - ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - DenseUnionArrayView, DictionaryArrayView, FixedSizeListArrayView, ListArrayView, - NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, - TimestampArrayView, - }, - error::{fail, Error, Result}, - schema::GenericField, - serialization::utils::meta_from_field, -}; +//! Convert between arrow arrays and the internal array representation +use std::sync::Arc; + +use half::f16; -use crate::_impl::arrow::{ - array::{ - Array, BooleanArray, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, - GenericBinaryArray, GenericListArray, GenericStringArray, MapArray, NullArray, - PrimitiveArray, StructArray, UnionArray, +use crate::{ + _impl::arrow::{ + array::{ + Array, ArrayData, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, MapArray, + NullArray, PrimitiveArray, StructArray, UnionArray, + }, + buffer::{Buffer, ScalarBuffer}, + datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Date32Type, + Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Field, Float16Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionMode, + }, }, - datatypes::{ - ArrowDictionaryKeyType, DataType, Date32Type, Date64Type, Decimal128Type, - DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, UnionMode, + internal::{ + arrow::FieldMeta, + arrow::{ + ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, + DenseUnionArrayView, DictionaryArrayView, FixedSizeListArrayView, ListArrayView, + NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, + TimestampArrayView, + }, + error::{fail, Error, Result}, + schema::GenericField, + serialization::utils::meta_from_field, }, }; +impl TryFrom for ArrayData { + type Error = Error; + + fn try_from(value: crate::internal::arrow::Array) -> Result { + use {crate::internal::arrow::Array as A, DataType as T}; + type ArrowF16 = ::Native; + + fn f16_to_f16(v: f16) -> ArrowF16 { + ArrowF16::from_bits(v.to_bits()) + } + + match value { + A::Null(arr) => Ok(NullArray::new(arr.len).into_data()), + A::Boolean(arr) => Ok(ArrayData::try_new( + T::Boolean, + // NOTE: use the explicit len + arr.len, + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.values).into_inner()], + vec![], + )?), + A::Int8(arr) => primitive_into_data(T::Int8, arr.validity, arr.values), + A::Int16(arr) => primitive_into_data(T::Int16, arr.validity, arr.values), + A::Int32(arr) => primitive_into_data(T::Int32, arr.validity, arr.values), + A::Int64(arr) => primitive_into_data(T::Int64, arr.validity, arr.values), + A::UInt8(arr) => primitive_into_data(T::UInt8, arr.validity, arr.values), + A::UInt16(arr) => primitive_into_data(T::UInt16, arr.validity, arr.values), + A::UInt32(arr) => primitive_into_data(T::UInt32, arr.validity, arr.values), + A::UInt64(arr) => primitive_into_data(T::UInt64, arr.validity, arr.values), + A::Float16(arr) => primitive_into_data( + T::Float16, + arr.validity, + arr.values.into_iter().map(f16_to_f16).collect(), + ), + A::Float32(arr) => primitive_into_data(T::Float32, arr.validity, arr.values), + A::Float64(arr) => primitive_into_data(T::Float64, arr.validity, arr.values), + A::Date32(arr) => primitive_into_data(T::Date32, arr.validity, arr.values), + A::Date64(arr) => primitive_into_data(T::Date64, arr.validity, arr.values), + A::Timestamp(arr) => primitive_into_data( + T::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), + arr.validity, + arr.values, + ), + A::Time32(arr) => { + primitive_into_data(T::Time32(arr.unit.into()), arr.validity, arr.values) + } + A::Time64(arr) => { + primitive_into_data(T::Time64(arr.unit.into()), arr.validity, arr.values) + } + A::Duration(arr) => { + primitive_into_data(T::Duration(arr.unit.into()), arr.validity, arr.values) + } + A::Decimal128(arr) => primitive_into_data( + T::Decimal128(arr.precision, arr.scale), + arr.validity, + arr.values, + ), + A::Utf8(arr) => bytes_into_data(T::Utf8, arr.offsets, arr.data, arr.validity), + A::LargeUtf8(arr) => bytes_into_data(T::LargeUtf8, arr.offsets, arr.data, arr.validity), + A::Binary(arr) => bytes_into_data(T::Binary, arr.offsets, arr.data, arr.validity), + A::LargeBinary(arr) => { + bytes_into_data(T::LargeBinary, arr.offsets, arr.data, arr.validity) + } + A::Struct(arr) => { + let mut fields = Vec::new(); + let mut data = Vec::new(); + + for (field, meta) in arr.fields { + let child: ArrayData = field.try_into()?; + let field = Field::new(meta.name, child.data_type().clone(), meta.nullable) + .with_metadata(meta.metadata); + fields.push(Arc::new(field)); + data.push(child); + } + let data_type = T::Struct(fields.into()); + + Ok(ArrayData::builder(data_type) + .len(arr.len) + .null_bit_buffer(arr.validity.map(Buffer::from)) + .child_data(data) + .build()?) + } + A::List(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + list_into_data( + T::List(Arc::new(field)), + arr.offsets.len().saturating_sub(1), + arr.offsets, + child, + arr.validity, + ) + } + A::LargeList(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + list_into_data( + T::LargeList(Arc::new(field)), + arr.offsets.len().saturating_sub(1), + arr.offsets, + child, + arr.validity, + ) + } + A::FixedSizeList(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + if (child.len() % usize::try_from(arr.n)?) != 0 { + fail!( + "Invalid FixedSizeList: number of child elements ({}) not divisible by n ({})", + child.len(), + arr.n, + ); + } + let field = field_from_data_and_meta(&child, arr.meta); + Ok(ArrayData::try_new( + T::FixedSizeList(Arc::new(field), arr.n), + child.len() / usize::try_from(arr.n)?, + arr.validity.map(Buffer::from), + 0, + vec![], + vec![child], + )?) + } + A::FixedSizeBinary(arr) => { + if (arr.data.len() % usize::try_from(arr.n)?) != 0 { + fail!( + "Invalid FixedSizeBinary: number of child elements ({}) not divisible by n ({})", + arr.data.len(), + arr.n, + ); + } + Ok(ArrayData::try_new( + T::FixedSizeBinary(arr.n), + arr.data.len() / usize::try_from(arr.n)?, + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.data).into_inner()], + vec![], + )?) + } + A::Dictionary(arr) => { + let indices: ArrayData = (*arr.indices).try_into()?; + let values: ArrayData = (*arr.values).try_into()?; + let data_type = T::Dictionary( + Box::new(indices.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + Ok(indices + .into_builder() + .data_type(data_type) + .child_data(vec![values]) + .build()?) + } + A::Map(arr) => { + let child: ArrayData = (*arr.element).try_into()?; + let field = field_from_data_and_meta(&child, arr.meta); + Ok(ArrayData::try_new( + T::Map(Arc::new(field), false), + arr.offsets.len().saturating_sub(1), + arr.validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(arr.offsets).into_inner()], + vec![child], + )?) + } + A::DenseUnion(arr) => { + let mut fields = Vec::new(); + let mut child_data = Vec::new(); + + for (idx, (array, meta)) in arr.fields.into_iter().enumerate() { + let child: ArrayData = array.try_into()?; + let field = field_from_data_and_meta(&child, meta); + + fields.push((idx as i8, Arc::new(field))); + child_data.push(child); + } + + Ok(ArrayData::try_new( + DataType::Union(fields.into_iter().collect(), UnionMode::Dense), + arr.types.len(), + None, + 0, + vec![ + ScalarBuffer::from(arr.types).into_inner(), + ScalarBuffer::from(arr.offsets).into_inner(), + ], + child_data, + )?) + } + } + } +} + impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { type Error = Error; @@ -332,6 +535,61 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { } } +fn field_from_data_and_meta(data: &ArrayData, meta: FieldMeta) -> Field { + Field::new(meta.name, data.data_type().clone(), meta.nullable).with_metadata(meta.metadata) +} + +fn primitive_into_data( + data_type: DataType, + validity: Option>, + values: Vec, +) -> Result { + Ok(ArrayData::try_new( + data_type, + values.len(), + validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(values).into_inner()], + vec![], + )?) +} + +fn bytes_into_data( + data_type: DataType, + offsets: Vec, + data: Vec, + validity: Option>, +) -> Result { + Ok(ArrayData::try_new( + data_type, + offsets.len().saturating_sub(1), + validity.map(Buffer::from), + 0, + vec![ + ScalarBuffer::from(offsets).into_inner(), + ScalarBuffer::from(data).into_inner(), + ], + vec![], + )?) +} + +fn list_into_data( + data_type: DataType, + len: usize, + offsets: Vec, + child_data: ArrayData, + validity: Option>, +) -> Result { + Ok(ArrayData::try_new( + data_type, + len, + validity.map(Buffer::from), + 0, + vec![ScalarBuffer::from(offsets).into_inner()], + vec![child_data], + )?) +} + fn wrap_dictionary_array( array: &DictionaryArray, ) -> Result> { diff --git a/serde_arrow/src/arrow_impl/mod.rs b/serde_arrow/src/arrow_impl/mod.rs index e8360559..063b900d 100644 --- a/serde_arrow/src/arrow_impl/mod.rs +++ b/serde_arrow/src/arrow_impl/mod.rs @@ -5,7 +5,6 @@ //! #![deny(missing_docs)] pub(crate) mod api; -mod deserialization; +mod array; mod schema; -pub(crate) mod serialization; mod type_support; diff --git a/serde_arrow/src/arrow_impl/serialization.rs b/serde_arrow/src/arrow_impl/serialization.rs deleted file mode 100644 index ada887ec..00000000 --- a/serde_arrow/src/arrow_impl/serialization.rs +++ /dev/null @@ -1,266 +0,0 @@ -#![allow(missing_docs)] -use std::sync::Arc; - -use half::f16; - -use crate::{ - _impl::arrow::{ - array::{Array, ArrayData, NullArray}, - buffer::{Buffer, ScalarBuffer}, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field, Float16Type, UnionMode}, - }, - internal::{ - arrow::FieldMeta, - error::{fail, Error, Result}, - }, -}; - -impl TryFrom for ArrayData { - type Error = Error; - - fn try_from(value: crate::internal::arrow::Array) -> Result { - use {crate::internal::arrow::Array as A, DataType as T}; - type ArrowF16 = ::Native; - - fn f16_to_f16(v: f16) -> ArrowF16 { - ArrowF16::from_bits(v.to_bits()) - } - - match value { - A::Null(arr) => Ok(NullArray::new(arr.len).into_data()), - A::Boolean(arr) => Ok(ArrayData::try_new( - T::Boolean, - // NOTE: use the explicit len - arr.len, - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.values).into_inner()], - vec![], - )?), - A::Int8(arr) => primitive_into_data(T::Int8, arr.validity, arr.values), - A::Int16(arr) => primitive_into_data(T::Int16, arr.validity, arr.values), - A::Int32(arr) => primitive_into_data(T::Int32, arr.validity, arr.values), - A::Int64(arr) => primitive_into_data(T::Int64, arr.validity, arr.values), - A::UInt8(arr) => primitive_into_data(T::UInt8, arr.validity, arr.values), - A::UInt16(arr) => primitive_into_data(T::UInt16, arr.validity, arr.values), - A::UInt32(arr) => primitive_into_data(T::UInt32, arr.validity, arr.values), - A::UInt64(arr) => primitive_into_data(T::UInt64, arr.validity, arr.values), - A::Float16(arr) => primitive_into_data( - T::Float16, - arr.validity, - arr.values.into_iter().map(f16_to_f16).collect(), - ), - A::Float32(arr) => primitive_into_data(T::Float32, arr.validity, arr.values), - A::Float64(arr) => primitive_into_data(T::Float64, arr.validity, arr.values), - A::Date32(arr) => primitive_into_data(T::Date32, arr.validity, arr.values), - A::Date64(arr) => primitive_into_data(T::Date64, arr.validity, arr.values), - A::Timestamp(arr) => primitive_into_data( - T::Timestamp(arr.unit.into(), arr.timezone.map(String::into)), - arr.validity, - arr.values, - ), - A::Time32(arr) => { - primitive_into_data(T::Time32(arr.unit.into()), arr.validity, arr.values) - } - A::Time64(arr) => { - primitive_into_data(T::Time64(arr.unit.into()), arr.validity, arr.values) - } - A::Duration(arr) => { - primitive_into_data(T::Duration(arr.unit.into()), arr.validity, arr.values) - } - A::Decimal128(arr) => primitive_into_data( - T::Decimal128(arr.precision, arr.scale), - arr.validity, - arr.values, - ), - A::Utf8(arr) => bytes_into_data(T::Utf8, arr.offsets, arr.data, arr.validity), - A::LargeUtf8(arr) => bytes_into_data(T::LargeUtf8, arr.offsets, arr.data, arr.validity), - A::Binary(arr) => bytes_into_data(T::Binary, arr.offsets, arr.data, arr.validity), - A::LargeBinary(arr) => { - bytes_into_data(T::LargeBinary, arr.offsets, arr.data, arr.validity) - } - A::Struct(arr) => { - let mut fields = Vec::new(); - let mut data = Vec::new(); - - for (field, meta) in arr.fields { - let child: ArrayData = field.try_into()?; - let field = Field::new(meta.name, child.data_type().clone(), meta.nullable) - .with_metadata(meta.metadata); - fields.push(Arc::new(field)); - data.push(child); - } - let data_type = T::Struct(fields.into()); - - Ok(ArrayData::builder(data_type) - .len(arr.len) - .null_bit_buffer(arr.validity.map(Buffer::from)) - .child_data(data) - .build()?) - } - A::List(arr) => { - let child: ArrayData = (*arr.element).try_into()?; - let field = field_from_data_and_meta(&child, arr.meta); - list_into_data( - T::List(Arc::new(field)), - arr.offsets.len().saturating_sub(1), - arr.offsets, - child, - arr.validity, - ) - } - A::LargeList(arr) => { - let child: ArrayData = (*arr.element).try_into()?; - let field = field_from_data_and_meta(&child, arr.meta); - list_into_data( - T::LargeList(Arc::new(field)), - arr.offsets.len().saturating_sub(1), - arr.offsets, - child, - arr.validity, - ) - } - A::FixedSizeList(arr) => { - let child: ArrayData = (*arr.element).try_into()?; - if (child.len() % usize::try_from(arr.n)?) != 0 { - fail!( - "Invalid FixedSizeList: number of child elements ({}) not divisible by n ({})", - child.len(), - arr.n, - ); - } - let field = field_from_data_and_meta(&child, arr.meta); - Ok(ArrayData::try_new( - T::FixedSizeList(Arc::new(field), arr.n), - child.len() / usize::try_from(arr.n)?, - arr.validity.map(Buffer::from), - 0, - vec![], - vec![child], - )?) - } - A::FixedSizeBinary(arr) => { - if (arr.data.len() % usize::try_from(arr.n)?) != 0 { - fail!( - "Invalid FixedSizeBinary: number of child elements ({}) not divisible by n ({})", - arr.data.len(), - arr.n, - ); - } - Ok(ArrayData::try_new( - T::FixedSizeBinary(arr.n), - arr.data.len() / usize::try_from(arr.n)?, - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.data).into_inner()], - vec![], - )?) - } - A::Dictionary(arr) => { - let indices: ArrayData = (*arr.indices).try_into()?; - let values: ArrayData = (*arr.values).try_into()?; - let data_type = T::Dictionary( - Box::new(indices.data_type().clone()), - Box::new(values.data_type().clone()), - ); - - Ok(indices - .into_builder() - .data_type(data_type) - .child_data(vec![values]) - .build()?) - } - A::Map(arr) => { - let child: ArrayData = (*arr.element).try_into()?; - let field = field_from_data_and_meta(&child, arr.meta); - Ok(ArrayData::try_new( - T::Map(Arc::new(field), false), - arr.offsets.len().saturating_sub(1), - arr.validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(arr.offsets).into_inner()], - vec![child], - )?) - } - A::DenseUnion(arr) => { - let mut fields = Vec::new(); - let mut child_data = Vec::new(); - - for (idx, (array, meta)) in arr.fields.into_iter().enumerate() { - let child: ArrayData = array.try_into()?; - let field = field_from_data_and_meta(&child, meta); - - fields.push((idx as i8, Arc::new(field))); - child_data.push(child); - } - - Ok(ArrayData::try_new( - DataType::Union(fields.into_iter().collect(), UnionMode::Dense), - arr.types.len(), - None, - 0, - vec![ - ScalarBuffer::from(arr.types).into_inner(), - ScalarBuffer::from(arr.offsets).into_inner(), - ], - child_data, - )?) - } - } - } -} - -fn field_from_data_and_meta(data: &ArrayData, meta: FieldMeta) -> Field { - Field::new(meta.name, data.data_type().clone(), meta.nullable).with_metadata(meta.metadata) -} - -fn primitive_into_data( - data_type: DataType, - validity: Option>, - values: Vec, -) -> Result { - Ok(ArrayData::try_new( - data_type, - values.len(), - validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(values).into_inner()], - vec![], - )?) -} - -fn bytes_into_data( - data_type: DataType, - offsets: Vec, - data: Vec, - validity: Option>, -) -> Result { - Ok(ArrayData::try_new( - data_type, - offsets.len().saturating_sub(1), - validity.map(Buffer::from), - 0, - vec![ - ScalarBuffer::from(offsets).into_inner(), - ScalarBuffer::from(data).into_inner(), - ], - vec![], - )?) -} - -fn list_into_data( - data_type: DataType, - len: usize, - offsets: Vec, - child_data: ArrayData, - validity: Option>, -) -> Result { - Ok(ArrayData::try_new( - data_type, - len, - validity.map(Buffer::from), - 0, - vec![ScalarBuffer::from(offsets).into_inner()], - vec![child_data], - )?) -} From 2877b79908f7fd84c639054977f73eb51c58e751 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 12:14:40 +0200 Subject: [PATCH 045/178] Implement Null, Int, UInt, Float --- serde_arrow/src/arrow2_impl/api.rs | 74 ++++++++++++++++++- serde_arrow/src/arrow2_impl/array.rs | 55 ++++++++++++++ .../src/arrow2_impl/deserialization.rs | 52 +------------ serde_arrow/src/arrow2_impl/mod.rs | 3 +- serde_arrow/src/arrow2_impl/serialization.rs | 71 +++++------------- serde_arrow/src/arrow_impl/api.rs | 2 +- 6 files changed, 150 insertions(+), 107 deletions(-) create mode 100644 serde_arrow/src/arrow2_impl/array.rs diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index c0352f8e..cb7c9153 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -2,12 +2,17 @@ //! //! Functions to convert Rust objects into Arrow arrays and back. //! +#![deny(missing_docs)] use serde::{Deserialize, Serialize}; use crate::{ _impl::arrow2::{array::Array, datatypes::Field}, internal::{ - array_builder::ArrayBuilder, deserializer::Deserializer, error::Result, + array_builder::ArrayBuilder, + deserialization::outer_sequence_deserializer::OuterSequenceDeserializer, + deserializer::Deserializer, + error::Result, + schema::{GenericField, SerdeArrowSchema}, serializer::Serializer, }, }; @@ -93,3 +98,70 @@ where let deserializer = Deserializer::from_arrow2(fields, arrays)?; T::deserialize(deserializer) } + +/// Support `arrow2` (*requires one of the `arrow2-*` features*) +impl crate::internal::array_builder::ArrayBuilder { + /// Build an ArrayBuilder from `arrow2` fields (*requires one of the + /// `arrow2-*` features*) + pub fn from_arrow2(fields: &[Field]) -> Result { + Self::new(SerdeArrowSchema::try_from(fields)?) + } + + /// Construct `arrow2` arrays and reset the builder (*requires one of the + /// `arrow2-*` features*) + pub fn to_arrow2(&mut self) -> Result>> { + let mut arrays = Vec::new(); + for field in self.builder.take_records()? { + arrays.push(super::serialization::build_array(field)?); + } + Ok(arrays) + } +} + +impl<'de> Deserializer<'de> { + /// Build a deserializer from `arrow2` arrays (*requires one of the + /// `arrow2-*` features*) + /// + /// Usage: + /// + /// ```rust + /// # fn main() -> serde_arrow::Result<()> { + /// # use serde_arrow::_impl::arrow2; + /// # let (_, arrays) = serde_arrow::_impl::docs::defs::example_arrow2_arrays(); + /// use arrow2::datatypes::Field; + /// use serde::{Deserialize, Serialize}; + /// use serde_arrow::{Deserializer, schema::{SchemaLike, TracingOptions}}; + /// + /// ##[derive(Deserialize, Serialize)] + /// struct Record { + /// a: Option, + /// b: u64, + /// } + /// + /// let fields = Vec::::from_type::(TracingOptions::default())?; + /// + /// let deserializer = Deserializer::from_arrow2(&fields, &arrays)?; + /// let items = Vec::::deserialize(deserializer)?; + /// # Ok(()) + /// # } + /// ``` + pub fn from_arrow2(fields: &[Field], arrays: &'de [A]) -> Result + where + A: AsRef, + { + let fields = fields + .iter() + .map(GenericField::try_from) + .collect::>>()?; + let arrays = arrays + .iter() + .map(|array| array.as_ref()) + .collect::>(); + + let (deserializers, len) = super::deserialization::build_struct_fields(&fields, &arrays)?; + let deserializer = OuterSequenceDeserializer::new(deserializers, len); + let deserializer = Deserializer(deserializer); + + Ok(deserializer) + } +} diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs new file mode 100644 index 00000000..85bc4fa9 --- /dev/null +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -0,0 +1,55 @@ +use crate::{ + _impl::arrow2::{ + array::{Array as A2Array, NullArray as A2NullArray, PrimitiveArray as A2PrimitiveArray}, + bitmap::Bitmap, + buffer::Buffer, + datatypes::DataType, + types::{f16, NativeType}, + }, + internal::{ + arrow::Array, + error::{fail, Error, Result}, + }, +}; + +impl TryFrom for Box { + type Error = Error; + + fn try_from(value: Array) -> Result { + use {Array as A, DataType as T}; + match value { + A::Null(arr) => Ok(Box::new(A2NullArray::new(T::Null, arr.len))), + A::Int8(arr) => build_primitive_array(T::Int8, arr.values, arr.validity), + A::Int16(arr) => build_primitive_array(T::Int16, arr.values, arr.validity), + A::Int32(arr) => build_primitive_array(T::Int32, arr.values, arr.validity), + A::Int64(arr) => build_primitive_array(T::Int64, arr.values, arr.validity), + A::UInt8(arr) => build_primitive_array(T::UInt8, arr.values, arr.validity), + A::UInt16(arr) => build_primitive_array(T::UInt16, arr.values, arr.validity), + A::UInt32(arr) => build_primitive_array(T::UInt32, arr.values, arr.validity), + A::UInt64(arr) => build_primitive_array(T::UInt64, arr.values, arr.validity), + A::Float16(arr) => build_primitive_array( + T::Float16, + arr.values + .into_iter() + .map(|v| f16::from_bits(v.to_bits())) + .collect(), + arr.validity, + ), + A::Float32(arr) => build_primitive_array(T::Float32, arr.values, arr.validity), + A::Float64(arr) => build_primitive_array(T::Float64, arr.values, arr.validity), + _ => fail!("cannot convert array to arrow2 array"), + } + } +} + +fn build_primitive_array( + data_type: DataType, + buffer: Vec, + validity: Option>, +) -> Result> { + let validity = validity.map(|v| Bitmap::from_u8_vec(v, buffer.len())); + let buffer = Buffer::from(buffer); + Ok(Box::new(A2PrimitiveArray::try_new( + data_type, buffer, validity, + )?)) +} diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 63d2aede..4c8f294b 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -14,13 +14,11 @@ use crate::internal::{ list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, - deserializer::Deserializer, error::{fail, Result}, schema::{GenericDataType, GenericField}, utils::Offset, @@ -31,58 +29,10 @@ use crate::_impl::arrow2::{ Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, PrimitiveArray, StructArray, UnionArray, Utf8Array, }, - datatypes::{DataType, Field, UnionMode}, + datatypes::{DataType, UnionMode}, types::{f16, NativeType, Offset as ArrowOffset}, }; -impl<'de> Deserializer<'de> { - /// Build a deserializer from `arrow2` arrays (*requires one of the - /// `arrow2-*` features*) - /// - /// Usage: - /// - /// ```rust - /// # fn main() -> serde_arrow::Result<()> { - /// # use serde_arrow::_impl::arrow2; - /// # let (_, arrays) = serde_arrow::_impl::docs::defs::example_arrow2_arrays(); - /// use arrow2::datatypes::Field; - /// use serde::{Deserialize, Serialize}; - /// use serde_arrow::{Deserializer, schema::{SchemaLike, TracingOptions}}; - /// - /// ##[derive(Deserialize, Serialize)] - /// struct Record { - /// a: Option, - /// b: u64, - /// } - /// - /// let fields = Vec::::from_type::(TracingOptions::default())?; - /// - /// let deserializer = Deserializer::from_arrow2(&fields, &arrays)?; - /// let items = Vec::::deserialize(deserializer)?; - /// # Ok(()) - /// # } - /// ``` - pub fn from_arrow2(fields: &[Field], arrays: &'de [A]) -> Result - where - A: AsRef, - { - let fields = fields - .iter() - .map(GenericField::try_from) - .collect::>>()?; - let arrays = arrays - .iter() - .map(|array| array.as_ref()) - .collect::>(); - - let (deserializers, len) = build_struct_fields(&fields, &arrays)?; - let deserializer = OuterSequenceDeserializer::new(deserializers, len); - let deserializer = Deserializer(deserializer); - - Ok(deserializer) - } -} - pub fn build_array_deserializer<'a>( field: &GenericField, array: &'a dyn Array, diff --git a/serde_arrow/src/arrow2_impl/mod.rs b/serde_arrow/src/arrow2_impl/mod.rs index 75bf9489..babd4ae8 100644 --- a/serde_arrow/src/arrow2_impl/mod.rs +++ b/serde_arrow/src/arrow2_impl/mod.rs @@ -4,7 +4,8 @@ //! #![deny(missing_docs)] pub(crate) mod api; +mod array; pub(crate) mod deserialization; -pub(crate) mod schema; +mod schema; pub(crate) mod serialization; mod type_support; diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index 21c42727..71ee2e02 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -3,54 +3,38 @@ use crate::{ _impl::arrow2::{ array::{ - Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, NullArray, + Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, PrimitiveArray, StructArray, UnionArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, datatypes::{DataType, Field}, offset::OffsetsBuffer, - types::{f16, NativeType, Offset}, + types::{NativeType, Offset}, }, internal::{ error::{fail, Result}, - schema::{GenericField, SerdeArrowSchema}, - serialization::{utils::MutableBitBuffer, ArrayBuilder, OuterSequenceBuilder}, + schema::GenericField, + serialization::{utils::MutableBitBuffer, ArrayBuilder}, }, }; -/// Support `arrow2` (*requires one of the `arrow2-*` features*) -impl crate::internal::array_builder::ArrayBuilder { - /// Build an ArrayBuilder from `arrow2` fields (*requires one of the - /// `arrow2-*` features*) - pub fn from_arrow2(fields: &[Field]) -> Result { - Self::new(SerdeArrowSchema::try_from(fields)?) - } - - /// Construct `arrow2` arrays and reset the builder (*requires one of the - /// `arrow2-*` features*) - pub fn to_arrow2(&mut self) -> Result>> { - self.builder.build_arrow2() - } -} - -impl OuterSequenceBuilder { - /// Build the arrow2 arrays - pub fn build_arrow2(&mut self) -> Result>> { - let fields = self.take_records()?; - let arrays = fields - .into_iter() - .map(build_array) - .collect::>>()?; - Ok(arrays) - } -} - -fn build_array(builder: ArrayBuilder) -> Result> { +pub fn build_array(builder: ArrayBuilder) -> Result> { use {ArrayBuilder as A, DataType as T}; match builder { - A::Null(builder) => Ok(Box::new(NullArray::new(T::Null, builder.count))), - A::UnknownVariant(_) => Ok(Box::new(NullArray::new(T::Null, 0))), + A::Null(_) + | A::UnknownVariant(_) + | A::I8(_) + | A::I16(_) + | A::I32(_) + | A::I64(_) + | A::U8(_) + | A::U16(_) + | A::U32(_) + | A::U64(_) + | A::F16(_) + | A::F32(_) + | A::F64(_) => builder.into_array()?.try_into(), A::Bool(builder) => { let buffer = Bitmap::from_u8_vec(builder.buffer.buffer, builder.buffer.len); let validity = build_validity(builder.validity); @@ -60,25 +44,6 @@ fn build_array(builder: ArrayBuilder) -> Result> { validity, )?)) } - A::I8(builder) => build_primitive_array(T::Int8, builder.buffer, builder.validity), - A::I16(builder) => build_primitive_array(T::Int16, builder.buffer, builder.validity), - A::I32(builder) => build_primitive_array(T::Int32, builder.buffer, builder.validity), - A::I64(builder) => build_primitive_array(T::Int64, builder.buffer, builder.validity), - A::U8(builder) => build_primitive_array(T::UInt8, builder.buffer, builder.validity), - A::U16(builder) => build_primitive_array(T::UInt16, builder.buffer, builder.validity), - A::U32(builder) => build_primitive_array(T::UInt32, builder.buffer, builder.validity), - A::U64(builder) => build_primitive_array(T::UInt64, builder.buffer, builder.validity), - A::F16(builder) => build_primitive_array( - T::Float16, - builder - .buffer - .into_iter() - .map(|v| f16::from_bits(v.to_bits())) - .collect(), - builder.validity, - ), - A::F32(builder) => build_primitive_array(T::Float32, builder.buffer, builder.validity), - A::F64(builder) => build_primitive_array(T::Float64, builder.buffer, builder.validity), A::Date32(builder) => build_primitive_array( Field::try_from(&builder.field)?.data_type, builder.buffer, diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index f683f410..347979c1 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -62,7 +62,7 @@ use super::type_support::fields_from_field_refs; /// ``` /// pub fn to_arrow(fields: &[FieldRef], items: &T) -> Result> { - let builder = ArrayBuilder::new(SerdeArrowSchema::try_from(fields)?)?; + let builder = ArrayBuilder::from_arrow(fields)?; items .serialize(Serializer::new(builder))? .into_inner() From a66a33c06a6f7281f575db98e4fe627bf6b9cc28 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 12:20:42 +0200 Subject: [PATCH 046/178] Implement Date32, Date64, Time32, Time64, Timestamp, Duration --- serde_arrow/src/arrow2_impl/array.rs | 16 ++++++++++ serde_arrow/src/arrow2_impl/serialization.rs | 32 ++++---------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 85bc4fa9..9452ed77 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -37,6 +37,22 @@ impl TryFrom for Box { ), A::Float32(arr) => build_primitive_array(T::Float32, arr.values, arr.validity), A::Float64(arr) => build_primitive_array(T::Float64, arr.values, arr.validity), + A::Date32(arr) => build_primitive_array(T::Date32, arr.values, arr.validity), + A::Date64(arr) => build_primitive_array(T::Date64, arr.values, arr.validity), + A::Duration(arr) => { + build_primitive_array(T::Duration(arr.unit.into()), arr.values, arr.validity) + } + A::Time32(arr) => { + build_primitive_array(T::Time32(arr.unit.into()), arr.values, arr.validity) + } + A::Time64(arr) => { + build_primitive_array(T::Time64(arr.unit.into()), arr.values, arr.validity) + } + A::Timestamp(arr) => build_primitive_array( + T::Timestamp(arr.unit.into(), arr.timezone), + arr.values, + arr.validity, + ), _ => fail!("cannot convert array to arrow2 array"), } } diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index 71ee2e02..df3f2b19 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -34,7 +34,12 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { | A::U64(_) | A::F16(_) | A::F32(_) - | A::F64(_) => builder.into_array()?.try_into(), + | A::F64(_) + | A::Date32(_) + | A::Date64(_) + | A::Duration(_) + | A::Time32(_) + | A::Time64(_) => builder.into_array()?.try_into(), A::Bool(builder) => { let buffer = Bitmap::from_u8_vec(builder.buffer.buffer, builder.buffer.len); let validity = build_validity(builder.validity); @@ -44,31 +49,6 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { validity, )?)) } - A::Date32(builder) => build_primitive_array( - Field::try_from(&builder.field)?.data_type, - builder.buffer, - builder.validity, - ), - A::Date64(builder) => build_primitive_array( - Field::try_from(&builder.field)?.data_type, - builder.buffer, - builder.validity, - ), - A::Time32(builder) => build_primitive_array( - Field::try_from(&builder.field)?.data_type, - builder.buffer, - builder.validity, - ), - A::Time64(builder) => build_primitive_array( - Field::try_from(&builder.field)?.data_type, - builder.buffer, - builder.validity, - ), - A::Duration(builder) => build_primitive_array( - T::Duration(builder.unit.into()), - builder.buffer, - builder.validity, - ), A::Decimal128(builder) => build_primitive_array( T::Decimal(builder.precision as usize, usize::try_from(builder.scale)?), builder.buffer, From ef74957d88ef47f48dda307e2ff9bf929166d8de Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 12:22:59 +0200 Subject: [PATCH 047/178] Implement Decimal128 --- serde_arrow/src/arrow2_impl/array.rs | 5 +++++ serde_arrow/src/arrow2_impl/serialization.rs | 22 +++----------------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 9452ed77..240a391c 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -53,6 +53,11 @@ impl TryFrom for Box { arr.values, arr.validity, ), + A::Decimal128(arr) => build_primitive_array( + T::Decimal(arr.precision as usize, usize::try_from(arr.scale)?), + arr.values, + arr.validity, + ), _ => fail!("cannot convert array to arrow2 array"), } } diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index df3f2b19..da1cdd05 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -10,7 +10,7 @@ use crate::{ buffer::Buffer, datatypes::{DataType, Field}, offset::OffsetsBuffer, - types::{NativeType, Offset}, + types::Offset, }, internal::{ error::{fail, Result}, @@ -39,7 +39,8 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { | A::Date64(_) | A::Duration(_) | A::Time32(_) - | A::Time64(_) => builder.into_array()?.try_into(), + | A::Time64(_) + | A::Decimal128(_) => builder.into_array()?.try_into(), A::Bool(builder) => { let buffer = Bitmap::from_u8_vec(builder.buffer.buffer, builder.buffer.len); let validity = build_validity(builder.validity); @@ -49,11 +50,6 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { validity, )?)) } - A::Decimal128(builder) => build_primitive_array( - T::Decimal(builder.precision as usize, usize::try_from(builder.scale)?), - builder.buffer, - builder.validity, - ), A::Utf8(builder) => build_array_utf8_array( T::Utf8, builder.offsets.offsets, @@ -157,18 +153,6 @@ fn build_validity(validity: Option) -> Option { Some(Bitmap::from_u8_vec(validity.buffer, validity.len)) } -fn build_primitive_array( - data_type: DataType, - buffer: Vec, - validity: Option, -) -> Result> { - let buffer = Buffer::from(buffer); - let validity = build_validity(validity); - Ok(Box::new(PrimitiveArray::try_new( - data_type, buffer, validity, - )?)) -} - fn build_dictionary_array( field: GenericField, data_type: DataType, From a43ca4a0cce5379e75e1359b9eb96707523d86c6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 12:26:17 +0200 Subject: [PATCH 048/178] Implement Boolean --- serde_arrow/src/arrow2_impl/array.rs | 10 +++++++++- serde_arrow/src/arrow2_impl/serialization.rs | 16 ++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 240a391c..93a36ab2 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -1,6 +1,9 @@ use crate::{ _impl::arrow2::{ - array::{Array as A2Array, NullArray as A2NullArray, PrimitiveArray as A2PrimitiveArray}, + array::{ + Array as A2Array, BooleanArray as A2BooleanArray, NullArray as A2NullArray, + PrimitiveArray as A2PrimitiveArray, + }, bitmap::Bitmap, buffer::Buffer, datatypes::DataType, @@ -19,6 +22,11 @@ impl TryFrom for Box { use {Array as A, DataType as T}; match value { A::Null(arr) => Ok(Box::new(A2NullArray::new(T::Null, arr.len))), + A::Boolean(arr) => Ok(Box::new(A2BooleanArray::try_new( + T::Boolean, + Bitmap::from_u8_vec(arr.values, arr.len), + arr.validity.map(|v| Bitmap::from_u8_vec(v, arr.len)), + )?)), A::Int8(arr) => build_primitive_array(T::Int8, arr.values, arr.validity), A::Int16(arr) => build_primitive_array(T::Int16, arr.values, arr.validity), A::Int32(arr) => build_primitive_array(T::Int32, arr.values, arr.validity), diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index da1cdd05..7fecf816 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -3,8 +3,8 @@ use crate::{ _impl::arrow2::{ array::{ - Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, - PrimitiveArray, StructArray, UnionArray, Utf8Array, + Array, DictionaryArray, DictionaryKey, ListArray, MapArray, PrimitiveArray, + StructArray, UnionArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, @@ -40,16 +40,8 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { | A::Duration(_) | A::Time32(_) | A::Time64(_) - | A::Decimal128(_) => builder.into_array()?.try_into(), - A::Bool(builder) => { - let buffer = Bitmap::from_u8_vec(builder.buffer.buffer, builder.buffer.len); - let validity = build_validity(builder.validity); - Ok(Box::new(BooleanArray::try_new( - T::Boolean, - buffer, - validity, - )?)) - } + | A::Decimal128(_) + | A::Bool(_) => builder.into_array()?.try_into(), A::Utf8(builder) => build_array_utf8_array( T::Utf8, builder.offsets.offsets, From 94ba5ac3b11fb9e32243a2fbb0d93e0e54c87627 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 12:33:55 +0200 Subject: [PATCH 049/178] Add Utft, LargeUtf8, Binary, LargeBinary --- serde_arrow/src/arrow2_impl/array.rs | 49 +++++++++++++++++--- serde_arrow/src/arrow2_impl/serialization.rs | 37 +++------------ 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 93a36ab2..d4404d89 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -1,13 +1,12 @@ use crate::{ _impl::arrow2::{ array::{ - Array as A2Array, BooleanArray as A2BooleanArray, NullArray as A2NullArray, - PrimitiveArray as A2PrimitiveArray, + Array as A2Array, BinaryArray, BooleanArray, NullArray, PrimitiveArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, datatypes::DataType, - types::{f16, NativeType}, + types::{f16, NativeType, Offset}, }, internal::{ arrow::Array, @@ -21,8 +20,8 @@ impl TryFrom for Box { fn try_from(value: Array) -> Result { use {Array as A, DataType as T}; match value { - A::Null(arr) => Ok(Box::new(A2NullArray::new(T::Null, arr.len))), - A::Boolean(arr) => Ok(Box::new(A2BooleanArray::try_new( + A::Null(arr) => Ok(Box::new(NullArray::new(T::Null, arr.len))), + A::Boolean(arr) => Ok(Box::new(BooleanArray::try_new( T::Boolean, Bitmap::from_u8_vec(arr.values, arr.len), arr.validity.map(|v| Bitmap::from_u8_vec(v, arr.len)), @@ -66,6 +65,14 @@ impl TryFrom for Box { arr.values, arr.validity, ), + A::Utf8(arr) => build_utf8_array(T::Utf8, arr.offsets, arr.data, arr.validity), + A::LargeUtf8(arr) => { + build_utf8_array(T::LargeUtf8, arr.offsets, arr.data, arr.validity) + } + A::Binary(arr) => build_binary_array(T::Binary, arr.offsets, arr.data, arr.validity), + A::LargeBinary(arr) => { + build_binary_array(T::LargeBinary, arr.offsets, arr.data, arr.validity) + } _ => fail!("cannot convert array to arrow2 array"), } } @@ -78,7 +85,37 @@ fn build_primitive_array( ) -> Result> { let validity = validity.map(|v| Bitmap::from_u8_vec(v, buffer.len())); let buffer = Buffer::from(buffer); - Ok(Box::new(A2PrimitiveArray::try_new( + Ok(Box::new(PrimitiveArray::try_new( data_type, buffer, validity, )?)) } + +fn build_utf8_array( + data_type: DataType, + offsets: Vec, + data: Vec, + validity: Option>, +) -> Result> { + let validity = validity.map(|v| Bitmap::from_u8_vec(v, offsets.len().saturating_sub(1))); + Ok(Box::new(Utf8Array::new( + data_type, + offsets.try_into()?, + Buffer::from(data), + validity, + ))) +} + +fn build_binary_array( + data_type: DataType, + offsets: Vec, + data: Vec, + validity: Option>, +) -> Result> { + let validity = validity.map(|v| Bitmap::from_u8_vec(v, offsets.len().saturating_sub(1))); + Ok(Box::new(BinaryArray::new( + data_type, + offsets.try_into()?, + Buffer::from(data), + validity, + ))) +} diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index 7fecf816..127f9c1b 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -4,13 +4,12 @@ use crate::{ _impl::arrow2::{ array::{ Array, DictionaryArray, DictionaryKey, ListArray, MapArray, PrimitiveArray, - StructArray, UnionArray, Utf8Array, + StructArray, UnionArray, }, bitmap::Bitmap, buffer::Buffer, datatypes::{DataType, Field}, offset::OffsetsBuffer, - types::Offset, }, internal::{ error::{fail, Result}, @@ -41,19 +40,11 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { | A::Time32(_) | A::Time64(_) | A::Decimal128(_) - | A::Bool(_) => builder.into_array()?.try_into(), - A::Utf8(builder) => build_array_utf8_array( - T::Utf8, - builder.offsets.offsets, - builder.buffer, - builder.validity, - ), - A::LargeUtf8(builder) => build_array_utf8_array( - T::LargeUtf8, - builder.offsets.offsets, - builder.buffer, - builder.validity, - ), + | A::Bool(_) + | A::Utf8(_) + | A::LargeUtf8(_) + | A::Binary(_) + | A::LargeBinary(_) => builder.into_array()?.try_into(), A::LargeList(builder) => Ok(Box::new(ListArray::try_new( T::LargeList(Box::new(Field::try_from(&builder.field)?)), OffsetsBuffer::try_from(builder.offsets.offsets)?, @@ -67,8 +58,6 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { build_validity(builder.validity), )?)), A::FixedSizedList(_) => fail!("FixedSizedList is not supported by arrow2"), - A::Binary(_) => fail!("Binary is not supported by arrow2"), - A::LargeBinary(_) => fail!("LargeBinary is not supported by arrow2"), A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), A::Struct(builder) => { let mut values = Vec::new(); @@ -158,17 +147,3 @@ fn build_dictionary_array( data_type, indices, values, )?)) } - -fn build_array_utf8_array( - data_type: DataType, - offsets: Vec, - data: Vec, - validity: Option, -) -> Result> { - Ok(Box::new(Utf8Array::new( - data_type, - OffsetsBuffer::try_from(offsets)?, - Buffer::from(data), - build_validity(validity), - ))) -} From 87ee1affefdd14deabf366d4c5a2ff2e2285b1a0 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 13:03:54 +0200 Subject: [PATCH 050/178] Add List, LargeList, Struct --- serde_arrow/src/arrow2_impl/array.rs | 64 ++++++++++++++++++-- serde_arrow/src/arrow2_impl/serialization.rs | 43 +++---------- 2 files changed, 67 insertions(+), 40 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index d4404d89..1854eb28 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -1,15 +1,16 @@ use crate::{ _impl::arrow2::{ array::{ - Array as A2Array, BinaryArray, BooleanArray, NullArray, PrimitiveArray, Utf8Array, + Array as A2Array, BinaryArray, BooleanArray, ListArray, NullArray, PrimitiveArray, + StructArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, - datatypes::DataType, + datatypes::{DataType, Field}, types::{f16, NativeType, Offset}, }, internal::{ - arrow::Array, + arrow::{Array, FieldMeta}, error::{fail, Error, Result}, }, }; @@ -73,7 +74,41 @@ impl TryFrom for Box { A::LargeBinary(arr) => { build_binary_array(T::LargeBinary, arr.offsets, arr.data, arr.validity) } - _ => fail!("cannot convert array to arrow2 array"), + A::List(arr) => build_list_array( + T::List, + arr.offsets, + arr.meta, + (*arr.element).try_into()?, + arr.validity, + ), + A::LargeList(arr) => build_list_array( + T::LargeList, + arr.offsets, + arr.meta, + (*arr.element).try_into()?, + arr.validity, + ), + A::Struct(arr) => { + let mut fields = Vec::new(); + let mut values = Vec::new(); + + for (child, meta) in arr.fields { + let child: Box = child.try_into()?; + let field = field_from_array_and_meta(child.as_ref(), meta); + + values.push(child); + fields.push(field); + } + + Ok(Box::new(StructArray::new( + T::Struct(fields), + values, + arr.validity.map(|v| Bitmap::from_u8_vec(v, arr.len)), + ))) + } + A::FixedSizeList(_) => fail!("FixedSizeList is not supported by arrow2"), + A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), + arr => fail!("cannot convert array {arr:?} to arrow2 array"), } } } @@ -119,3 +154,24 @@ fn build_binary_array( validity, ))) } + +fn build_list_array) -> DataType, O: Offset>( + data_type: F, + offsets: Vec, + meta: FieldMeta, + values: Box, + validity: Option>, +) -> Result> { + let validity = validity.map(|v| Bitmap::from_u8_vec(v, offsets.len().saturating_sub(1))); + Ok(Box::new(ListArray::new( + data_type(Box::new(field_from_array_and_meta(values.as_ref(), meta))), + offsets.try_into()?, + values, + validity, + ))) +} + +fn field_from_array_and_meta(arr: &dyn A2Array, meta: FieldMeta) -> Field { + Field::new(meta.name, arr.data_type().clone(), meta.nullable) + .with_metadata(meta.metadata.into_iter().collect()) +} diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index 127f9c1b..8987673f 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -2,10 +2,7 @@ //! use crate::{ _impl::arrow2::{ - array::{ - Array, DictionaryArray, DictionaryKey, ListArray, MapArray, PrimitiveArray, - StructArray, UnionArray, - }, + array::{Array, DictionaryArray, DictionaryKey, MapArray, PrimitiveArray, UnionArray}, bitmap::Bitmap, buffer::Buffer, datatypes::{DataType, Field}, @@ -44,38 +41,12 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { | A::Utf8(_) | A::LargeUtf8(_) | A::Binary(_) - | A::LargeBinary(_) => builder.into_array()?.try_into(), - A::LargeList(builder) => Ok(Box::new(ListArray::try_new( - T::LargeList(Box::new(Field::try_from(&builder.field)?)), - OffsetsBuffer::try_from(builder.offsets.offsets)?, - build_array(*builder.element)?, - build_validity(builder.validity), - )?)), - A::List(builder) => Ok(Box::new(ListArray::try_new( - T::List(Box::new(Field::try_from(&builder.field)?)), - OffsetsBuffer::try_from(builder.offsets.offsets)?, - build_array(*builder.element)?, - build_validity(builder.validity), - )?)), - A::FixedSizedList(_) => fail!("FixedSizedList is not supported by arrow2"), - A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), - A::Struct(builder) => { - let mut values = Vec::new(); - for (_, field) in builder.named_fields { - values.push(build_array(field)?); - } - - let fields = builder - .fields - .iter() - .map(Field::try_from) - .collect::>>()?; - Ok(Box::new(StructArray::try_new( - T::Struct(fields), - values, - build_validity(builder.validity), - )?)) - } + | A::LargeBinary(_) + | A::List(_) + | A::LargeList(_) + | A::Struct(_) + | A::FixedSizedList(_) + | A::FixedSizeBinary(_) => builder.into_array()?.try_into(), A::Map(builder) => Ok(Box::new(MapArray::try_new( T::Map(Box::new(Field::try_from(&builder.entry_field)?), false), OffsetsBuffer::try_from(builder.offsets.offsets)?, From ba796f5451e83511f0fa75cbf8c00a9d81cb5d96 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 13:08:56 +0200 Subject: [PATCH 051/178] Implement Map --- serde_arrow/src/arrow2_impl/array.rs | 17 ++++++++- serde_arrow/src/arrow2_impl/serialization.rs | 39 +------------------- 2 files changed, 17 insertions(+), 39 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 1854eb28..e3f4fc15 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -1,8 +1,8 @@ use crate::{ _impl::arrow2::{ array::{ - Array as A2Array, BinaryArray, BooleanArray, ListArray, NullArray, PrimitiveArray, - StructArray, Utf8Array, + Array as A2Array, BinaryArray, BooleanArray, ListArray, MapArray, NullArray, + PrimitiveArray, StructArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, @@ -106,6 +106,19 @@ impl TryFrom for Box { arr.validity.map(|v| Bitmap::from_u8_vec(v, arr.len)), ))) } + A::Map(arr) => { + let child: Box = (*arr.element).try_into()?; + let field = field_from_array_and_meta(child.as_ref(), arr.meta); + let validity = arr + .validity + .map(|v| Bitmap::from_u8_vec(v, arr.offsets.len().saturating_sub(1))); + Ok(Box::new(MapArray::new( + T::Map(Box::new(field), false), + arr.offsets.try_into()?, + child, + validity, + ))) + } A::FixedSizeList(_) => fail!("FixedSizeList is not supported by arrow2"), A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), arr => fail!("cannot convert array {arr:?} to arrow2 array"), diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index 8987673f..ae3fafbd 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -2,11 +2,10 @@ //! use crate::{ _impl::arrow2::{ - array::{Array, DictionaryArray, DictionaryKey, MapArray, PrimitiveArray, UnionArray}, + array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray, UnionArray}, bitmap::Bitmap, buffer::Buffer, datatypes::{DataType, Field}, - offset::OffsetsBuffer, }, internal::{ error::{fail, Result}, @@ -18,41 +17,6 @@ use crate::{ pub fn build_array(builder: ArrayBuilder) -> Result> { use {ArrayBuilder as A, DataType as T}; match builder { - A::Null(_) - | A::UnknownVariant(_) - | A::I8(_) - | A::I16(_) - | A::I32(_) - | A::I64(_) - | A::U8(_) - | A::U16(_) - | A::U32(_) - | A::U64(_) - | A::F16(_) - | A::F32(_) - | A::F64(_) - | A::Date32(_) - | A::Date64(_) - | A::Duration(_) - | A::Time32(_) - | A::Time64(_) - | A::Decimal128(_) - | A::Bool(_) - | A::Utf8(_) - | A::LargeUtf8(_) - | A::Binary(_) - | A::LargeBinary(_) - | A::List(_) - | A::LargeList(_) - | A::Struct(_) - | A::FixedSizedList(_) - | A::FixedSizeBinary(_) => builder.into_array()?.try_into(), - A::Map(builder) => Ok(Box::new(MapArray::try_new( - T::Map(Box::new(Field::try_from(&builder.entry_field)?), false), - OffsetsBuffer::try_from(builder.offsets.offsets)?, - build_array(*builder.entry)?, - build_validity(builder.validity), - )?)), A::DictionaryUtf8(builder) => { let values = build_array(*builder.values)?; match *builder.indices { @@ -97,6 +61,7 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { Some(Buffer::from(builder.offsets)), )?)) } + _ => builder.into_array()?.try_into(), } } From d9c2e242c1c4231e611c57ccdb92941667b8b278 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 13:14:27 +0200 Subject: [PATCH 052/178] Add Union support --- serde_arrow/src/arrow2_impl/array.rs | 42 ++++++++++++++------ serde_arrow/src/arrow2_impl/serialization.rs | 16 +------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index e3f4fc15..43d16ddd 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -2,11 +2,11 @@ use crate::{ _impl::arrow2::{ array::{ Array as A2Array, BinaryArray, BooleanArray, ListArray, MapArray, NullArray, - PrimitiveArray, StructArray, Utf8Array, + PrimitiveArray, StructArray, UnionArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, UnionMode}, types::{f16, NativeType, Offset}, }, internal::{ @@ -89,17 +89,7 @@ impl TryFrom for Box { arr.validity, ), A::Struct(arr) => { - let mut fields = Vec::new(); - let mut values = Vec::new(); - - for (child, meta) in arr.fields { - let child: Box = child.try_into()?; - let field = field_from_array_and_meta(child.as_ref(), meta); - - values.push(child); - fields.push(field); - } - + let (values, fields) = array_with_meta_to_array_and_fields(arr.fields)?; Ok(Box::new(StructArray::new( T::Struct(fields), values, @@ -119,6 +109,15 @@ impl TryFrom for Box { validity, ))) } + A::DenseUnion(arr) => { + let (values, fields) = array_with_meta_to_array_and_fields(arr.fields)?; + Ok(Box::new(UnionArray::try_new( + T::Union(fields, None, UnionMode::Dense), + arr.types.into(), + values, + Some(arr.offsets.into()), + )?)) + } A::FixedSizeList(_) => fail!("FixedSizeList is not supported by arrow2"), A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), arr => fail!("cannot convert array {arr:?} to arrow2 array"), @@ -188,3 +187,20 @@ fn field_from_array_and_meta(arr: &dyn A2Array, meta: FieldMeta) -> Field { Field::new(meta.name, arr.data_type().clone(), meta.nullable) .with_metadata(meta.metadata.into_iter().collect()) } + +fn array_with_meta_to_array_and_fields( + arrays: Vec<(Array, FieldMeta)>, +) -> Result<(Vec>, Vec)> { + let mut res_fields = Vec::new(); + let mut res_arrays = Vec::new(); + + for (child, meta) in arrays { + let child: Box = child.try_into()?; + let field = field_from_array_and_meta(child.as_ref(), meta); + + res_arrays.push(child); + res_fields.push(field); + } + + Ok((res_arrays, res_fields)) +} diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs index ae3fafbd..d991f91c 100644 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ b/serde_arrow/src/arrow2_impl/serialization.rs @@ -2,7 +2,7 @@ //! use crate::{ _impl::arrow2::{ - array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray, UnionArray}, + array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}, bitmap::Bitmap, buffer::Buffer, datatypes::{DataType, Field}, @@ -47,20 +47,6 @@ pub fn build_array(builder: ArrayBuilder) -> Result> { builder => fail!("Cannot use {} as an index for a dictionary", builder.name()), } } - A::Union(builder) => { - let data_type = Field::try_from(&builder.field)?.data_type; - let children = builder - .fields - .into_iter() - .map(build_array) - .collect::>()?; - Ok(Box::new(UnionArray::try_new( - data_type, - Buffer::from(builder.types), - children, - Some(Buffer::from(builder.offsets)), - )?)) - } _ => builder.into_array()?.try_into(), } } From 0d3230c0459014a202e131443afa473693337e26 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 13:33:42 +0200 Subject: [PATCH 053/178] Implement Dictionary --- serde_arrow/src/arrow2_impl/api.rs | 2 +- serde_arrow/src/arrow2_impl/array.rs | 41 +++++++++-- serde_arrow/src/arrow2_impl/mod.rs | 1 - serde_arrow/src/arrow2_impl/serialization.rs | 71 -------------------- 4 files changed, 36 insertions(+), 79 deletions(-) delete mode 100644 serde_arrow/src/arrow2_impl/serialization.rs diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index cb7c9153..ec96881f 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -112,7 +112,7 @@ impl crate::internal::array_builder::ArrayBuilder { pub fn to_arrow2(&mut self) -> Result>> { let mut arrays = Vec::new(); for field in self.builder.take_records()? { - arrays.push(super::serialization::build_array(field)?); + arrays.push(field.into_array()?.try_into()?); } Ok(arrays) } diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 43d16ddd..ab081f5a 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -1,16 +1,16 @@ use crate::{ _impl::arrow2::{ array::{ - Array as A2Array, BinaryArray, BooleanArray, ListArray, MapArray, NullArray, - PrimitiveArray, StructArray, UnionArray, Utf8Array, + Array as A2Array, BinaryArray, BooleanArray, DictionaryArray, DictionaryKey, ListArray, + MapArray, NullArray, PrimitiveArray, StructArray, UnionArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, - datatypes::{DataType, Field, UnionMode}, + datatypes::{DataType, Field, IntegerType, UnionMode}, types::{f16, NativeType, Offset}, }, internal::{ - arrow::{Array, FieldMeta}, + arrow::{Array, FieldMeta, PrimitiveArray as InternalPrimitiveArray}, error::{fail, Error, Result}, }, }; @@ -19,7 +19,7 @@ impl TryFrom for Box { type Error = Error; fn try_from(value: Array) -> Result { - use {Array as A, DataType as T}; + use {Array as A, DataType as T, IntegerType as I}; match value { A::Null(arr) => Ok(Box::new(NullArray::new(T::Null, arr.len))), A::Boolean(arr) => Ok(Box::new(BooleanArray::try_new( @@ -74,6 +74,18 @@ impl TryFrom for Box { A::LargeBinary(arr) => { build_binary_array(T::LargeBinary, arr.offsets, arr.data, arr.validity) } + A::Dictionary(arr) => match *arr.indices { + A::Int8(indices) => build_dictionary_array(I::Int8, indices, *arr.values), + A::Int16(indices) => build_dictionary_array(I::Int16, indices, *arr.values), + A::Int32(indices) => build_dictionary_array(I::Int32, indices, *arr.values), + A::Int64(indices) => build_dictionary_array(I::Int64, indices, *arr.values), + A::UInt8(indices) => build_dictionary_array(I::UInt8, indices, *arr.values), + A::UInt16(indices) => build_dictionary_array(I::UInt16, indices, *arr.values), + A::UInt32(indices) => build_dictionary_array(I::UInt32, indices, *arr.values), + A::UInt64(indices) => build_dictionary_array(I::UInt64, indices, *arr.values), + // TODO: improve error message by including the data type + _ => fail!("unsupported dictionary index array during arrow2 conversion"), + }, A::List(arr) => build_list_array( T::List, arr.offsets, @@ -120,7 +132,6 @@ impl TryFrom for Box { } A::FixedSizeList(_) => fail!("FixedSizeList is not supported by arrow2"), A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), - arr => fail!("cannot convert array {arr:?} to arrow2 array"), } } } @@ -204,3 +215,21 @@ fn array_with_meta_to_array_and_fields( Ok((res_arrays, res_fields)) } + +fn build_dictionary_array( + indices_type: IntegerType, + indices: InternalPrimitiveArray, + values: Array, +) -> Result> { + let values: Box = values.try_into()?; + let validity = indices + .validity + .map(|v| Bitmap::from_u8_vec(v, indices.values.len())); + let keys = PrimitiveArray::new(indices_type.into(), indices.values.into(), validity); + + Ok(Box::new(DictionaryArray::try_new( + DataType::Dictionary(indices_type, Box::new(values.data_type().clone()), false), + keys, + values, + )?)) +} diff --git a/serde_arrow/src/arrow2_impl/mod.rs b/serde_arrow/src/arrow2_impl/mod.rs index babd4ae8..64df8e4d 100644 --- a/serde_arrow/src/arrow2_impl/mod.rs +++ b/serde_arrow/src/arrow2_impl/mod.rs @@ -7,5 +7,4 @@ pub(crate) mod api; mod array; pub(crate) mod deserialization; mod schema; -pub(crate) mod serialization; mod type_support; diff --git a/serde_arrow/src/arrow2_impl/serialization.rs b/serde_arrow/src/arrow2_impl/serialization.rs deleted file mode 100644 index d991f91c..00000000 --- a/serde_arrow/src/arrow2_impl/serialization.rs +++ /dev/null @@ -1,71 +0,0 @@ -//! Build arrow2 arrays from individual buffers -//! -use crate::{ - _impl::arrow2::{ - array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}, - bitmap::Bitmap, - buffer::Buffer, - datatypes::{DataType, Field}, - }, - internal::{ - error::{fail, Result}, - schema::GenericField, - serialization::{utils::MutableBitBuffer, ArrayBuilder}, - }, -}; - -pub fn build_array(builder: ArrayBuilder) -> Result> { - use {ArrayBuilder as A, DataType as T}; - match builder { - A::DictionaryUtf8(builder) => { - let values = build_array(*builder.values)?; - match *builder.indices { - A::U8(ib) => { - build_dictionary_array(builder.field, T::UInt8, ib.buffer, ib.validity, values) - } - A::U16(ib) => { - build_dictionary_array(builder.field, T::UInt16, ib.buffer, ib.validity, values) - } - A::U32(ib) => { - build_dictionary_array(builder.field, T::UInt32, ib.buffer, ib.validity, values) - } - A::U64(ib) => { - build_dictionary_array(builder.field, T::UInt64, ib.buffer, ib.validity, values) - } - A::I8(ib) => { - build_dictionary_array(builder.field, T::Int8, ib.buffer, ib.validity, values) - } - A::I16(ib) => { - build_dictionary_array(builder.field, T::Int16, ib.buffer, ib.validity, values) - } - A::I32(ib) => { - build_dictionary_array(builder.field, T::Int32, ib.buffer, ib.validity, values) - } - A::I64(ib) => { - build_dictionary_array(builder.field, T::Int64, ib.buffer, ib.validity, values) - } - builder => fail!("Cannot use {} as an index for a dictionary", builder.name()), - } - } - _ => builder.into_array()?.try_into(), - } -} - -fn build_validity(validity: Option) -> Option { - let validity = validity?; - Some(Bitmap::from_u8_vec(validity.buffer, validity.len)) -} - -fn build_dictionary_array( - field: GenericField, - data_type: DataType, - indices: Vec, - validity: Option, - values: Box, -) -> Result> { - let indices = PrimitiveArray::new(data_type, Buffer::from(indices), build_validity(validity)); - let data_type = Field::try_from(&field)?.data_type; - Ok(Box::new(DictionaryArray::try_new( - data_type, indices, values, - )?)) -} From 1b63937935fac028cd344345e492b39f303c2ab6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 13:41:41 +0200 Subject: [PATCH 054/178] Implement Null view --- serde_arrow/src/arrow2_impl/array.rs | 20 ++++++++++++++++++- .../src/arrow2_impl/deserialization.rs | 3 +-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index ab081f5a..f9e536ef 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -10,7 +10,9 @@ use crate::{ types::{f16, NativeType, Offset}, }, internal::{ - arrow::{Array, FieldMeta, PrimitiveArray as InternalPrimitiveArray}, + arrow::{ + Array, ArrayView, FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, + }, error::{fail, Error, Result}, }, }; @@ -136,6 +138,22 @@ impl TryFrom for Box { } } +impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { + type Error = Error; + + fn try_from(array: &'a dyn A2Array) -> Result { + let any = array.as_any(); + if let Some(array) = any.downcast_ref::() { + Ok(ArrayView::Null(NullArrayView { len: array.len() })) + } else { + fail!( + "Cannot convert array with data type {:?} into an array view", + array.data_type() + ); + } + } +} + fn build_primitive_array( data_type: DataType, buffer: Vec, diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 4c8f294b..6891b1cb 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -13,7 +13,6 @@ use crate::internal::{ integer_deserializer::{Integer, IntegerDeserializer}, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, - null_deserializer::NullDeserializer, string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, @@ -39,7 +38,6 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::Null => Ok(ArrayDeserializer::Null(NullDeserializer)), T::Bool => build_bool_deserializer(field, array), T::U8 => build_integer_deserializer::(field, array), T::U16 => build_integer_deserializer::(field, array), @@ -83,6 +81,7 @@ pub fn build_array_deserializer<'a>( T::FixedSizeList(_) => fail!("FixedSizedList is not supported by arrow2"), T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), + _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), } } From c69a1b02d3da7c8b0eda2fc56f1427669f4af31b Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 13:57:19 +0200 Subject: [PATCH 055/178] Implement Int, UInt, Float view --- serde_arrow/src/arrow2_impl/array.rs | 44 ++++++++++++++++++- .../src/arrow2_impl/deserialization.rs | 35 +-------------- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index f9e536ef..c39286ed 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -11,7 +11,8 @@ use crate::{ }, internal::{ arrow::{ - Array, ArrayView, FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, + Array, ArrayView, BitsWithOffset, FieldMeta, NullArrayView, + PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, }, error::{fail, Error, Result}, }, @@ -142,9 +143,36 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { type Error = Error; fn try_from(array: &'a dyn A2Array) -> Result { + use ArrayView as V; + let any = array.as_any(); if let Some(array) = any.downcast_ref::() { - Ok(ArrayView::Null(NullArrayView { len: array.len() })) + Ok(V::Null(NullArrayView { len: array.len() })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Int8(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Int16(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Int32(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Int64(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::UInt8(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::UInt16(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::UInt32(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::UInt64(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Float16(PrimitiveArrayView { + values: bytemuck::cast_slice::(array.values().as_slice()), + validity: bits_with_offset_from_bitmap(array.validity()), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Float32(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Float64(view_primitive_array(array))) } else { fail!( "Cannot convert array with data type {:?} into an array view", @@ -251,3 +279,15 @@ fn build_dictionary_array( values, )?)) } + +fn view_primitive_array(array: &PrimitiveArray) -> PrimitiveArrayView<'_, T> { + PrimitiveArrayView { + values: array.values().as_slice(), + validity: bits_with_offset_from_bitmap(array.validity()), + } +} + +fn bits_with_offset_from_bitmap(bitmap: Option<&Bitmap>) -> Option> { + let (data, offset, _) = bitmap?.as_slice(); + Some(BitsWithOffset { data, offset }) +} diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 6891b1cb..39ef25c2 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -9,7 +9,6 @@ use crate::internal::{ decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, - float_deserializer::{Float, FloatDeserializer}, integer_deserializer::{Integer, IntegerDeserializer}, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, @@ -29,7 +28,7 @@ use crate::_impl::arrow2::{ StructArray, UnionArray, Utf8Array, }, datatypes::{DataType, UnionMode}, - types::{f16, NativeType, Offset as ArrowOffset}, + types::{NativeType, Offset as ArrowOffset}, }; pub fn build_array_deserializer<'a>( @@ -39,17 +38,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::Bool => build_bool_deserializer(field, array), - T::U8 => build_integer_deserializer::(field, array), - T::U16 => build_integer_deserializer::(field, array), - T::U32 => build_integer_deserializer::(field, array), - T::U64 => build_integer_deserializer::(field, array), - T::I8 => build_integer_deserializer::(field, array), - T::I16 => build_integer_deserializer::(field, array), - T::I32 => build_integer_deserializer::(field, array), - T::I64 => build_integer_deserializer::(field, array), - T::F16 => build_float16_deserializer(field, array), - T::F32 => build_float_deserializer::(field, array), - T::F64 => build_float_deserializer::(field, array), T::Decimal128(_, _) => build_decimal128_deserializer(field, array), T::Date32 => build_date32_deserializer(field, array), T::Date64 => build_date64_deserializer(field, array), @@ -117,27 +105,6 @@ where Ok(IntegerDeserializer::new(as_primitive_values(array)?, get_validity(array)).into()) } -pub fn build_float16_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let buffer = as_primitive_values(array)?; - let validity = get_validity(array); - - Ok(FloatDeserializer::new(bytemuck::cast_slice::(buffer), validity).into()) -} - -pub fn build_float_deserializer<'a, T>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - T: Float + NativeType + 'static, - ArrayDeserializer<'a>: From>, -{ - Ok(FloatDeserializer::new(as_primitive_values::(array)?, get_validity(array)).into()) -} - pub fn build_decimal128_deserializer<'a>( field: &GenericField, array: &'a dyn Array, From 486921a072ec9a395f25596016be3870abec25aa Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:05:48 +0200 Subject: [PATCH 056/178] Implement Date32, Date64, Decimal128 view --- serde_arrow/src/arrow2_impl/array.rs | 26 ++++++++++-- .../src/arrow2_impl/deserialization.rs | 42 ------------------- 2 files changed, 22 insertions(+), 46 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index c39286ed..cff6f36e 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -11,7 +11,7 @@ use crate::{ }, internal::{ arrow::{ - Array, ArrayView, BitsWithOffset, FieldMeta, NullArrayView, + Array, ArrayView, BitsWithOffset, DecimalArrayView, FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, }, error::{fail, Error, Result}, @@ -143,7 +143,7 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { type Error = Error; fn try_from(array: &'a dyn A2Array) -> Result { - use ArrayView as V; + use {ArrayView as V, DataType as T}; let any = array.as_any(); if let Some(array) = any.downcast_ref::() { @@ -153,9 +153,27 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { } else if let Some(array) = any.downcast_ref::>() { Ok(V::Int16(view_primitive_array(array))) } else if let Some(array) = any.downcast_ref::>() { - Ok(V::Int32(view_primitive_array(array))) + match array.data_type() { + T::Int32 => Ok(V::Int32(view_primitive_array(array))), + T::Date32 => Ok(V::Date32(view_primitive_array(array))), + dt => fail!("unsupported data type {dt:?} for i32 arrow2 array"), + } } else if let Some(array) = any.downcast_ref::>() { - Ok(V::Int64(view_primitive_array(array))) + match array.data_type() { + T::Int64 => Ok(V::Int64(view_primitive_array(array))), + T::Date64 => Ok(V::Date64(view_primitive_array(array))), + dt => fail!("unsupported data type {dt:?} for i64 arrow2 array"), + } + } else if let Some(array) = any.downcast_ref::>() { + match array.data_type() { + T::Decimal(precision, scale) => Ok(V::Decimal128(DecimalArrayView { + precision: (*precision).try_into()?, + scale: (*scale).try_into()?, + validity: bits_with_offset_from_bitmap(array.validity()), + values: array.values().as_slice(), + })), + dt => fail!("unsupported data type {dt:?} for i128 arrow2 array"), + } } else if let Some(array) = any.downcast_ref::>() { Ok(V::UInt8(view_primitive_array(array))) } else if let Some(array) = any.downcast_ref::>() { diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 39ef25c2..a68c0092 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,12 +1,8 @@ use crate::internal::{ - arrow::TimeUnit, deserialization::{ array_deserializer::ArrayDeserializer, bool_deserializer::BoolDeserializer, construction, - date32_deserializer::Date32Deserializer, - date64_deserializer::Date64Deserializer, - decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, integer_deserializer::{Integer, IntegerDeserializer}, @@ -38,9 +34,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::Bool => build_bool_deserializer(field, array), - T::Decimal128(_, _) => build_decimal128_deserializer(field, array), - T::Date32 => build_date32_deserializer(field, array), - T::Date64 => build_date64_deserializer(field, array), T::Time32(unit) => Ok(ArrayDeserializer::Time32(TimeDeserializer::new( as_primitive_values::(array)?, get_validity(array), @@ -105,41 +98,6 @@ where Ok(IntegerDeserializer::new(as_primitive_values(array)?, get_validity(array)).into()) } -pub fn build_decimal128_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let GenericDataType::Decimal128(_, scale) = field.data_type else { - fail!("Invalid data type for Decimal128Deserializer"); - }; - Ok(DecimalDeserializer::new( - as_primitive_values::(array)?, - get_validity(array), - scale, - ) - .into()) -} - -pub fn build_date32_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - Ok(Date32Deserializer::new(as_primitive_values::(array)?, get_validity(array)).into()) -} - -pub fn build_date64_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - Ok(Date64Deserializer::new( - as_primitive_values(array)?, - get_validity(array), - TimeUnit::Millisecond, - field.is_utc()?, - ) - .into()) -} - pub fn build_string_deserializer<'a, O>( _field: &GenericField, array: &'a dyn Array, From 50fe251e8c42120f46735d54599204c67bfa431c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:11:39 +0200 Subject: [PATCH 057/178] Implement Time32, Time64, Timestamp, Duration view --- serde_arrow/src/arrow2_impl/array.rs | 24 +++++++++- .../src/arrow2_impl/deserialization.rs | 46 ++----------------- serde_arrow/src/arrow2_impl/schema.rs | 11 +++++ 3 files changed, 38 insertions(+), 43 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index cff6f36e..2d4fa54e 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -12,7 +12,8 @@ use crate::{ internal::{ arrow::{ Array, ArrayView, BitsWithOffset, DecimalArrayView, FieldMeta, NullArrayView, - PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, + PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, TimeArrayView, + TimestampArrayView, }, error::{fail, Error, Result}, }, @@ -156,12 +157,33 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { match array.data_type() { T::Int32 => Ok(V::Int32(view_primitive_array(array))), T::Date32 => Ok(V::Date32(view_primitive_array(array))), + T::Time32(unit) => Ok(V::Time32(TimeArrayView { + unit: (*unit).into(), + validity: bits_with_offset_from_bitmap(array.validity()), + values: array.values().as_slice(), + })), dt => fail!("unsupported data type {dt:?} for i32 arrow2 array"), } } else if let Some(array) = any.downcast_ref::>() { match array.data_type() { T::Int64 => Ok(V::Int64(view_primitive_array(array))), T::Date64 => Ok(V::Date64(view_primitive_array(array))), + T::Timestamp(unit, tz) => Ok(V::Timestamp(TimestampArrayView { + unit: (*unit).into(), + timezone: tz.to_owned(), + validity: bits_with_offset_from_bitmap(array.validity()), + values: array.values().as_slice(), + })), + T::Time64(unit) => Ok(V::Time64(TimeArrayView { + unit: (*unit).into(), + validity: bits_with_offset_from_bitmap(array.validity()), + values: array.values().as_slice(), + })), + T::Duration(unit) => Ok(V::Duration(TimeArrayView { + unit: (*unit).into(), + validity: bits_with_offset_from_bitmap(array.validity()), + values: array.values().as_slice(), + })), dt => fail!("unsupported data type {dt:?} for i64 arrow2 array"), } } else if let Some(array) = any.downcast_ref::>() { diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index a68c0092..d69294b8 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -2,15 +2,13 @@ use crate::internal::{ deserialization::{ array_deserializer::ArrayDeserializer, bool_deserializer::BoolDeserializer, - construction, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, - integer_deserializer::{Integer, IntegerDeserializer}, + integer_deserializer::Integer, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, - time_deserializer::TimeDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, error::{fail, Result}, @@ -20,11 +18,11 @@ use crate::internal::{ use crate::_impl::arrow2::{ array::{ - Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, PrimitiveArray, - StructArray, UnionArray, Utf8Array, + Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, StructArray, + UnionArray, Utf8Array, }, datatypes::{DataType, UnionMode}, - types::{NativeType, Offset as ArrowOffset}, + types::Offset as ArrowOffset, }; pub fn build_array_deserializer<'a>( @@ -34,22 +32,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::Bool => build_bool_deserializer(field, array), - T::Time32(unit) => Ok(ArrayDeserializer::Time32(TimeDeserializer::new( - as_primitive_values::(array)?, - get_validity(array), - *unit, - ))), - T::Time64(unit) => Ok(ArrayDeserializer::Time64(TimeDeserializer::new( - as_primitive_values::(array)?, - get_validity(array), - *unit, - ))), - T::Timestamp(_, _) => construction::build_timestamp_deserializer( - field, - as_primitive_values::(array)?, - get_validity(array), - ), - T::Duration(_) => build_integer_deserializer::(field, array), T::Utf8 => build_string_deserializer::(field, array), T::LargeUtf8 => build_string_deserializer::(field, array), T::Dictionary => build_dictionary_deserializer(field, array), @@ -87,17 +69,6 @@ pub fn build_bool_deserializer<'a>( ))) } -pub fn build_integer_deserializer<'a, T>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - T: Integer + NativeType + 'static, - ArrayDeserializer<'a>: From>, -{ - Ok(IntegerDeserializer::new(as_primitive_values(array)?, get_validity(array)).into()) -} - pub fn build_string_deserializer<'a, O>( _field: &GenericField, array: &'a dyn Array, @@ -312,15 +283,6 @@ pub fn build_union_deserializer<'a>( Ok(EnumDeserializer::new(type_ids, variants).into()) } -fn as_primitive_values(array: &dyn Array) -> Result<&[T]> { - let Some(array) = array.as_any().downcast_ref::>() else { - fail!("cannot interpret array as integer array"); - }; - - let buffer = array.values().as_slice(); - Ok(buffer) -} - fn get_validity(arr: &dyn Array) -> Option> { let validity = arr.validity()?; let (data, offset, number_of_bits) = validity.as_slice(); diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 07ab3843..58c22e4a 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -319,3 +319,14 @@ impl From for ArrowTimeUnit { } } } + +impl From for TimeUnit { + fn from(value: ArrowTimeUnit) -> Self { + match value { + ArrowTimeUnit::Second => Self::Second, + ArrowTimeUnit::Millisecond => Self::Millisecond, + ArrowTimeUnit::Microsecond => Self::Microsecond, + ArrowTimeUnit::Nanosecond => Self::Nanosecond, + } + } +} From 3c07cfae6359651904bf29ced51fd7615e414362 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:15:17 +0200 Subject: [PATCH 058/178] Implement Boolean view --- serde_arrow/src/arrow2_impl/array.rs | 11 +++++--- .../src/arrow2_impl/deserialization.rs | 25 +------------------ 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 2d4fa54e..a54e5b9c 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -11,9 +11,7 @@ use crate::{ }, internal::{ arrow::{ - Array, ArrayView, BitsWithOffset, DecimalArrayView, FieldMeta, NullArrayView, - PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, TimeArrayView, - TimestampArrayView, + Array, ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, TimeArrayView, TimestampArrayView }, error::{fail, Error, Result}, }, @@ -149,6 +147,13 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { let any = array.as_any(); if let Some(array) = any.downcast_ref::() { Ok(V::Null(NullArrayView { len: array.len() })) + } else if let Some(array) = any.downcast_ref::() { + let (values_data, values_offset, _) = array.values().as_slice(); + Ok(V::Boolean(BooleanArrayView { + len: array.len(), + validity: bits_with_offset_from_bitmap(array.validity()), + values: BitsWithOffset { offset: values_offset, data: values_data }, + })) } else if let Some(array) = any.downcast_ref::>() { Ok(V::Int8(view_primitive_array(array))) } else if let Some(array) = any.downcast_ref::>() { diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index d69294b8..09420497 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,7 +1,6 @@ use crate::internal::{ deserialization::{ array_deserializer::ArrayDeserializer, - bool_deserializer::BoolDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, integer_deserializer::Integer, @@ -18,7 +17,7 @@ use crate::internal::{ use crate::_impl::arrow2::{ array::{ - Array, BooleanArray, DictionaryArray, DictionaryKey, ListArray, MapArray, StructArray, + Array, DictionaryArray, DictionaryKey, ListArray, MapArray, StructArray, UnionArray, Utf8Array, }, datatypes::{DataType, UnionMode}, @@ -31,7 +30,6 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::Bool => build_bool_deserializer(field, array), T::Utf8 => build_string_deserializer::(field, array), T::LargeUtf8 => build_string_deserializer::(field, array), T::Dictionary => build_dictionary_deserializer(field, array), @@ -48,27 +46,6 @@ pub fn build_array_deserializer<'a>( } } -pub fn build_bool_deserializer<'a>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!("cannot interpret array as Bool array"); - }; - - let (data, offset, number_of_bits) = array.values().as_slice(); - let buffer = BitBuffer { - data, - offset, - number_of_bits, - }; - let validity = get_validity(array); - - Ok(ArrayDeserializer::Bool(BoolDeserializer::new( - buffer, validity, - ))) -} - pub fn build_string_deserializer<'a, O>( _field: &GenericField, array: &'a dyn Array, From 4716655bd2d1244db577a14fe1585539144a42e1 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:28:09 +0200 Subject: [PATCH 059/178] Implement Utf8, LargeUtf8, Binary, LargeBinary view --- Changes.md | 1 + serde_arrow/src/arrow2_impl/array.rs | 33 +++++++++++++++++-- .../src/arrow2_impl/deserialization.rs | 28 ++-------------- serde_arrow/src/arrow2_impl/schema.rs | 2 ++ .../src/test_with_arrow/impls/bytes.rs | 8 ----- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/Changes.md b/Changes.md index c4588156..df9cd12b 100644 --- a/Changes.md +++ b/Changes.md @@ -2,6 +2,7 @@ ## 0.12 +- Add `Binary` / `LargeBinary` support for `arrow2` - Remove `serde_arrow::schema::Schema` - Remove `serde_arrow::ArrowBuilder` and `serde_arrow::Arrow2Builder` - Use `impl serde::Serialize` instead of `&(impl serde::Serialize + ?Sized)` diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index a54e5b9c..f5eb915a 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -11,7 +11,9 @@ use crate::{ }, internal::{ arrow::{ - Array, ArrayView, BitsWithOffset, BooleanArrayView, DecimalArrayView, FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, TimeArrayView, TimestampArrayView + Array, ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, + FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, + TimeArrayView, TimestampArrayView, }, error::{fail, Error, Result}, }, @@ -152,7 +154,10 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { Ok(V::Boolean(BooleanArrayView { len: array.len(), validity: bits_with_offset_from_bitmap(array.validity()), - values: BitsWithOffset { offset: values_offset, data: values_data }, + values: BitsWithOffset { + offset: values_offset, + data: values_data, + }, })) } else if let Some(array) = any.downcast_ref::>() { Ok(V::Int8(view_primitive_array(array))) @@ -218,6 +223,30 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { Ok(V::Float32(view_primitive_array(array))) } else if let Some(array) = any.downcast_ref::>() { Ok(V::Float64(view_primitive_array(array))) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Utf8(BytesArrayView { + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + data: array.values().as_slice(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::LargeUtf8(BytesArrayView { + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + data: array.values().as_slice(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Binary(BytesArrayView { + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + data: array.values().as_slice(), + })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::LargeBinary(BytesArrayView { + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + data: array.values().as_slice(), + })) } else { fail!( "Cannot convert array with data type {:?} into an array view", diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 09420497..69e0c144 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -6,7 +6,6 @@ use crate::internal::{ integer_deserializer::Integer, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, - string_deserializer::StringDeserializer, struct_deserializer::StructDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, @@ -17,8 +16,8 @@ use crate::internal::{ use crate::_impl::arrow2::{ array::{ - Array, DictionaryArray, DictionaryKey, ListArray, MapArray, StructArray, - UnionArray, Utf8Array, + Array, DictionaryArray, DictionaryKey, ListArray, MapArray, StructArray, UnionArray, + Utf8Array, }, datatypes::{DataType, UnionMode}, types::Offset as ArrowOffset, @@ -30,14 +29,10 @@ pub fn build_array_deserializer<'a>( ) -> Result> { use GenericDataType as T; match &field.data_type { - T::Utf8 => build_string_deserializer::(field, array), - T::LargeUtf8 => build_string_deserializer::(field, array), T::Dictionary => build_dictionary_deserializer(field, array), T::Struct => build_struct_deserializer(field, array), T::List => build_list_deserializer::(field, array), T::LargeList => build_list_deserializer::(field, array), - T::Binary => fail!("Binary is not supported by arrow2"), - T::LargeBinary => fail!("LargeBinary is not supported by arrow2"), T::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), T::FixedSizeList(_) => fail!("FixedSizedList is not supported by arrow2"), T::Map => build_map_deserializer(field, array), @@ -46,25 +41,6 @@ pub fn build_array_deserializer<'a>( } } -pub fn build_string_deserializer<'a, O>( - _field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - O: ArrowOffset + Offset, - ArrayDeserializer<'a>: From>, -{ - let Some(array) = array.as_any().downcast_ref::>() else { - fail!("cannot interpret array as Utf8 array"); - }; - - let buffer = array.values().as_slice(); - let offsets = array.offsets().as_slice(); - let validity = get_validity(array); - - Ok(StringDeserializer::new(buffer, offsets, validity).into()) -} - pub fn build_dictionary_deserializer<'a>( field: &GenericField, array: &'a dyn Array, diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 58c22e4a..c19d6d96 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -176,6 +176,8 @@ impl TryFrom<&Field> for GenericField { children.push((&Field::new("", data_type.as_ref().clone(), false)).try_into()?); T::Dictionary } + DataType::Binary => T::Binary, + DataType::LargeBinary => T::LargeBinary, dt => fail!("Cannot convert data type {dt:?}"), }; diff --git a/serde_arrow/src/test_with_arrow/impls/bytes.rs b/serde_arrow/src/test_with_arrow/impls/bytes.rs index 37079947..c6e6d930 100644 --- a/serde_arrow/src/test_with_arrow/impls/bytes.rs +++ b/serde_arrow/src/test_with_arrow/impls/bytes.rs @@ -34,7 +34,6 @@ fn example_as_binary() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "Binary"}])) .serialize(&items) .deserialize(&items); @@ -49,7 +48,6 @@ fn example_large_binary() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary"}])) .trace_schema_from_type::>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) @@ -66,7 +64,6 @@ fn example_large_binary_nullable() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary", "nullable": true}])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) @@ -84,7 +81,6 @@ fn example_vec_as_large_binary() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary"}])) .serialize(&items) .deserialize(&items); @@ -99,7 +95,6 @@ fn example_vec_as_large_binary_nullable() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary", "nullable": true}])) .serialize(&items) .check_nulls(&[&[false, true, false]]) @@ -111,7 +106,6 @@ fn example_vec_i64_as_large_binary() { let items = [Item(vec![1_i64, 2, 3]), Item(vec![128, 255, 75])]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary"}])) .serialize(&items) .deserialize(&items); @@ -126,7 +120,6 @@ fn example_borrowed() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary"}])) .trace_schema_from_type::>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) @@ -143,7 +136,6 @@ fn example_borrowed_nullable() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "LargeBinary", "nullable": true}])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) From 74d161cc71602aa09f210e3c263b44b743c1b286 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:39:20 +0200 Subject: [PATCH 060/178] Implement List, LargeList, Struct view --- serde_arrow/src/arrow2_impl/array.rs | 50 ++++++++++++++++- .../src/arrow2_impl/deserialization.rs | 55 +------------------ 2 files changed, 49 insertions(+), 56 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index f5eb915a..cefc659e 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -12,10 +12,11 @@ use crate::{ internal::{ arrow::{ Array, ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - FieldMeta, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, - TimeArrayView, TimestampArrayView, + FieldMeta, ListArrayView, NullArrayView, PrimitiveArray as InternalPrimitiveArray, + PrimitiveArrayView, StructArrayView, TimeArrayView, TimestampArrayView, }, error::{fail, Error, Result}, + serialization::utils::meta_from_field, }, }; @@ -247,6 +248,51 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { offsets: array.offsets().as_slice(), data: array.values().as_slice(), })) + } else if let Some(array) = any.downcast_ref::>() { + let T::List(field) = array.data_type() else { + fail!( + "invalid data type for arrow2 List array: {:?}", + array.data_type() + ); + }; + Ok(V::List(ListArrayView { + meta: meta_from_field(field.as_ref().try_into()?)?, + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + element: Box::new(array.values().as_ref().try_into()?), + })) + } else if let Some(array) = any.downcast_ref::>() { + let T::LargeList(field) = array.data_type() else { + fail!( + "invalid data type for arrow2 LargeList array: {:?}", + array.data_type() + ); + }; + Ok(V::LargeList(ListArrayView { + meta: meta_from_field(field.as_ref().try_into()?)?, + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + element: Box::new(array.values().as_ref().try_into()?), + })) + } else if let Some(array) = any.downcast_ref::() { + let T::Struct(child_fields) = array.data_type() else { + fail!( + "invalid data type for arrow2 Struct array: {:?}", + array.data_type() + ); + }; + let mut fields = Vec::new(); + for (child_field, child) in child_fields.iter().zip(array.values()) { + fields.push(( + child.as_ref().try_into()?, + meta_from_field(child_field.try_into()?)?, + )); + } + Ok(V::Struct(StructArrayView { + len: array.len(), + validity: bits_with_offset_from_bitmap(array.validity()), + fields, + })) } else { fail!( "Cannot convert array with data type {:?} into an array view", diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 69e0c144..c9ce2a41 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -4,9 +4,7 @@ use crate::internal::{ dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, integer_deserializer::Integer, - list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, - struct_deserializer::StructDeserializer, utils::{check_supported_list_layout, BitBuffer}, }, error::{fail, Result}, @@ -15,10 +13,7 @@ use crate::internal::{ }; use crate::_impl::arrow2::{ - array::{ - Array, DictionaryArray, DictionaryKey, ListArray, MapArray, StructArray, UnionArray, - Utf8Array, - }, + array::{Array, DictionaryArray, DictionaryKey, MapArray, StructArray, UnionArray, Utf8Array}, datatypes::{DataType, UnionMode}, types::Offset as ArrowOffset, }; @@ -30,11 +25,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::Dictionary => build_dictionary_deserializer(field, array), - T::Struct => build_struct_deserializer(field, array), - T::List => build_list_deserializer::(field, array), - T::LargeList => build_list_deserializer::(field, array), - T::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), - T::FixedSizeList(_) => fail!("FixedSizedList is not supported by arrow2"), T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), @@ -103,26 +93,6 @@ pub fn build_dictionary_deserializer<'a>( } } -pub fn build_struct_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!("Cannot convert array into struct"); - }; - - let fields = &field.children; - let arrays = array - .values() - .iter() - .map(|array| array.as_ref()) - .collect::>(); - let validity = get_validity(array); - - let (deserializers, len) = build_struct_fields(fields, &arrays)?; - Ok(StructDeserializer::new(deserializers, validity, len).into()) -} - pub fn build_struct_fields<'a>( fields: &[GenericField], arrays: &[&'a dyn Array], @@ -148,29 +118,6 @@ pub fn build_struct_fields<'a>( Ok((deserializers, len)) } -pub fn build_list_deserializer<'a, O>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> -where - O: Offset + ArrowOffset, - ArrayDeserializer<'a>: From>, -{ - let Some(array) = array.as_any().downcast_ref::>() else { - fail!("cannot interpret array as LargeList array"); - }; - - let validity = get_validity(array); - let offsets = array.offsets().as_slice(); - - let Some(item_field) = field.children.first() else { - fail!("cannot get first child of list array") - }; - let item = build_array_deserializer(item_field, array.values().as_ref())?; - - Ok(ListDeserializer::new(item, offsets, validity)?.into()) -} - pub fn build_map_deserializer<'a>( field: &GenericField, array: &'a dyn Array, From 7de38b495ab1f4cdd495ae31e1208028538b1da0 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:46:23 +0200 Subject: [PATCH 061/178] Implement Map view --- serde_arrow/src/arrow2_impl/array.rs | 16 +++++++ .../src/arrow2_impl/deserialization.rs | 48 ++----------------- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index cefc659e..e39e9f65 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -293,6 +293,22 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { validity: bits_with_offset_from_bitmap(array.validity()), fields, })) + } else if let Some(array) = any.downcast_ref::() { + let T::Map(field, _) = array.data_type() else { + fail!( + "invalid data type for arrow2 Map array: {:?}", + array.data_type(), + ); + }; + let meta = meta_from_field(field.as_ref().try_into()?)?; + let element: ArrayView<'_> = array.field().as_ref().try_into()?; + + Ok(V::Map(ListArrayView { + element: Box::new(element), + meta, + validity: bits_with_offset_from_bitmap(array.validity()), + offsets: array.offsets().as_slice(), + })) } else { fail!( "Cannot convert array with data type {:?} into an array view", diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index c9ce2a41..15b6b7ee 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,11 +1,7 @@ use crate::internal::{ deserialization::{ - array_deserializer::ArrayDeserializer, - dictionary_deserializer::DictionaryDeserializer, - enum_deserializer::EnumDeserializer, - integer_deserializer::Integer, - map_deserializer::MapDeserializer, - utils::{check_supported_list_layout, BitBuffer}, + array_deserializer::ArrayDeserializer, dictionary_deserializer::DictionaryDeserializer, + enum_deserializer::EnumDeserializer, integer_deserializer::Integer, utils::BitBuffer, }, error::{fail, Result}, schema::{GenericDataType, GenericField}, @@ -13,7 +9,7 @@ use crate::internal::{ }; use crate::_impl::arrow2::{ - array::{Array, DictionaryArray, DictionaryKey, MapArray, StructArray, UnionArray, Utf8Array}, + array::{Array, DictionaryArray, DictionaryKey, UnionArray, Utf8Array}, datatypes::{DataType, UnionMode}, types::Offset as ArrowOffset, }; @@ -25,7 +21,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::Dictionary => build_dictionary_deserializer(field, array), - T::Map => build_map_deserializer(field, array), T::Union => build_union_deserializer(field, array), _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), } @@ -118,43 +113,6 @@ pub fn build_struct_fields<'a>( Ok((deserializers, len)) } -pub fn build_map_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(entries_field) = field.children.first() else { - fail!("cannot get children of map"); - }; - let Some(keys_field) = entries_field.children.first() else { - fail!("cannot get keys field"); - }; - let Some(values_field) = entries_field.children.get(1) else { - fail!("cannot get values field"); - }; - let Some(array) = array.as_any().downcast_ref::() else { - fail!("cannot convert array into map array"); - }; - let Some(entries) = array.field().as_any().downcast_ref::() else { - fail!("cannot convert map field into struct array"); - }; - let Some(keys) = entries.values().first() else { - fail!("cannot get keys array of map entries"); - }; - let Some(values) = entries.values().get(1) else { - fail!("cannot get values array of map entries"); - }; - - let offsets = array.offsets().as_slice(); - let validity = get_validity(array); - - check_supported_list_layout(validity, offsets)?; - - let keys = build_array_deserializer(keys_field, keys.as_ref())?; - let values = build_array_deserializer(values_field, values.as_ref())?; - - Ok(MapDeserializer::new(keys, values, offsets, validity)?.into()) -} - pub fn build_union_deserializer<'a>( field: &GenericField, array: &'a dyn Array, From f0e98553a90cecc3fda0afe0630338002605ab21 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 14:55:41 +0200 Subject: [PATCH 062/178] Implement Union view --- serde_arrow/src/arrow2_impl/array.rs | 29 ++++++++++++++-- .../src/arrow2_impl/deserialization.rs | 34 ++----------------- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index e39e9f65..443a827f 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -12,8 +12,9 @@ use crate::{ internal::{ arrow::{ Array, ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - FieldMeta, ListArrayView, NullArrayView, PrimitiveArray as InternalPrimitiveArray, - PrimitiveArrayView, StructArrayView, TimeArrayView, TimestampArrayView, + DenseUnionArrayView, FieldMeta, ListArrayView, NullArrayView, + PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, StructArrayView, + TimeArrayView, TimestampArrayView, }, error::{fail, Error, Result}, serialization::utils::meta_from_field, @@ -309,6 +310,30 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { validity: bits_with_offset_from_bitmap(array.validity()), offsets: array.offsets().as_slice(), })) + } else if let Some(array) = any.downcast_ref::() { + // TODO: check type ids + let T::Union(union_fields, _, UnionMode::Dense) = array.data_type() else { + fail!("Invalid data type: only dense unions are supported"); + }; + + let types = array.types().as_slice(); + let Some(offsets) = array.offsets() else { + fail!("DenseUnion array without offsets are not supported"); + }; + + let mut fields = Vec::new(); + for (child, child_field) in array.fields().iter().zip(union_fields) { + fields.push(( + child.as_ref().try_into()?, + meta_from_field(child_field.try_into()?)?, + )); + } + + Ok(V::DenseUnion(DenseUnionArrayView { + types, + offsets: offsets.as_slice(), + fields, + })) } else { fail!( "Cannot convert array with data type {:?} into an array view", diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index 15b6b7ee..aaa6a141 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,7 +1,7 @@ use crate::internal::{ deserialization::{ array_deserializer::ArrayDeserializer, dictionary_deserializer::DictionaryDeserializer, - enum_deserializer::EnumDeserializer, integer_deserializer::Integer, utils::BitBuffer, + integer_deserializer::Integer, utils::BitBuffer, }, error::{fail, Result}, schema::{GenericDataType, GenericField}, @@ -9,8 +9,7 @@ use crate::internal::{ }; use crate::_impl::arrow2::{ - array::{Array, DictionaryArray, DictionaryKey, UnionArray, Utf8Array}, - datatypes::{DataType, UnionMode}, + array::{Array, DictionaryArray, DictionaryKey, Utf8Array}, types::Offset as ArrowOffset, }; @@ -21,7 +20,6 @@ pub fn build_array_deserializer<'a>( use GenericDataType as T; match &field.data_type { T::Dictionary => build_dictionary_deserializer(field, array), - T::Union => build_union_deserializer(field, array), _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), } } @@ -113,34 +111,6 @@ pub fn build_struct_fields<'a>( Ok((deserializers, len)) } -pub fn build_union_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - let Some(array) = array.as_any().downcast_ref::() else { - fail!("Cannot interpret array as a union array"); - }; - - if !matches!(array.data_type(), DataType::Union(_, _, UnionMode::Dense)) { - fail!("Invalid data type: only dense unions are supported"); - } - - let type_ids = array.types().as_slice(); - - let mut variants = Vec::new(); - for (type_id, field) in field.children.iter().enumerate() { - let name = field.name.to_owned(); - let Some(child) = array.fields().get(type_id) else { - fail!("Cannot get variant"); - }; - let deser = build_array_deserializer(field, child.as_ref())?; - - variants.push((name, deser)); - } - - Ok(EnumDeserializer::new(type_ids, variants).into()) -} - fn get_validity(arr: &dyn Array) -> Option> { let validity = arr.validity()?; let (data, offset, number_of_bits) = validity.as_slice(); From a29a87670dd863d1ae8c6d71fa9b13c31bfb9618 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 15:03:51 +0200 Subject: [PATCH 063/178] Implement Dictionary view --- serde_arrow/src/arrow2_impl/array.rs | 32 +++++- .../src/arrow2_impl/deserialization.rs | 100 +----------------- 2 files changed, 36 insertions(+), 96 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 443a827f..45e49b9d 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -12,7 +12,7 @@ use crate::{ internal::{ arrow::{ Array, ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - DenseUnionArrayView, FieldMeta, ListArrayView, NullArrayView, + DenseUnionArrayView, DictionaryArrayView, FieldMeta, ListArrayView, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, StructArrayView, TimeArrayView, TimestampArrayView, }, @@ -249,6 +249,22 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { offsets: array.offsets().as_slice(), data: array.values().as_slice(), })) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::Int8, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::Int16, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::Int32, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::Int64, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::UInt8, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::UInt16, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::UInt32, array)?)) + } else if let Some(array) = any.downcast_ref::>() { + Ok(V::Dictionary(view_dictionary_array(V::UInt64, array)?)) } else if let Some(array) = any.downcast_ref::>() { let T::List(field) = array.data_type() else { fail!( @@ -448,6 +464,20 @@ fn view_primitive_array(array: &PrimitiveArray) -> PrimitiveAr } } +fn view_dictionary_array< + 'a, + K: DictionaryKey, + I: FnOnce(PrimitiveArrayView<'a, K>) -> ArrayView<'a>, +>( + index_type: I, + array: &'a DictionaryArray, +) -> Result> { + Ok(DictionaryArrayView { + indices: Box::new(index_type(view_primitive_array(array.keys()))), + values: Box::new(array.values().as_ref().try_into()?), + }) +} + fn bits_with_offset_from_bitmap(bitmap: Option<&Bitmap>) -> Option> { let (data, offset, _) = bitmap?.as_slice(); Some(BitsWithOffset { data, offset }) diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs index aaa6a141..cc28277a 100644 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ b/serde_arrow/src/arrow2_impl/deserialization.rs @@ -1,90 +1,10 @@ use crate::internal::{ - deserialization::{ - array_deserializer::ArrayDeserializer, dictionary_deserializer::DictionaryDeserializer, - integer_deserializer::Integer, utils::BitBuffer, - }, + deserialization::array_deserializer::ArrayDeserializer, error::{fail, Result}, - schema::{GenericDataType, GenericField}, - utils::Offset, + schema::GenericField, }; -use crate::_impl::arrow2::{ - array::{Array, DictionaryArray, DictionaryKey, Utf8Array}, - types::Offset as ArrowOffset, -}; - -pub fn build_array_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - use GenericDataType as T; - match &field.data_type { - T::Dictionary => build_dictionary_deserializer(field, array), - _ => ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?), - } -} - -pub fn build_dictionary_deserializer<'a>( - field: &GenericField, - array: &'a dyn Array, -) -> Result> { - use GenericDataType as T; - - let Some(key_field) = field.children.first() else { - fail!("Missing key field"); - }; - let Some(value_field) = field.children.get(1) else { - fail!("Missing key field"); - }; - - return match (&key_field.data_type, &value_field.data_type) { - (T::U8, T::Utf8) => typed::(field, array), - (T::U16, T::Utf8) => typed::(field, array), - (T::U32, T::Utf8) => typed::(field, array), - (T::U64, T::Utf8) => typed::(field, array), - (T::I8, T::Utf8) => typed::(field, array), - (T::I16, T::Utf8) => typed::(field, array), - (T::I32, T::Utf8) => typed::(field, array), - (T::I64, T::Utf8) => typed::(field, array), - (T::U8, T::LargeUtf8) => typed::(field, array), - (T::U16, T::LargeUtf8) => typed::(field, array), - (T::U32, T::LargeUtf8) => typed::(field, array), - (T::U64, T::LargeUtf8) => typed::(field, array), - (T::I8, T::LargeUtf8) => typed::(field, array), - (T::I16, T::LargeUtf8) => typed::(field, array), - (T::I32, T::LargeUtf8) => typed::(field, array), - (T::I64, T::LargeUtf8) => typed::(field, array), - _ => fail!("invalid dicitonary key / value data type"), - }; - - pub fn typed<'a, K, V>( - _field: &GenericField, - array: &'a dyn Array, - ) -> Result> - where - K: DictionaryKey + Integer, - V: Offset + ArrowOffset, - DictionaryDeserializer<'a, K, V>: Into>, - { - let Some(array) = array.as_any().downcast_ref::>() else { - fail!("cannot convert array into dictionary array"); - }; - let Some(values) = array.values().as_any().downcast_ref::>() else { - fail!("invalid values"); - }; - - let keys_buffer = array.keys().values(); - let keys_validity = get_validity(array); - - let values_data = values.values().as_slice(); - let values_offsets = values.offsets().as_slice(); - - Ok( - DictionaryDeserializer::new(keys_buffer, keys_validity, values_data, values_offsets) - .into(), - ) - } -} +use crate::_impl::arrow2::array::Array; pub fn build_struct_fields<'a>( fields: &[GenericField], @@ -104,19 +24,9 @@ pub fn build_struct_fields<'a>( if array.len() != len { fail!("arrays of different lengths are not supported"); } - - deserializers.push((field.name.clone(), build_array_deserializer(field, array)?)); + let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; + deserializers.push((field.name.clone(), deserializer)); } Ok((deserializers, len)) } - -fn get_validity(arr: &dyn Array) -> Option> { - let validity = arr.validity()?; - let (data, offset, number_of_bits) = validity.as_slice(); - Some(BitBuffer { - data, - offset, - number_of_bits, - }) -} From c30281df6685fff68554e238857fe084a617c00f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 3 Aug 2024 15:05:58 +0200 Subject: [PATCH 064/178] Remove unnecessary abstraction --- serde_arrow/src/arrow2_impl/api.rs | 26 +++++++++++++-- .../src/arrow2_impl/deserialization.rs | 32 ------------------- serde_arrow/src/arrow2_impl/mod.rs | 1 - 3 files changed, 23 insertions(+), 36 deletions(-) delete mode 100644 serde_arrow/src/arrow2_impl/deserialization.rs diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index ec96881f..912fd33d 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -9,9 +9,12 @@ use crate::{ _impl::arrow2::{array::Array, datatypes::Field}, internal::{ array_builder::ArrayBuilder, - deserialization::outer_sequence_deserializer::OuterSequenceDeserializer, + deserialization::{ + array_deserializer::ArrayDeserializer, + outer_sequence_deserializer::OuterSequenceDeserializer, + }, deserializer::Deserializer, - error::Result, + error::{fail, Result}, schema::{GenericField, SerdeArrowSchema}, serializer::Serializer, }, @@ -158,7 +161,24 @@ impl<'de> Deserializer<'de> { .map(|array| array.as_ref()) .collect::>(); - let (deserializers, len) = super::deserialization::build_struct_fields(&fields, &arrays)?; + if fields.len() != arrays.len() { + fail!( + "different number of fields ({}) and arrays ({})", + fields.len(), + arrays.len() + ); + } + let len = arrays.first().map(|array| array.len()).unwrap_or_default(); + + let mut deserializers = Vec::new(); + for (field, array) in std::iter::zip(fields, arrays) { + if array.len() != len { + fail!("arrays of different lengths are not supported"); + } + let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; + deserializers.push((field.name.clone(), deserializer)); + } + let deserializer = OuterSequenceDeserializer::new(deserializers, len); let deserializer = Deserializer(deserializer); diff --git a/serde_arrow/src/arrow2_impl/deserialization.rs b/serde_arrow/src/arrow2_impl/deserialization.rs deleted file mode 100644 index cc28277a..00000000 --- a/serde_arrow/src/arrow2_impl/deserialization.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::internal::{ - deserialization::array_deserializer::ArrayDeserializer, - error::{fail, Result}, - schema::GenericField, -}; - -use crate::_impl::arrow2::array::Array; - -pub fn build_struct_fields<'a>( - fields: &[GenericField], - arrays: &[&'a dyn Array], -) -> Result<(Vec<(String, ArrayDeserializer<'a>)>, usize)> { - if fields.len() != arrays.len() { - fail!( - "different number of fields ({}) and arrays ({})", - fields.len(), - arrays.len() - ); - } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - - let mut deserializers = Vec::new(); - for (field, &array) in std::iter::zip(fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } - let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; - deserializers.push((field.name.clone(), deserializer)); - } - - Ok((deserializers, len)) -} diff --git a/serde_arrow/src/arrow2_impl/mod.rs b/serde_arrow/src/arrow2_impl/mod.rs index 64df8e4d..2a33d5c9 100644 --- a/serde_arrow/src/arrow2_impl/mod.rs +++ b/serde_arrow/src/arrow2_impl/mod.rs @@ -5,6 +5,5 @@ #![deny(missing_docs)] pub(crate) mod api; mod array; -pub(crate) mod deserialization; mod schema; mod type_support; From f98a5ca2a1890f1f4b2077c3d17502fd0c851af9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 10:15:40 +0200 Subject: [PATCH 065/178] Add known cfgs to Cargo.toml --- serde_arrow/Cargo.toml | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/serde_arrow/Cargo.toml b/serde_arrow/Cargo.toml index 780fc528..49c62da6 100644 --- a/serde_arrow/Cargo.toml +++ b/serde_arrow/Cargo.toml @@ -167,4 +167,31 @@ features = [ "serde-with-float", # NOTE activating this feature breaks JSON -> float processing # "serde-with-arbitrary-precision", -] \ No newline at end of file +] + +[lints.rust.unexpected_cfgs] +level = "warn" +check-cfg = [ + 'cfg(has_arrow2)', + 'cfg(has_arrow2_0_17)', + 'cfg(has_arrow2_0_16)', + 'cfg(has_arrow)', + 'cfg(has_arrow_fixed_binary_support)', + # arrow-version:insert: 'cfg(has_arrow_{version})', + 'cfg(has_arrow_52)', + 'cfg(has_arrow_51)', + 'cfg(has_arrow_50)', + 'cfg(has_arrow_49)', + 'cfg(has_arrow_48)', + 'cfg(has_arrow_47)', + 'cfg(has_arrow_46)', + 'cfg(has_arrow_45)', + 'cfg(has_arrow_44)', + 'cfg(has_arrow_43)', + 'cfg(has_arrow_42)', + 'cfg(has_arrow_41)', + 'cfg(has_arrow_40)', + 'cfg(has_arrow_39)', + 'cfg(has_arrow_38)', + 'cfg(has_arrow_37)', +] From 058ab4e265297c4611efa51ba6cbd339a5b791c8 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 10:50:10 +0200 Subject: [PATCH 066/178] Fix clippy --- serde_arrow/src/arrow2_impl/array.rs | 24 +++++++++++++----------- serde_arrow/src/arrow_impl/array.rs | 4 ++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 45e49b9d..36b1d10d 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -21,7 +21,9 @@ use crate::{ }, }; -impl TryFrom for Box { +type ArrayRef = Box; + +impl TryFrom for ArrayRef { type Error = Error; fn try_from(value: Array) -> Result { @@ -115,7 +117,7 @@ impl TryFrom for Box { ))) } A::Map(arr) => { - let child: Box = (*arr.element).try_into()?; + let child: ArrayRef = (*arr.element).try_into()?; let field = field_from_array_and_meta(child.as_ref(), arr.meta); let validity = arr .validity @@ -363,7 +365,7 @@ fn build_primitive_array( data_type: DataType, buffer: Vec, validity: Option>, -) -> Result> { +) -> Result { let validity = validity.map(|v| Bitmap::from_u8_vec(v, buffer.len())); let buffer = Buffer::from(buffer); Ok(Box::new(PrimitiveArray::try_new( @@ -376,7 +378,7 @@ fn build_utf8_array( offsets: Vec, data: Vec, validity: Option>, -) -> Result> { +) -> Result { let validity = validity.map(|v| Bitmap::from_u8_vec(v, offsets.len().saturating_sub(1))); Ok(Box::new(Utf8Array::new( data_type, @@ -391,7 +393,7 @@ fn build_binary_array( offsets: Vec, data: Vec, validity: Option>, -) -> Result> { +) -> Result { let validity = validity.map(|v| Bitmap::from_u8_vec(v, offsets.len().saturating_sub(1))); Ok(Box::new(BinaryArray::new( data_type, @@ -405,9 +407,9 @@ fn build_list_array) -> DataType, O: Offset>( data_type: F, offsets: Vec, meta: FieldMeta, - values: Box, + values: ArrayRef, validity: Option>, -) -> Result> { +) -> Result { let validity = validity.map(|v| Bitmap::from_u8_vec(v, offsets.len().saturating_sub(1))); Ok(Box::new(ListArray::new( data_type(Box::new(field_from_array_and_meta(values.as_ref(), meta))), @@ -424,12 +426,12 @@ fn field_from_array_and_meta(arr: &dyn A2Array, meta: FieldMeta) -> Field { fn array_with_meta_to_array_and_fields( arrays: Vec<(Array, FieldMeta)>, -) -> Result<(Vec>, Vec)> { +) -> Result<(Vec, Vec)> { let mut res_fields = Vec::new(); let mut res_arrays = Vec::new(); for (child, meta) in arrays { - let child: Box = child.try_into()?; + let child: ArrayRef = child.try_into()?; let field = field_from_array_and_meta(child.as_ref(), meta); res_arrays.push(child); @@ -443,8 +445,8 @@ fn build_dictionary_array( indices_type: IntegerType, indices: InternalPrimitiveArray, values: Array, -) -> Result> { - let values: Box = values.try_into()?; +) -> Result { + let values: ArrayRef = values.try_into()?; let validity = indices .validity .map(|v| Bitmap::from_u8_vec(v, indices.values.len())); diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index fadcd94a..dbc6f5ba 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -602,7 +602,7 @@ fn wrap_dictionary_array( } #[cfg(has_arrow_fixed_binary_support)] -pub fn wrap_fixed_size_binary_array<'a>(array: &'a FixedSizeBinaryArray) -> Result> { +pub fn wrap_fixed_size_binary_array(array: &FixedSizeBinaryArray) -> Result> { use crate::internal::arrow::FixedSizeBinaryArrayView; Ok(ArrayView::FixedSizeBinary(FixedSizeBinaryArrayView { @@ -613,7 +613,7 @@ pub fn wrap_fixed_size_binary_array<'a>(array: &'a FixedSizeBinaryArray) -> Resu } #[cfg(not(has_arrow_fixed_binary_support))] -pub fn wrap_fixed_size_binary_array<'a>(_array: &'a FixedSizeBinaryArray) -> Result> { +pub fn wrap_fixed_size_binary_array(_array: &FixedSizeBinaryArray) -> Result> { fail!("FixedSizeBinary arrays are not supported for arrow<=46"); } From fa1b7b2c84b0526bde915536f7d7bac848c514b1 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 10:50:47 +0200 Subject: [PATCH 067/178] Remove unused cfg --- serde_arrow/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/serde_arrow/src/lib.rs b/serde_arrow/src/lib.rs index 44be12f0..364e4ccf 100644 --- a/serde_arrow/src/lib.rs +++ b/serde_arrow/src/lib.rs @@ -284,7 +284,6 @@ pub mod _impl { #[cfg(has_arrow_39)] build_arrow_crate!(arrow_array_39, arrow_buffer_39, arrow_data_39, arrow_schema_39); #[cfg(has_arrow_38)] build_arrow_crate!(arrow_array_38, arrow_buffer_38, arrow_data_38, arrow_schema_38); #[cfg(has_arrow_37)] build_arrow_crate!(arrow_array_37, arrow_buffer_37, arrow_data_37, arrow_schema_37); - #[cfg(has_arrow_36)] build_arrow_crate!(arrow_array_36, arrow_buffer_36, arrow_data_36, arrow_schema_36); /// Documentation pub mod docs { From ce7835bc28950a41bfcc0ba532648eef281061d7 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 10:51:00 +0200 Subject: [PATCH 068/178] Use arrays directly in builders (where sensible) --- serde_arrow/src/internal/arrow/array.rs | 4 + .../src/internal/serialization/array_ext.rs | 330 ++++++++++++++++++ .../internal/serialization/binary_builder.rs | 57 +-- .../internal/serialization/bool_builder.rs | 51 +-- .../internal/serialization/date32_builder.rs | 49 +-- .../internal/serialization/date64_builder.rs | 62 ++-- .../internal/serialization/decimal_builder.rs | 78 +---- .../serialization/duration_builder.rs | 62 ++-- .../fixed_size_binary_builder.rs | 38 +- .../serialization/fixed_size_list_builder.rs | 33 +- .../internal/serialization/float_builder.rs | 104 +++--- .../src/internal/serialization/int_builder.rs | 78 ++--- .../internal/serialization/list_builder.rs | 53 ++- .../src/internal/serialization/map_builder.rs | 34 +- serde_arrow/src/internal/serialization/mod.rs | 1 + .../serialization/outer_sequence_builder.rs | 10 +- .../internal/serialization/struct_builder.rs | 32 +- .../internal/serialization/time_builder.rs | 56 ++- .../internal/serialization/utf8_builder.rs | 53 +-- .../src/internal/serialization/utils.rs | 104 +----- serde_arrow/src/internal/utils/mod.rs | 2 +- 21 files changed, 640 insertions(+), 651 deletions(-) create mode 100644 serde_arrow/src/internal/serialization/array_ext.rs diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index b3db7209..2521fd8e 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -191,10 +191,14 @@ pub struct ListArrayView<'a, O> { pub element: Box>, } +/// An array comprised of lists of fixed size #[derive(Clone, Debug)] pub struct FixedSizeListArray { + /// The number of elements in this array, each a list with `n` children pub len: usize, + /// The number of children per element pub n: i32, + /// The validity mask of the elements pub validity: Option>, pub meta: FieldMeta, pub element: Box, diff --git a/serde_arrow/src/internal/serialization/array_ext.rs b/serde_arrow/src/internal/serialization/array_ext.rs new file mode 100644 index 00000000..3248ae60 --- /dev/null +++ b/serde_arrow/src/internal/serialization/array_ext.rs @@ -0,0 +1,330 @@ +//! Extension of the array types + +use crate::internal::{ + arrow::{BytesArray, PrimitiveArray}, + error::{fail, Result}, + utils::Offset, +}; + +pub trait ArrayExt: Sized + 'static { + fn take(&mut self) -> Self; +} + +pub trait ScalarArrayExt<'value>: ArrayExt { + type Value: 'value; + + fn push_scalar_default(&mut self) -> Result<()>; + fn push_scalar_none(&mut self) -> Result<()>; + fn push_scalar_value(&mut self, value: Self::Value) -> Result<()>; +} + +/// An array that models a sequence +/// +/// As some sequence arrays, e.g., `ListArrays`, can contain arbitrarily nested subarrays, the +/// element itself is not modelled. +pub trait SeqArrayExt: ArrayExt { + fn push_seq_default(&mut self) -> Result<()>; + fn push_seq_none(&mut self) -> Result<()>; + fn start_seq(&mut self) -> Result<()>; + fn push_seq_elements(&mut self, n: usize) -> Result<()>; + fn end_seq(&mut self) -> Result<()>; +} + +pub fn new_primitive_array(is_nullable: bool) -> PrimitiveArray { + PrimitiveArray { + validity: is_nullable.then(Vec::new), + values: Vec::new(), + } +} + +impl ArrayExt for PrimitiveArray { + fn take(&mut self) -> Self { + Self { + validity: self.validity.as_mut().map(std::mem::take), + values: std::mem::take(&mut self.values), + } + } +} + +impl ScalarArrayExt<'static> for PrimitiveArray { + type Value = T; + + fn push_scalar_default(&mut self) -> Result<()> { + set_validity_default(self.validity.as_mut(), self.values.len()); + self.values.push(T::default()); + Ok(()) + } + + fn push_scalar_none(&mut self) -> Result<()> { + set_validity(self.validity.as_mut(), self.values.len(), false)?; + self.values.push(T::default()); + Ok(()) + } + + fn push_scalar_value(&mut self, value: Self::Value) -> Result<()> { + set_validity(self.validity.as_mut(), self.values.len(), true)?; + self.values.push(value); + Ok(()) + } +} + +pub fn new_bytes_array(is_nullable: bool) -> BytesArray { + BytesArray { + validity: is_nullable.then(Vec::new), + offsets: vec![O::default()], + data: Vec::new(), + } +} + +impl ArrayExt for BytesArray { + fn take(&mut self) -> Self { + Self { + validity: self.validity.as_mut().map(std::mem::take), + data: std::mem::take(&mut self.data), + offsets: std::mem::replace(&mut self.offsets, vec![O::default()]), + } + } +} + +impl SeqArrayExt for BytesArray { + fn push_seq_default(&mut self) -> Result<()> { + self.push_scalar_default() + } + + fn push_seq_none(&mut self) -> Result<()> { + self.push_scalar_none() + } + + fn start_seq(&mut self) -> Result<()> { + set_validity( + self.validity.as_mut(), + self.offsets.len().saturating_sub(1), + true, + )?; + duplicate_last(&mut self.offsets)?; + Ok(()) + } + + fn push_seq_elements(&mut self, n: usize) -> Result<()> { + increment_last(&mut self.offsets, n)?; + Ok(()) + } + + fn end_seq(&mut self) -> Result<()> { + Ok(()) + } +} + +impl<'s, O: Offset> ScalarArrayExt<'s> for BytesArray { + type Value = &'s [u8]; + + fn push_scalar_default(&mut self) -> Result<()> { + set_validity_default(self.validity.as_mut(), self.offsets.len().saturating_sub(1)); + duplicate_last(&mut self.offsets)?; + Ok(()) + } + + fn push_scalar_none(&mut self) -> Result<()> { + set_validity( + self.validity.as_mut(), + self.offsets.len().saturating_sub(1), + false, + )?; + duplicate_last(&mut self.offsets)?; + Ok(()) + } + + fn push_scalar_value(&mut self, value: Self::Value) -> Result<()> { + set_validity( + self.validity.as_mut(), + self.offsets.len().saturating_sub(1), + true, + )?; + duplicate_last(&mut self.offsets)?; + increment_last(&mut self.offsets, value.len())?; + self.data.extend(value); + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct OffsetsArray { + pub validity: Option>, + pub offsets: Vec, +} + +impl OffsetsArray { + pub fn new(is_nullable: bool) -> Self { + Self { + validity: is_nullable.then(Vec::new), + offsets: vec![O::default()], + } + } +} + +impl ArrayExt for OffsetsArray { + fn take(&mut self) -> Self { + Self { + validity: self.validity.as_mut().map(std::mem::take), + offsets: std::mem::replace(&mut self.offsets, vec![O::default()]), + } + } +} + +impl SeqArrayExt for OffsetsArray { + fn push_seq_default(&mut self) -> Result<()> { + set_validity_default(self.validity.as_mut(), self.offsets.len().saturating_sub(1)); + duplicate_last(&mut self.offsets)?; + Ok(()) + } + + fn push_seq_none(&mut self) -> Result<()> { + set_validity( + self.validity.as_mut(), + self.offsets.len().saturating_sub(1), + false, + )?; + duplicate_last(&mut self.offsets)?; + Ok(()) + } + + fn start_seq(&mut self) -> Result<()> { + set_validity( + self.validity.as_mut(), + self.offsets.len().saturating_sub(1), + true, + )?; + duplicate_last(&mut self.offsets)?; + Ok(()) + } + + fn push_seq_elements(&mut self, n: usize) -> Result<()> { + increment_last(&mut self.offsets, n)?; + Ok(()) + } + + fn end_seq(&mut self) -> Result<()> { + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct CountArray { + pub len: usize, + pub validity: Option>, +} + +impl CountArray { + pub fn new(is_nullable: bool) -> Self { + Self { + len: 0, + validity: is_nullable.then(Vec::new), + } + } +} + +impl ArrayExt for CountArray { + fn take(&mut self) -> Self { + Self { + len: std::mem::take(&mut self.len), + validity: self.validity.as_mut().map(std::mem::take), + } + } +} + +impl SeqArrayExt for CountArray { + fn push_seq_default(&mut self) -> Result<()> { + set_validity_default(self.validity.as_mut(), self.len); + self.len += 1; + Ok(()) + } + + fn push_seq_none(&mut self) -> Result<()> { + set_validity(self.validity.as_mut(), self.len, false)?; + self.len += 1; + Ok(()) + } + + fn start_seq(&mut self) -> Result<()> { + set_validity(self.validity.as_mut(), self.len, true)?; + self.len += 1; + Ok(()) + } + + fn push_seq_elements(&mut self, _n: usize) -> Result<()> { + Ok(()) + } + + fn end_seq(&mut self) -> Result<()> { + Ok(()) + } +} + +pub fn duplicate_last(vec: &mut Vec) -> Result<()> { + let Some(last) = vec.last() else { + fail!("invalid offset array") + }; + vec.push(last.clone()); + Ok(()) +} + +pub fn increment_last(vec: &mut Vec, inc: usize) -> Result<()> { + let Some(last) = vec.last_mut() else { + fail!("invalid offset array") + }; + *last = *last + O::try_form_usize(inc)?; + Ok(()) +} + +pub fn set_validity(buffer: Option<&mut Vec>, idx: usize, value: bool) -> Result<()> { + if let Some(buffer) = buffer { + set_bit_buffer(buffer, idx, value); + Ok(()) + } else if value { + Ok(()) + } else { + fail!("cannot push null for non-nullable array"); + } +} + +/// In contrast to `set_validity` nulls for non-nullable fields are not an error +pub fn set_validity_default(buffer: Option<&mut Vec>, idx: usize) { + if let Some(buffer) = buffer { + set_bit_buffer(buffer, idx, false); + } +} + +pub fn set_bit_buffer(buffer: &mut Vec, idx: usize, value: bool) { + while idx / 8 >= buffer.len() { + buffer.push(0); + } + + let bit_mask: u8 = 1 << (idx % 8); + if value { + buffer[idx / 8] |= bit_mask; + } else { + buffer[idx / 8] &= !bit_mask; + } +} + +#[test] +fn test_set_bit_buffer() { + let mut buffer = vec![]; + + set_bit_buffer(&mut buffer, 0, true); + set_bit_buffer(&mut buffer, 1, false); + set_bit_buffer(&mut buffer, 2, false); + set_bit_buffer(&mut buffer, 3, false); + set_bit_buffer(&mut buffer, 4, true); + set_bit_buffer(&mut buffer, 5, true); + + assert_eq!(buffer, vec![0b_0011_0001]); + + set_bit_buffer(&mut buffer, 16 + 2, true); + set_bit_buffer(&mut buffer, 16 + 4, false); + + assert_eq!(buffer, vec![0b_0011_0001, 0b_0000_0000, 0b_0000_0100]); + + set_bit_buffer(&mut buffer, 4, false); + assert_eq!(buffer, vec![0b_0010_0001, 0b_0000_0000, 0b_0000_0100]); +} diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 192195ac..f926c8bb 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -6,77 +6,55 @@ use crate::internal::{ utils::{Mut, Offset}, }; -use super::utils::{ - push_validity, push_validity_default, MutableBitBuffer, MutableOffsetBuffer, SimpleSerializer, +use super::{ + array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, + utils::SimpleSerializer, }; #[derive(Debug, Clone)] -pub struct BinaryBuilder { - pub validity: Option, - pub offsets: MutableOffsetBuffer, - pub buffer: Vec, -} +pub struct BinaryBuilder(BytesArray); impl BinaryBuilder { pub fn new(is_nullable: bool) -> Self { - Self { - validity: is_nullable.then(MutableBitBuffer::default), - offsets: Default::default(), - buffer: Vec::new(), - } + Self(new_bytes_array(is_nullable)) } pub fn take(&mut self) -> Self { - Self { - validity: self.validity.as_mut().map(std::mem::take), - offsets: std::mem::take(&mut self.offsets), - buffer: std::mem::take(&mut self.buffer), - } + Self(self.0.take()) } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.0.validity.is_some() } } impl BinaryBuilder { pub fn into_array(self) -> Result { - Ok(Array::Binary(BytesArray { - validity: self.validity.map(|b| b.buffer), - offsets: self.offsets.offsets, - data: self.buffer, - })) + Ok(Array::Binary(self.0)) } } impl BinaryBuilder { pub fn into_array(self) -> Result { - Ok(Array::LargeBinary(BytesArray { - validity: self.validity.map(|b| b.buffer), - offsets: self.offsets.offsets, - data: self.buffer, - })) + Ok(Array::LargeBinary(self.0)) } } impl BinaryBuilder { fn start(&mut self) -> Result<()> { - push_validity(&mut self.validity, true) + self.0.start_seq() } fn element(&mut self, value: &V) -> Result<()> { let mut u8_serializer = U8Serializer(0); value.serialize(Mut(&mut u8_serializer))?; - self.offsets.inc_current_items()?; - self.buffer.push(u8_serializer.0); - - Ok(()) + self.0.data.push(u8_serializer.0); + self.0.push_seq_elements(1) } fn end(&mut self) -> Result<()> { - self.offsets.push_current_items(); Ok(()) } } @@ -87,14 +65,11 @@ impl SimpleSerializer for BinaryBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.offsets.push_current_items(); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_current_items(); - push_validity(&mut self.validity, false) + self.0.push_scalar_none() } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { @@ -134,9 +109,7 @@ impl SimpleSerializer for BinaryBuilder { } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.extend(v); - self.offsets.push(v.len()) + self.0.push_scalar_value(v) } } diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index a884e341..1c4c24d1 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -3,39 +3,37 @@ use crate::internal::{ error::Result, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{set_bit_buffer, set_validity, set_validity_default}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] -pub struct BoolBuilder { - pub validity: Option, - pub buffer: MutableBitBuffer, -} +pub struct BoolBuilder(BooleanArray); impl BoolBuilder { pub fn new(is_nullable: bool) -> Self { - Self { - validity: is_nullable.then(MutableBitBuffer::default), - buffer: MutableBitBuffer::default(), - } + Self(BooleanArray { + len: 0, + validity: is_nullable.then(Vec::new), + values: Vec::new(), + }) } pub fn take(&mut self) -> Self { - Self { - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), - } + Self(BooleanArray { + len: std::mem::take(&mut self.0.len), + validity: self.0.validity.as_mut().map(std::mem::take), + values: std::mem::take(&mut self.0.values), + }) } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.0.validity.is_some() } pub fn into_array(self) -> Result { - Ok(Array::Boolean(BooleanArray { - len: self.buffer.len, - validity: self.validity.map(|v| v.buffer), - values: self.buffer.buffer, - })) + Ok(Array::Boolean(self.0)) } } @@ -45,20 +43,23 @@ impl SimpleSerializer for BoolBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(false); + set_validity_default(self.0.validity.as_mut(), self.0.len); + set_bit_buffer(&mut self.0.values, self.0.len, false); + self.0.len += 1; Ok(()) } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(false); + set_validity(self.0.validity.as_mut(), self.0.len, false)?; + set_bit_buffer(&mut self.0.values, self.0.len, false); + self.0.len += 1; Ok(()) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(v); + set_validity(self.0.validity.as_mut(), self.0.len, true)?; + set_bit_buffer(&mut self.0.values, self.0.len, v); + self.0.len += 1; Ok(()) } } diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 8537c5df..077c2672 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -3,44 +3,31 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, error::Result, - schema::GenericField, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] -pub struct Date32Builder { - pub field: GenericField, - pub validity: Option, - pub buffer: Vec, -} +pub struct Date32Builder(PrimitiveArray); impl Date32Builder { - pub fn new(field: GenericField, nullable: bool) -> Self { - Self { - field, - validity: nullable.then(MutableBitBuffer::default), - buffer: Vec::new(), - } + pub fn new(is_nullable: bool) -> Self { + Self(new_primitive_array(is_nullable)) } pub fn take(&mut self) -> Self { - Self { - field: self.field.clone(), - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), - } + Self(self.0.take()) } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.0.validity.is_some() } pub fn into_array(self) -> Result { - Ok(Array::Date32(PrimitiveArray { - validity: self.validity.map(|validity| validity.buffer), - values: self.buffer, - })) + Ok(Array::Date32(self.0)) } } @@ -50,15 +37,11 @@ impl SimpleSerializer for Date32Builder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(0); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(0); - Ok(()) + self.0.push_scalar_none() } fn serialize_str(&mut self, v: &str) -> Result<()> { @@ -68,14 +51,10 @@ impl SimpleSerializer for Date32Builder { let duration_since_epoch = date.signed_duration_since(UNIX_EPOCH); let days_since_epoch = duration_since_epoch.num_days().try_into()?; - push_validity(&mut self.validity, true)?; - self.buffer.push(days_since_epoch); - Ok(()) + self.0.push_scalar_value(days_since_epoch) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(v); - Ok(()) + self.0.push_scalar_value(v) } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 9668a199..7aa44e67 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,26 +1,27 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, - error::{Error, Result}, + error::{fail, Result}, schema::{GenericDataType, GenericField}, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] pub struct Date64Builder { pub field: GenericField, pub utc: bool, - pub validity: Option, - pub buffer: Vec, + pub array: PrimitiveArray, } impl Date64Builder { - pub fn new(field: GenericField, utc: bool, nullable: bool) -> Self { + pub fn new(field: GenericField, utc: bool, is_nullable: bool) -> Self { Self { field, utc, - validity: nullable.then(MutableBitBuffer::default), - buffer: Vec::new(), + array: new_primitive_array(is_nullable), } } @@ -28,27 +29,28 @@ impl Date64Builder { Self { field: self.field.clone(), utc: self.utc, - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), + array: self.array.take(), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.array.validity.is_some() } + // TODO: fix this pub fn into_array(self) -> Result { if let GenericDataType::Timestamp(unit, timezone) = self.field.data_type { Ok(Array::Timestamp(TimestampArray { unit, timezone, - validity: self.validity.map(|validity| validity.buffer), - values: self.buffer, + validity: self.array.validity, + values: self.array.values, })) } else { + // TOOD: check data type Ok(Array::Date64(PrimitiveArray { - validity: self.validity.map(|validity| validity.buffer), - values: self.buffer, + validity: self.array.validity, + values: self.array.values, })) } } @@ -60,15 +62,11 @@ impl SimpleSerializer for Date64Builder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(0); - Ok(()) + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(0); - Ok(()) + self.array.push_scalar_none() } fn serialize_str(&mut self, v: &str) -> Result<()> { @@ -82,24 +80,28 @@ impl SimpleSerializer for Date64Builder { let timestamp = match self.field.data_type { GenericDataType::Timestamp(TimeUnit::Nanosecond, _) => { - date_time - .timestamp_nanos_opt() - .ok_or_else(|| Error::custom(format!("Timestamp '{v}' cannot be converted to nanoseconds. The dates that can be represented as nanoseconds are between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804.")))? - }, + match date_time.timestamp_nanos_opt() { + Some(timestamp) => timestamp, + _ => fail!( + concat!( + "Timestamp '{date_time}' cannot be converted to nanoseconds. ", + "The dates that can be represented as nanoseconds are between ", + "1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804.", + ), + date_time = date_time, + ), + } + } GenericDataType::Timestamp(TimeUnit::Microsecond, _) => date_time.timestamp_micros(), GenericDataType::Timestamp(TimeUnit::Millisecond, _) => date_time.timestamp_millis(), GenericDataType::Timestamp(TimeUnit::Second, _) => date_time.timestamp(), _ => date_time.timestamp_millis(), }; - push_validity(&mut self.validity, true)?; - self.buffer.push(timestamp); - Ok(()) + self.array.push_scalar_value(timestamp) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(v); - Ok(()) + self.array.push_scalar_value(v) } } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 5bb48de3..437f876e 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,32 +1,33 @@ use crate::internal::{ - arrow::{Array, DecimalArray}, + arrow::{Array, DecimalArray, PrimitiveArray}, error::Result, utils::decimal::{self, DecimalParser}, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] pub struct DecimalBuilder { pub precision: u8, pub scale: i8, - pub validity: Option, - pub buffer: Vec, pub f32_factor: f32, pub f64_factor: f64, pub parser: DecimalParser, + pub array: PrimitiveArray, } impl DecimalBuilder { - pub fn new(precision: u8, scale: i8, nullable: bool) -> Self { + pub fn new(precision: u8, scale: i8, is_nullable: bool) -> Self { Self { precision, scale, - validity: nullable.then(MutableBitBuffer::default), - buffer: Vec::new(), f32_factor: (10.0_f32).powi(scale as i32), f64_factor: (10.0_f64).powi(scale as i32), parser: DecimalParser::new(precision, scale, true), + array: new_primitive_array(is_nullable), } } @@ -34,24 +35,23 @@ impl DecimalBuilder { Self { precision: self.precision, scale: self.scale, - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), f32_factor: self.f32_factor, f64_factor: self.f64_factor, parser: self.parser, + array: self.array.take(), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.array.validity.is_some() } pub fn into_array(self) -> Result { Ok(Array::Decimal128(DecimalArray { precision: self.precision, scale: self.scale, - validity: self.validity.map(|b| b.buffer), - values: self.buffer, + validity: self.array.validity, + values: self.array.values, })) } } @@ -62,27 +62,19 @@ impl SimpleSerializer for DecimalBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(0); - Ok(()) + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(0); - Ok(()) + self.array.push_scalar_none() } fn serialize_f32(&mut self, v: f32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push((v * self.f32_factor) as i128); - Ok(()) + self.array.push_scalar_value((v * self.f32_factor) as i128) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push((v * self.f64_factor) as i128); - Ok(()) + self.array.push_scalar_value((v * self.f64_factor) as i128) } fn serialize_str(&mut self, v: &str) -> Result<()> { @@ -91,42 +83,6 @@ impl SimpleSerializer for DecimalBuilder { .parser .parse_decimal128(&mut parse_buffer, v.as_bytes())?; - push_validity(&mut self.validity, true)?; - self.buffer.push(val); - Ok(()) + self.array.push_scalar_value(val) } } - -/* - - fn accept_f32( - &self, - _: &Structure, - context: &mut SerializationContext, - val: f32, - ) -> Result { - - } - - fn accept_f64( - &self, - _: &Structure, - context: &mut SerializationContext, - val: f64, - ) -> Result { - let val = (val * self.f64_factor) as i128; - context.buffers.u128[self.idx].push(ToBytes::to_bytes(val)); - Ok(self.next) - } - - fn accept_str( - &self, - _: &Structure, - context: &mut SerializationContext, - val: &str, - ) -> Result { - let mut buffer = [0; decimal::BUFFER_SIZE_I128]; - context.buffers.u128[self.idx].push(ToBytes::to_bytes(val)); - Ok(self.next) - } -*/ diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 8b43845a..27f2ac01 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,43 +1,43 @@ use crate::internal::{ - arrow::{Array, TimeArray, TimeUnit}, + arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, error::Result, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] pub struct DurationBuilder { pub unit: TimeUnit, - pub validity: Option, - pub buffer: Vec, + pub array: PrimitiveArray, } impl DurationBuilder { pub fn new(unit: TimeUnit, is_nullable: bool) -> Self { Self { unit, - validity: is_nullable.then(MutableBitBuffer::default), - buffer: Default::default(), + array: new_primitive_array(is_nullable), } } pub fn take(&mut self) -> Self { Self { unit: self.unit, - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), + array: self.array.take(), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.array.validity.is_some() } pub fn into_array(self) -> Result { Ok(Array::Duration(TimeArray { unit: self.unit, - validity: self.validity.map(|b| b.buffer), - values: self.buffer, + validity: self.array.validity, + values: self.array.values, })) } } @@ -48,62 +48,42 @@ impl SimpleSerializer for DurationBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(i64::default()); - Ok(()) + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(i64::default()); - Ok(()) + self.array.push_scalar_none() } fn serialize_i8(&mut self, v: i8) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::from(v)); - Ok(()) + self.array.push_scalar_value(i64::from(v)) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::from(v)); - Ok(()) + self.array.push_scalar_value(i64::from(v)) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::from(v)); - Ok(()) + self.array.push_scalar_value(i64::from(v)) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(v); - Ok(()) + self.array.push_scalar_value(v) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::from(v)); - Ok(()) + self.array.push_scalar_value(i64::from(v)) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::from(v)); - Ok(()) + self.array.push_scalar_value(i64::from(v)) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::from(v)); - Ok(()) + self.array.push_scalar_value(i64::from(v)) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(i64::try_from(v)?); - Ok(()) + self.array.push_scalar_value(i64::try_from(v)?) } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 06fc3193..2a806b7c 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -6,14 +6,16 @@ use crate::internal::{ utils::Mut, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{ArrayExt, CountArray, SeqArrayExt}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] pub struct FixedSizeBinaryBuilder { - pub validity: Option, + pub seq: CountArray, pub buffer: Vec, - pub len: usize, pub current_n: usize, pub n: usize, } @@ -21,32 +23,30 @@ pub struct FixedSizeBinaryBuilder { impl FixedSizeBinaryBuilder { pub fn new(n: usize, is_nullable: bool) -> Self { Self { - validity: is_nullable.then(MutableBitBuffer::default), + seq: CountArray::new(is_nullable), buffer: Vec::new(), n, - len: 0, current_n: 0, } } pub fn take(&mut self) -> Self { Self { - validity: self.validity.as_mut().map(std::mem::take), + seq: self.seq.take(), buffer: std::mem::take(&mut self.buffer), - len: std::mem::take(&mut self.len), current_n: std::mem::take(&mut self.current_n), n: self.n, } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.seq.validity.is_some() } pub fn into_array(self) -> Result { Ok(Array::FixedSizeBinary(FixedSizeBinaryArray { n: self.n.try_into()?, - validity: self.validity.map(|v| v.buffer), + validity: self.seq.validity, data: self.buffer, })) } @@ -54,9 +54,8 @@ impl FixedSizeBinaryBuilder { impl FixedSizeBinaryBuilder { fn start(&mut self) -> Result<()> { - push_validity(&mut self.validity, true)?; self.current_n = 0; - Ok(()) + self.seq.start_seq() } fn element(&mut self, value: &V) -> Result<()> { @@ -77,9 +76,7 @@ impl FixedSizeBinaryBuilder { expected = self.n, ); } - - self.len += 1; - Ok(()) + self.seq.end_seq() } } @@ -89,22 +86,18 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); + self.seq.push_seq_default()?; for _ in 0..self.n { self.buffer.push(0); } - self.len += 1; - Ok(()) } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; + self.seq.push_seq_none()?; for _ in 0..self.n { self.buffer.push(0); } - self.len += 1; - Ok(()) } @@ -153,10 +146,9 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { ); } - push_validity(&mut self.validity, true)?; + self.seq.start_seq()?; self.buffer.extend(v); - self.len += 1; - Ok(()) + self.seq.end_seq() } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index b977d648..fd898a76 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -9,55 +9,51 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{ - meta_from_field, push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer, - }, + array_ext::{ArrayExt, CountArray, SeqArrayExt}, + utils::{meta_from_field, SimpleSerializer}, }; #[derive(Debug, Clone)] pub struct FixedSizeListBuilder { + pub seq: CountArray, pub field: GenericField, pub n: usize, pub current_count: usize, - pub num_elements: usize, - pub validity: Option, pub element: Box, } impl FixedSizeListBuilder { pub fn new(field: GenericField, element: ArrayBuilder, n: usize, is_nullable: bool) -> Self { Self { + seq: CountArray::new(is_nullable), field, n, current_count: 0, - num_elements: 0, - validity: is_nullable.then(MutableBitBuffer::default), element: Box::new(element), } } pub fn take(&mut self) -> Self { Self { + seq: self.seq.take(), field: self.field.clone(), n: self.n, current_count: std::mem::take(&mut self.current_count), - num_elements: std::mem::take(&mut self.num_elements), - validity: self.validity.as_mut().map(std::mem::take), element: Box::new(self.element.take()), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.seq.validity.is_some() } pub fn into_array(self) -> Result { Ok(Array::FixedSizeList(FixedSizeListArray { - len: self.num_elements, + len: self.seq.len, + validity: self.seq.validity, n: self.n.try_into()?, meta: meta_from_field(self.field)?, - validity: self.validity.map(|v| v.buffer), element: Box::new((*self.element).into_array()?), })) } @@ -66,15 +62,17 @@ impl FixedSizeListBuilder { impl FixedSizeListBuilder { fn start(&mut self) -> Result<()> { self.current_count = 0; - push_validity(&mut self.validity, true) + self.seq.start_seq() } fn element(&mut self, value: &V) -> Result<()> { self.current_count += 1; + self.seq.push_seq_elements(1)?; value.serialize(Mut(self.element.as_mut())) } fn end(&mut self) -> Result<()> { + // TODO: fill with default values? would simplify using the builder if self.current_count != self.n { fail!( "Invalid number of elements for FixedSizedList({n}). Expected {n}, got {actual}", @@ -82,8 +80,7 @@ impl FixedSizeListBuilder { actual = self.current_count ); } - self.num_elements += 1; - Ok(()) + self.seq.end_seq() } } @@ -93,20 +90,18 @@ impl SimpleSerializer for FixedSizeListBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); + self.seq.push_seq_default()?; for _ in 0..self.n { self.element.serialize_default()?; } - self.num_elements += 1; Ok(()) } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; + self.seq.push_seq_none()?; for _ in 0..self.n { self.element.serialize_default()?; } - self.num_elements += 1; Ok(()) } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 6381f905..81af7edd 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -6,37 +6,25 @@ use crate::internal::{ utils::Mut, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; -#[derive(Debug, Clone, Default)] -pub struct FloatBuilder { - pub validity: Option, - pub buffer: Vec, -} +#[derive(Debug, Clone)] +pub struct FloatBuilder(PrimitiveArray); -impl FloatBuilder { +impl FloatBuilder { pub fn new(is_nullable: bool) -> Self { - Self { - validity: is_nullable.then(MutableBitBuffer::default), - buffer: Default::default(), - } + Self(new_primitive_array(is_nullable)) } pub fn take(&mut self) -> Self { - Self { - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), - } + Self(self.0.take()) } pub fn is_nullable(&self) -> bool { - self.validity.is_some() - } - - fn serialize_value(&mut self, value: I) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(value); - Ok(()) + self.0.validity.is_some() } } @@ -44,10 +32,7 @@ macro_rules! impl_into_array { ($ty:ty, $var:ident) => { impl FloatBuilder<$ty> { pub fn into_array(self) -> Result { - Ok(Array::$var(PrimitiveArray { - validity: self.validity.map(|b| b.buffer), - values: self.buffer, - })) + Ok(Array::$var(self.0)) } } }; @@ -63,14 +48,11 @@ impl SimpleSerializer for FloatBuilder { } fn serialize_default(&mut self) -> Result<()> { - self.buffer.push(0.0); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(0.0); - Ok(()) + self.0.push_scalar_none() } fn serialize_some(&mut self, value: &V) -> Result<()> { @@ -78,43 +60,43 @@ impl SimpleSerializer for FloatBuilder { } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.serialize_value(v) + self.0.push_scalar_value(v) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.serialize_value(v as f32) + self.0.push_scalar_value(v as f32) } } @@ -124,55 +106,51 @@ impl SimpleSerializer for FloatBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(0.0); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(0.0); - Ok(()) + self.0.push_scalar_none() } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.serialize_value(v as f64) + self.0.push_scalar_value(v as f64) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.serialize_value(v) + self.0.push_scalar_value(v) } } @@ -182,22 +160,18 @@ impl SimpleSerializer for FloatBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(f16::ZERO); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(f16::ZERO); - Ok(()) + self.0.push_scalar_none() } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.serialize_value(f16::from_f32(v)) + self.0.push_scalar_value(f16::from_f32(v)) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.serialize_value(f16::from_f64(v)) + self.0.push_scalar_value(f16::from_f64(v)) } } diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 391a1f41..e346c818 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -3,31 +3,25 @@ use crate::internal::{ error::{Error, Result}, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; -#[derive(Debug, Clone, Default)] -pub struct IntBuilder { - pub validity: Option, - pub buffer: Vec, -} +#[derive(Debug, Clone)] +pub struct IntBuilder(PrimitiveArray); -impl IntBuilder { +impl IntBuilder { pub fn new(is_nullable: bool) -> Self { - Self { - validity: is_nullable.then(MutableBitBuffer::default), - buffer: Default::default(), - } + Self(new_primitive_array(is_nullable)) } pub fn take(&mut self) -> Self { - Self { - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), - } + Self(self.0.take()) } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.0.validity.is_some() } } @@ -35,10 +29,7 @@ macro_rules! impl_into_array { ($ty:ty, $var:ident) => { impl IntBuilder<$ty> { pub fn into_array(self) -> Result { - Ok(Array::$var(PrimitiveArray { - validity: self.validity.map(|b| b.buffer), - values: self.buffer, - })) + Ok(Array::$var(self.0)) } } }; @@ -63,7 +54,8 @@ where + TryFrom + TryFrom + TryFrom - + TryFrom, + + TryFrom + + 'static, Error: From<>::Error>, Error: From<>::Error>, Error: From<>::Error>, @@ -78,68 +70,46 @@ where } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(I::default()); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(I::default()); - Ok(()) + self.0.push_scalar_none() } fn serialize_i8(&mut self, v: i8) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(v)?); - Ok(()) + self.0.push_scalar_value(I::try_from(v)?) } fn serialize_char(&mut self, v: char) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(I::try_from(u32::from(v))?); - Ok(()) + self.0.push_scalar_value(I::try_from(u32::from(v))?) } } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 784e2f22..c5454eaa 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::internal::{ - arrow::{Array, ListArray}, + arrow::{Array, FieldMeta, ListArray}, error::Result, schema::GenericField, utils::{Mut, Offset}, @@ -9,52 +9,47 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{ - meta_from_field, push_validity, push_validity_default, MutableBitBuffer, - MutableOffsetBuffer, SimpleSerializer, - }, + array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, + utils::{meta_from_field, SimpleSerializer}, }; #[derive(Debug, Clone)] pub struct ListBuilder { - pub field: GenericField, - pub validity: Option, - pub offsets: MutableOffsetBuffer, + pub meta: FieldMeta, pub element: Box, + pub offsets: OffsetsArray, } impl ListBuilder { - pub fn new(field: GenericField, element: ArrayBuilder, is_nullable: bool) -> Self { - Self { - field, - validity: is_nullable.then(MutableBitBuffer::default), - offsets: Default::default(), + pub fn new(field: GenericField, element: ArrayBuilder, is_nullable: bool) -> Result { + Ok(Self { + meta: meta_from_field(field)?, element: Box::new(element), - } + offsets: OffsetsArray::new(is_nullable), + }) } pub fn take(&mut self) -> Self { Self { - field: self.field.clone(), - validity: self.validity.as_mut().map(std::mem::take), - offsets: std::mem::take(&mut self.offsets), + meta: self.meta.clone(), + offsets: self.offsets.take(), element: Box::new(self.element.take()), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.offsets.validity.is_some() } } impl ListBuilder { pub fn into_array(self) -> Result { Ok(Array::List(ListArray { - validity: self.validity.map(|b| b.buffer), + validity: self.offsets.validity, offsets: self.offsets.offsets, element: Box::new(self.element.into_array()?), - meta: meta_from_field(self.field)?, + meta: self.meta, })) } } @@ -62,27 +57,26 @@ impl ListBuilder { impl ListBuilder { pub fn into_array(self) -> Result { Ok(Array::LargeList(ListArray { - validity: self.validity.map(|b| b.buffer), + validity: self.offsets.validity, offsets: self.offsets.offsets, element: Box::new(self.element.into_array()?), - meta: meta_from_field(self.field)?, + meta: self.meta, })) } } impl ListBuilder { fn start(&mut self) -> Result<()> { - push_validity(&mut self.validity, true) + self.offsets.start_seq() } fn element(&mut self, value: &V) -> Result<()> { - self.offsets.inc_current_items()?; + self.offsets.push_seq_elements(1)?; value.serialize(Mut(self.element.as_mut())) } fn end(&mut self) -> Result<()> { - self.offsets.push_current_items(); - Ok(()) + self.offsets.end_seq() } } @@ -92,14 +86,11 @@ impl SimpleSerializer for ListBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.offsets.push_current_items(); - Ok(()) + self.offsets.push_seq_default() } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_current_items(); - push_validity(&mut self.validity, false) + self.offsets.push_seq_none() } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 6679219d..5cb2ceed 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -8,26 +8,22 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{ - meta_from_field, push_validity, push_validity_default, MutableBitBuffer, - MutableOffsetBuffer, SimpleSerializer, - }, + array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, + utils::{meta_from_field, SimpleSerializer}, }; #[derive(Debug, Clone)] pub struct MapBuilder { pub entry_field: GenericField, - pub validity: Option, - pub offsets: MutableOffsetBuffer, pub entry: Box, + pub offsets: OffsetsArray, } impl MapBuilder { pub fn new(entry_field: GenericField, entry: ArrayBuilder, is_nullable: bool) -> Self { Self { entry_field, - validity: is_nullable.then(MutableBitBuffer::default), - offsets: MutableOffsetBuffer::default(), + offsets: OffsetsArray::new(is_nullable), entry: Box::new(entry), } } @@ -35,21 +31,20 @@ impl MapBuilder { pub fn take(&mut self) -> Self { Self { entry_field: self.entry_field.clone(), - validity: self.validity.as_mut().map(std::mem::take), - offsets: std::mem::take(&mut self.offsets), + offsets: self.offsets.take(), entry: Box::new(self.entry.take()), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.offsets.validity.is_some() } pub fn into_array(self) -> Result { Ok(Array::Map(ListArray { meta: meta_from_field(self.entry_field)?, element: Box::new((*self.entry).into_array()?), - validity: self.validity.map(|v| v.buffer), + validity: self.offsets.validity, offsets: self.offsets.offsets, })) } @@ -61,23 +56,19 @@ impl SimpleSerializer for MapBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.offsets.push_current_items(); - Ok(()) + self.offsets.push_seq_default() } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_current_items(); - push_validity(&mut self.validity, false) + self.offsets.push_seq_none() } fn serialize_map_start(&mut self, _: Option) -> Result<()> { - push_validity(&mut self.validity, true)?; - Ok(()) + self.offsets.start_seq() } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - self.offsets.inc_current_items()?; + self.offsets.push_seq_elements(1)?; self.entry.serialize_tuple_start(2)?; self.entry.serialize_tuple_element(key) } @@ -88,7 +79,6 @@ impl SimpleSerializer for MapBuilder { } fn serialize_map_end(&mut self) -> Result<()> { - self.offsets.push_current_items(); - Ok(()) + self.offsets.end_seq() } } diff --git a/serde_arrow/src/internal/serialization/mod.rs b/serde_arrow/src/internal/serialization/mod.rs index f66832b6..898583ed 100644 --- a/serde_arrow/src/internal/serialization/mod.rs +++ b/serde_arrow/src/internal/serialization/mod.rs @@ -1,6 +1,7 @@ //! A serialization implementation without the event model pub mod array_builder; +pub mod array_ext; pub mod binary_builder; pub mod bool_builder; pub mod date32_builder; diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 98e66ab5..af19c913 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -62,7 +62,7 @@ impl OuterSequenceBuilder { T::F16 => A::F16(FloatBuilder::new(field.nullable)), T::F32 => A::F32(FloatBuilder::new(field.nullable)), T::F64 => A::F64(FloatBuilder::new(field.nullable)), - T::Date32 => A::Date32(Date32Builder::new(field.clone(), field.nullable)), + T::Date32 => A::Date32(Date32Builder::new(field.nullable)), T::Date64 => { let is_utc = match field.strategy.as_ref() { Some(Strategy::UtcStrAsDate64) | None => true, @@ -82,13 +82,13 @@ impl OuterSequenceBuilder { if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { fail!("Only timestamps with second or millisecond unit are supported"); } - A::Time32(TimeBuilder::new(field.clone(), field.nullable, *unit)) + A::Time32(TimeBuilder::new(*unit, field.nullable)) } T::Time64(unit) => { if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { fail!("Only timestamps with nanosecond or microsecond unit are supported"); } - A::Time64(TimeBuilder::new(field.clone(), field.nullable, *unit)) + A::Time64(TimeBuilder::new(*unit, field.nullable)) } T::Duration(unit) => A::Duration(DurationBuilder::new(*unit, field.nullable)), T::Decimal128(precision, scale) => { @@ -104,7 +104,7 @@ impl OuterSequenceBuilder { child.clone(), build_builder(child)?, field.nullable, - )) + )?) } T::LargeList => { let Some(child) = field.children.first() else { @@ -114,7 +114,7 @@ impl OuterSequenceBuilder { child.clone(), build_builder(child)?, field.nullable, - )) + )?) } T::FixedSizeList(n) => { let Some(child) = field.children.first() else { diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 2753c03e..76867b7e 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -11,9 +11,8 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, - utils::{ - meta_from_field, push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer, - }, + array_ext::{ArrayExt, CountArray, SeqArrayExt}, + utils::{meta_from_field, SimpleSerializer}, }; const UNKNOWN_KEY: usize = usize::MAX; @@ -22,13 +21,12 @@ const UNKNOWN_KEY: usize = usize::MAX; pub struct StructBuilder { // TODO: clean this up pub fields: Vec, - pub validity: Option, pub named_fields: Vec<(String, ArrayBuilder)>, pub cached_names: Vec>, pub seen: Vec, pub next: usize, pub index: BTreeMap, - pub len: usize, + pub seq: CountArray, } impl StructBuilder { @@ -53,18 +51,16 @@ impl StructBuilder { fields, seen: vec![false; named_fields.len()], cached_names: vec![None; named_fields.len()], - validity: is_nullable.then(MutableBitBuffer::default), named_fields, next: 0, index, - len: 0, + seq: CountArray::new(is_nullable), }) } pub fn take(&mut self) -> Self { Self { fields: self.fields.clone(), - validity: self.validity.as_mut().map(std::mem::take), named_fields: self .named_fields .iter_mut() @@ -76,13 +72,13 @@ impl StructBuilder { ), seen: std::mem::replace(&mut self.seen, vec![false; self.named_fields.len()]), next: std::mem::take(&mut self.next), - len: std::mem::take(&mut self.len), index: self.index.clone(), + seq: self.seq.take(), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.seq.validity.is_some() } pub fn into_array(self) -> Result { @@ -94,8 +90,8 @@ impl StructBuilder { } Ok(Array::Struct(StructArray { - len: self.len, - validity: self.validity.map(|b| b.buffer), + len: self.seq.len, + validity: self.seq.validity, fields, })) } @@ -103,8 +99,7 @@ impl StructBuilder { impl StructBuilder { fn start(&mut self) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.len += 1; + self.seq.start_seq()?; self.reset(); Ok(()) } @@ -115,6 +110,7 @@ impl StructBuilder { } fn end(&mut self) -> Result<()> { + self.seq.end_seq()?; for (idx, seen) in self.seen.iter_mut().enumerate() { if !*seen { if !self.named_fields[idx].1.is_nullable() { @@ -131,6 +127,7 @@ impl StructBuilder { } fn element(&mut self, idx: usize, value: &T) -> Result<()> { + self.seq.push_seq_elements(1)?; if self.seen[idx] { fail!("Duplicate field {key}", key = self.named_fields[idx].0); } @@ -148,8 +145,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.len += 1; + self.seq.push_seq_default()?; for (_, field) in &mut self.named_fields { field.serialize_default()?; } @@ -158,9 +154,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.len += 1; - + self.seq.push_seq_none()?; for (_, field) in &mut self.named_fields { field.serialize_default()?; } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 71ea9fa1..5d2eb8e8 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -1,42 +1,38 @@ use chrono::Timelike; use crate::internal::{ - arrow::{Array, TimeArray, TimeUnit}, + arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, error::{Error, Result}, - schema::GenericField, }; -use super::utils::{push_validity, push_validity_default, MutableBitBuffer, SimpleSerializer}; +use super::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, +}; #[derive(Debug, Clone)] pub struct TimeBuilder { - pub field: GenericField, - pub validity: Option, - pub buffer: Vec, pub unit: TimeUnit, + pub array: PrimitiveArray, } -impl TimeBuilder { - pub fn new(field: GenericField, nullable: bool, unit: TimeUnit) -> Self { +impl TimeBuilder { + pub fn new(unit: TimeUnit, is_nullable: bool) -> Self { Self { - field, - validity: nullable.then(MutableBitBuffer::default), - buffer: Vec::new(), unit, + array: new_primitive_array(is_nullable), } } pub fn take(&mut self) -> Self { Self { - field: self.field.clone(), - validity: self.validity.as_mut().map(std::mem::take), - buffer: std::mem::take(&mut self.buffer), unit: self.unit, + array: self.array.take(), } } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.array.validity.is_some() } } @@ -44,8 +40,8 @@ impl TimeBuilder { pub fn into_array(self) -> Result { Ok(Array::Time32(TimeArray { unit: self.unit, - validity: self.validity.map(|v| v.buffer), - values: self.buffer, + validity: self.array.validity, + values: self.array.values, })) } } @@ -54,15 +50,15 @@ impl TimeBuilder { pub fn into_array(self) -> Result { Ok(Array::Time64(TimeArray { unit: self.unit, - validity: self.validity.map(|v| v.buffer), - values: self.buffer, + validity: self.array.validity, + values: self.array.values, })) } } impl SimpleSerializer for TimeBuilder where - I: TryFrom + TryFrom + Default, + I: TryFrom + TryFrom + Default + 'static, Error: From<>::Error>, Error: From<>::Error>, { @@ -71,15 +67,11 @@ where } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.buffer.push(I::default()); - Ok(()) + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.buffer.push(I::default()); - Ok(()) + self.array.push_scalar_none() } fn serialize_str(&mut self, v: &str) -> Result<()> { @@ -95,20 +87,14 @@ where let timestamp = time.num_seconds_from_midnight() as i64 * seconds_factor + time.nanosecond() as i64 / nanoseconds_factor; - push_validity(&mut self.validity, true)?; - self.buffer.push(timestamp.try_into()?); - Ok(()) + self.array.push_scalar_value(timestamp.try_into()?) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(v.try_into()?); - Ok(()) + self.array.push_scalar_value(v.try_into()?) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.buffer.push(v.try_into()?); - Ok(()) + self.array.push_scalar_value(v.try_into()?) } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 0ebd6051..19047444 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -4,56 +4,37 @@ use crate::internal::{ utils::Offset, }; -use super::utils::{ - push_validity, push_validity_default, MutableBitBuffer, MutableOffsetBuffer, SimpleSerializer, +use super::{ + array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, + utils::SimpleSerializer, }; #[derive(Debug, Clone)] -pub struct Utf8Builder { - pub validity: Option, - pub offsets: MutableOffsetBuffer, - pub buffer: Vec, -} +pub struct Utf8Builder(BytesArray); impl Utf8Builder { pub fn new(is_nullable: bool) -> Self { - Self { - validity: is_nullable.then(MutableBitBuffer::default), - offsets: MutableOffsetBuffer::default(), - buffer: Vec::new(), - } + Self(new_bytes_array(is_nullable)) } pub fn take(&mut self) -> Self { - Self { - validity: self.validity.as_mut().map(std::mem::take), - offsets: std::mem::take(&mut self.offsets), - buffer: std::mem::take(&mut self.buffer), - } + Self(self.0.take()) } pub fn is_nullable(&self) -> bool { - self.validity.is_some() + self.0.validity.is_some() } } impl Utf8Builder { pub fn into_array(self) -> Result { - Ok(Array::Utf8(BytesArray { - validity: self.validity.map(|b| b.buffer), - offsets: self.offsets.offsets, - data: self.buffer, - })) + Ok(Array::Utf8(self.0)) } } impl Utf8Builder { pub fn into_array(self) -> Result { - Ok(Array::LargeUtf8(BytesArray { - validity: self.validity.map(|b| b.buffer), - offsets: self.offsets.offsets, - data: self.buffer, - })) + Ok(Array::LargeUtf8(self.0)) } } @@ -63,23 +44,15 @@ impl SimpleSerializer for Utf8Builder { } fn serialize_default(&mut self) -> Result<()> { - push_validity_default(&mut self.validity); - self.offsets.push_current_items(); - Ok(()) + self.0.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - push_validity(&mut self.validity, false)?; - self.offsets.push_current_items(); - Ok(()) + self.0.push_scalar_none() } fn serialize_str(&mut self, v: &str) -> Result<()> { - push_validity(&mut self.validity, true)?; - self.offsets.push(v.len())?; - self.buffer.extend(v.as_bytes()); - - Ok(()) + self.0.push_scalar_value(v.as_bytes()) } fn serialize_unit_variant( @@ -88,7 +61,7 @@ impl SimpleSerializer for Utf8Builder { _: u32, variant: &'static str, ) -> Result<()> { - self.serialize_str(variant) + self.0.push_scalar_value(variant.as_bytes()) } fn serialize_tuple_variant_start<'this>( diff --git a/serde_arrow/src/internal/serialization/utils.rs b/serde_arrow/src/internal/serialization/utils.rs index 4a65d9ac..a938eef4 100644 --- a/serde_arrow/src/internal/serialization/utils.rs +++ b/serde_arrow/src/internal/serialization/utils.rs @@ -10,7 +10,7 @@ use crate::internal::{ arrow::FieldMeta, error::{fail, Error, Result}, schema::{merge_strategy_with_metadata, GenericField}, - utils::{Mut, Offset}, + utils::Mut, }; use super::ArrayBuilder; @@ -23,108 +23,6 @@ pub fn meta_from_field(field: GenericField) -> Result { }) } -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct MutableBitBuffer { - pub(crate) buffer: Vec, - pub(crate) len: usize, - pub(crate) capacity: usize, -} - -impl MutableBitBuffer { - pub fn as_bool(&self) -> Vec { - (0..self.len()) - .map(|i| { - let flag = 1 << i; - (self.buffer[i / 8] & flag) == flag - }) - .collect() - } - - #[allow(unused)] - pub fn len(&self) -> usize { - self.len - } - - pub fn push(&mut self, value: bool) { - while self.len >= self.capacity { - for _ in 0..64 { - self.buffer.push(0); - self.capacity += 8; - } - } - - if value { - self.buffer[self.len / 8] |= 1 << (self.len % 8); - } - self.len += 1; - } - - pub fn clear(&mut self) { - *self = Self::default(); - } -} - -#[derive(Debug, Clone)] -pub struct MutableOffsetBuffer { - pub(crate) offsets: Vec, - pub(crate) current_items: O, -} - -impl std::default::Default for MutableOffsetBuffer { - fn default() -> Self { - Self { - offsets: vec![O::default()], - current_items: O::default(), - } - } -} - -impl MutableOffsetBuffer { - /// The number of items pushed (one less than the number of offsets) - #[allow(unused)] - pub fn len(&self) -> usize { - self.offsets.len() - 1 - } - - // push a new item with the given number of children - pub fn push(&mut self, num_children: usize) -> Result<()> { - self.current_items = self.current_items + O::try_form_usize(num_children)?; - self.offsets.push(self.current_items); - - Ok(()) - } - - pub fn push_current_items(&mut self) { - self.offsets.push(self.current_items); - } - - pub fn inc_current_items(&mut self) -> Result<()> { - self.current_items = self.current_items + O::try_form_usize(1)?; - Ok(()) - } - - pub fn clear(&mut self) { - *self = Self::default(); - } -} - -pub fn push_validity(buffer: &mut Option, value: bool) -> Result<()> { - if let Some(buffer) = buffer.as_mut() { - buffer.push(value); - Ok(()) - } else if value { - Ok(()) - } else { - fail!("cannot push null for non-nullable array"); - } -} - -pub fn push_validity_default(buffer: &mut Option) { - if let Some(buffer) = buffer.as_mut() { - buffer.push(false); - } -} - /// A simplified serialization trait with default implementations raising an /// error /// diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index caf75003..3ef9892c 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -151,7 +151,7 @@ impl<'a, T: Serialize> Serialize for Items<&'a [T]> { pub struct Mut<'a, T>(pub &'a mut T); /// A trait to handle different offset types -pub trait Offset: std::ops::Add + Clone + Copy + Default { +pub trait Offset: std::ops::Add + Clone + Copy + Default + 'static { fn try_form_usize(val: usize) -> Result; fn try_into_usize(self) -> Result; } From 2ac97de8a685f1d14a0930ea435b2b204ea30a96 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 10:59:05 +0200 Subject: [PATCH 069/178] Reformat --- serde_arrow/src/internal/serialization/array_ext.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/serialization/array_ext.rs b/serde_arrow/src/internal/serialization/array_ext.rs index 3248ae60..bc079ead 100644 --- a/serde_arrow/src/internal/serialization/array_ext.rs +++ b/serde_arrow/src/internal/serialization/array_ext.rs @@ -19,7 +19,7 @@ pub trait ScalarArrayExt<'value>: ArrayExt { } /// An array that models a sequence -/// +/// /// As some sequence arrays, e.g., `ListArrays`, can contain arbitrarily nested subarrays, the /// element itself is not modelled. pub trait SeqArrayExt: ArrayExt { From 8947432e9f440ecc4cdd9c2fbe0ca9ef8e01cc6f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 10:59:20 +0200 Subject: [PATCH 070/178] Use fieldmeta directly in UnionBuilder --- .../serialization/outer_sequence_builder.rs | 6 ++-- .../internal/serialization/union_builder.rs | 33 ++++++++----------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index af19c913..e5d72000 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -7,7 +7,7 @@ use crate::internal::{ serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, - fixed_size_list_builder::FixedSizeListBuilder, + fixed_size_list_builder::FixedSizeListBuilder, utils::meta_from_field, }, utils::Mut, }; @@ -170,10 +170,10 @@ impl OuterSequenceBuilder { T::Union => { let mut fields = Vec::new(); for field in &field.children { - fields.push(build_builder(field)?); + fields.push((build_builder(field)?, meta_from_field(field.clone())?)); } - A::Union(UnionBuilder::new(field.clone(), fields)?) + A::Union(UnionBuilder::new(fields)) } }; Ok(builder) diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 720e3dc8..0a9b292e 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -1,39 +1,36 @@ use crate::internal::{ - arrow::{Array, DenseUnionArray}, + arrow::{Array, DenseUnionArray, FieldMeta}, error::{fail, Result}, - schema::GenericField, utils::Mut, }; -use super::{ - utils::{meta_from_field, SimpleSerializer}, - ArrayBuilder, -}; +use super::{utils::SimpleSerializer, ArrayBuilder}; #[derive(Debug, Clone)] pub struct UnionBuilder { - pub field: GenericField, - pub fields: Vec, + pub fields: Vec<(ArrayBuilder, FieldMeta)>, pub types: Vec, pub offsets: Vec, pub current_offset: Vec, } impl UnionBuilder { - pub fn new(field: GenericField, fields: Vec) -> Result { - Ok(Self { - field, + pub fn new(fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self { + Self { current_offset: vec![0; fields.len()], types: Vec::new(), offsets: Vec::new(), fields, - }) + } } pub fn take(&mut self) -> Self { Self { - field: self.field.clone(), - fields: self.fields.iter_mut().map(|field| field.take()).collect(), + fields: self + .fields + .iter_mut() + .map(|(field, meta)| (field.take(), meta.clone())) + .collect(), types: std::mem::take(&mut self.types), offsets: std::mem::take(&mut self.offsets), current_offset: std::mem::replace(&mut self.current_offset, vec![0; self.fields.len()]), @@ -46,10 +43,8 @@ impl UnionBuilder { pub fn into_array(self) -> Result { let mut fields = Vec::new(); - for (field, builder) in self.field.children.into_iter().zip(self.fields) { - let meta = meta_from_field(field)?; - let array = builder.into_array()?; - fields.push((array, meta)); + for (builder, meta) in self.fields { + fields.push((builder.into_array()?, meta)); } Ok(Array::DenseUnion(DenseUnionArray { @@ -63,7 +58,7 @@ impl UnionBuilder { impl UnionBuilder { pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> { let variant_index = variant_index as usize; - let Some(variant_builder) = self.fields.get_mut(variant_index) else { + let Some((variant_builder, _)) = self.fields.get_mut(variant_index) else { fail!("Unknown variant {variant_index}"); }; From a6863880c64f401db21f240c996bcdb37c3ffe3e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 11:20:11 +0200 Subject: [PATCH 071/178] Use FieldMeta directly in builders --- serde_arrow/src/arrow2_impl/array.rs | 2 +- serde_arrow/src/arrow_impl/array.rs | 2 +- .../serialization/fixed_size_list_builder.rs | 15 ++-- .../internal/serialization/list_builder.rs | 7 +- .../src/internal/serialization/map_builder.rs | 15 ++-- .../serialization/outer_sequence_builder.rs | 31 ++++---- .../internal/serialization/struct_builder.rs | 72 +++++++------------ .../src/internal/serialization/utils.rs | 10 --- serde_arrow/src/internal/utils/mod.rs | 13 ++++ 9 files changed, 72 insertions(+), 95 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 36b1d10d..b133a818 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -17,7 +17,7 @@ use crate::{ TimeArrayView, TimestampArrayView, }, error::{fail, Error, Result}, - serialization::utils::meta_from_field, + utils::meta_from_field, }, }; diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index dbc6f5ba..aa034c66 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -31,7 +31,7 @@ use crate::{ }, error::{fail, Error, Result}, schema::GenericField, - serialization::utils::meta_from_field, + utils::meta_from_field, }, }; diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index fd898a76..d07ac9bc 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -1,33 +1,32 @@ use serde::Serialize; use crate::internal::{ - arrow::{Array, FixedSizeListArray}, + arrow::{Array, FieldMeta, FixedSizeListArray}, error::{fail, Result}, - schema::GenericField, utils::Mut, }; use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::{meta_from_field, SimpleSerializer}, + utils::SimpleSerializer, }; #[derive(Debug, Clone)] pub struct FixedSizeListBuilder { pub seq: CountArray, - pub field: GenericField, + pub meta: FieldMeta, pub n: usize, pub current_count: usize, pub element: Box, } impl FixedSizeListBuilder { - pub fn new(field: GenericField, element: ArrayBuilder, n: usize, is_nullable: bool) -> Self { + pub fn new(meta: FieldMeta, element: ArrayBuilder, n: usize, is_nullable: bool) -> Self { Self { seq: CountArray::new(is_nullable), - field, + meta, n, current_count: 0, element: Box::new(element), @@ -37,7 +36,7 @@ impl FixedSizeListBuilder { pub fn take(&mut self) -> Self { Self { seq: self.seq.take(), - field: self.field.clone(), + meta: self.meta.clone(), n: self.n, current_count: std::mem::take(&mut self.current_count), element: Box::new(self.element.take()), @@ -53,7 +52,7 @@ impl FixedSizeListBuilder { len: self.seq.len, validity: self.seq.validity, n: self.n.try_into()?, - meta: meta_from_field(self.field)?, + meta: self.meta, element: Box::new((*self.element).into_array()?), })) } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index c5454eaa..9d00e039 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -3,14 +3,13 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, error::Result, - schema::GenericField, utils::{Mut, Offset}, }; use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - utils::{meta_from_field, SimpleSerializer}, + utils::SimpleSerializer, }; #[derive(Debug, Clone)] @@ -22,9 +21,9 @@ pub struct ListBuilder { } impl ListBuilder { - pub fn new(field: GenericField, element: ArrayBuilder, is_nullable: bool) -> Result { + pub fn new(meta: FieldMeta, element: ArrayBuilder, is_nullable: bool) -> Result { Ok(Self { - meta: meta_from_field(field)?, + meta, element: Box::new(element), offsets: OffsetsArray::new(is_nullable), }) diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 5cb2ceed..ef109f44 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -1,28 +1,27 @@ use serde::Serialize; use crate::internal::{ - arrow::{Array, ListArray}, + arrow::{Array, FieldMeta, ListArray}, error::Result, - schema::GenericField, }; use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - utils::{meta_from_field, SimpleSerializer}, + utils::SimpleSerializer, }; #[derive(Debug, Clone)] pub struct MapBuilder { - pub entry_field: GenericField, + pub meta: FieldMeta, pub entry: Box, pub offsets: OffsetsArray, } impl MapBuilder { - pub fn new(entry_field: GenericField, entry: ArrayBuilder, is_nullable: bool) -> Self { + pub fn new(meta: FieldMeta, entry: ArrayBuilder, is_nullable: bool) -> Self { Self { - entry_field, + meta, offsets: OffsetsArray::new(is_nullable), entry: Box::new(entry), } @@ -30,7 +29,7 @@ impl MapBuilder { pub fn take(&mut self) -> Self { Self { - entry_field: self.entry_field.clone(), + meta: self.meta.clone(), offsets: self.offsets.take(), entry: Box::new(self.entry.take()), } @@ -42,7 +41,7 @@ impl MapBuilder { pub fn into_array(self) -> Result { Ok(Array::Map(ListArray { - meta: meta_from_field(self.entry_field)?, + meta: self.meta, element: Box::new((*self.entry).into_array()?), validity: self.offsets.validity, offsets: self.offsets.offsets, diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index e5d72000..2875a7c7 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -7,9 +7,9 @@ use crate::internal::{ serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, - fixed_size_list_builder::FixedSizeListBuilder, utils::meta_from_field, + fixed_size_list_builder::FixedSizeListBuilder, }, - utils::Mut, + utils::{meta_from_field, Mut}, }; use super::{ @@ -29,14 +29,12 @@ impl OuterSequenceBuilder { pub fn new(schema: &SerdeArrowSchema) -> Result { return Ok(Self(build_struct(&schema.fields, false)?)); - fn build_struct(fields: &[GenericField], nullable: bool) -> Result { - let mut named_fields = Vec::new(); - for field in fields { - let builder = build_builder(field)?; - named_fields.push((field.name.to_owned(), builder)); + fn build_struct(struct_fields: &[GenericField], nullable: bool) -> Result { + let mut fields = Vec::new(); + for field in struct_fields { + fields.push((build_builder(field)?, meta_from_field(field.clone())?)); } - - StructBuilder::new(fields.to_vec(), named_fields, nullable) + StructBuilder::new(fields, nullable) } fn build_builder(field: &GenericField) -> Result { @@ -101,7 +99,7 @@ impl OuterSequenceBuilder { fail!("cannot build a list without an element field"); }; A::List(ListBuilder::new( - child.clone(), + meta_from_field(child.clone())?, build_builder(child)?, field.nullable, )?) @@ -111,7 +109,7 @@ impl OuterSequenceBuilder { fail!("cannot build list without an element field"); }; A::LargeList(ListBuilder::new( - child.clone(), + meta_from_field(child.clone())?, build_builder(child)?, field.nullable, )?) @@ -121,7 +119,7 @@ impl OuterSequenceBuilder { fail!("cannot build list without an element field"); }; A::FixedSizedList(FixedSizeListBuilder::new( - child.clone(), + meta_from_field(child.clone())?, build_builder(child)?, (*n).try_into()?, field.nullable, @@ -141,7 +139,7 @@ impl OuterSequenceBuilder { fail!("Invalid child field for map: {entry_field:?}") } A::Map(MapBuilder::new( - entry_field.clone(), + meta_from_field(entry_field.clone())?, build_builder(entry_field)?, field.nullable, )) @@ -182,13 +180,10 @@ impl OuterSequenceBuilder { /// Extract the contained struct fields pub fn take_records(&mut self) -> Result> { - let builder = self.0.take(); - let mut result = Vec::new(); - for (_, field) in builder.named_fields { - result.push(field); + for (builder, _) in self.0.take().fields { + result.push(builder); } - Ok(result) } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 76867b7e..20560d36 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -3,25 +3,22 @@ use std::collections::BTreeMap; use serde::Serialize; use crate::internal::{ - arrow::{Array, StructArray}, + arrow::{Array, FieldMeta, StructArray}, error::{fail, Result}, - schema::GenericField, utils::Mut, }; use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::{meta_from_field, SimpleSerializer}, + utils::SimpleSerializer, }; const UNKNOWN_KEY: usize = usize::MAX; #[derive(Debug, Clone)] pub struct StructBuilder { - // TODO: clean this up - pub fields: Vec, - pub named_fields: Vec<(String, ArrayBuilder)>, + pub fields: Vec<(ArrayBuilder, FieldMeta)>, pub cached_names: Vec>, pub seen: Vec, pub next: usize, @@ -30,28 +27,19 @@ pub struct StructBuilder { } impl StructBuilder { - pub fn new( - fields: Vec, - named_fields: Vec<(String, ArrayBuilder)>, - is_nullable: bool, - ) -> Result { - if fields.len() != named_fields.len() { - fail!("mismatched number of fields and builders"); - } - + pub fn new(fields: Vec<(ArrayBuilder, FieldMeta)>, is_nullable: bool) -> Result { let mut index = BTreeMap::new(); - for (idx, (name, _)) in named_fields.iter().enumerate() { - if index.contains_key(name) { - fail!("Duplicate field {name}"); + for (idx, (_, meta)) in fields.iter().enumerate() { + if index.contains_key(&meta.name) { + fail!("Duplicate field {name}", name = meta.name); } - index.insert(name.to_owned(), idx); + index.insert(meta.name.clone(), idx); } Ok(Self { + seen: vec![false; fields.len()], + cached_names: vec![None; fields.len()], fields, - seen: vec![false; named_fields.len()], - cached_names: vec![None; named_fields.len()], - named_fields, next: 0, index, seq: CountArray::new(is_nullable), @@ -60,17 +48,13 @@ impl StructBuilder { pub fn take(&mut self) -> Self { Self { - fields: self.fields.clone(), - named_fields: self - .named_fields + fields: self + .fields .iter_mut() - .map(|(name, builder)| (name.clone(), builder.take())) + .map(|(builder, meta)| (builder.take(), meta.clone())) .collect(), - cached_names: std::mem::replace( - &mut self.cached_names, - vec![None; self.named_fields.len()], - ), - seen: std::mem::replace(&mut self.seen, vec![false; self.named_fields.len()]), + cached_names: std::mem::replace(&mut self.cached_names, vec![None; self.fields.len()]), + seen: std::mem::replace(&mut self.seen, vec![false; self.fields.len()]), next: std::mem::take(&mut self.next), index: self.index.clone(), seq: self.seq.take(), @@ -83,10 +67,8 @@ impl StructBuilder { pub fn into_array(self) -> Result { let mut fields = Vec::new(); - for (field, (_, builder)) in self.fields.into_iter().zip(self.named_fields) { - let meta = meta_from_field(field)?; - let array = builder.into_array()?; - fields.push((array, meta)); + for (builder, meta) in self.fields { + fields.push((builder.into_array()?, meta)); } Ok(Array::Struct(StructArray { @@ -113,14 +95,14 @@ impl StructBuilder { self.seq.end_seq()?; for (idx, seen) in self.seen.iter_mut().enumerate() { if !*seen { - if !self.named_fields[idx].1.is_nullable() { + if !self.fields[idx].1.nullable { fail!( "missing non-nullable field {:?} in struct", - self.named_fields[idx].0 + self.fields[idx].1.name ); } - self.named_fields[idx].1.serialize_none()?; + self.fields[idx].0.serialize_none()?; } } Ok(()) @@ -129,10 +111,10 @@ impl StructBuilder { fn element(&mut self, idx: usize, value: &T) -> Result<()> { self.seq.push_seq_elements(1)?; if self.seen[idx] { - fail!("Duplicate field {key}", key = self.named_fields[idx].0); + fail!("Duplicate field {key}", key = self.fields[idx].1.name); } - value.serialize(Mut(&mut self.named_fields[idx].1))?; + value.serialize(Mut(&mut self.fields[idx].0))?; self.seen[idx] = true; self.next = idx + 1; Ok(()) @@ -146,8 +128,8 @@ impl SimpleSerializer for StructBuilder { fn serialize_default(&mut self) -> Result<()> { self.seq.push_seq_default()?; - for (_, field) in &mut self.named_fields { - field.serialize_default()?; + for (builder, _) in &mut self.fields { + builder.serialize_default()?; } Ok(()) @@ -155,8 +137,8 @@ impl SimpleSerializer for StructBuilder { fn serialize_none(&mut self) -> Result<()> { self.seq.push_seq_none()?; - for (_, field) in &mut self.named_fields { - field.serialize_default()?; + for (builder, _) in &mut self.fields { + builder.serialize_default()?; } Ok(()) } @@ -210,7 +192,7 @@ impl SimpleSerializer for StructBuilder { fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { // ignore extra tuple fields - if self.next < self.named_fields.len() { + if self.next < self.fields.len() { self.element(self.next, value)?; } Ok(()) diff --git a/serde_arrow/src/internal/serialization/utils.rs b/serde_arrow/src/internal/serialization/utils.rs index a938eef4..52438114 100644 --- a/serde_arrow/src/internal/serialization/utils.rs +++ b/serde_arrow/src/internal/serialization/utils.rs @@ -7,22 +7,12 @@ use serde::{ }; use crate::internal::{ - arrow::FieldMeta, error::{fail, Error, Result}, - schema::{merge_strategy_with_metadata, GenericField}, utils::Mut, }; use super::ArrayBuilder; -pub fn meta_from_field(field: GenericField) -> Result { - Ok(FieldMeta { - name: field.name, - nullable: field.nullable, - metadata: merge_strategy_with_metadata(field.metadata, field.strategy)?, - }) -} - /// A simplified serialization trait with default implementations raising an /// error /// diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 3ef9892c..4f32d6f1 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -9,6 +9,11 @@ use serde::{ser::SerializeSeq, Deserialize, Serialize}; use crate::internal::error::Result; +use super::{ + arrow::FieldMeta, + schema::{merge_strategy_with_metadata, GenericField}, +}; + /// A wrapper around a sequence of items /// /// When serialized or deserialized, it behaves as if each item was wrapped in a @@ -175,3 +180,11 @@ impl Offset for i64 { Ok(self.try_into()?) } } + +pub fn meta_from_field(field: GenericField) -> Result { + Ok(FieldMeta { + name: field.name, + nullable: field.nullable, + metadata: merge_strategy_with_metadata(field.metadata, field.strategy)?, + }) +} From 0f3a27099c1eb2edddec0c11ab95fe91dd40e566 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 11:32:49 +0200 Subject: [PATCH 072/178] Move field lookup logic for StructBuilder into separate struct --- .../internal/serialization/struct_builder.rs | 93 ++++++++++++------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 20560d36..9d55613f 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -19,30 +19,22 @@ const UNKNOWN_KEY: usize = usize::MAX; #[derive(Debug, Clone)] pub struct StructBuilder { pub fields: Vec<(ArrayBuilder, FieldMeta)>, - pub cached_names: Vec>, - pub seen: Vec, + pub lookup: FieldLookup, pub next: usize, - pub index: BTreeMap, + pub seen: Vec, pub seq: CountArray, } impl StructBuilder { pub fn new(fields: Vec<(ArrayBuilder, FieldMeta)>, is_nullable: bool) -> Result { - let mut index = BTreeMap::new(); - for (idx, (_, meta)) in fields.iter().enumerate() { - if index.contains_key(&meta.name) { - fail!("Duplicate field {name}", name = meta.name); - } - index.insert(meta.name.clone(), idx); - } + let lookup = FieldLookup::new(fields.iter().map(|(_, meta)| meta.name.clone()).collect())?; Ok(Self { + seq: CountArray::new(is_nullable), seen: vec![false; fields.len()], - cached_names: vec![None; fields.len()], - fields, next: 0, - index, - seq: CountArray::new(is_nullable), + lookup, + fields, }) } @@ -53,11 +45,10 @@ impl StructBuilder { .iter_mut() .map(|(builder, meta)| (builder.take(), meta.clone())) .collect(), - cached_names: std::mem::replace(&mut self.cached_names, vec![None; self.fields.len()]), + lookup: self.lookup.take(), seen: std::mem::replace(&mut self.seen, vec![false; self.fields.len()]), - next: std::mem::take(&mut self.next), - index: self.index.clone(), seq: self.seq.take(), + next: std::mem::take(&mut self.next), } } @@ -152,21 +143,10 @@ impl SimpleSerializer for StructBuilder { key: &'static str, value: &T, ) -> Result<()> { - let fast_key = (key.as_ptr(), key.len()); - let idx = if self.cached_names.get(self.next) == Some(&Some(fast_key)) { - self.next - } else { - let Some(&idx) = self.index.get(key) else { - // ignore unknown fields - return Ok(()); - }; - - if self.cached_names[idx].is_none() { - self.cached_names[idx] = Some(fast_key); - } - idx + let Some(idx) = self.lookup.lookup(self.next, key) else { + // ignore unknown fields + return Ok(()); }; - self.element(idx, value) } @@ -210,7 +190,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - self.next = KeyLookupSerializer::lookup(&self.index, key)?.unwrap_or(UNKNOWN_KEY); + self.next = self.lookup.lookup_serialize(key)?.unwrap_or(UNKNOWN_KEY); Ok(()) } @@ -228,6 +208,55 @@ impl SimpleSerializer for StructBuilder { } } +#[derive(Debug, Clone)] +pub struct FieldLookup { + pub cached_names: Vec>, + pub index: BTreeMap, +} + +impl FieldLookup { + pub fn new(field_names: Vec) -> Result { + let mut index = BTreeMap::new(); + for (idx, name) in field_names.into_iter().enumerate() { + if index.contains_key(&name) { + fail!("Duplicate field {name}"); + } + index.insert(name, idx); + } + Ok(Self { + cached_names: vec![None; index.len()], + index, + }) + } + + pub fn take(&mut self) -> Self { + Self { + cached_names: std::mem::replace(&mut self.cached_names, vec![None; self.index.len()]), + index: self.index.clone(), + } + } + + pub fn lookup(&mut self, guess: usize, key: &'static str) -> Option { + let fast_key = (key.as_ptr(), key.len()); + if self.cached_names.get(guess) == Some(&Some(fast_key)) { + Some(guess) + } else { + let Some(&idx) = self.index.get(key) else { + return None; + }; + + if self.cached_names[idx].is_none() { + self.cached_names[idx] = Some(fast_key); + } + Some(idx) + } + } + + pub fn lookup_serialize(&mut self, key: &V) -> Result> { + KeyLookupSerializer::lookup(&self.index, key) + } +} + #[derive(Debug)] pub struct KeyLookupSerializer<'a> { index: &'a BTreeMap, From a6dfeff342f99a4b93717b4835526dfee5e4cc7c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 11:34:28 +0200 Subject: [PATCH 073/178] Rename utils.rs -> simple_serializer.rs --- serde_arrow/src/internal/serialization/array_builder.rs | 2 +- serde_arrow/src/internal/serialization/binary_builder.rs | 2 +- serde_arrow/src/internal/serialization/bool_builder.rs | 2 +- serde_arrow/src/internal/serialization/date32_builder.rs | 2 +- serde_arrow/src/internal/serialization/date64_builder.rs | 2 +- serde_arrow/src/internal/serialization/decimal_builder.rs | 2 +- .../src/internal/serialization/dictionary_utf8_builder.rs | 2 +- serde_arrow/src/internal/serialization/duration_builder.rs | 2 +- .../src/internal/serialization/fixed_size_binary_builder.rs | 2 +- .../src/internal/serialization/fixed_size_list_builder.rs | 2 +- serde_arrow/src/internal/serialization/float_builder.rs | 2 +- serde_arrow/src/internal/serialization/int_builder.rs | 2 +- serde_arrow/src/internal/serialization/list_builder.rs | 2 +- serde_arrow/src/internal/serialization/map_builder.rs | 2 +- serde_arrow/src/internal/serialization/mod.rs | 2 +- serde_arrow/src/internal/serialization/null_builder.rs | 2 +- .../src/internal/serialization/outer_sequence_builder.rs | 2 +- .../internal/serialization/{utils.rs => simple_serializer.rs} | 0 serde_arrow/src/internal/serialization/struct_builder.rs | 2 +- serde_arrow/src/internal/serialization/time_builder.rs | 2 +- serde_arrow/src/internal/serialization/union_builder.rs | 2 +- .../src/internal/serialization/unknown_variant_builder.rs | 2 +- serde_arrow/src/internal/serialization/utf8_builder.rs | 2 +- 23 files changed, 22 insertions(+), 22 deletions(-) rename serde_arrow/src/internal/serialization/{utils.rs => simple_serializer.rs} (100%) diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index fc73c70d..cc3aacf4 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -12,7 +12,7 @@ use super::{ int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder, null_builder::NullBuilder, struct_builder::StructBuilder, time_builder::TimeBuilder, union_builder::UnionBuilder, unknown_variant_builder::UnknownVariantBuilder, - utf8_builder::Utf8Builder, utils::SimpleSerializer, + utf8_builder::Utf8Builder, simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index f926c8bb..2bc1ea5a 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 1c4c24d1..1614bcfe 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -5,7 +5,7 @@ use crate::internal::{ use super::{ array_ext::{set_bit_buffer, set_validity, set_validity_default}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 077c2672..1aaa3079 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -7,7 +7,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 7aa44e67..2acc2107 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -6,7 +6,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 437f876e..1caf8802 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -6,7 +6,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index cbbbd028..d7cbb4d0 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ utils::Mut, }; -use super::{array_builder::ArrayBuilder, utils::SimpleSerializer}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct DictionaryUtf8Builder { diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 27f2ac01..17ecc99a 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -5,7 +5,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 2a806b7c..2c82b46e 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index d07ac9bc..89007756 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 81af7edd..a9be17b2 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index e346c818..acea49ff 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -5,7 +5,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 9d00e039..8a2bdad1 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index ef109f44..fdf557d1 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/mod.rs b/serde_arrow/src/internal/serialization/mod.rs index 898583ed..22b0088b 100644 --- a/serde_arrow/src/internal/serialization/mod.rs +++ b/serde_arrow/src/internal/serialization/mod.rs @@ -22,7 +22,7 @@ pub mod time_builder; pub mod union_builder; pub mod unknown_variant_builder; pub mod utf8_builder; -pub mod utils; +pub mod simple_serializer; // #[cfg(test)] // mod test; diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index eb02acdb..60850d94 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -3,7 +3,7 @@ use crate::internal::{ error::Result, }; -use super::utils::SimpleSerializer; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone, Default)] pub struct NullBuilder { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 2875a7c7..3bd6173a 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -19,7 +19,7 @@ use super::{ map_builder::MapBuilder, null_builder::NullBuilder, struct_builder::StructBuilder, time_builder::TimeBuilder, union_builder::UnionBuilder, unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, - utils::SimpleSerializer, ArrayBuilder, + simple_serializer::SimpleSerializer, ArrayBuilder, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/utils.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs similarity index 100% rename from serde_arrow/src/internal/serialization/utils.rs rename to serde_arrow/src/internal/serialization/simple_serializer.rs diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 9d55613f..9c688a32 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ use super::{ array_builder::ArrayBuilder, array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; const UNKNOWN_KEY: usize = usize::MAX; diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 5d2eb8e8..ed6dae67 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -7,7 +7,7 @@ use crate::internal::{ use super::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 0a9b292e..ade767f4 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -4,7 +4,7 @@ use crate::internal::{ utils::Mut, }; -use super::{utils::SimpleSerializer, ArrayBuilder}; +use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; #[derive(Debug, Clone)] pub struct UnionBuilder { diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 845574de..b6b6c0f3 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -8,7 +8,7 @@ use crate::{ Result, }; -use super::{utils::SimpleSerializer, ArrayBuilder}; +use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; #[derive(Debug, Clone)] pub struct UnknownVariantBuilder; diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 19047444..788123cf 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -6,7 +6,7 @@ use crate::internal::{ use super::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - utils::SimpleSerializer, + simple_serializer::SimpleSerializer, }; #[derive(Debug, Clone)] From 167f16bfd9815b6d2bae180d7b38f3693f8ef2fa Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 11:51:30 +0200 Subject: [PATCH 074/178] Simplify meta hanlding date64_builder --- .../internal/serialization/array_builder.rs | 6 +-- .../internal/serialization/date64_builder.rs | 44 ++++++++--------- serde_arrow/src/internal/serialization/mod.rs | 2 +- .../serialization/outer_sequence_builder.rs | 48 +++++++++++-------- 4 files changed, 52 insertions(+), 48 deletions(-) diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index cc3aacf4..368261c0 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -10,9 +10,9 @@ use super::{ fixed_size_binary_builder::FixedSizeBinaryBuilder, fixed_size_list_builder::FixedSizeListBuilder, float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder, - null_builder::NullBuilder, struct_builder::StructBuilder, time_builder::TimeBuilder, - union_builder::UnionBuilder, unknown_variant_builder::UnknownVariantBuilder, - utf8_builder::Utf8Builder, simple_serializer::SimpleSerializer, + null_builder::NullBuilder, simple_serializer::SimpleSerializer, struct_builder::StructBuilder, + time_builder::TimeBuilder, union_builder::UnionBuilder, + unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, }; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 2acc2107..576229c3 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,7 +1,6 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, error::{fail, Result}, - schema::{GenericDataType, GenericField}, }; use super::{ @@ -11,15 +10,15 @@ use super::{ #[derive(Debug, Clone)] pub struct Date64Builder { - pub field: GenericField, + pub meta: Option<(TimeUnit, Option)>, pub utc: bool, pub array: PrimitiveArray, } impl Date64Builder { - pub fn new(field: GenericField, utc: bool, is_nullable: bool) -> Self { + pub fn new(meta: Option<(TimeUnit, Option)>, utc: bool, is_nullable: bool) -> Self { Self { - field, + meta, utc, array: new_primitive_array(is_nullable), } @@ -27,7 +26,7 @@ impl Date64Builder { pub fn take(&mut self) -> Self { Self { - field: self.field.clone(), + meta: self.meta.clone(), utc: self.utc, array: self.array.take(), } @@ -37,9 +36,8 @@ impl Date64Builder { self.array.validity.is_some() } - // TODO: fix this pub fn into_array(self) -> Result { - if let GenericDataType::Timestamp(unit, timezone) = self.field.data_type { + if let Some((unit, timezone)) = self.meta { Ok(Array::Timestamp(TimestampArray { unit, timezone, @@ -47,7 +45,6 @@ impl Date64Builder { values: self.array.values, })) } else { - // TOOD: check data type Ok(Array::Date64(PrimitiveArray { validity: self.array.validity, values: self.array.values, @@ -78,24 +75,21 @@ impl SimpleSerializer for Date64Builder { v.parse::()?.and_utc() }; - let timestamp = match self.field.data_type { - GenericDataType::Timestamp(TimeUnit::Nanosecond, _) => { - match date_time.timestamp_nanos_opt() { - Some(timestamp) => timestamp, - _ => fail!( - concat!( - "Timestamp '{date_time}' cannot be converted to nanoseconds. ", - "The dates that can be represented as nanoseconds are between ", - "1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804.", - ), - date_time = date_time, + let timestamp = match self.meta.as_ref() { + Some((TimeUnit::Nanosecond, _)) => match date_time.timestamp_nanos_opt() { + Some(timestamp) => timestamp, + _ => fail!( + concat!( + "Timestamp '{date_time}' cannot be converted to nanoseconds. ", + "The dates that can be represented as nanoseconds are between ", + "1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804.", ), - } - } - GenericDataType::Timestamp(TimeUnit::Microsecond, _) => date_time.timestamp_micros(), - GenericDataType::Timestamp(TimeUnit::Millisecond, _) => date_time.timestamp_millis(), - GenericDataType::Timestamp(TimeUnit::Second, _) => date_time.timestamp(), - _ => date_time.timestamp_millis(), + date_time = date_time, + ), + }, + Some((TimeUnit::Microsecond, _)) => date_time.timestamp_micros(), + Some((TimeUnit::Millisecond, _)) | None => date_time.timestamp_millis(), + Some((TimeUnit::Second, _)) => date_time.timestamp(), }; self.array.push_scalar_value(timestamp) diff --git a/serde_arrow/src/internal/serialization/mod.rs b/serde_arrow/src/internal/serialization/mod.rs index 22b0088b..ac137b17 100644 --- a/serde_arrow/src/internal/serialization/mod.rs +++ b/serde_arrow/src/internal/serialization/mod.rs @@ -17,12 +17,12 @@ pub mod list_builder; pub mod map_builder; pub mod null_builder; pub mod outer_sequence_builder; +pub mod simple_serializer; pub mod struct_builder; pub mod time_builder; pub mod union_builder; pub mod unknown_variant_builder; pub mod utf8_builder; -pub mod simple_serializer; // #[cfg(test)] // mod test; diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 3bd6173a..21647373 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -16,10 +16,9 @@ use super::{ bool_builder::BoolBuilder, date32_builder::Date32Builder, date64_builder::Date64Builder, decimal_builder::DecimalBuilder, dictionary_utf8_builder::DictionaryUtf8Builder, float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder, - map_builder::MapBuilder, null_builder::NullBuilder, struct_builder::StructBuilder, - time_builder::TimeBuilder, union_builder::UnionBuilder, - unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, - simple_serializer::SimpleSerializer, ArrayBuilder, + map_builder::MapBuilder, null_builder::NullBuilder, simple_serializer::SimpleSerializer, + struct_builder::StructBuilder, time_builder::TimeBuilder, union_builder::UnionBuilder, + unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, ArrayBuilder, }; #[derive(Debug, Clone)] @@ -61,21 +60,16 @@ impl OuterSequenceBuilder { T::F32 => A::F32(FloatBuilder::new(field.nullable)), T::F64 => A::F64(FloatBuilder::new(field.nullable)), T::Date32 => A::Date32(Date32Builder::new(field.nullable)), - T::Date64 => { - let is_utc = match field.strategy.as_ref() { - Some(Strategy::UtcStrAsDate64) | None => true, - Some(Strategy::NaiveStrAsDate64) => false, - Some(st) => fail!("Cannot builder Date64 builder with strategy {st}"), - }; - A::Date64(Date64Builder::new(field.clone(), is_utc, field.nullable)) - } - T::Timestamp(_, tz) => match tz.as_deref() { - None => A::Date64(Date64Builder::new(field.clone(), false, field.nullable)), - Some(tz) if tz.to_uppercase() == "UTC" => { - A::Date64(Date64Builder::new(field.clone(), true, field.nullable)) - } - Some(tz) => fail!("Timezone {tz} is not supported"), - }, + T::Date64 => A::Date64(Date64Builder::new( + None, + is_utc_strategy(field.strategy.as_ref())?, + field.nullable, + )), + T::Timestamp(unit, tz) => A::Date64(Date64Builder::new( + Some((*unit, tz.clone())), + is_utc_tz(tz.as_deref())?, + field.nullable, + )), T::Time32(unit) => { if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { fail!("Only timestamps with second or millisecond unit are supported"); @@ -249,3 +243,19 @@ impl SimpleSerializer for OuterSequenceBuilder { Ok(()) } } + +fn is_utc_tz(tz: Option<&str>) -> Result { + match tz { + None => Ok(false), + Some(tz) if tz.to_uppercase() == "UTC" => Ok(true), + Some(tz) => fail!("Timezone {tz} is not supported"), + } +} + +fn is_utc_strategy(strategy: Option<&Strategy>) -> Result { + match strategy { + Some(Strategy::UtcStrAsDate64) | None => Ok(true), + Some(Strategy::NaiveStrAsDate64) => Ok(false), + Some(st) => fail!("Cannot builder Date64 builder with strategy {st}"), + } +} From a299dbe40e5f06600464afff921da3eee4dd2ada Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 11:53:06 +0200 Subject: [PATCH 075/178] Fix clippy --- serde_arrow/src/internal/serialization/array_ext.rs | 2 +- serde_arrow/src/internal/serialization/struct_builder.rs | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/serde_arrow/src/internal/serialization/array_ext.rs b/serde_arrow/src/internal/serialization/array_ext.rs index bc079ead..b2df07f5 100644 --- a/serde_arrow/src/internal/serialization/array_ext.rs +++ b/serde_arrow/src/internal/serialization/array_ext.rs @@ -268,7 +268,7 @@ pub fn duplicate_last(vec: &mut Vec) -> Result<()> { Ok(()) } -pub fn increment_last(vec: &mut Vec, inc: usize) -> Result<()> { +pub fn increment_last(vec: &mut [O], inc: usize) -> Result<()> { let Some(last) = vec.last_mut() else { fail!("invalid offset array") }; diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 9c688a32..264a7b86 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -241,10 +241,7 @@ impl FieldLookup { if self.cached_names.get(guess) == Some(&Some(fast_key)) { Some(guess) } else { - let Some(&idx) = self.index.get(key) else { - return None; - }; - + let &idx = self.index.get(key)?; if self.cached_names[idx].is_none() { self.cached_names[idx] = Some(fast_key); } From d8589565b5e03bb307aa91a3433b64cb8b7b9f1b Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 12:11:17 +0200 Subject: [PATCH 076/178] Use BitsWithOffset directly in primitive deserializers --- .../deserialization/array_deserializer.rs | 46 +++++++++---------- .../deserialization/bool_deserializer.rs | 33 +++++++------ .../internal/deserialization/construction.rs | 35 -------------- .../deserialization/date32_deserializer.rs | 9 ++-- .../deserialization/date64_deserializer.rs | 9 ++-- .../deserialization/decimal_deserializer.rs | 8 ++-- .../dictionary_deserializer.rs | 9 ++-- .../deserialization/float_deserializer.rs | 9 ++-- .../deserialization/integer_deserializer.rs | 9 ++-- .../src/internal/deserialization/mod.rs | 1 - .../deserialization/time_deserializer.rs | 9 ++-- .../src/internal/deserialization/utils.rs | 29 +++++++----- 12 files changed, 78 insertions(+), 128 deletions(-) delete mode 100644 serde_arrow/src/internal/deserialization/construction.rs diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 8f52798e..80cea649 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -80,86 +80,84 @@ pub enum ArrayDeserializer<'a> { impl<'a> ArrayDeserializer<'a> { pub fn new(strategy: Option<&Strategy>, array: ArrayView<'a>) -> Result { + use {ArrayDeserializer as D, ArrayView as V}; match array { ArrayView::Null(_) => Ok(Self::Null(NullDeserializer {})), - ArrayView::Boolean(view) => Ok(Self::Bool(BoolDeserializer::new( - buffer_from_bits_with_offset(view.values, view.len), - buffer_from_bits_with_offset_opt(view.validity, view.len), - ))), + V::Boolean(view) => Ok(D::Bool(BoolDeserializer::new(view))), ArrayView::Int8(view) => Ok(Self::I8(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Int16(view) => Ok(Self::I16(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Int32(view) => Ok(Self::I32(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Int64(view) => Ok(Self::I64(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::UInt8(view) => Ok(Self::U8(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::UInt16(view) => Ok(Self::U16(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::UInt32(view) => Ok(Self::U32(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::UInt64(view) => Ok(Self::U64(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Float16(view) => Ok(Self::F16(FloatDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Float32(view) => Ok(Self::F32(FloatDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Float64(view) => Ok(Self::F64(FloatDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Decimal128(view) => Ok(Self::Decimal128(DecimalDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, view.scale, ))), ArrayView::Date32(view) => Ok(Self::Date32(Date32Deserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Date64(view) => Ok(Self::Date64(Date64Deserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, TimeUnit::Millisecond, is_utc_date64(strategy)?, ))), ArrayView::Time32(view) => Ok(Self::Time32(TimeDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, view.unit, ))), ArrayView::Time64(view) => Ok(Self::Time64(TimeDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, view.unit, ))), ArrayView::Timestamp(view) => match strategy { Some(Strategy::NaiveStrAsDate64 | Strategy::UtcStrAsDate64) => { Ok(Self::Date64(Date64Deserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, view.unit, is_utc_timestamp(view.timezone.as_deref())?, ))) @@ -167,7 +165,7 @@ impl<'a> ArrayDeserializer<'a> { Some(strategy) => fail!("invalid strategy {strategy} for timestamp field"), None => Ok(Date64Deserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, view.unit, is_utc_timestamp(view.timezone.as_deref())?, ) @@ -175,7 +173,7 @@ impl<'a> ArrayDeserializer<'a> { }, ArrayView::Duration(view) => Ok(Self::I64(IntegerDeserializer::new( view.values, - buffer_from_bits_with_offset_opt(view.validity, view.values.len()), + view.validity, ))), ArrayView::Utf8(view) => Ok(Self::Utf8(StringDeserializer::new( view.data, @@ -359,7 +357,7 @@ fn build_dictionary_array<'a, K: Integer, V: Offset>( } Ok(DictionaryDeserializer::new( keys.values, - buffer_from_bits_with_offset_opt(keys.validity, keys.values.len()), + keys.validity, values.data, values.offsets, )) diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index 5461285e..4e7fd87d 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -1,36 +1,35 @@ use serde::de::Visitor; -use crate::internal::{error::fail, error::Result, utils::Mut}; +use crate::internal::{ + arrow::BooleanArrayView, + error::{fail, Result}, + utils::Mut, +}; -use super::{simple_deserializer::SimpleDeserializer, utils::BitBuffer}; +use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; pub struct BoolDeserializer<'a> { - pub buffer: BitBuffer<'a>, - pub validity: Option>, + pub view: BooleanArrayView<'a>, pub next: usize, } impl<'a> BoolDeserializer<'a> { - pub fn new(buffer: BitBuffer<'a>, validity: Option>) -> Self { - Self { - buffer, - validity, - next: 0, - } + pub fn new(view: BooleanArrayView<'a>) -> Self { + Self { view, next: 0 } } fn next(&mut self) -> Result> { - if self.next >= self.buffer.len() { + if self.next >= self.view.len { fail!("Exhausted BoolDeserializer"); } - if let Some(validty) = &self.validity { - if !validty.is_set(self.next) { + if let Some(validty) = &self.view.validity { + if !bitset_is_set(validty, self.next)? { self.next += 1; return Ok(None); } } - let val = self.buffer.is_set(self.next); + let val = bitset_is_set(&self.view.values, self.next)?; self.next += 1; Ok(Some(val)) } @@ -44,10 +43,10 @@ impl<'a> BoolDeserializer<'a> { } fn peek_next(&self) -> Result { - if self.next >= self.buffer.len() { + if self.next >= self.view.len { fail!("Exhausted BoolDeserializer"); - } else if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next)) + } else if let Some(validity) = &self.view.validity { + bitset_is_set(validity, self.next) } else { Ok(true) } diff --git a/serde_arrow/src/internal/deserialization/construction.rs b/serde_arrow/src/internal/deserialization/construction.rs deleted file mode 100644 index db7e8ab1..00000000 --- a/serde_arrow/src/internal/deserialization/construction.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::internal::{ - deserialization::date64_deserializer::Date64Deserializer, - error::{fail, Result}, - schema::{GenericDataType, GenericField, Strategy}, -}; - -use super::{array_deserializer::ArrayDeserializer, utils::BitBuffer}; - -#[allow(unused)] -pub fn build_timestamp_deserializer<'a>( - field: &GenericField, - values: &'a [i64], - validity: Option>, -) -> Result> { - let strategy = field.strategy.as_ref(); - let GenericDataType::Timestamp(unit, _) = &field.data_type else { - fail!( - "invalid data type for timestamp deserializer: {dt}", - dt = field.data_type - ); - }; - - if matches!( - strategy, - Some(Strategy::NaiveStrAsDate64 | Strategy::UtcStrAsDate64) - ) { - return Ok(Date64Deserializer::new(values, validity, *unit, field.is_utc()?).into()); - } - - if let Some(strategy) = strategy { - fail!("invalid strategy {strategy} for timestamp field"); - } - - Ok(Date64Deserializer::new(values, validity, *unit, field.is_utc()?).into()) -} diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index d9306288..fca2d52f 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -1,17 +1,14 @@ use chrono::{Duration, NaiveDate, NaiveDateTime}; use serde::de::Visitor; -use crate::internal::{error::Result, utils::Mut}; +use crate::internal::{arrow::BitsWithOffset, error::Result, utils::Mut}; -use super::{ - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, -}; +use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; pub struct Date32Deserializer<'a>(ArrayBufferIterator<'a, i32>); impl<'a> Date32Deserializer<'a> { - pub fn new(buffer: &'a [i32], validity: Option>) -> Self { + pub fn new(buffer: &'a [i32], validity: Option>) -> Self { Self(ArrayBufferIterator::new(buffer, validity)) } diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index bd905560..d6dd2be2 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -2,22 +2,19 @@ use chrono::DateTime; use serde::de::Visitor; use crate::internal::{ - arrow::TimeUnit, + arrow::{BitsWithOffset, TimeUnit}, error::{fail, Result}, utils::Mut, }; -use super::{ - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, -}; +use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; pub struct Date64Deserializer<'a>(ArrayBufferIterator<'a, i64>, TimeUnit, bool); impl<'a> Date64Deserializer<'a> { pub fn new( buffer: &'a [i64], - validity: Option>, + validity: Option>, unit: TimeUnit, is_utc: bool, ) -> Self { diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index 8bd49a37..cb1789e0 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -1,14 +1,12 @@ use serde::de::Visitor; use crate::internal::{ + arrow::BitsWithOffset, error::Result, utils::{decimal, Mut}, }; -use super::{ - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, -}; +use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; pub struct DecimalDeserializer<'a> { inner: ArrayBufferIterator<'a, i128>, @@ -16,7 +14,7 @@ pub struct DecimalDeserializer<'a> { } impl<'a> DecimalDeserializer<'a> { - pub fn new(buffer: &'a [i128], validity: Option>, scale: i8) -> Self { + pub fn new(buffer: &'a [i128], validity: Option>, scale: i8) -> Self { Self { inner: ArrayBufferIterator::new(buffer, validity), scale, diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index a3958eb7..5f6a2829 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -1,15 +1,14 @@ use serde::de::Visitor; use crate::internal::{ + arrow::BitsWithOffset, error::{fail, Result}, utils::{Mut, Offset}, }; use super::{ - enums_as_string_impl::EnumAccess, - integer_deserializer::Integer, - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, + enums_as_string_impl::EnumAccess, integer_deserializer::Integer, + simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator, }; pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> { @@ -21,7 +20,7 @@ pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> { impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { pub fn new( keys_buffer: &'a [K], - keys_validity: Option>, + keys_validity: Option>, data: &'a [u8], offsets: &'a [V], ) -> Self { diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 41a8b975..82754609 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -1,11 +1,8 @@ use serde::de::Visitor; -use crate::internal::{error::Result, utils::Mut}; +use crate::internal::{arrow::BitsWithOffset, error::Result, utils::Mut}; -use super::{ - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, -}; +use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; pub trait Float: Copy { fn deserialize_any<'de, S: SimpleDeserializer<'de>, V: Visitor<'de>>( @@ -20,7 +17,7 @@ pub trait Float: Copy { pub struct FloatDeserializer<'a, F: Float>(ArrayBufferIterator<'a, F>); impl<'a, F: Float> FloatDeserializer<'a, F> { - pub fn new(buffer: &'a [F], validity: Option>) -> Self { + pub fn new(buffer: &'a [F], validity: Option>) -> Self { Self(ArrayBufferIterator::new(buffer, validity)) } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index 6e7493d4..82972558 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -1,11 +1,8 @@ use serde::de::Visitor; -use crate::internal::{error::Result, utils::Mut}; +use crate::internal::{arrow::BitsWithOffset, error::Result, utils::Mut}; -use super::{ - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, -}; +use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; pub trait Integer: Sized + Copy { fn deserialize_any<'de, S: SimpleDeserializer<'de>, V: Visitor<'de>>( @@ -29,7 +26,7 @@ pub trait Integer: Sized + Copy { pub struct IntegerDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>); impl<'a, T: Integer> IntegerDeserializer<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>) -> Self { + pub fn new(buffer: &'a [T], validity: Option>) -> Self { Self(ArrayBufferIterator::new(buffer, validity)) } } diff --git a/serde_arrow/src/internal/deserialization/mod.rs b/serde_arrow/src/internal/deserialization/mod.rs index 6aed969b..8da9f023 100644 --- a/serde_arrow/src/internal/deserialization/mod.rs +++ b/serde_arrow/src/internal/deserialization/mod.rs @@ -1,7 +1,6 @@ pub mod array_deserializer; pub mod binary_deserializer; pub mod bool_deserializer; -pub mod construction; pub mod date32_deserializer; pub mod date64_deserializer; pub mod decimal_deserializer; diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index 9c755fbb..f474311e 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -2,21 +2,20 @@ use chrono::NaiveTime; use serde::de::Visitor; use crate::internal::{ - arrow::TimeUnit, + arrow::{BitsWithOffset, TimeUnit}, error::{fail, Result}, utils::Mut, }; use super::{ - integer_deserializer::Integer, - simple_deserializer::SimpleDeserializer, - utils::{ArrayBufferIterator, BitBuffer}, + integer_deserializer::Integer, simple_deserializer::SimpleDeserializer, + utils::ArrayBufferIterator, }; pub struct TimeDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>, i64, i64); impl<'a, T: Integer> TimeDeserializer<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>, unit: TimeUnit) -> Self { + pub fn new(buffer: &'a [T], validity: Option>, unit: TimeUnit) -> Self { let (seconds_factor, nanoseconds_factor) = match unit { TimeUnit::Nanosecond => (1_000_000_000, 1), TimeUnit::Microsecond => (1_000_000, 1_000), diff --git a/serde_arrow/src/internal/deserialization/utils.rs b/serde_arrow/src/internal/deserialization/utils.rs index cdf8b32b..ac8cf173 100644 --- a/serde_arrow/src/internal/deserialization/utils.rs +++ b/serde_arrow/src/internal/deserialization/utils.rs @@ -1,8 +1,17 @@ use crate::internal::{ + arrow::BitsWithOffset, error::{error, fail, Result}, utils::Offset, }; +pub fn bitset_is_set(set: &BitsWithOffset<'_>, idx: usize) -> Result { + let flag = 1 << ((idx + set.offset) % 8); + let Some(byte) = set.data.get((idx + set.offset) / 8) else { + fail!("invalid access in bitset"); + }; + Ok(byte & flag == flag) +} + #[derive(Debug, PartialEq, Clone, Copy)] pub struct BitBuffer<'a> { pub data: &'a [u8], @@ -24,12 +33,12 @@ impl<'a> BitBuffer<'a> { pub struct ArrayBufferIterator<'a, T: Copy> { pub buffer: &'a [T], - pub validity: Option>, + pub validity: Option>, pub next: usize, } impl<'a, T: Copy> ArrayBufferIterator<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>) -> Self { + pub fn new(buffer: &'a [T], validity: Option>) -> Self { Self { buffer, validity, @@ -43,7 +52,7 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { } if let Some(validity) = &self.validity { - if !validity.is_set(self.next) { + if !bitset_is_set(validity, self.next)? { return Ok(None); } } @@ -63,7 +72,7 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { } if let Some(validity) = &self.validity { - if !validity.is_set(self.next) { + if !bitset_is_set(validity, self.next)? { return Ok(false); } } @@ -92,15 +101,11 @@ pub fn check_supported_list_layout<'a, O: Offset>( return Ok(()); }; - if offsets.len() != validity.len() + 1 { - fail!( - "validity length {val} and offsets length {off} do not match (expected {val}, {exp})", - val = validity.len(), - off = offsets.len(), - exp = validity.len() + 1, - ); + if offsets.is_empty() { + fail!("list offsets must be non empty"); } - for i in 0..validity.len() { + + for i in 0..offsets.len().saturating_sub(1) { let curr = offsets[i].try_into_usize()?; let next = offsets[i + 1].try_into_usize()?; if !validity.is_set(i) && (next - curr) != 0 { From 8b6ff63495d17553569e3f803a860c195310fde9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 12:17:38 +0200 Subject: [PATCH 077/178] Use view directly in BinaryDeserializer --- .../deserialization/array_deserializer.rs | 18 +------ .../deserialization/binary_deserializer.rs | 47 +++++++++---------- 2 files changed, 23 insertions(+), 42 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 80cea649..3a3ba3cc 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -191,22 +191,8 @@ impl<'a> ArrayDeserializer<'a> { view.offsets.len().saturating_sub(1), ), ))), - ArrayView::Binary(view) => Ok(Self::Binary(BinaryDeserializer::new( - view.data, - view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), - ))), - ArrayView::LargeBinary(view) => Ok(Self::LargeBinary(BinaryDeserializer::new( - view.data, - view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), - ))), + V::Binary(view) => Ok(D::Binary(BinaryDeserializer::new(view))), + V::LargeBinary(view) => Ok(D::LargeBinary(BinaryDeserializer::new(view))), ArrayView::FixedSizeBinary(view) => { let value_length: usize = view.n.try_into()?; if view.data.len() % value_length != 0 { diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index 55271d0f..7c9ff1df 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -1,35 +1,29 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ + arrow::BytesArrayView, error::{fail, Error, Result}, utils::{Mut, Offset}, }; -use super::{simple_deserializer::SimpleDeserializer, utils::BitBuffer}; +use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; pub struct BinaryDeserializer<'a, O: Offset> { - pub buffer: &'a [u8], - pub offsets: &'a [O], - pub validity: Option>, + pub view: BytesArrayView<'a, O>, pub next: (usize, usize), } impl<'a, O: Offset> BinaryDeserializer<'a, O> { - pub fn new(buffer: &'a [u8], offsets: &'a [O], validity: Option>) -> Self { - Self { - buffer, - offsets, - validity, - next: (0, 0), - } + pub fn new(view: BytesArrayView<'a, O>) -> Self { + Self { view, next: (0, 0) } } pub fn peek_next(&self) -> Result { - if self.next.0 + 1 >= self.offsets.len() { + if self.next.0 + 1 >= self.view.offsets.len() { fail!("Exhausted ListDeserializer") } - if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next.0)) + if let Some(validity) = &self.view.validity { + bitset_is_set(validity, self.next.0) } else { Ok(true) } @@ -39,16 +33,21 @@ impl<'a, O: Offset> BinaryDeserializer<'a, O> { self.next = (self.next.0 + 1, 0); } - pub fn next_slice(&mut self) -> Result<&'a [u8]> { + pub fn peek_next_slice_range(&self) -> Result<(usize, usize)> { let (item, _) = self.next; - if item + 1 >= self.offsets.len() { + if item + 1 >= self.view.offsets.len() { fail!("called next_slices on exhausted BinaryDeserializer"); } - let end = self.offsets[item + 1].try_into_usize()?; - let start = self.offsets[item].try_into_usize()?; - self.next = (item + 1, 0); + let end = self.view.offsets[item + 1].try_into_usize()?; + let start = self.view.offsets[item].try_into_usize()?; + Ok((start, end)) + } - Ok(&self.buffer[start..end]) + pub fn next_slice(&mut self) -> Result<&'a [u8]> { + let (start, end) = self.peek_next_slice_range()?; + let (item, _) = self.next; + self.next = (item + 1, 0); + Ok(&self.view.data[start..end]) } } @@ -96,11 +95,7 @@ impl<'de, O: Offset> SeqAccess<'de> for BinaryDeserializer<'de, O> { seed: T, ) -> Result> { let (item, offset) = self.next; - if item + 1 >= self.offsets.len() { - return Ok(None); - } - let end = self.offsets[item + 1].try_into_usize()?; - let start = self.offsets[item].try_into_usize()?; + let (start, end) = self.peek_next_slice_range()?; if offset >= end - start { self.next = (item + 1, 0); @@ -108,7 +103,7 @@ impl<'de, O: Offset> SeqAccess<'de> for BinaryDeserializer<'de, O> { } self.next = (item, offset + 1); - let mut item_deserializer = U8Deserializer(self.buffer[start + offset]); + let mut item_deserializer = U8Deserializer(self.view.data[start + offset]); let item = seed.deserialize(Mut(&mut item_deserializer))?; Ok(Some(item)) } From 648890162fa7e092c2ecf1c72ed7badaea2ca2c8 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 12:26:45 +0200 Subject: [PATCH 078/178] Use views directly for {Float,Integer,Sting}Deserializer --- .../deserialization/array_deserializer.rs | 81 ++++--------------- .../deserialization/float_deserializer.rs | 6 +- .../deserialization/integer_deserializer.rs | 6 +- .../deserialization/string_deserializer.rs | 34 ++++---- .../src/internal/deserialization/test.rs | 17 +++- 5 files changed, 51 insertions(+), 93 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 3a3ba3cc..ba9c1b04 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -84,50 +84,17 @@ impl<'a> ArrayDeserializer<'a> { match array { ArrayView::Null(_) => Ok(Self::Null(NullDeserializer {})), V::Boolean(view) => Ok(D::Bool(BoolDeserializer::new(view))), - ArrayView::Int8(view) => Ok(Self::I8(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Int16(view) => Ok(Self::I16(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Int32(view) => Ok(Self::I32(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Int64(view) => Ok(Self::I64(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::UInt8(view) => Ok(Self::U8(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::UInt16(view) => Ok(Self::U16(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::UInt32(view) => Ok(Self::U32(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::UInt64(view) => Ok(Self::U64(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Float16(view) => Ok(Self::F16(FloatDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Float32(view) => Ok(Self::F32(FloatDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Float64(view) => Ok(Self::F64(FloatDeserializer::new( - view.values, - view.validity, - ))), + V::Int8(view) => Ok(D::I8(IntegerDeserializer::new(view))), + V::Int16(view) => Ok(D::I16(IntegerDeserializer::new(view))), + V::Int32(view) => Ok(D::I32(IntegerDeserializer::new(view))), + V::Int64(view) => Ok(D::I64(IntegerDeserializer::new(view))), + V::UInt8(view) => Ok(D::U8(IntegerDeserializer::new(view))), + V::UInt16(view) => Ok(D::U16(IntegerDeserializer::new(view))), + V::UInt32(view) => Ok(D::U32(IntegerDeserializer::new(view))), + V::UInt64(view) => Ok(D::U64(IntegerDeserializer::new(view))), + V::Float16(view) => Ok(D::F16(FloatDeserializer::new(view))), + V::Float32(view) => Ok(D::F32(FloatDeserializer::new(view))), + V::Float64(view) => Ok(D::F64(FloatDeserializer::new(view))), ArrayView::Decimal128(view) => Ok(Self::Decimal128(DecimalDeserializer::new( view.values, view.validity, @@ -171,26 +138,12 @@ impl<'a> ArrayDeserializer<'a> { ) .into()), }, - ArrayView::Duration(view) => Ok(Self::I64(IntegerDeserializer::new( - view.values, - view.validity, - ))), - ArrayView::Utf8(view) => Ok(Self::Utf8(StringDeserializer::new( - view.data, - view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), - ))), - ArrayView::LargeUtf8(view) => Ok(Self::LargeUtf8(StringDeserializer::new( - view.data, - view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), - ))), + V::Duration(view) => Ok(D::I64(IntegerDeserializer::new(PrimitiveArrayView { + values: view.values, + validity: view.validity, + }))), + V::Utf8(view) => Ok(D::Utf8(StringDeserializer::new(view))), + V::LargeUtf8(view) => Ok(D::LargeUtf8(StringDeserializer::new(view))), V::Binary(view) => Ok(D::Binary(BinaryDeserializer::new(view))), V::LargeBinary(view) => Ok(D::LargeBinary(BinaryDeserializer::new(view))), ArrayView::FixedSizeBinary(view) => { diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 82754609..36dba2a1 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -1,6 +1,6 @@ use serde::de::Visitor; -use crate::internal::{arrow::BitsWithOffset, error::Result, utils::Mut}; +use crate::internal::{arrow::PrimitiveArrayView, error::Result, utils::Mut}; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -17,8 +17,8 @@ pub trait Float: Copy { pub struct FloatDeserializer<'a, F: Float>(ArrayBufferIterator<'a, F>); impl<'a, F: Float> FloatDeserializer<'a, F> { - pub fn new(buffer: &'a [F], validity: Option>) -> Self { - Self(ArrayBufferIterator::new(buffer, validity)) + pub fn new(view: PrimitiveArrayView<'a, F>) -> Self { + Self(ArrayBufferIterator::new(view.values, view.validity)) } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index 82972558..4c5afe1a 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -1,6 +1,6 @@ use serde::de::Visitor; -use crate::internal::{arrow::BitsWithOffset, error::Result, utils::Mut}; +use crate::internal::{arrow::PrimitiveArrayView, error::Result, utils::Mut}; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -26,8 +26,8 @@ pub trait Integer: Sized + Copy { pub struct IntegerDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>); impl<'a, T: Integer> IntegerDeserializer<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>) -> Self { - Self(ArrayBufferIterator::new(buffer, validity)) + pub fn new(view: PrimitiveArrayView<'a, T>) -> Self { + Self(ArrayBufferIterator::new(view.values, view.validity)) } } diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index 64b4d22c..ab03e2d8 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -1,43 +1,37 @@ use crate::internal::{ + arrow::BytesArrayView, error::{error, fail, Result}, utils::{Mut, Offset}, }; use super::{ - enums_as_string_impl::EnumAccess, simple_deserializer::SimpleDeserializer, utils::BitBuffer, + enums_as_string_impl::EnumAccess, simple_deserializer::SimpleDeserializer, utils::bitset_is_set, }; pub struct StringDeserializer<'a, O: Offset> { - pub data: &'a [u8], - pub offsets: &'a [O], - pub validity: Option>, + pub view: BytesArrayView<'a, O>, pub next: usize, } impl<'a, O: Offset> StringDeserializer<'a, O> { - pub fn new(data: &'a [u8], offsets: &'a [O], validity: Option>) -> Self { - Self { - data, - offsets, - validity, - next: 0, - } + pub fn new(view: BytesArrayView<'a, O>) -> Self { + Self { view, next: 0 } } pub fn next(&mut self) -> Result> { - if self.next + 1 > self.offsets.len() { + if self.next + 1 > self.view.offsets.len() { fail!("Tried to deserialize a value from an exhausted StringDeserializer"); } - if let Some(validity) = &self.validity { - if !validity.is_set(self.next) { + if let Some(validity) = &self.view.validity { + if !bitset_is_set(validity, self.next)? { return Ok(None); } } - let start = self.offsets[self.next].try_into_usize()?; - let end = self.offsets[self.next + 1].try_into_usize()?; - let s = std::str::from_utf8(&self.data[start..end])?; + let start = self.view.offsets[self.next].try_into_usize()?; + let end = self.view.offsets[self.next + 1].try_into_usize()?; + let s = std::str::from_utf8(&self.view.data[start..end])?; self.next += 1; @@ -51,12 +45,12 @@ impl<'a, O: Offset> StringDeserializer<'a, O> { } pub fn peek_next(&self) -> Result { - if self.next + 1 > self.offsets.len() { + if self.next + 1 > self.view.offsets.len() { fail!("Tried to deserialize a value from an exhausted StringDeserializer"); } - if let Some(validity) = &self.validity { - if !validity.is_set(self.next) { + if let Some(validity) = &self.view.validity { + if !bitset_is_set(validity, self.next)? { return Ok(false); } } diff --git a/serde_arrow/src/internal/deserialization/test.rs b/serde_arrow/src/internal/deserialization/test.rs index cef01eae..0fd19eef 100644 --- a/serde_arrow/src/internal/deserialization/test.rs +++ b/serde_arrow/src/internal/deserialization/test.rs @@ -1,6 +1,9 @@ use serde::Deserialize; -use crate::internal::{deserialization::integer_deserializer::IntegerDeserializer, utils::Mut}; +use crate::internal::{ + arrow::PrimitiveArrayView, deserialization::integer_deserializer::IntegerDeserializer, + utils::Mut, +}; use super::outer_sequence_deserializer::OuterSequenceDeserializer; @@ -10,11 +13,19 @@ fn example() { vec![ ( String::from("a"), - IntegerDeserializer::new(&[1, 2, 3], None).into(), + IntegerDeserializer::new(PrimitiveArrayView { + values: &[1, 2, 3], + validity: None, + }) + .into(), ), ( String::from("b"), - IntegerDeserializer::new(&[4, 5, 6], None).into(), + IntegerDeserializer::new(PrimitiveArrayView { + values: &[4, 5, 6], + validity: None, + }) + .into(), ), ], 3, From 5be7f0f3bfe376599dd240f899f041001c107595 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 12:35:19 +0200 Subject: [PATCH 079/178] Use views in {Decimal,Dictionary,Time}Deserializers --- .../deserialization/array_deserializer.rs | 149 +++++++----------- .../deserialization/decimal_deserializer.rs | 8 +- .../dictionary_deserializer.rs | 21 ++- .../deserialization/time_deserializer.rs | 8 +- 4 files changed, 74 insertions(+), 112 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index ba9c1b04..0a870aaf 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -2,32 +2,23 @@ use half::f16; use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ - arrow::{ArrayView, BitsWithOffset, BytesArrayView, FieldMeta, PrimitiveArrayView, TimeUnit}, + arrow::{ArrayView, BitsWithOffset, FieldMeta, PrimitiveArrayView, TimeUnit}, error::{fail, Error, Result}, schema::{Strategy, STRATEGY_KEY}, - utils::{Mut, Offset}, + utils::Mut, }; use super::{ - binary_deserializer::BinaryDeserializer, - bool_deserializer::BoolDeserializer, - date32_deserializer::Date32Deserializer, - date64_deserializer::Date64Deserializer, - decimal_deserializer::DecimalDeserializer, - dictionary_deserializer::DictionaryDeserializer, + binary_deserializer::BinaryDeserializer, bool_deserializer::BoolDeserializer, + date32_deserializer::Date32Deserializer, date64_deserializer::Date64Deserializer, + decimal_deserializer::DecimalDeserializer, dictionary_deserializer::DictionaryDeserializer, enum_deserializer::EnumDeserializer, fixed_size_binary_deserializer::FixedSizeBinaryDeserializer, - fixed_size_list_deserializer::FixedSizeListDeserializer, - float_deserializer::FloatDeserializer, - integer_deserializer::{Integer, IntegerDeserializer}, - list_deserializer::ListDeserializer, - map_deserializer::MapDeserializer, - null_deserializer::NullDeserializer, - simple_deserializer::SimpleDeserializer, - string_deserializer::StringDeserializer, - struct_deserializer::StructDeserializer, - time_deserializer::TimeDeserializer, - utils::BitBuffer, + fixed_size_list_deserializer::FixedSizeListDeserializer, float_deserializer::FloatDeserializer, + integer_deserializer::IntegerDeserializer, list_deserializer::ListDeserializer, + map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, + simple_deserializer::SimpleDeserializer, string_deserializer::StringDeserializer, + struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::BitBuffer, }; pub enum ArrayDeserializer<'a> { @@ -95,11 +86,7 @@ impl<'a> ArrayDeserializer<'a> { V::Float16(view) => Ok(D::F16(FloatDeserializer::new(view))), V::Float32(view) => Ok(D::F32(FloatDeserializer::new(view))), V::Float64(view) => Ok(D::F64(FloatDeserializer::new(view))), - ArrayView::Decimal128(view) => Ok(Self::Decimal128(DecimalDeserializer::new( - view.values, - view.validity, - view.scale, - ))), + V::Decimal128(view) => Ok(D::Decimal128(DecimalDeserializer::new(view))), ArrayView::Date32(view) => Ok(Self::Date32(Date32Deserializer::new( view.values, view.validity, @@ -110,16 +97,8 @@ impl<'a> ArrayDeserializer<'a> { TimeUnit::Millisecond, is_utc_date64(strategy)?, ))), - ArrayView::Time32(view) => Ok(Self::Time32(TimeDeserializer::new( - view.values, - view.validity, - view.unit, - ))), - ArrayView::Time64(view) => Ok(Self::Time64(TimeDeserializer::new( - view.values, - view.validity, - view.unit, - ))), + V::Time32(view) => Ok(D::Time32(TimeDeserializer::new(view))), + V::Time64(view) => Ok(D::Time64(TimeDeserializer::new(view))), ArrayView::Timestamp(view) => match strategy { Some(Strategy::NaiveStrAsDate64 | Strategy::UtcStrAsDate64) => { Ok(Self::Date64(Date64Deserializer::new( @@ -221,55 +200,55 @@ impl<'a> ArrayDeserializer<'a> { ), )?)) } - ArrayView::Dictionary(view) => match (*view.indices, *view.values) { - (ArrayView::Int8(keys), ArrayView::Utf8(values)) => { - Ok(Self::DictionaryI8I32(build_dictionary_array(keys, values)?)) - } - (ArrayView::Int16(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryI16I32( - build_dictionary_array(keys, values)?, + V::Dictionary(view) => match (*view.indices, *view.values) { + (V::Int8(keys), V::Utf8(values)) => Ok(D::DictionaryI8I32( + DictionaryDeserializer::new(keys, values)?, )), - (ArrayView::Int32(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryI32I32( - build_dictionary_array(keys, values)?, + (V::Int16(keys), V::Utf8(values)) => Ok(D::DictionaryI16I32( + DictionaryDeserializer::new(keys, values)?, )), - (ArrayView::Int64(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryI64I32( - build_dictionary_array(keys, values)?, + (V::Int32(keys), V::Utf8(values)) => Ok(D::DictionaryI32I32( + DictionaryDeserializer::new(keys, values)?, )), - (ArrayView::UInt8(keys), ArrayView::Utf8(values)) => { - Ok(Self::DictionaryU8I32(build_dictionary_array(keys, values)?)) - } - (ArrayView::UInt16(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryU16I32( - build_dictionary_array(keys, values)?, + (V::Int64(keys), V::Utf8(values)) => Ok(D::DictionaryI64I32( + DictionaryDeserializer::new(keys, values)?, )), - (ArrayView::UInt32(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryU32I32( - build_dictionary_array(keys, values)?, + (V::UInt8(keys), V::Utf8(values)) => Ok(Self::DictionaryU8I32( + DictionaryDeserializer::new(keys, values)?, )), - (ArrayView::UInt64(keys), ArrayView::Utf8(values)) => Ok(Self::DictionaryU64I32( - build_dictionary_array(keys, values)?, + (V::UInt16(keys), V::Utf8(values)) => Ok(D::DictionaryU16I32( + DictionaryDeserializer::new(keys, values)?, + )), + (V::UInt32(keys), V::Utf8(values)) => Ok(D::DictionaryU32I32( + DictionaryDeserializer::new(keys, values)?, + )), + (V::UInt64(keys), V::Utf8(values)) => Ok(D::DictionaryU64I32( + DictionaryDeserializer::new(keys, values)?, + )), + (V::Int8(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI8I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::Int16(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI16I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::Int32(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI32I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::Int64(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI64I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::UInt8(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU8I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::UInt16(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU16I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::UInt32(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU32I64( + DictionaryDeserializer::new(keys, values)?, + )), + (V::UInt64(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU64I64( + DictionaryDeserializer::new(keys, values)?, )), - (ArrayView::Int8(keys), ArrayView::LargeUtf8(values)) => { - Ok(Self::DictionaryI8I64(build_dictionary_array(keys, values)?)) - } - (ArrayView::Int16(keys), ArrayView::LargeUtf8(values)) => Ok( - Self::DictionaryI16I64(build_dictionary_array(keys, values)?), - ), - (ArrayView::Int32(keys), ArrayView::LargeUtf8(values)) => Ok( - Self::DictionaryI32I64(build_dictionary_array(keys, values)?), - ), - (ArrayView::Int64(keys), ArrayView::LargeUtf8(values)) => Ok( - Self::DictionaryI64I64(build_dictionary_array(keys, values)?), - ), - (ArrayView::UInt8(keys), ArrayView::LargeUtf8(values)) => { - Ok(Self::DictionaryU8I64(build_dictionary_array(keys, values)?)) - } - (ArrayView::UInt16(keys), ArrayView::LargeUtf8(values)) => Ok( - Self::DictionaryU16I64(build_dictionary_array(keys, values)?), - ), - (ArrayView::UInt32(keys), ArrayView::LargeUtf8(values)) => Ok( - Self::DictionaryU32I64(build_dictionary_array(keys, values)?), - ), - (ArrayView::UInt64(keys), ArrayView::LargeUtf8(values)) => Ok( - Self::DictionaryU64I64(build_dictionary_array(keys, values)?), - ), _ => fail!("unsupported dictionary array"), }, ArrayView::DenseUnion(view) => { @@ -286,22 +265,6 @@ impl<'a> ArrayDeserializer<'a> { } } -fn build_dictionary_array<'a, K: Integer, V: Offset>( - keys: PrimitiveArrayView<'a, K>, - values: BytesArrayView<'a, V>, -) -> Result> { - if values.validity.is_some() { - // TODO: check whether all values are defined? - fail!("dictionaries with nullable values are not supported"); - } - Ok(DictionaryDeserializer::new( - keys.values, - keys.validity, - values.data, - values.offsets, - )) -} - fn is_utc_timestamp(timezone: Option<&str>) -> Result { match timezone { Some(tz) if tz.to_lowercase() == "utc" => Ok(true), diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index cb1789e0..11b62a45 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -1,7 +1,7 @@ use serde::de::Visitor; use crate::internal::{ - arrow::BitsWithOffset, + arrow::DecimalArrayView, error::Result, utils::{decimal, Mut}, }; @@ -14,10 +14,10 @@ pub struct DecimalDeserializer<'a> { } impl<'a> DecimalDeserializer<'a> { - pub fn new(buffer: &'a [i128], validity: Option>, scale: i8) -> Self { + pub fn new(view: DecimalArrayView<'a, i128>) -> Self { Self { - inner: ArrayBufferIterator::new(buffer, validity), - scale, + inner: ArrayBufferIterator::new(view.values, view.validity), + scale: view.scale, } } } diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index 5f6a2829..ef9bb397 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -1,7 +1,7 @@ use serde::de::Visitor; use crate::internal::{ - arrow::BitsWithOffset, + arrow::{BytesArrayView, PrimitiveArrayView}, error::{fail, Result}, utils::{Mut, Offset}, }; @@ -18,17 +18,16 @@ pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> { } impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { - pub fn new( - keys_buffer: &'a [K], - keys_validity: Option>, - data: &'a [u8], - offsets: &'a [V], - ) -> Self { - Self { - keys: ArrayBufferIterator::new(keys_buffer, keys_validity), - offsets, - data, + pub fn new(keys: PrimitiveArrayView<'a, K>, values: BytesArrayView<'a, V>) -> Result { + if values.validity.is_some() { + // TODO: check whether all values are defined? + fail!("dictionaries with nullable values are not supported"); } + Ok(Self { + keys: ArrayBufferIterator::new(keys.values, keys.validity), + offsets: values.offsets, + data: values.data, + }) } pub fn next_str(&mut self) -> Result<&str> { diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index f474311e..ce5ae075 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -2,7 +2,7 @@ use chrono::NaiveTime; use serde::de::Visitor; use crate::internal::{ - arrow::{BitsWithOffset, TimeUnit}, + arrow::{TimeArrayView, TimeUnit}, error::{fail, Result}, utils::Mut, }; @@ -15,8 +15,8 @@ use super::{ pub struct TimeDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>, i64, i64); impl<'a, T: Integer> TimeDeserializer<'a, T> { - pub fn new(buffer: &'a [T], validity: Option>, unit: TimeUnit) -> Self { - let (seconds_factor, nanoseconds_factor) = match unit { + pub fn new(view: TimeArrayView<'a, T>) -> Self { + let (seconds_factor, nanoseconds_factor) = match view.unit { TimeUnit::Nanosecond => (1_000_000_000, 1), TimeUnit::Microsecond => (1_000_000, 1_000), TimeUnit::Millisecond => (1_000, 1_000_000), @@ -24,7 +24,7 @@ impl<'a, T: Integer> TimeDeserializer<'a, T> { }; Self( - ArrayBufferIterator::new(buffer, validity), + ArrayBufferIterator::new(view.values, view.validity), seconds_factor, nanoseconds_factor, ) From d3d44d38db6c5d2dd7cf11e5c4af7d8eccdc9356 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 13:45:02 +0200 Subject: [PATCH 080/178] Remove BitBuffer --- .../deserialization/array_deserializer.rs | 336 ++---------------- .../fixed_size_binary_deserializer.rs | 36 +- .../fixed_size_list_deserializer.rs | 23 +- .../deserialization/list_deserializer.rs | 9 +- .../deserialization/map_deserializer.rs | 9 +- .../deserialization/struct_deserializer.rs | 9 +- .../src/internal/deserialization/test.rs | 15 +- .../src/internal/deserialization/utils.rs | 23 +- 8 files changed, 82 insertions(+), 378 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 0a870aaf..6c8b86c6 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -2,7 +2,7 @@ use half::f16; use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ - arrow::{ArrayView, BitsWithOffset, FieldMeta, PrimitiveArrayView, TimeUnit}, + arrow::{ArrayView, FieldMeta, PrimitiveArrayView, TimeUnit}, error::{fail, Error, Result}, schema::{Strategy, STRATEGY_KEY}, utils::Mut, @@ -18,7 +18,7 @@ use super::{ integer_deserializer::IntegerDeserializer, list_deserializer::ListDeserializer, map_deserializer::MapDeserializer, null_deserializer::NullDeserializer, simple_deserializer::SimpleDeserializer, string_deserializer::StringDeserializer, - struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, utils::BitBuffer, + struct_deserializer::StructDeserializer, time_deserializer::TimeDeserializer, }; pub enum ArrayDeserializer<'a> { @@ -109,13 +109,12 @@ impl<'a> ArrayDeserializer<'a> { ))) } Some(strategy) => fail!("invalid strategy {strategy} for timestamp field"), - None => Ok(Date64Deserializer::new( + None => Ok(Self::Date64(Date64Deserializer::new( view.values, view.validity, view.unit, is_utc_timestamp(view.timezone.as_deref())?, - ) - .into()), + ))), }, V::Duration(view) => Ok(D::I64(IntegerDeserializer::new(PrimitiveArrayView { values: view.values, @@ -125,44 +124,26 @@ impl<'a> ArrayDeserializer<'a> { V::LargeUtf8(view) => Ok(D::LargeUtf8(StringDeserializer::new(view))), V::Binary(view) => Ok(D::Binary(BinaryDeserializer::new(view))), V::LargeBinary(view) => Ok(D::LargeBinary(BinaryDeserializer::new(view))), - ArrayView::FixedSizeBinary(view) => { - let value_length: usize = view.n.try_into()?; - if view.data.len() % value_length != 0 { - fail!("Invalid FixedSizeBinary array: Data is not evenly divisible into chunks of size {value_length}"); - } - let len = view.data.len() / value_length; - - Ok(Self::FixedSizeBinary(FixedSizeBinaryDeserializer::new( - (len, value_length), - view.data, - buffer_from_bits_with_offset_opt(view.validity, len), - ))) + V::FixedSizeBinary(view) => { + Ok(D::FixedSizeBinary(FixedSizeBinaryDeserializer::new(view)?)) } - ArrayView::List(view) => Ok(Self::List(ListDeserializer::new( + V::List(view) => Ok(D::List(ListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), + view.validity, )?)), - ArrayView::LargeList(view) => Ok(Self::LargeList(ListDeserializer::new( + V::LargeList(view) => Ok(D::LargeList(ListDeserializer::new( ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), + view.validity, )?)), - ArrayView::FixedSizeList(view) => { - Ok(Self::FixedSizeList(FixedSizeListDeserializer::new( - ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, - buffer_from_bits_with_offset_opt(view.validity, view.len), - view.n.try_into()?, - view.len, - ))) - } - ArrayView::Struct(view) => { + V::FixedSizeList(view) => Ok(D::FixedSizeList(FixedSizeListDeserializer::new( + ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, + view.validity, + view.n.try_into()?, + view.len, + ))), + V::Struct(view) => { let mut fields = Vec::new(); for (field_view, field_meta) in view.fields { let field_deserializer = @@ -172,13 +153,13 @@ impl<'a> ArrayDeserializer<'a> { fields.push((field_name, field_deserializer)); } - Ok(Self::Struct(StructDeserializer::new( + Ok(D::Struct(StructDeserializer::new( fields, - buffer_from_bits_with_offset_opt(view.validity, view.len), + view.validity, view.len, ))) } - ArrayView::Map(view) => { + V::Map(view) => { let ArrayView::Struct(entries_view) = *view.element else { fail!("invalid entries field in map array"); }; @@ -190,14 +171,11 @@ impl<'a> ArrayDeserializer<'a> { let values = ArrayDeserializer::new(get_strategy(&values_meta)?.as_ref(), values_view)?; - Ok(Self::Map(MapDeserializer::new( + Ok(D::Map(MapDeserializer::new( keys, values, view.offsets, - buffer_from_bits_with_offset_opt( - view.validity, - view.offsets.len().saturating_sub(1), - ), + view.validity, )?)) } V::Dictionary(view) => match (*view.indices, *view.values) { @@ -288,276 +266,6 @@ fn get_strategy(meta: &FieldMeta) -> Result> { Ok(Some(strategy.parse()?)) } -fn buffer_from_bits_with_offset(bits: BitsWithOffset, len: usize) -> BitBuffer { - BitBuffer { - data: bits.data, - offset: bits.offset, - number_of_bits: len, - } -} - -fn buffer_from_bits_with_offset_opt(bits: Option, len: usize) -> Option { - Some(buffer_from_bits_with_offset(bits?, len)) -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, i8>) -> Self { - Self::I8(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, i16>) -> Self { - Self::I16(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, i32>) -> Self { - Self::I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, i64>) -> Self { - Self::I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, u8>) -> Self { - Self::U8(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, u16>) -> Self { - Self::U16(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, u32>) -> Self { - Self::U32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: IntegerDeserializer<'a, u64>) -> Self { - Self::U64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: FloatDeserializer<'a, f16>) -> Self { - Self::F16(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: FloatDeserializer<'a, f32>) -> Self { - Self::F32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: FloatDeserializer<'a, f64>) -> Self { - Self::F64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DecimalDeserializer<'a>) -> Self { - Self::Decimal128(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: Date32Deserializer<'a>) -> Self { - Self::Date32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: Date64Deserializer<'a>) -> Self { - Self::Date64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: TimeDeserializer<'a, i32>) -> Self { - Self::Time32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: TimeDeserializer<'a, i64>) -> Self { - Self::Time64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: StructDeserializer<'a>) -> Self { - Self::Struct(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: ListDeserializer<'a, i32>) -> Self { - Self::List(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: ListDeserializer<'a, i64>) -> Self { - Self::LargeList(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: FixedSizeListDeserializer<'a>) -> Self { - Self::FixedSizeList(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: BinaryDeserializer<'a, i32>) -> Self { - Self::Binary(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: BinaryDeserializer<'a, i64>) -> Self { - Self::LargeBinary(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: FixedSizeBinaryDeserializer<'a>) -> Self { - Self::FixedSizeBinary(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: StringDeserializer<'a, i32>) -> Self { - Self::Utf8(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: StringDeserializer<'a, i64>) -> Self { - Self::LargeUtf8(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u8, i32>) -> Self { - Self::DictionaryU8I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u16, i32>) -> Self { - Self::DictionaryU16I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u32, i32>) -> Self { - Self::DictionaryU32I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u64, i32>) -> Self { - Self::DictionaryU64I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i8, i32>) -> Self { - Self::DictionaryI8I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i16, i32>) -> Self { - Self::DictionaryI16I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i32, i32>) -> Self { - Self::DictionaryI32I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i64, i32>) -> Self { - Self::DictionaryI64I32(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u8, i64>) -> Self { - Self::DictionaryU8I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u16, i64>) -> Self { - Self::DictionaryU16I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u32, i64>) -> Self { - Self::DictionaryU32I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, u64, i64>) -> Self { - Self::DictionaryU64I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i8, i64>) -> Self { - Self::DictionaryI8I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i16, i64>) -> Self { - Self::DictionaryI16I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i32, i64>) -> Self { - Self::DictionaryI32I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: DictionaryDeserializer<'a, i64, i64>) -> Self { - Self::DictionaryI64I64(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: MapDeserializer<'a>) -> Self { - Self::Map(value) - } -} - -impl<'a> From> for ArrayDeserializer<'a> { - fn from(value: EnumDeserializer<'a>) -> Self { - Self::Enum(value) - } -} - macro_rules! dispatch { ($obj:expr, $wrapper:ident($name:ident) => $expr:expr) => { match $obj { diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index b2cdc611..757a6989 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -1,35 +1,47 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ + arrow::FixedSizeBinaryArrayView, error::{fail, Error, Result}, utils::Mut, }; -use super::{simple_deserializer::SimpleDeserializer, utils::BitBuffer}; +use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; pub struct FixedSizeBinaryDeserializer<'a> { - pub buffer: &'a [u8], - pub validity: Option>, + pub view: FixedSizeBinaryArrayView<'a>, pub next: (usize, usize), pub shape: (usize, usize), } impl<'a> FixedSizeBinaryDeserializer<'a> { - pub fn new(shape: (usize, usize), buffer: &'a [u8], validity: Option>) -> Self { - Self { - buffer, - validity, + pub fn new(view: FixedSizeBinaryArrayView<'a>) -> Result { + let n = usize::try_from(view.n)?; + if view.data.len() % n != 0 { + fail!( + concat!( + "Invalid FixedSizeBinary array: Data of len {len} is not ", + "evenly divisible into chunks of size {n}", + ), + len = view.data.len(), + n = n, + ); + } + + let shape = (view.data.len() / n, n); + Ok(Self { + view, shape, next: (0, 0), - } + }) } pub fn peek_next(&self) -> Result { if self.next.0 >= self.shape.0 { fail!("Exhausted ListDeserializer") } - if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next.0)) + if let Some(validity) = &self.view.validity { + Ok(bitset_is_set(validity, self.next.0)?) } else { Ok(true) } @@ -46,7 +58,7 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { } self.next = (item + 1, 0); - Ok(&self.buffer[item * self.shape.1..(item + 1) * self.shape.1]) + Ok(&self.view.data[item * self.shape.1..(item + 1) * self.shape.1]) } } @@ -102,7 +114,7 @@ impl<'de> SeqAccess<'de> for FixedSizeBinaryDeserializer<'de> { return Ok(None); } self.next = (item, offset + 1); - let mut item_deserializer = U8Deserializer(self.buffer[item * self.shape.1 + offset]); + let mut item_deserializer = U8Deserializer(self.view.data[item * self.shape.1 + offset]); let item = seed.deserialize(Mut(&mut item_deserializer))?; Ok(Some(item)) } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index 28e2a1e1..e146b309 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -1,52 +1,51 @@ use serde::de::{IgnoredAny, SeqAccess, Visitor}; use crate::internal::{ + arrow::BitsWithOffset, error::{fail, Error, Result}, utils::Mut, }; use super::{ array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer, - utils::BitBuffer, + utils::bitset_is_set, }; pub struct FixedSizeListDeserializer<'a> { pub item: Box>, - pub validity: Option>, - pub n: usize, - pub len: usize, + pub validity: Option>, + pub shape: (usize, usize), pub next: (usize, usize), } impl<'a> FixedSizeListDeserializer<'a> { pub fn new( item: ArrayDeserializer<'a>, - validity: Option>, + validity: Option>, n: usize, len: usize, ) -> Self { Self { item: Box::new(item), validity, - n, - len, + shape: (len, n), next: (0, 0), } } pub fn peek_next(&self) -> Result { - if self.next.0 >= self.len { + if self.next.0 >= self.shape.0 { fail!("Exhausted ListDeserializer") } if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next.0)) + Ok(bitset_is_set(validity, self.next.0)?) } else { Ok(true) } } pub fn consume_next(&mut self) -> Result<()> { - for _ in 0..self.n { + for _ in 0..self.shape.1 { self.item.deserialize_ignored_any(IgnoredAny)?; } @@ -91,11 +90,11 @@ impl<'de> SeqAccess<'de> for FixedSizeListDeserializer<'de> { seed: T, ) -> Result> { let (item, offset) = self.next; - if item >= self.len { + if item >= self.shape.0 { return Ok(None); } - if offset >= self.n { + if offset >= self.shape.1 { self.next = (item + 1, 0); return Ok(None); } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index ce4fccb7..ebf45562 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -1,6 +1,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ + arrow::BitsWithOffset, error::{fail, Error, Result}, utils::{Mut, Offset}, }; @@ -8,13 +9,13 @@ use crate::internal::{ use super::{ array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer, - utils::{check_supported_list_layout, BitBuffer}, + utils::{bitset_is_set, check_supported_list_layout}, }; pub struct ListDeserializer<'a, O: Offset> { pub item: Box>, pub offsets: &'a [O], - pub validity: Option>, + pub validity: Option>, pub next: (usize, usize), } @@ -22,7 +23,7 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { pub fn new( item: ArrayDeserializer<'a>, offsets: &'a [O], - validity: Option>, + validity: Option>, ) -> Result { check_supported_list_layout(validity, offsets)?; @@ -39,7 +40,7 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { fail!("Exhausted ListDeserializer") } if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next.0)) + Ok(bitset_is_set(validity, self.next.0)?) } else { Ok(true) } diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 10f18c8c..fde9e0a0 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -1,6 +1,7 @@ use serde::de::{DeserializeSeed, MapAccess, Visitor}; use crate::internal::{ + arrow::BitsWithOffset, error::{fail, Error, Result}, utils::Mut, }; @@ -8,14 +9,14 @@ use crate::internal::{ use super::{ array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer, - utils::{check_supported_list_layout, BitBuffer}, + utils::{bitset_is_set, check_supported_list_layout}, }; pub struct MapDeserializer<'a> { key: Box>, value: Box>, offsets: &'a [i32], - validity: Option>, + validity: Option>, next: (usize, usize), } @@ -24,7 +25,7 @@ impl<'a> MapDeserializer<'a> { key: ArrayDeserializer<'a>, value: ArrayDeserializer<'a>, offsets: &'a [i32], - validity: Option>, + validity: Option>, ) -> Result { check_supported_list_layout(validity, offsets)?; @@ -42,7 +43,7 @@ impl<'a> MapDeserializer<'a> { fail!("Exhausted ListDeserializer") } if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next.0)) + Ok(bitset_is_set(validity, self.next.0)?) } else { Ok(true) } diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index 24f4efcf..df7b1879 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -3,18 +3,19 @@ use serde::de::{ }; use crate::internal::{ + arrow::BitsWithOffset, error::{fail, Error, Result}, utils::Mut, }; use super::{ array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer, - utils::BitBuffer, + utils::bitset_is_set, }; pub struct StructDeserializer<'a> { pub fields: Vec<(String, ArrayDeserializer<'a>)>, - pub validity: Option>, + pub validity: Option>, pub next: (usize, usize), pub len: usize, } @@ -22,7 +23,7 @@ pub struct StructDeserializer<'a> { impl<'a> StructDeserializer<'a> { pub fn new( fields: Vec<(String, ArrayDeserializer<'a>)>, - validity: Option>, + validity: Option>, len: usize, ) -> Self { Self { @@ -38,7 +39,7 @@ impl<'a> StructDeserializer<'a> { fail!("Exhausted StructDeserializer"); } if let Some(validity) = &self.validity { - Ok(validity.is_set(self.next.0)) + Ok(bitset_is_set(validity, self.next.0)?) } else { Ok(true) } diff --git a/serde_arrow/src/internal/deserialization/test.rs b/serde_arrow/src/internal/deserialization/test.rs index 0fd19eef..362f4fee 100644 --- a/serde_arrow/src/internal/deserialization/test.rs +++ b/serde_arrow/src/internal/deserialization/test.rs @@ -1,7 +1,10 @@ use serde::Deserialize; use crate::internal::{ - arrow::PrimitiveArrayView, deserialization::integer_deserializer::IntegerDeserializer, + arrow::PrimitiveArrayView, + deserialization::{ + array_deserializer::ArrayDeserializer, integer_deserializer::IntegerDeserializer, + }, utils::Mut, }; @@ -13,19 +16,17 @@ fn example() { vec![ ( String::from("a"), - IntegerDeserializer::new(PrimitiveArrayView { + ArrayDeserializer::I32(IntegerDeserializer::new(PrimitiveArrayView { values: &[1, 2, 3], validity: None, - }) - .into(), + })), ), ( String::from("b"), - IntegerDeserializer::new(PrimitiveArrayView { + ArrayDeserializer::I32(IntegerDeserializer::new(PrimitiveArrayView { values: &[4, 5, 6], validity: None, - }) - .into(), + })), ), ], 3, diff --git a/serde_arrow/src/internal/deserialization/utils.rs b/serde_arrow/src/internal/deserialization/utils.rs index ac8cf173..6f4e21b4 100644 --- a/serde_arrow/src/internal/deserialization/utils.rs +++ b/serde_arrow/src/internal/deserialization/utils.rs @@ -12,25 +12,6 @@ pub fn bitset_is_set(set: &BitsWithOffset<'_>, idx: usize) -> Result { Ok(byte & flag == flag) } -#[derive(Debug, PartialEq, Clone, Copy)] -pub struct BitBuffer<'a> { - pub data: &'a [u8], - pub offset: usize, - pub number_of_bits: usize, -} - -impl<'a> BitBuffer<'a> { - pub fn is_set(&self, idx: usize) -> bool { - let flag = 1 << ((idx + self.offset) % 8); - let byte = self.data[(idx + self.offset) / 8]; - byte & flag == flag - } - - pub fn len(&self) -> usize { - self.number_of_bits - } -} - pub struct ArrayBufferIterator<'a, T: Copy> { pub buffer: &'a [T], pub validity: Option>, @@ -94,7 +75,7 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { /// /// [arrow format spec]: https://arrow.apache.org/docs/format/Columnar.html#variable-size-list-layout pub fn check_supported_list_layout<'a, O: Offset>( - validity: Option>, + validity: Option>, offsets: &'a [O], ) -> Result<()> { let Some(validity) = validity else { @@ -108,7 +89,7 @@ pub fn check_supported_list_layout<'a, O: Offset>( for i in 0..offsets.len().saturating_sub(1) { let curr = offsets[i].try_into_usize()?; let next = offsets[i + 1].try_into_usize()?; - if !validity.is_set(i) && (next - curr) != 0 { + if !bitset_is_set(&validity, i)? && (next - curr) != 0 { fail!("lists with data in null values are currently not supported in deserialization"); } } From b5f43c1076de9d81c273e0acd49a397d7b3bbe17 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 13:53:06 +0200 Subject: [PATCH 081/178] Change unused markers --- serde_arrow/src/internal/error.rs | 1 - serde_arrow/src/lib.rs | 2 +- .../src/test_with_arrow/impls/issue_90_type_tracing.rs | 5 ++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 441596c9..610cd1ab 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -203,7 +203,6 @@ impl From for Error { pub type PanicOnError = std::result::Result; /// An error type for testing, that panics once an error is converted -#[allow(unused)] #[derive(Debug)] pub struct PanicOnErrorError; diff --git a/serde_arrow/src/lib.rs b/serde_arrow/src/lib.rs index 364e4ccf..b3a85ce5 100644 --- a/serde_arrow/src/lib.rs +++ b/serde_arrow/src/lib.rs @@ -165,7 +165,7 @@ //! | `arrow2-0-16` | `arrow2=0.16` | // be more forgiving without any active implementation -#[cfg_attr(all(not(has_arrow), not(has_arrow2)), allow(unused))] +#[cfg_attr(not(any(has_arrow, has_arrow2)), allow(unused))] mod internal; /// *Internal. Do not use* diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index f8233491..d025a1ce 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -16,17 +16,16 @@ fn trace_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> F { #[test] fn issue_90() { + #[allow(unused)] #[derive(Deserialize)] pub struct Distribution { - #[allow(unused)] pub samples: Vec, - #[allow(unused)] pub statistic: String, } + #[allow(unused)] #[derive(Deserialize)] pub struct VectorMetric { - #[allow(unused)] pub distribution: Option, } From 3d343b9a9e3915bf46b2b0ee54b1cf0a0db05ae2 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 14:03:13 +0200 Subject: [PATCH 082/178] Remove unused BaseDataTypeDisplay --- serde_arrow/src/arrow_impl/schema.rs | 26 --------------- serde_arrow/src/internal/arrow/data_type.rs | 36 +-------------------- serde_arrow/src/internal/arrow/mod.rs | 3 +- 3 files changed, 3 insertions(+), 62 deletions(-) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 242440c1..7d3dcc00 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -387,29 +387,3 @@ impl From for ArrowTimeUnit { } } } - -impl TryFrom for DataType { - type Error = Error; - - fn try_from(value: crate::internal::arrow::DataType) -> Result { - use {crate::internal::arrow::DataType as DT, DataType as ArrowDT}; - - match value { - DT::Int8 => Ok(ArrowDT::Int8), - DT::Int16 => Ok(ArrowDT::Int16), - DT::Int32 => Ok(ArrowDT::Int32), - DT::Int64 => Ok(ArrowDT::Int64), - DT::UInt8 => Ok(ArrowDT::UInt8), - DT::UInt16 => Ok(ArrowDT::UInt16), - DT::UInt32 => Ok(ArrowDT::UInt32), - DT::UInt64 => Ok(ArrowDT::UInt64), - DT::Float16 => Ok(ArrowDT::Float16), - DT::Float32 => Ok(ArrowDT::Float32), - DT::Float64 => Ok(ArrowDT::Float64), - dt => fail!( - "{} not supported", - crate::internal::arrow::BaseDataTypeDisplay(&dt) - ), - } - } -} diff --git a/serde_arrow/src/internal/arrow/data_type.rs b/serde_arrow/src/internal/arrow/data_type.rs index a732e466..896b8106 100644 --- a/serde_arrow/src/internal/arrow/data_type.rs +++ b/serde_arrow/src/internal/arrow/data_type.rs @@ -11,6 +11,7 @@ pub struct Field { pub metadata: HashMap, } +#[allow(unused)] #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum DataType { @@ -42,41 +43,6 @@ pub enum DataType { LargeList(Box), } -pub struct BaseDataTypeDisplay<'a>(pub &'a DataType); - -impl<'a> std::fmt::Display for BaseDataTypeDisplay<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.0 { - DataType::Null => write!(f, "Null"), - DataType::Boolean => write!(f, "Boolean"), - DataType::Int8 => write!(f, "Int8"), - DataType::Int16 => write!(f, "Int16"), - DataType::Int32 => write!(f, "Int32"), - DataType::Int64 => write!(f, "Int64"), - DataType::UInt8 => write!(f, "UInt8"), - DataType::UInt16 => write!(f, "UInt16"), - DataType::UInt32 => write!(f, "UInt32"), - DataType::UInt64 => write!(f, "UInt64"), - DataType::Float16 => write!(f, "Float16"), - DataType::Float32 => write!(f, "Float32"), - DataType::Float64 => write!(f, "Float64"), - DataType::Utf8 => write!(f, "Utf8"), - DataType::LargeUtf8 => write!(f, "LargeUtf8"), - DataType::Binary => write!(f, "Binary"), - DataType::LargeBinary => write!(f, "LargeBinary"), - DataType::Date32 => write!(f, "Date32"), - DataType::Date64 => write!(f, "Date64"), - DataType::Timestamp(_, _) => write!(f, "Timestamp"), - DataType::Time32(_) => write!(f, "Time32"), - DataType::Time64(_) => write!(f, "Time64"), - DataType::Decimal128 => write!(f, "Decimal128"), - DataType::Struct(_) => write!(f, "Struct"), - DataType::List(_) => write!(f, "List"), - DataType::LargeList(_) => write!(f, "LargeList"), - } - } -} - #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] pub enum TimeUnit { Second, diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index d439e73e..98e6b749 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -11,4 +11,5 @@ pub use array::{ PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, }; -pub use data_type::{BaseDataTypeDisplay, DataType, TimeUnit}; +#[allow(unused)] +pub use data_type::{Field, DataType, TimeUnit}; From c674a61060985013dc23754d8937f2ea589582fc Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 14:36:58 +0200 Subject: [PATCH 083/178] Add serde tests for fields --- serde_arrow/src/internal/arrow/mod.rs | 2 +- serde_arrow/src/internal/schema/mod.rs | 3 +- serde_arrow/src/internal/schema/serde/mod.rs | 6 ++ .../src/internal/schema/serde/serialize.rs | 43 +++++++++ serde_arrow/src/internal/schema/serde/test.rs | 96 +++++++++++++++++++ 5 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 serde_arrow/src/internal/schema/serde/mod.rs create mode 100644 serde_arrow/src/internal/schema/serde/serialize.rs create mode 100644 serde_arrow/src/internal/schema/serde/test.rs diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 98e6b749..78c62c9a 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -12,4 +12,4 @@ pub use array::{ TimestampArray, TimestampArrayView, }; #[allow(unused)] -pub use data_type::{Field, DataType, TimeUnit}; +pub use data_type::{DataType, Field, TimeUnit}; diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 13b33ceb..ba20a418 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -4,6 +4,7 @@ mod data_type; mod deserialization; mod from_samples; mod from_type; +mod serde; mod strategy; pub mod tracer; mod tracing_options; @@ -19,7 +20,7 @@ use crate::internal::{ utils::value, }; -use serde::{Deserialize, Serialize}; +use ::serde::{Deserialize, Serialize}; pub use data_type::GenericDataType; pub use strategy::{ diff --git a/serde_arrow/src/internal/schema/serde/mod.rs b/serde_arrow/src/internal/schema/serde/mod.rs new file mode 100644 index 00000000..e08d9999 --- /dev/null +++ b/serde_arrow/src/internal/schema/serde/mod.rs @@ -0,0 +1,6 @@ +//! Group all serialization / deserialization related functionality +//! +mod serialize; + +#[cfg(test)] +mod test; diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs new file mode 100644 index 00000000..1579a560 --- /dev/null +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -0,0 +1,43 @@ +//! Serialize and deserialize a field split into + +use std::collections::HashMap; + +use serde::ser::SerializeStruct; + +use crate::{internal::schema::GenericField, schema::STRATEGY_KEY}; + +/*impl serde::Serialize for GenericField { + fn serialize(&self, serializer: S) -> Result { + let non_strategy_metadata = self.metadata.iter().filter(|(key, _)| *key != STRATEGY_KEY).collect::>(); + + let mut num_fields = 2; + if !non_strategy_metadata.is_empty() { + num_fields += 1; + } + if self.metadata.contains_key(STRATEGY_KEY) { + num_fields += 1; + } + if self.nullable { + num_fields += 1; + } + if !self.children.is_empty() { + num_fields += 1; + } + + let mut s = serializer.serialize_struct("Field", num_fields)?; + s.serialize_field("name", &self.name)?; + s.serialize_field("data_type", &self.data_type)?; + + if !non_strategy_metadata.is_empty() { + s.serialize_field("metadata", &non_strategy_metadata)?; + } + if let Some(strategy) = self.metadata.get(STRATEGY_KEY) { + s.serialize_field("strategy", strategy)?; + } + if !self.children.is_empty() { + s.serialize_field("children", &self.children)?; + } + s.end() + } +} +*/ diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs new file mode 100644 index 00000000..7362f172 --- /dev/null +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -0,0 +1,96 @@ +use serde_json::json; + +use crate::internal::{ + error::PanicOnError, + schema::{GenericDataType, GenericField, Strategy}, + testing::hash_map, +}; + +#[test] +fn i16_field_simple() -> PanicOnError<()> { + let field = GenericField { + name: String::from("my_field_name"), + data_type: GenericDataType::I16, + metadata: hash_map!(), + strategy: None, + nullable: false, + children: vec![], + }; + let expected = json!({ + "name": "my_field_name", + "data_type": "I16", + }); + + let actual = serde_json::to_value(&field)?; + assert_eq!(actual, expected); + + let roundtripped = serde_json::from_value::(actual)?; + assert_eq!(roundtripped, field); + + Ok(()) +} + +#[test] +fn date64_field_complex() -> PanicOnError<()> { + let field = GenericField { + name: String::from("my_field_name"), + data_type: GenericDataType::Date64, + metadata: hash_map!("foo" => "bar"), + strategy: Some(Strategy::NaiveStrAsDate64), + nullable: true, + children: vec![], + }; + let expected = json!({ + "name": "my_field_name", + "data_type": "Date64", + "metadata": { + "foo": "bar", + }, + "strategy": "NaiveStrAsDate64", + "nullable": true, + }); + + let actual = serde_json::to_value(&field)?; + assert_eq!(actual, expected); + + let roundtripped = serde_json::from_value::(actual)?; + assert_eq!(roundtripped, field); + + Ok(()) +} + +#[test] +fn list_field_complex() -> PanicOnError<()> { + let field = GenericField { + name: String::from("my_field_name"), + data_type: GenericDataType::List, + metadata: hash_map!("foo" => "bar"), + strategy: None, + nullable: true, + children: vec![GenericField { + name: String::from("element"), + data_type: GenericDataType::I64, + metadata: hash_map!(), + strategy: None, + nullable: false, + children: vec![], + }], + }; + let expected = json!({ + "name": "my_field_name", + "data_type": "List", + "metadata": {"foo": "bar"}, + "nullable": true, + "children": [ + {"name": "element", "data_type": "I64"}, + ] + }); + + let actual = serde_json::to_value(&field)?; + assert_eq!(actual, expected); + + let roundtripped = serde_json::from_value::(actual)?; + assert_eq!(roundtripped, field); + + Ok(()) +} From a1aaa61f7bca5b3df7199d7df633f22668448735 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 14:40:15 +0200 Subject: [PATCH 084/178] Add custom Serialize impl for Field --- serde_arrow/src/internal/schema/mod.rs | 14 +------------- .../src/internal/schema/serde/serialize.rs | 16 +++++++++++----- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index ba20a418..712c7200 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -309,28 +309,16 @@ impl SchemaLike for SerdeArrowSchema { } } -#[derive(Serialize, Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub struct GenericField { pub name: String, pub data_type: GenericDataType, - - #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub metadata: HashMap, - - #[serde(default, skip_serializing_if = "Option::is_none")] pub strategy: Option, - - #[serde(default, skip_serializing_if = "is_false")] pub nullable: bool, - - #[serde(default, skip_serializing_if = "Vec::is_empty")] pub children: Vec, } -fn is_false(val: &bool) -> bool { - !*val -} - impl GenericField { pub fn new(name: &str, data_type: GenericDataType, nullable: bool) -> Self { Self { diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index 1579a560..bbe93f51 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -6,15 +6,19 @@ use serde::ser::SerializeStruct; use crate::{internal::schema::GenericField, schema::STRATEGY_KEY}; -/*impl serde::Serialize for GenericField { +impl serde::Serialize for GenericField { fn serialize(&self, serializer: S) -> Result { - let non_strategy_metadata = self.metadata.iter().filter(|(key, _)| *key != STRATEGY_KEY).collect::>(); + let non_strategy_metadata = self + .metadata + .iter() + .filter(|(key, _)| *key != STRATEGY_KEY) + .collect::>(); let mut num_fields = 2; if !non_strategy_metadata.is_empty() { num_fields += 1; } - if self.metadata.contains_key(STRATEGY_KEY) { + if self.strategy.is_some() { num_fields += 1; } if self.nullable { @@ -28,10 +32,13 @@ use crate::{internal::schema::GenericField, schema::STRATEGY_KEY}; s.serialize_field("name", &self.name)?; s.serialize_field("data_type", &self.data_type)?; + if self.nullable { + s.serialize_field("nullable", &self.nullable)?; + } if !non_strategy_metadata.is_empty() { s.serialize_field("metadata", &non_strategy_metadata)?; } - if let Some(strategy) = self.metadata.get(STRATEGY_KEY) { + if let Some(strategy) = self.strategy.as_ref() { s.serialize_field("strategy", strategy)?; } if !self.children.is_empty() { @@ -40,4 +47,3 @@ use crate::{internal::schema::GenericField, schema::STRATEGY_KEY}; s.end() } } -*/ From 35f8ee8151d6c849102e0a85c893336b24678986 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 4 Aug 2024 15:12:21 +0200 Subject: [PATCH 085/178] Remove metadata / schema split --- serde_arrow/src/arrow2_impl/api.rs | 5 +- serde_arrow/src/arrow2_impl/schema.rs | 30 ++++------ serde_arrow/src/arrow_impl/api.rs | 5 +- serde_arrow/src/arrow_impl/schema.rs | 21 ++----- .../src/internal/schema/deserialization.rs | 12 +--- serde_arrow/src/internal/schema/extensions.rs | 2 - serde_arrow/src/internal/schema/mod.rs | 55 +++++++------------ .../src/internal/schema/serde/serialize.rs | 4 +- serde_arrow/src/internal/schema/serde/test.rs | 11 ++-- serde_arrow/src/internal/schema/strategy.rs | 45 +++++---------- serde_arrow/src/internal/schema/test.rs | 12 ++-- serde_arrow/src/internal/schema/tracer.rs | 28 +++++++--- .../serialization/outer_sequence_builder.rs | 17 +++--- serde_arrow/src/internal/utils/mod.rs | 7 +-- .../src/test_with_arrow/impls/examples.rs | 14 +++-- .../impls/issue_90_type_tracing.rs | 18 ++++-- serde_arrow/src/test_with_arrow/impls/map.rs | 22 ++++---- .../src/test_with_arrow/impls/struct.rs | 6 +- .../src/test_with_arrow/impls/tuple.rs | 47 ++++++++++++---- .../src/test_with_arrow/impls/union.rs | 12 +++- .../issue_35_preserve_metadata.rs | 16 +++--- 21 files changed, 195 insertions(+), 194 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index 912fd33d..4f0f1288 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -15,7 +15,7 @@ use crate::{ }, deserializer::Deserializer, error::{fail, Result}, - schema::{GenericField, SerdeArrowSchema}, + schema::{get_strategy_from_metadata, GenericField, SerdeArrowSchema}, serializer::Serializer, }, }; @@ -175,7 +175,8 @@ impl<'de> Deserializer<'de> { if array.len() != len { fail!("arrays of different lengths are not supported"); } - let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; + let strategy = get_strategy_from_metadata(&field.metadata)?; + let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?; deserializers.push((field.name.clone(), deserializer)); } diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index c19d6d96..7ef95ce0 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -7,10 +7,7 @@ use crate::{ internal::{ arrow::TimeUnit, error::{error, fail, Error, Result}, - schema::{ - merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, - GenericField, SchemaLike, Sealed, SerdeArrowSchema, - }, + schema::{GenericDataType, GenericField, SchemaLike, Sealed, SerdeArrowSchema}, }, }; @@ -72,16 +69,6 @@ impl TryFrom<&Field> for GenericField { fn try_from(field: &Field) -> Result { use {GenericDataType as T, TimeUnit as U}; - let metadata = field - .metadata - .clone() - .into_iter() - .collect::>(); - let (metadata, strategy) = split_strategy_from_metadata(metadata)?; - - let name = field.name.to_owned(); - let nullable = field.is_nullable; - let mut children = Vec::::new(); let data_type = match &field.data_type { DataType::Boolean => T::Bool, @@ -181,11 +168,19 @@ impl TryFrom<&Field> for GenericField { dt => fail!("Cannot convert data type {dt:?}"), }; + let name = field.name.to_owned(); + let nullable = field.is_nullable; + + let metadata = field + .metadata + .clone() + .into_iter() + .collect::>(); + let field = GenericField { name, data_type, metadata, - strategy, children, nullable, }; @@ -301,11 +296,8 @@ impl TryFrom<&GenericField> for Field { } }; - let metadata = - merge_strategy_with_metadata(value.metadata.clone(), value.strategy.clone())?; - let mut field = Field::new(&value.name, data_type, value.nullable); - field.metadata = metadata.into_iter().collect(); + field.metadata = value.metadata.clone().into_iter().collect(); Ok(field) } diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index 347979c1..852a0198 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -16,7 +16,7 @@ use crate::{ }, deserializer::Deserializer, error::{fail, Result}, - schema::SerdeArrowSchema, + schema::{get_strategy_from_metadata, SerdeArrowSchema}, serializer::Serializer, }, }; @@ -262,7 +262,8 @@ impl<'de> Deserializer<'de> { fail!("arrays of different lengths are not supported"); } - let deserializer = ArrayDeserializer::new(field.strategy.as_ref(), array.try_into()?)?; + let strategy = get_strategy_from_metadata(&field.metadata)?; + let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?; deserializers.push((field.name.clone(), deserializer)); } diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 7d3dcc00..a938641c 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -5,10 +5,7 @@ use crate::{ internal::{ arrow::TimeUnit, error::{error, fail, Error, Result}, - schema::{ - merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, - GenericField, SchemaLike, Sealed, SerdeArrowSchema, - }, + schema::{GenericDataType, GenericField, SchemaLike, Sealed, SerdeArrowSchema}, }, }; @@ -181,12 +178,6 @@ impl TryFrom<&Field> for GenericField { type Error = Error; fn try_from(field: &Field) -> Result { - let metadata = field.metadata().clone(); - let (metadata, strategy) = split_strategy_from_metadata(metadata)?; - - let name = field.name().to_owned(); - let nullable = field.is_nullable(); - let mut children = Vec::::new(); let data_type = match field.data_type() { DataType::List(field) => { @@ -236,11 +227,14 @@ impl TryFrom<&Field> for GenericField { dt => dt.try_into()?, }; + let name = field.name().to_owned(); + let nullable = field.is_nullable(); + let metadata = field.metadata().clone(); + let field = GenericField { name, data_type, metadata, - strategy, children, nullable, }; @@ -367,11 +361,8 @@ impl TryFrom<&GenericField> for Field { T::Duration(unit) => DataType::Duration((*unit).into()), }; - let metadata = - merge_strategy_with_metadata(value.metadata.clone(), value.strategy.clone())?; - let mut field = Field::new(&value.name, data_type, value.nullable); - field.set_metadata(metadata); + field.set_metadata(value.metadata.clone()); Ok(field) } diff --git a/serde_arrow/src/internal/schema/deserialization.rs b/serde_arrow/src/internal/schema/deserialization.rs index cb3e909f..229624a5 100644 --- a/serde_arrow/src/internal/schema/deserialization.rs +++ b/serde_arrow/src/internal/schema/deserialization.rs @@ -9,8 +9,7 @@ use crate::internal::{ arrow::TimeUnit, error::{fail, Error, Result}, schema::{ - merge_strategy_with_metadata, split_strategy_from_metadata, GenericDataType, GenericField, - SerdeArrowSchema, Strategy, + merge_strategy_with_metadata, GenericDataType, GenericField, SerdeArrowSchema, Strategy, }, }; @@ -158,16 +157,12 @@ impl TryFrom for GenericField { fn try_from(value: ArrowField) -> Result { let (data_type, children) = value.data_type.into_generic()?; - - let (metadata, strategy) = split_strategy_from_metadata(value.metadata)?; - Ok(GenericField { name: value.name, nullable: value.nullable, + metadata: value.metadata, data_type, - metadata, children, - strategy, }) } } @@ -304,8 +299,6 @@ impl<'de> Deserialize<'de> for GenericField { let metadata = merge_strategy_with_metadata(metadata.unwrap_or_default(), strategy.flatten()) .map_err(A::Error::custom)?; - let (metadata, strategy) = - split_strategy_from_metadata(metadata).map_err(A::Error::custom)?; Ok(GenericField { name: name.ok_or_else(|| A::Error::custom("missing field `name`"))?, @@ -313,7 +306,6 @@ impl<'de> Deserialize<'de> for GenericField { children, nullable: nullable.unwrap_or_default(), metadata, - strategy, }) } } diff --git a/serde_arrow/src/internal/schema/extensions.rs b/serde_arrow/src/internal/schema/extensions.rs index aa670f57..7b825872 100644 --- a/serde_arrow/src/internal/schema/extensions.rs +++ b/serde_arrow/src/internal/schema/extensions.rs @@ -139,7 +139,6 @@ impl TryFrom<&FixedShapeTensorField> for GenericField { nullable: value.nullable, data_type: GenericDataType::FixedSizeList(n.try_into()?), children: vec![value.element.clone()], - strategy: None, metadata, }) } @@ -298,7 +297,6 @@ impl TryFrom<&VariableShapeTensorField> for GenericField { false, )), ], - strategy: None, metadata, }) } diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 712c7200..fb77d80e 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -23,9 +23,8 @@ use crate::internal::{ use ::serde::{Deserialize, Serialize}; pub use data_type::GenericDataType; -pub use strategy::{ - merge_strategy_with_metadata, split_strategy_from_metadata, Strategy, STRATEGY_KEY, -}; +pub use strategy::get_strategy_from_metadata; +pub use strategy::{merge_strategy_with_metadata, Strategy, STRATEGY_KEY}; use tracer::Tracer; pub use tracing_options::{Overwrites, TracingMode, TracingOptions}; @@ -314,7 +313,6 @@ pub struct GenericField { pub name: String, pub data_type: GenericDataType, pub metadata: HashMap, - pub strategy: Option, pub nullable: bool, pub children: Vec, } @@ -327,7 +325,6 @@ impl GenericField { metadata: HashMap::new(), nullable, children: Vec::new(), - strategy: None, } } @@ -374,7 +371,7 @@ impl GenericField { pub fn is_utc(&self) -> Result { match &self.data_type { - GenericDataType::Date64 => match &self.strategy { + GenericDataType::Date64 => match get_strategy_from_metadata(&self.metadata)? { None | Some(Strategy::UtcStrAsDate64) => Ok(true), Some(Strategy::NaiveStrAsDate64) => Ok(false), Some(strategy) => fail!("invalid strategy for date64 deserializer: {strategy}"), @@ -392,27 +389,17 @@ impl GenericField { self } - pub fn with_strategy(mut self, strategy: Strategy) -> Self { - self.strategy = Some(strategy); - self - } - - pub fn with_optional_strategy(mut self, strategy: Option) -> Self { - self.strategy = strategy; + pub fn with_metadata(mut self, key: String, value: String) -> Self { + self.metadata.insert(key, value); self } } impl GenericField { pub(crate) fn validate_null(&self) -> Result<()> { - if !matches!( - self.strategy, - None | Some(Strategy::InconsistentTypes) | Some(Strategy::UnknownVariant) - ) { - fail!( - "invalid strategy for Null field: {}", - self.strategy.as_ref().unwrap() - ); + match get_strategy_from_metadata(&self.metadata)? { + None | Some(Strategy::InconsistentTypes) | Some(Strategy::UnknownVariant) => {} + Some(strategy) => fail!("invalid strategy for Null field: {strategy}"), } if !self.children.is_empty() { fail!("Null field must not have children"); @@ -429,7 +416,7 @@ impl GenericField { } pub(crate) fn validate_date64(&self) -> Result<()> { - match self.strategy.as_ref() { + match get_strategy_from_metadata(&self.metadata)? { None | Some(Strategy::UtcStrAsDate64) | Some(Strategy::NaiveStrAsDate64) => {} Some(strategy) => fail!("invalid strategy for Date64 field: {strategy}"), } @@ -440,7 +427,7 @@ impl GenericField { } pub(crate) fn validate_timestamp(&self) -> Result<()> { - match &self.strategy { + match get_strategy_from_metadata(&self.metadata)? { None => Ok(()), Some(strategy @ Strategy::UtcStrAsDate64) => { if !matches!(&self.data_type, GenericDataType::Timestamp(_, Some(tz)) if tz.to_uppercase() == "UTC") @@ -472,12 +459,8 @@ impl GenericField { } pub(crate) fn validate_time32(&self) -> Result<()> { - if self.strategy.is_some() { - fail!( - "invalid strategy for {}: {}", - self.data_type, - self.strategy.as_ref().unwrap() - ); + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { + fail!("invalid strategy for {dt}: {strategy}", dt = self.data_type); } if !self.children.is_empty() { fail!("{} field must not have children", self.data_type); @@ -492,7 +475,7 @@ impl GenericField { } pub(crate) fn validate_time64(&self) -> Result<()> { - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { fail!( "invalid strategy for {data_type}: {strategy}", data_type = self.data_type, @@ -516,7 +499,7 @@ impl GenericField { pub(crate) fn validate_struct(&self) -> Result<()> { // NOTE: do not check number of children: arrow-rs can 0 children, arrow2 not - match self.strategy.as_ref() { + match get_strategy_from_metadata(&self.metadata)? { None | Some(Strategy::MapAsStruct) | Some(Strategy::TupleAsStruct) => {} Some(strategy) => fail!("invalid strategy for Struct field: {strategy}"), } @@ -527,7 +510,7 @@ impl GenericField { } pub(crate) fn validate_map(&self) -> Result<()> { - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { fail!("invalid strategy for Map field: {strategy}"); } if self.children.len() != 1 { @@ -561,7 +544,7 @@ impl GenericField { } pub(crate) fn validate_list(&self) -> Result<()> { - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { fail!("invalid strategy for List field: {strategy}"); } if self.children.len() != 1 { @@ -587,7 +570,7 @@ impl GenericField { } pub(crate) fn validate_union(&self) -> Result<()> { - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { fail!("invalid strategy for Union field: {strategy}"); } if self.children.is_empty() { @@ -600,7 +583,7 @@ impl GenericField { } pub(crate) fn validate_dictionary(&self) -> Result<()> { - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { fail!("invalid strategy for Dictionary field: {strategy}"); } if self.children.len() != 2 { @@ -641,7 +624,7 @@ impl GenericField { } pub(crate) fn validate_no_strategy_no_children(&self) -> Result<()> { - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { fail!( "invalid strategy for {data_type}: {strategy}", data_type = self.data_type, diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index bbe93f51..0fd5cd0a 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -18,7 +18,7 @@ impl serde::Serialize for GenericField { if !non_strategy_metadata.is_empty() { num_fields += 1; } - if self.strategy.is_some() { + if self.metadata.contains_key(STRATEGY_KEY) { num_fields += 1; } if self.nullable { @@ -38,7 +38,7 @@ impl serde::Serialize for GenericField { if !non_strategy_metadata.is_empty() { s.serialize_field("metadata", &non_strategy_metadata)?; } - if let Some(strategy) = self.strategy.as_ref() { + if let Some(strategy) = self.metadata.get(STRATEGY_KEY) { s.serialize_field("strategy", strategy)?; } if !self.children.is_empty() { diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs index 7362f172..19fa5c6a 100644 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -2,7 +2,7 @@ use serde_json::json; use crate::internal::{ error::PanicOnError, - schema::{GenericDataType, GenericField, Strategy}, + schema::{GenericDataType, GenericField, STRATEGY_KEY}, testing::hash_map, }; @@ -12,7 +12,6 @@ fn i16_field_simple() -> PanicOnError<()> { name: String::from("my_field_name"), data_type: GenericDataType::I16, metadata: hash_map!(), - strategy: None, nullable: false, children: vec![], }; @@ -35,8 +34,10 @@ fn date64_field_complex() -> PanicOnError<()> { let field = GenericField { name: String::from("my_field_name"), data_type: GenericDataType::Date64, - metadata: hash_map!("foo" => "bar"), - strategy: Some(Strategy::NaiveStrAsDate64), + metadata: hash_map!( + "foo" => "bar", + STRATEGY_KEY => "NaiveStrAsDate64", + ), nullable: true, children: vec![], }; @@ -65,13 +66,11 @@ fn list_field_complex() -> PanicOnError<()> { name: String::from("my_field_name"), data_type: GenericDataType::List, metadata: hash_map!("foo" => "bar"), - strategy: None, nullable: true, children: vec![GenericField { name: String::from("element"), data_type: GenericDataType::I64, metadata: hash_map!(), - strategy: None, nullable: false, children: vec![], }], diff --git a/serde_arrow/src/internal/schema/strategy.rs b/serde_arrow/src/internal/schema/strategy.rs index b7fa5338..7622df25 100644 --- a/serde_arrow/src/internal/schema/strategy.rs +++ b/serde_arrow/src/internal/schema/strategy.rs @@ -130,16 +130,11 @@ impl From for HashMap { } } -pub fn split_strategy_from_metadata( - mut metadata: HashMap, -) -> Result<(HashMap, Option)> { - let strategy = if let Some(strategy_str) = metadata.remove(STRATEGY_KEY) { - Some(strategy_str.parse::()?) - } else { - None +pub fn get_strategy_from_metadata(metadata: &HashMap) -> Result> { + let Some(strategy) = metadata.get(STRATEGY_KEY) else { + return Ok(None); }; - - Ok((metadata, strategy)) + Ok(Some(strategy.parse()?)) } pub fn merge_strategy_with_metadata( @@ -159,47 +154,37 @@ pub fn merge_strategy_with_metadata( fn test_split_strategy_from_metadata_with_metadata() { use crate::internal::testing::hash_map; - let input: HashMap = hash_map!( - "SERDE_ARROW:strategy" => "TupleAsStruct", + let metadata: HashMap = hash_map!( "key1" => "value1", "key2" => "value2", ); + let strategy: Option = Some(Strategy::TupleAsStruct); - let expected_metadata: HashMap = hash_map!( + let expected: HashMap = hash_map!( + "SERDE_ARROW:strategy" => "TupleAsStruct", "key1" => "value1", "key2" => "value2", ); - let expected_strategy: Option = Some(Strategy::TupleAsStruct); - - let (actual_metadata, actual_strategy) = split_strategy_from_metadata(input.clone()).unwrap(); - let roundtripped = - merge_strategy_with_metadata(actual_metadata.clone(), actual_strategy.clone()).unwrap(); - assert_eq!(actual_metadata, expected_metadata); - assert_eq!(actual_strategy, expected_strategy); - assert_eq!(roundtripped, input); + let actual = merge_strategy_with_metadata(metadata, strategy).unwrap(); + assert_eq!(actual, expected); } #[test] fn test_split_strategy_from_metadata_without_metadata() { use crate::internal::testing::hash_map; - let input: HashMap = hash_map!( + let metadata: HashMap = hash_map!( "key1" => "value1", "key2" => "value2", ); + let strategy: Option = None; - let expected_metadata: HashMap = hash_map!( + let expected: HashMap = hash_map!( "key1" => "value1", "key2" => "value2", ); - let expected_strategy: Option = None; - - let (actual_metadata, actual_strategy) = split_strategy_from_metadata(input.clone()).unwrap(); - let roundtripped = - merge_strategy_with_metadata(actual_metadata.clone(), actual_strategy.clone()).unwrap(); - assert_eq!(actual_metadata, expected_metadata); - assert_eq!(actual_strategy, expected_strategy); - assert_eq!(roundtripped, input); + let actual = merge_strategy_with_metadata(metadata, strategy).unwrap(); + assert_eq!(actual, expected); } diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index 8e19822c..a4719d82 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -80,8 +80,10 @@ fn doc_schema() { #[test] fn date64_with_strategy() { let schema = SerdeArrowSchema::default().with_field( - GenericField::new("item", GenericDataType::Date64, false) - .with_strategy(Strategy::NaiveStrAsDate64), + GenericField::new("item", GenericDataType::Date64, false).with_metadata( + STRATEGY_KEY.to_string(), + Strategy::NaiveStrAsDate64.to_string(), + ), ); let actual = serde_json::to_string(&schema).unwrap(); @@ -213,10 +215,9 @@ fn test_metadata_strategy_from_explicit() { ])) .unwrap(); - assert_eq!(schema.fields[0].strategy, Some(Strategy::UtcStrAsDate64)); assert_eq!( schema.fields[0].metadata, - hash_map!("foo" => "bar", "hello" => "world") + hash_map!("foo" => "bar", "hello" => "world", STRATEGY_KEY => "UtcStrAsDate64"), ); let schema_value = serde_json::to_value(&schema).unwrap(); @@ -252,10 +253,9 @@ fn test_metadata_strategy_from_metadata() { ])) .unwrap(); - assert_eq!(schema.fields[0].strategy, Some(Strategy::UtcStrAsDate64)); assert_eq!( schema.fields[0].metadata, - hash_map!("foo" => "bar", "hello" => "world") + hash_map!("foo" => "bar", "hello" => "world", STRATEGY_KEY => "UtcStrAsDate64") ); // NOTE: the strategy is always normalized to be an extra field diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 9e4e6358..1777a544 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -10,7 +10,7 @@ use crate::internal::{ use super::{ tracing_options::{TracingMode, TracingOptions}, - Overwrites, + Overwrites, STRATEGY_KEY, }; // TODO: allow to customize @@ -692,7 +692,10 @@ impl TupleTracer { for tracer in &self.field_tracers { field.children.push(tracer.to_field()?); } - field.strategy = Some(Strategy::TupleAsStruct); + field.metadata.insert( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ); Ok(field) } @@ -808,7 +811,9 @@ impl StructTracer { if let StructMode::Map = self.mode { res_field.children.sort_by(|a, b| a.name.cmp(&b.name)); - res_field.strategy = Some(Strategy::MapAsStruct); + res_field + .metadata + .insert(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()); } Ok(res_field) } @@ -911,8 +916,10 @@ impl UnionTracer { field.children.push(variant.tracer.to_field()?); } else { field.children.push( - GenericField::new("", GenericDataType::Null, true) - .with_strategy(Strategy::UnknownVariant), + GenericField::new("", GenericDataType::Null, true).with_metadata( + STRATEGY_KEY.to_string(), + Strategy::UnknownVariant.to_string(), + ), ); }; } @@ -994,8 +1001,15 @@ impl PrimitiveTracer { Ok(default_dictionary_field(&self.name, self.nullable)) } } - dt => Ok(GenericField::new(&self.name, dt.clone(), self.nullable) - .with_optional_strategy(self.strategy.clone())), + dt => { + let mut field = GenericField::new(&self.name, dt.clone(), self.nullable); + if let Some(strategy) = self.strategy.as_ref() { + field + .metadata + .insert(STRATEGY_KEY.to_string(), strategy.to_string()); + } + Ok(field) + } } } } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 21647373..b2140657 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -3,7 +3,9 @@ use serde::Serialize; use crate::internal::{ arrow::TimeUnit, error::{fail, Result}, - schema::{GenericDataType, GenericField, SerdeArrowSchema, Strategy}, + schema::{ + get_strategy_from_metadata, GenericDataType, GenericField, SerdeArrowSchema, Strategy, + }, serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, @@ -40,13 +42,10 @@ impl OuterSequenceBuilder { use {ArrayBuilder as A, GenericDataType as T}; let builder = match &field.data_type { - T::Null => { - if matches!(&field.strategy, Some(Strategy::UnknownVariant)) { - A::UnknownVariant(UnknownVariantBuilder) - } else { - A::Null(NullBuilder::new()) - } - } + T::Null => match get_strategy_from_metadata(&field.metadata)? { + Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder), + _ => A::Null(NullBuilder::new()), + }, T::Bool => A::Bool(BoolBuilder::new(field.nullable)), T::I8 => A::I8(IntBuilder::new(field.nullable)), T::I16 => A::I16(IntBuilder::new(field.nullable)), @@ -62,7 +61,7 @@ impl OuterSequenceBuilder { T::Date32 => A::Date32(Date32Builder::new(field.nullable)), T::Date64 => A::Date64(Date64Builder::new( None, - is_utc_strategy(field.strategy.as_ref())?, + is_utc_strategy(get_strategy_from_metadata(&field.metadata)?.as_ref())?, field.nullable, )), T::Timestamp(unit, tz) => A::Date64(Date64Builder::new( diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 4f32d6f1..7c58941a 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -9,10 +9,7 @@ use serde::{ser::SerializeSeq, Deserialize, Serialize}; use crate::internal::error::Result; -use super::{ - arrow::FieldMeta, - schema::{merge_strategy_with_metadata, GenericField}, -}; +use super::{arrow::FieldMeta, schema::GenericField}; /// A wrapper around a sequence of items /// @@ -185,6 +182,6 @@ pub fn meta_from_field(field: GenericField) -> Result { Ok(FieldMeta { name: field.name, nullable: field.nullable, - metadata: merge_strategy_with_metadata(field.metadata, field.strategy)?, + metadata: field.metadata, }) } diff --git a/serde_arrow/src/test_with_arrow/impls/examples.rs b/serde_arrow/src/test_with_arrow/impls/examples.rs index 84adc0e0..677b7f2b 100644 --- a/serde_arrow/src/test_with_arrow/impls/examples.rs +++ b/serde_arrow/src/test_with_arrow/impls/examples.rs @@ -1,7 +1,7 @@ use super::utils::Test; use crate::{ internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions}, + schema::{Strategy, TracingOptions, STRATEGY_KEY}, utils::Item, }; @@ -357,8 +357,10 @@ fn issue_57() { GenericField::new("filename", GenericDataType::LargeUtf8, false), GenericField::new("game_type", GenericDataType::Union, false) .with_child( - GenericField::new("", GenericDataType::Null, true) - .with_strategy(Strategy::UnknownVariant), + GenericField::new("", GenericDataType::Null, true).with_metadata( + STRATEGY_KEY.to_string(), + Strategy::UnknownVariant.to_string(), + ), ) .with_child(GenericField::new( "RegularSeason", @@ -367,8 +369,10 @@ fn issue_57() { )), GenericField::new("account_type", GenericDataType::Union, false) .with_child( - GenericField::new("", GenericDataType::Null, true) - .with_strategy(Strategy::UnknownVariant), + GenericField::new("", GenericDataType::Null, true).with_metadata( + STRATEGY_KEY.to_string(), + Strategy::UnknownVariant.to_string(), + ), ) .with_child(GenericField::new("Deduced", GenericDataType::Null, true)), GenericField::new("file_index", GenericDataType::U64, false), diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index d025a1ce..8a74a386 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -2,10 +2,15 @@ use std::collections::HashMap; use serde::Deserialize; -use crate::internal::{ - schema::{tracer::Tracer, GenericDataType as T, GenericField as F, Strategy, TracingOptions}, - testing::assert_error, - utils::Item, +use crate::{ + internal::{ + schema::{ + tracer::Tracer, GenericDataType as T, GenericField as F, Strategy, TracingOptions, + }, + testing::assert_error, + utils::Item, + }, + schema::STRATEGY_KEY, }; fn trace_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> F { @@ -128,7 +133,10 @@ fn trace_tuple_as_struct() { let expected = F::new("item", T::Struct, false) .with_child(F::new("0", T::Bool, false)) .with_child(F::new("1", T::I8, true)) - .with_strategy(Strategy::TupleAsStruct); + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ); assert_eq!(actual, expected); } diff --git a/serde_arrow/src/test_with_arrow/impls/map.rs b/serde_arrow/src/test_with_arrow/impls/map.rs index 47fe5106..062d614b 100644 --- a/serde_arrow/src/test_with_arrow/impls/map.rs +++ b/serde_arrow/src/test_with_arrow/impls/map.rs @@ -1,9 +1,11 @@ use std::collections::{BTreeMap, HashMap}; use crate::{ - internal::schema::{GenericDataType, GenericField}, - internal::testing::{btree_map, hash_map}, - schema::{Strategy, TracingOptions}, + internal::{ + schema::{GenericDataType, GenericField}, + testing::{btree_map, hash_map}, + }, + schema::{Strategy, TracingOptions, STRATEGY_KEY}, utils::Item, }; @@ -14,7 +16,7 @@ use super::utils::Test; #[test] fn map_as_struct() { let field = GenericField::new("item", GenericDataType::Struct, false) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, false)) .with_child(GenericField::new("b", GenericDataType::U32, false)); type Ty = BTreeMap; @@ -35,7 +37,7 @@ fn map_as_struct() { #[test] fn hash_map_as_struct() { let field = GenericField::new("item", GenericDataType::Struct, false) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, false)) .with_child(GenericField::new("b", GenericDataType::U32, false)); type Ty = HashMap; @@ -55,7 +57,7 @@ fn hash_map_as_struct() { #[test] fn map_as_struct_nullable() { let field = GenericField::new("item", GenericDataType::Struct, true) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, false)) .with_child(GenericField::new("b", GenericDataType::U32, false)); type Ty = Option>; @@ -76,7 +78,7 @@ fn map_as_struct_nullable() { #[test] fn map_as_struct_missing_fields() { let field = GenericField::new("item", GenericDataType::Struct, false) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, false)) .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap; @@ -95,7 +97,7 @@ fn map_as_struct_missing_fields() { #[test] fn map_as_struct_missing_fields_2() { let field = GenericField::new("item", GenericDataType::Struct, false) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, true)) .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap; @@ -116,7 +118,7 @@ fn map_as_struct_missing_fields_2() { #[test] fn map_as_struct_missing_fields_3() { let field = GenericField::new("item", GenericDataType::Struct, false) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, true)) .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap; @@ -137,7 +139,7 @@ fn map_as_struct_missing_fields_3() { #[test] fn map_as_struct_nullable_fields() { let field = GenericField::new("item", GenericDataType::Struct, false) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::U32, true)) .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap>; diff --git a/serde_arrow/src/test_with_arrow/impls/struct.rs b/serde_arrow/src/test_with_arrow/impls/struct.rs index 000b9e4c..ea0f8108 100644 --- a/serde_arrow/src/test_with_arrow/impls/struct.rs +++ b/serde_arrow/src/test_with_arrow/impls/struct.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::{ internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions}, + schema::{Strategy, TracingOptions, STRATEGY_KEY}, utils::Item, }; @@ -242,7 +242,7 @@ fn nullable_struct_list_field() { #[test] fn serde_flatten() { let field = GenericField::new("item", GenericDataType::Struct, true) - .with_strategy(Strategy::MapAsStruct) + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) .with_child(GenericField::new("a", GenericDataType::I8, false)) .with_child(GenericField::new("value", GenericDataType::Bool, false)); let values = [Item(Some(LocalItem { @@ -275,7 +275,7 @@ fn flattened_structures() { .with_child(GenericField::new("a", GenericDataType::I64, false)) .with_child(GenericField::new("b", GenericDataType::F32, false)) .with_child(GenericField::new("c", GenericDataType::F64, false)) - .with_strategy(Strategy::MapAsStruct); + .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()); let values = [ Item(Outer { diff --git a/serde_arrow/src/test_with_arrow/impls/tuple.rs b/serde_arrow/src/test_with_arrow/impls/tuple.rs index edb481d0..e90a78b9 100644 --- a/serde_arrow/src/test_with_arrow/impls/tuple.rs +++ b/serde_arrow/src/test_with_arrow/impls/tuple.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::{ internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions}, + schema::{Strategy, TracingOptions, STRATEGY_KEY}, utils::Item, }; @@ -18,7 +18,10 @@ fn tuple_u64_bool() { GenericDataType::Struct, false, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U64, false)) .with_child(GenericField::new("1", GenericDataType::Bool, false))]) .trace_schema_from_type::>(TracingOptions::default()) @@ -41,7 +44,10 @@ fn tuple_struct_u64_bool() { GenericDataType::Struct, false, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U64, false)) .with_child(GenericField::new("1", GenericDataType::Bool, false))]) .trace_schema_from_type::>(TracingOptions::default()) @@ -68,7 +74,10 @@ fn nullbale_tuple_u64_bool() { GenericDataType::Struct, true, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U64, false)) .with_child(GenericField::new("1", GenericDataType::Bool, false))]) .trace_schema_from_type::>>(TracingOptions::default()) @@ -88,7 +97,10 @@ fn tuple_nullable_u64() { GenericDataType::Struct, false, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U64, true))]) .trace_schema_from_type::,)>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) @@ -107,10 +119,16 @@ fn tuple_nested() { GenericDataType::Struct, false, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child( GenericField::new("0", GenericDataType::Struct, false) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U64, false)), )]) .trace_schema_from_type::>(TracingOptions::default()) @@ -134,7 +152,10 @@ fn tuple_nullable() { GenericDataType::Struct, true, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::Bool, false)) .with_child(GenericField::new("1", GenericDataType::I64, false))]) .trace_schema_from_type::>>(TracingOptions::default()) @@ -158,10 +179,16 @@ fn tuple_nullable_nested() { GenericDataType::Struct, true, ) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child( GenericField::new("0", GenericDataType::Struct, false) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::Bool, false)) .with_child(GenericField::new("1", GenericDataType::I64, false)), ) diff --git a/serde_arrow/src/test_with_arrow/impls/union.rs b/serde_arrow/src/test_with_arrow/impls/union.rs index c16b5f3b..6c2b5c92 100644 --- a/serde_arrow/src/test_with_arrow/impls/union.rs +++ b/serde_arrow/src/test_with_arrow/impls/union.rs @@ -3,7 +3,7 @@ use serde_json::json; use crate::{ internal::schema::{GenericDataType, GenericField}, - schema::{SchemaLike, Strategy, TracingOptions}, + schema::{SchemaLike, Strategy, TracingOptions, STRATEGY_KEY}, utils::{Item, Items}, }; @@ -225,13 +225,19 @@ fn enums_tuple() { let field = GenericField::new("item", GenericDataType::Union, false) .with_child( GenericField::new("A", GenericDataType::Struct, false) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U8, false)) .with_child(GenericField::new("1", GenericDataType::U32, false)), ) .with_child( GenericField::new("B", GenericDataType::Struct, false) - .with_strategy(Strategy::TupleAsStruct) + .with_metadata( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ) .with_child(GenericField::new("0", GenericDataType::U16, false)) .with_child(GenericField::new("1", GenericDataType::U64, false)), ); diff --git a/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs b/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs index befb32b6..22168490 100644 --- a/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs +++ b/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs @@ -6,7 +6,7 @@ use serde_json::json; use crate::{ _impl::{arrow, arrow2}, internal::{schema::GenericField, testing::hash_map}, - schema::{SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, + schema::{SchemaLike, SerdeArrowSchema, STRATEGY_KEY}, }; fn example_field_desc() -> serde_json::Value { @@ -31,9 +31,10 @@ fn example_field_desc() -> serde_json::Value { #[test] fn arrow() { let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); - - assert_eq!(initial_field.metadata, hash_map!("foo" => "bar")); - assert_eq!(initial_field.strategy, Some(Strategy::MapAsStruct)); + assert_eq!( + initial_field.metadata, + hash_map!("foo" => "bar", STRATEGY_KEY => "MapAsStruct") + ); let arrow_field = arrow::datatypes::Field::try_from(&initial_field).unwrap(); assert_eq!( @@ -54,9 +55,10 @@ fn arrow() { #[test] fn arrow2() { let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); - - assert_eq!(initial_field.metadata, hash_map!("foo" => "bar")); - assert_eq!(initial_field.strategy, Some(Strategy::MapAsStruct)); + assert_eq!( + initial_field.metadata, + hash_map!("foo" => "bar", STRATEGY_KEY => "MapAsStruct") + ); let arrow_field = arrow2::datatypes::Field::try_from(&initial_field).unwrap(); assert_eq!( From f04554b39e94f74cf1139d31b6751c36e962a2c4 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 5 Aug 2024 11:06:14 +0200 Subject: [PATCH 086/178] Broken. Switch API to use arrow::Field internally --- serde_arrow/src/arrow2_impl/api.rs | 15 +- serde_arrow/src/arrow2_impl/schema.rs | 367 +++++-------- serde_arrow/src/arrow_impl/array.rs | 24 +- serde_arrow/src/arrow_impl/schema.rs | 328 ++++-------- serde_arrow/src/arrow_impl/type_support.rs | 15 +- serde_arrow/src/internal/arrow/data_type.rs | 47 +- serde_arrow/src/internal/arrow/mod.rs | 3 +- .../src/internal/schema/deserialization.rs | 56 +- serde_arrow/src/internal/schema/extensions.rs | 82 +-- .../src/internal/schema/from_samples/mod.rs | 50 +- .../src/internal/schema/from_type/mod.rs | 47 +- serde_arrow/src/internal/schema/mod.rs | 484 +++++++----------- .../src/internal/schema/serde/deserialize.rs | 113 ++++ serde_arrow/src/internal/schema/serde/mod.rs | 3 +- .../src/internal/schema/serde/serialize.rs | 162 +++++- serde_arrow/src/internal/schema/serde/test.rs | 40 +- serde_arrow/src/internal/schema/test.rs | 283 +++++----- serde_arrow/src/internal/schema/tracer.rs | 284 ++++++---- .../src/internal/schema/tracing_options.rs | 12 +- .../serialization/dictionary_utf8_builder.rs | 6 +- .../src/internal/serialization/map_builder.rs | 17 +- .../serialization/outer_sequence_builder.rs | 136 ++--- serde_arrow/src/internal/utils/mod.rs | 4 +- serde_arrow/src/test/schema_like.rs | 20 +- .../src/test_with_arrow/impls/examples.rs | 88 ++-- .../impls/issue_90_type_tracing.rs | 141 +++-- serde_arrow/src/test_with_arrow/impls/list.rs | 127 +++-- serde_arrow/src/test_with_arrow/impls/map.rs | 302 +++++++---- .../src/test_with_arrow/impls/primitives.rs | 45 +- .../src/test_with_arrow/impls/struct.rs | 277 ++++++---- .../src/test_with_arrow/impls/tuple.rs | 190 +++---- .../src/test_with_arrow/impls/union.rs | 261 ++++++---- .../issue_35_preserve_metadata.rs | 17 +- 33 files changed, 2177 insertions(+), 1869 deletions(-) create mode 100644 serde_arrow/src/internal/schema/serde/deserialize.rs diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index 4f0f1288..10a75f1f 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -6,16 +6,17 @@ use serde::{Deserialize, Serialize}; use crate::{ - _impl::arrow2::{array::Array, datatypes::Field}, + _impl::arrow2::{array::Array, datatypes::Field as ArrowField}, internal::{ array_builder::ArrayBuilder, + arrow::Field, deserialization::{ array_deserializer::ArrayDeserializer, outer_sequence_deserializer::OuterSequenceDeserializer, }, deserializer::Deserializer, error::{fail, Result}, - schema::{get_strategy_from_metadata, GenericField, SerdeArrowSchema}, + schema::{get_strategy_from_metadata, SerdeArrowSchema}, serializer::Serializer, }, }; @@ -55,7 +56,7 @@ use crate::{ /// # } /// ``` /// -pub fn to_arrow2(fields: &[Field], items: &T) -> Result>> +pub fn to_arrow2(fields: &[ArrowField], items: &T) -> Result>> where T: Serialize + ?Sized, { @@ -93,7 +94,7 @@ where /// # } /// ``` /// -pub fn from_arrow2<'de, T, A>(fields: &[Field], arrays: &'de [A]) -> Result +pub fn from_arrow2<'de, T, A>(fields: &[ArrowField], arrays: &'de [A]) -> Result where T: Deserialize<'de>, A: AsRef, @@ -106,7 +107,7 @@ where impl crate::internal::array_builder::ArrayBuilder { /// Build an ArrayBuilder from `arrow2` fields (*requires one of the /// `arrow2-*` features*) - pub fn from_arrow2(fields: &[Field]) -> Result { + pub fn from_arrow2(fields: &[ArrowField]) -> Result { Self::new(SerdeArrowSchema::try_from(fields)?) } @@ -148,13 +149,13 @@ impl<'de> Deserializer<'de> { /// # Ok(()) /// # } /// ``` - pub fn from_arrow2(fields: &[Field], arrays: &'de [A]) -> Result + pub fn from_arrow2(fields: &[ArrowField], arrays: &'de [A]) -> Result where A: AsRef, { let fields = fields .iter() - .map(GenericField::try_from) + .map(Field::try_from) .collect::>>()?; let arrays = arrays .iter() diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 7ef95ce0..6cb1cb73 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -1,50 +1,46 @@ -use std::collections::HashMap; - use crate::{ _impl::arrow2::datatypes::{ - DataType, Field, IntegerType, TimeUnit as ArrowTimeUnit, UnionMode, + DataType as ArrowDataType, Field as ArrowField, IntegerType, TimeUnit as ArrowTimeUnit, + UnionMode as ArrowUnionMode, }, internal::{ - arrow::TimeUnit, - error::{error, fail, Error, Result}, - schema::{GenericDataType, GenericField, SchemaLike, Sealed, SerdeArrowSchema}, + arrow::{DataType, Field, TimeUnit, UnionMode}, + error::{fail, Error, Result}, + schema::{validate_field, SchemaLike, Sealed, SerdeArrowSchema}, }, }; -impl TryFrom for Vec { +impl TryFrom for Vec { type Error = Error; fn try_from(value: SerdeArrowSchema) -> Result { - Vec::::try_from(&value) + Vec::::try_from(&value) } } -impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { +impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { type Error = Error; fn try_from(value: &'a SerdeArrowSchema) -> Result { - value.fields.iter().map(Field::try_from).collect() + value.fields.iter().map(ArrowField::try_from).collect() } } -impl<'a> TryFrom<&'a [Field]> for SerdeArrowSchema { +impl<'a> TryFrom<&'a [ArrowField]> for SerdeArrowSchema { type Error = Error; - fn try_from(fields: &'a [Field]) -> std::prelude::v1::Result { + fn try_from(fields: &'a [ArrowField]) -> std::prelude::v1::Result { Ok(Self { - fields: fields - .iter() - .map(GenericField::try_from) - .collect::>()?, + fields: fields.iter().map(Field::try_from).collect::>()?, }) } } -impl Sealed for Vec {} +impl Sealed for Vec {} /// Schema support for `Vec` (*requires one of the /// `arrow2-*` features*) -impl SchemaLike for Vec { +impl SchemaLike for Vec { fn from_value(value: &T) -> Result { SerdeArrowSchema::from_value(value)?.try_into() } @@ -63,243 +59,140 @@ impl SchemaLike for Vec { } } -impl TryFrom<&Field> for GenericField { +impl TryFrom<&ArrowDataType> for DataType { type Error = Error; - fn try_from(field: &Field) -> Result { - use {GenericDataType as T, TimeUnit as U}; - - let mut children = Vec::::new(); - let data_type = match &field.data_type { - DataType::Boolean => T::Bool, - DataType::Null => T::Null, - DataType::Int8 => T::I8, - DataType::Int16 => T::I16, - DataType::Int32 => T::I32, - DataType::Int64 => T::I64, - DataType::UInt8 => T::U8, - DataType::UInt16 => T::U16, - DataType::UInt32 => T::U32, - DataType::UInt64 => T::U64, - DataType::Float16 => T::F16, - DataType::Float32 => T::F32, - DataType::Float64 => T::F64, - DataType::Utf8 => T::Utf8, - DataType::LargeUtf8 => T::LargeUtf8, - DataType::Date32 => T::Date32, - DataType::Date64 => T::Date64, - DataType::Decimal(precision, scale) => { + fn try_from(value: &ArrowDataType) -> Result { + use {ArrowDataType as AT, DataType as T}; + match value { + AT::Null => Ok(T::Null), + AT::Boolean => Ok(T::Boolean), + AT::Int8 => Ok(T::Int8), + AT::Int16 => Ok(T::Int16), + AT::Int32 => Ok(T::Int32), + AT::Int64 => Ok(T::Int64), + AT::UInt8 => Ok(T::UInt8), + AT::UInt16 => Ok(T::UInt16), + AT::UInt32 => Ok(T::UInt32), + AT::UInt64 => Ok(T::UInt64), + AT::Float16 => Ok(T::Float16), + AT::Float32 => Ok(T::Float32), + AT::Float64 => Ok(T::Float64), + AT::Date32 => Ok(T::Date32), + AT::Date64 => Ok(T::Date64), + AT::Time32(unit) => Ok(T::Time32((*unit).into())), + AT::Time64(unit) => Ok(T::Time64((*unit).into())), + AT::Duration(unit) => Ok(T::Duration((*unit).into())), + AT::Timestamp(unit, tz) => Ok(T::Timestamp((*unit).into(), tz.clone())), + AT::Decimal(precision, scale) => { if *precision > u8::MAX as usize || *scale > i8::MAX as usize { fail!("cannot represent precision / scale of the decimal"); } - T::Decimal128(*precision as u8, *scale as i8) - } - DataType::Time32(ArrowTimeUnit::Second) => T::Time32(U::Second), - DataType::Time32(ArrowTimeUnit::Millisecond) => T::Time32(U::Millisecond), - DataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - DataType::Time64(ArrowTimeUnit::Microsecond) => T::Time64(U::Microsecond), - DataType::Time64(ArrowTimeUnit::Nanosecond) => T::Time64(U::Nanosecond), - DataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - DataType::Timestamp(ArrowTimeUnit::Second, tz) => T::Timestamp(U::Second, tz.clone()), - DataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => { - T::Timestamp(U::Millisecond, tz.clone()) - } - DataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => { - T::Timestamp(U::Microsecond, tz.clone()) - } - DataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => { - T::Timestamp(U::Nanosecond, tz.clone()) - } - DataType::Duration(ArrowTimeUnit::Second) => T::Duration(U::Second), - DataType::Duration(ArrowTimeUnit::Millisecond) => T::Duration(U::Millisecond), - DataType::Duration(ArrowTimeUnit::Microsecond) => T::Duration(U::Microsecond), - DataType::Duration(ArrowTimeUnit::Nanosecond) => T::Duration(U::Nanosecond), - DataType::List(field) => { - children.push(GenericField::try_from(field.as_ref())?); - T::List + Ok(T::Decimal128(*precision as u8, *scale as i8)) } - DataType::LargeList(field) => { - children.push(field.as_ref().try_into()?); - T::LargeList - } - DataType::Struct(fields) => { - for field in fields { - children.push(field.try_into()?); - } - T::Struct - } - DataType::Map(field, _) => { - children.push(field.as_ref().try_into()?); - T::Map - } - DataType::Union(fields, field_indices, mode) => { - if field_indices.is_some() { - fail!("Union types with explicit field indices are not supported"); - } - if !mode.is_dense() { - fail!("Only dense unions are supported at the moment"); - } - + AT::Utf8 => Ok(T::Utf8), + AT::LargeUtf8 => Ok(T::LargeUtf8), + AT::Binary => Ok(T::Binary), + AT::LargeBinary => Ok(T::LargeBinary), + AT::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(i32::try_from(*n)?)), + AT::List(entry) => Ok(T::List(Box::new(entry.as_ref().try_into()?))), + AT::LargeList(entry) => Ok(T::LargeList(Box::new(entry.as_ref().try_into()?))), + AT::FixedSizeList(entry, n) => Ok(T::FixedSizeList( + Box::new(entry.as_ref().try_into()?), + i32::try_from(*n)?, + )), + AT::Map(field, sorted) => Ok(T::Map(Box::new(field.as_ref().try_into()?), *sorted)), + AT::Struct(fields) => { + let mut res_fields = Vec::new(); for field in fields { - children.push(field.try_into()?); + res_fields.push(Field::try_from(field)?); } - T::Union + Ok(T::Struct(res_fields)) } - DataType::Dictionary(int_type, data_type, sorted) => { - if *sorted { - fail!("Sorted dictionary are not supported"); - } - let key_type = match int_type { - IntegerType::Int8 => DataType::Int8, - IntegerType::Int16 => DataType::Int16, - IntegerType::Int32 => DataType::Int32, - IntegerType::Int64 => DataType::Int64, - IntegerType::UInt8 => DataType::UInt8, - IntegerType::UInt16 => DataType::UInt16, - IntegerType::UInt32 => DataType::UInt32, - IntegerType::UInt64 => DataType::UInt64, + AT::Dictionary(key, value, sorted) => { + let key = match key { + IntegerType::Int8 => T::Int8, + IntegerType::Int16 => T::Int16, + IntegerType::Int32 => T::Int32, + IntegerType::Int64 => T::Int64, + IntegerType::UInt8 => T::UInt8, + IntegerType::UInt16 => T::UInt16, + IntegerType::UInt32 => T::UInt32, + IntegerType::UInt64 => T::UInt64, }; - children.push((&Field::new("", key_type, false)).try_into()?); - children.push((&Field::new("", data_type.as_ref().clone(), false)).try_into()?); - T::Dictionary + Ok(T::Dictionary( + Box::new(key), + Box::new(value.as_ref().try_into()?), + *sorted, + )) } - DataType::Binary => T::Binary, - DataType::LargeBinary => T::LargeBinary, - dt => fail!("Cannot convert data type {dt:?}"), - }; - - let name = field.name.to_owned(); - let nullable = field.is_nullable; + AT::Union(fields, type_ids, mode) => todo!(), + dt => fail!("Cannot convert data type {dt:?} to internal data type"), + } + } +} - let metadata = field - .metadata - .clone() - .into_iter() - .collect::>(); +impl TryFrom<&ArrowField> for Field { + type Error = Error; - let field = GenericField { - name, - data_type, - metadata, - children, - nullable, + fn try_from(field: &ArrowField) -> Result { + let field = Field { + name: field.name.to_owned(), + data_type: DataType::try_from(&field.data_type)?, + nullable: field.is_nullable, + metadata: field.metadata.clone().into_iter().collect(), }; - field.validate()?; - + validate_field(&field)?; Ok(field) } } -impl TryFrom<&GenericField> for Field { +impl TryFrom<&DataType> for ArrowDataType { type Error = Error; - fn try_from(value: &GenericField) -> Result { - use {GenericDataType as T, TimeUnit as U}; - - let data_type = match &value.data_type { - T::Null => DataType::Null, - T::Bool => DataType::Boolean, - T::I8 => DataType::Int8, - T::I16 => DataType::Int16, - T::I32 => DataType::Int32, - T::I64 => DataType::Int64, - T::U8 => DataType::UInt8, - T::U16 => DataType::UInt16, - T::U32 => DataType::UInt32, - T::U64 => DataType::UInt64, - T::F16 => DataType::Float16, - T::F32 => DataType::Float32, - T::F64 => DataType::Float64, - T::Date32 => DataType::Date32, - T::Date64 => DataType::Date64, - T::Time32(U::Second) => DataType::Time32(ArrowTimeUnit::Second), - T::Time32(U::Millisecond) => DataType::Time32(ArrowTimeUnit::Millisecond), - T::Time32(unit) => fail!("Invalid time unit {unit} for Time32"), - T::Time64(U::Microsecond) => DataType::Time64(ArrowTimeUnit::Microsecond), - T::Time64(U::Nanosecond) => DataType::Time64(ArrowTimeUnit::Nanosecond), - T::Time64(unit) => fail!("Invalid time unit {unit} for Time64"), - T::Timestamp(unit, tz) => DataType::Timestamp((*unit).into(), tz.clone()), - T::Duration(unit) => DataType::Duration((*unit).into()), + fn try_from(value: &DataType) -> std::result::Result { + use {ArrowDataType as AT, DataType as T}; + match value { + T::Null => Ok(AT::Null), + T::Boolean => Ok(AT::Boolean), + T::Int8 => Ok(AT::Int8), + T::Int16 => Ok(AT::Int16), + T::Int32 => Ok(AT::Int32), + T::Int64 => Ok(AT::Int64), + T::UInt8 => Ok(AT::UInt8), + T::UInt16 => Ok(AT::UInt16), + T::UInt32 => Ok(AT::UInt32), + T::UInt64 => Ok(AT::UInt64), + T::Float16 => Ok(AT::Float16), + T::Float32 => Ok(AT::Float32), + T::Float64 => Ok(AT::Float64), + T::Date32 => Ok(AT::Date32), + T::Date64 => Ok(AT::Date64), + T::Duration(unit) => Ok(AT::Duration((*unit).into())), + T::Time32(unit) => Ok(AT::Time32((*unit).into())), + T::Time64(unit) => Ok(AT::Time64((*unit).into())), + T::Timestamp(unit, tz) => Ok(AT::Timestamp((*unit).into(), tz.clone())), T::Decimal128(precision, scale) => { if *scale < 0 { fail!("arrow2 does not support decimals with negative scale"); } - DataType::Decimal(*precision as usize, *scale as usize) + Ok(AT::Decimal((*precision).try_into()?, (*scale).try_into()?)) } - T::Utf8 => DataType::Utf8, - T::LargeUtf8 => DataType::LargeUtf8, - T::List => DataType::List(Box::new( - value - .children - .first() - .ok_or_else(|| error!("List must a single child"))? - .try_into()?, - )), - T::LargeList => DataType::LargeList(Box::new( - value - .children - .first() - .ok_or_else(|| error!("List must a single child"))? - .try_into()?, - )), - T::FixedSizeList(_) => fail!("FixedSizedList is not supported by arrow2"), - T::Binary => DataType::Binary, - T::LargeBinary => DataType::LargeBinary, - T::FixedSizeBinary(n) => DataType::FixedSizeBinary((*n).try_into()?), - T::Struct => DataType::Struct( - value - .children - .iter() - .map(Field::try_from) - .collect::>>()?, - ), - T::Map => { - let element_field: Field = value - .children - .first() - .ok_or_else(|| error!("Map must a two children"))? - .try_into()?; - DataType::Map(Box::new(element_field), false) - } - T::Union => DataType::Union( - value - .children - .iter() - .map(Field::try_from) - .collect::>>()?, - None, - UnionMode::Dense, - ), - T::Dictionary => { - let Some(key_field) = value.children.first() else { - fail!("Dictionary must a two children"); - }; - let val_field: Field = value - .children - .get(1) - .ok_or_else(|| error!("Dictionary must a two children"))? - .try_into()?; - - let key_type = match &key_field.data_type { - T::U8 => IntegerType::UInt8, - T::U16 => IntegerType::UInt16, - T::U32 => IntegerType::UInt32, - T::U64 => IntegerType::UInt64, - T::I8 => IntegerType::Int8, - T::I16 => IntegerType::Int16, - T::I32 => IntegerType::Int32, - T::I64 => IntegerType::Int64, - _ => fail!("Invalid key type for dictionary"), - }; - - DataType::Dictionary(key_type, Box::new(val_field.data_type), false) - } - }; + _ => todo!(), + } + } +} - let mut field = Field::new(&value.name, data_type, value.nullable); - field.metadata = value.metadata.clone().into_iter().collect(); +impl TryFrom<&Field> for ArrowField { + type Error = Error; - Ok(field) + fn try_from(value: &Field) -> Result { + Ok(ArrowField { + name: value.name.to_owned(), + data_type: ArrowDataType::try_from(&value.data_type)?, + is_nullable: value.nullable, + metadata: value.metadata.clone().into_iter().collect(), + }) } } @@ -324,3 +217,21 @@ impl From for TimeUnit { } } } + +impl From for UnionMode { + fn from(value: ArrowUnionMode) -> Self { + match value { + ArrowUnionMode::Dense => UnionMode::Dense, + ArrowUnionMode::Sparse => UnionMode::Sparse, + } + } +} + +impl From for ArrowUnionMode { + fn from(value: UnionMode) -> Self { + match value { + UnionMode::Dense => ArrowUnionMode::Dense, + UnionMode::Sparse => ArrowUnionMode::Sparse, + } + } +} diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index aa034c66..8c7a2d6b 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -14,23 +14,22 @@ use crate::{ datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType, - DurationNanosecondType, DurationSecondType, Field, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + DurationNanosecondType, DurationSecondType, Field as ArrowField, Float16Type, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionMode, }, }, internal::{ - arrow::FieldMeta, arrow::{ ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, DenseUnionArrayView, DictionaryArrayView, FixedSizeListArrayView, ListArrayView, NullArrayView, PrimitiveArrayView, StructArrayView, TimeArrayView, TimeUnit, TimestampArrayView, }, + arrow::{Field, FieldMeta}, error::{fail, Error, Result}, - schema::GenericField, utils::meta_from_field, }, }; @@ -105,8 +104,9 @@ impl TryFrom for ArrayData { for (field, meta) in arr.fields { let child: ArrayData = field.try_into()?; - let field = Field::new(meta.name, child.data_type().clone(), meta.nullable) - .with_metadata(meta.metadata); + let field = + ArrowField::new(meta.name, child.data_type().clone(), meta.nullable) + .with_metadata(meta.metadata); fields.push(Arc::new(field)); data.push(child); } @@ -465,7 +465,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { let mut fields = Vec::new(); for (field, array) in std::iter::zip(column_fields, array.columns()) { let view = ArrayView::try_from(array.as_ref())?; - let meta = meta_from_field(GenericField::try_from(field.as_ref())?)?; + let meta = meta_from_field(Field::try_from(field.as_ref())?)?; fields.push((view, meta)); } @@ -483,7 +483,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { Ok(ArrayView::Map(ListArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - meta: meta_from_field(GenericField::try_from(entries_field.as_ref())?)?, + meta: meta_from_field(Field::try_from(entries_field.as_ref())?)?, element: Box::new(entries_array.try_into()?), })) } else if let Some(array) = any.downcast_ref::>() { @@ -513,7 +513,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { fail!("invalid union, only unions with consecutive variants are supported"); } - let meta = meta_from_field(GenericField::try_from(field.as_ref())?)?; + let meta = meta_from_field(Field::try_from(field.as_ref())?)?; let view: ArrayView = array.child(type_id).as_ref().try_into()?; fields.push((view, meta)); } @@ -535,8 +535,8 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { } } -fn field_from_data_and_meta(data: &ArrayData, meta: FieldMeta) -> Field { - Field::new(meta.name, data.data_type().clone(), meta.nullable).with_metadata(meta.metadata) +fn field_from_data_and_meta(data: &ArrayData, meta: FieldMeta) -> ArrowField { + ArrowField::new(meta.name, data.data_type().clone(), meta.nullable).with_metadata(meta.metadata) } fn primitive_into_data( diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index a938641c..b76f1ef3 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -1,15 +1,18 @@ use std::sync::Arc; use crate::{ - _impl::arrow::datatypes::{DataType, Field, FieldRef, TimeUnit as ArrowTimeUnit, UnionMode}, + _impl::arrow::datatypes::{ + DataType as ArrowDataType, Field as ArrowField, FieldRef, TimeUnit as ArrowTimeUnit, + UnionMode as ArrowUnionMode, + }, internal::{ - arrow::TimeUnit, - error::{error, fail, Error, Result}, - schema::{GenericDataType, GenericField, SchemaLike, Sealed, SerdeArrowSchema}, + arrow::{DataType, Field, TimeUnit, UnionMode}, + error::{fail, Error, Result}, + schema::{validate_field, SchemaLike, Sealed, SerdeArrowSchema}, }, }; -impl TryFrom for Vec { +impl TryFrom for Vec { type Error = Error; fn try_from(value: SerdeArrowSchema) -> Result { @@ -17,11 +20,11 @@ impl TryFrom for Vec { } } -impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { +impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { type Error = Error; fn try_from(value: &'a SerdeArrowSchema) -> Result { - value.fields.iter().map(Field::try_from).collect() + value.fields.iter().map(ArrowField::try_from).collect() } } @@ -40,20 +43,17 @@ impl<'a> TryFrom<&'a SerdeArrowSchema> for Vec { value .fields .iter() - .map(|f| Ok(Arc::new(Field::try_from(f)?))) + .map(|f| Ok(Arc::new(ArrowField::try_from(f)?))) .collect() } } -impl<'a> TryFrom<&'a [Field]> for SerdeArrowSchema { +impl<'a> TryFrom<&'a [ArrowField]> for SerdeArrowSchema { type Error = Error; - fn try_from(fields: &'a [Field]) -> Result { + fn try_from(fields: &'a [ArrowField]) -> Result { Ok(Self { - fields: fields - .iter() - .map(GenericField::try_from) - .collect::>()?, + fields: fields.iter().map(Field::try_from).collect::>()?, }) } } @@ -65,17 +65,17 @@ impl<'a> TryFrom<&'a [FieldRef]> for SerdeArrowSchema { Ok(Self { fields: fields .iter() - .map(|f| GenericField::try_from(f.as_ref())) + .map(|f| Field::try_from(f.as_ref())) .collect::>()?, }) } } -impl Sealed for Vec {} +impl Sealed for Vec {} /// Schema support for `Vec` (*requires one of the /// `arrow-*` features*) -impl SchemaLike for Vec { +impl SchemaLike for Vec { fn from_value(value: &T) -> Result { SerdeArrowSchema::from_value(value)?.try_into() } @@ -117,251 +117,95 @@ impl SchemaLike for Vec { } } -impl TryFrom<&DataType> for GenericDataType { +impl TryFrom<&ArrowDataType> for DataType { type Error = Error; - fn try_from(value: &DataType) -> Result { - use {GenericDataType as T, TimeUnit as U}; + fn try_from(value: &ArrowDataType) -> Result { + use {DataType as T, TimeUnit as U}; match value { - DataType::Boolean => Ok(T::Bool), - DataType::Null => Ok(T::Null), - DataType::Int8 => Ok(T::I8), - DataType::Int16 => Ok(T::I16), - DataType::Int32 => Ok(T::I32), - DataType::Int64 => Ok(T::I64), - DataType::UInt8 => Ok(T::U8), - DataType::UInt16 => Ok(T::U16), - DataType::UInt32 => Ok(T::U32), - DataType::UInt64 => Ok(T::U64), - DataType::Float16 => Ok(T::F16), - DataType::Float32 => Ok(T::F32), - DataType::Float64 => Ok(T::F64), - DataType::Utf8 => Ok(T::Utf8), - DataType::LargeUtf8 => Ok(T::LargeUtf8), - DataType::Date32 => Ok(T::Date32), - DataType::Date64 => Ok(T::Date64), - DataType::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), - DataType::Time32(ArrowTimeUnit::Second) => Ok(T::Time32(U::Second)), - DataType::Time32(ArrowTimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), - DataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - DataType::Time64(ArrowTimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), - DataType::Time64(ArrowTimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), - DataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - DataType::Timestamp(ArrowTimeUnit::Second, tz) => { + ArrowDataType::Boolean => Ok(T::Boolean), + ArrowDataType::Null => Ok(T::Null), + ArrowDataType::Int8 => Ok(T::Int8), + ArrowDataType::Int16 => Ok(T::Int16), + ArrowDataType::Int32 => Ok(T::Int32), + ArrowDataType::Int64 => Ok(T::Int64), + ArrowDataType::UInt8 => Ok(T::UInt8), + ArrowDataType::UInt16 => Ok(T::UInt16), + ArrowDataType::UInt32 => Ok(T::UInt32), + ArrowDataType::UInt64 => Ok(T::UInt64), + ArrowDataType::Float16 => Ok(T::Float16), + ArrowDataType::Float32 => Ok(T::Float32), + ArrowDataType::Float64 => Ok(T::Float64), + ArrowDataType::Utf8 => Ok(T::Utf8), + ArrowDataType::LargeUtf8 => Ok(T::LargeUtf8), + ArrowDataType::Date32 => Ok(T::Date32), + ArrowDataType::Date64 => Ok(T::Date64), + ArrowDataType::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), + ArrowDataType::Time32(ArrowTimeUnit::Second) => Ok(T::Time32(U::Second)), + ArrowDataType::Time32(ArrowTimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), + ArrowDataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), + ArrowDataType::Time64(ArrowTimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), + ArrowDataType::Time64(ArrowTimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), + ArrowDataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), + ArrowDataType::Timestamp(ArrowTimeUnit::Second, tz) => { Ok(T::Timestamp(U::Second, tz.as_ref().map(|s| s.to_string()))) } - DataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => Ok(T::Timestamp( + ArrowDataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => Ok(T::Timestamp( U::Millisecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => Ok(T::Timestamp( + ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => Ok(T::Timestamp( U::Microsecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => Ok(T::Timestamp( + ArrowDataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => Ok(T::Timestamp( U::Nanosecond, tz.as_ref().map(|s| s.to_string()), )), - DataType::Duration(ArrowTimeUnit::Second) => Ok(T::Duration(U::Second)), - DataType::Duration(ArrowTimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), - DataType::Duration(ArrowTimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), - DataType::Duration(ArrowTimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), - DataType::Binary => Ok(T::Binary), - DataType::LargeBinary => Ok(T::LargeBinary), - DataType::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), + ArrowDataType::Duration(ArrowTimeUnit::Second) => Ok(T::Duration(U::Second)), + ArrowDataType::Duration(ArrowTimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), + ArrowDataType::Duration(ArrowTimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), + ArrowDataType::Duration(ArrowTimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), + ArrowDataType::Binary => Ok(T::Binary), + ArrowDataType::LargeBinary => Ok(T::LargeBinary), + ArrowDataType::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), _ => fail!("Only primitive data types can be converted to T"), } } } -impl TryFrom<&Field> for GenericField { +impl TryFrom<&ArrowField> for Field { type Error = Error; - fn try_from(field: &Field) -> Result { - let mut children = Vec::::new(); - let data_type = match field.data_type() { - DataType::List(field) => { - children.push(GenericField::try_from(field.as_ref())?); - GenericDataType::List - } - DataType::LargeList(field) => { - children.push(field.as_ref().try_into()?); - GenericDataType::LargeList - } - DataType::FixedSizeList(field, n) => { - children.push(field.as_ref().try_into()?); - GenericDataType::FixedSizeList(*n) - } - DataType::Struct(fields) => { - for field in fields { - children.push(field.as_ref().try_into()?); - } - GenericDataType::Struct - } - DataType::Map(field, _) => { - children.push(field.as_ref().try_into()?); - GenericDataType::Map - } - DataType::Union(fields, mode) => { - if !matches!(mode, UnionMode::Dense) { - fail!("Only dense unions are supported at the moment"); - } - - for (pos, (idx, field)) in fields.iter().enumerate() { - if pos as i8 != idx { - fail!("Union types with non-sequential field indices are not supported"); - } - children.push(field.as_ref().try_into()?); - } - GenericDataType::Union - } - DataType::Dictionary(key_type, value_type) => { - children.push(GenericField::new("", key_type.as_ref().try_into()?, false)); - children.push(GenericField::new( - "", - value_type.as_ref().try_into()?, - false, - )); - GenericDataType::Dictionary - } - dt => dt.try_into()?, - }; - - let name = field.name().to_owned(); - let nullable = field.is_nullable(); - let metadata = field.metadata().clone(); - - let field = GenericField { - name, - data_type, - metadata, - children, - nullable, + fn try_from(field: &ArrowField) -> Result { + let field = Field { + name: field.name().to_owned(), + data_type: DataType::try_from(field.data_type())?, + metadata: field.metadata().clone(), + nullable: field.is_nullable(), }; - field.validate()?; - + validate_field(&field)?; Ok(field) } } -impl TryFrom<&GenericField> for Field { +impl TryFrom<&DataType> for ArrowDataType { type Error = Error; - fn try_from(value: &GenericField) -> Result { - use {GenericDataType as T, TimeUnit as U}; - - let data_type = match &value.data_type { - T::Null => DataType::Null, - T::Bool => DataType::Boolean, - T::I8 => DataType::Int8, - T::I16 => DataType::Int16, - T::I32 => DataType::Int32, - T::I64 => DataType::Int64, - T::U8 => DataType::UInt8, - T::U16 => DataType::UInt16, - T::U32 => DataType::UInt32, - T::U64 => DataType::UInt64, - T::F16 => DataType::Float16, - T::F32 => DataType::Float32, - T::F64 => DataType::Float64, - T::Date32 => DataType::Date32, - T::Date64 => DataType::Date64, - T::Decimal128(precision, scale) => DataType::Decimal128(*precision, *scale), - T::Utf8 => DataType::Utf8, - T::LargeUtf8 => DataType::LargeUtf8, - T::List => DataType::List( - Box::::new( - value - .children - .first() - .ok_or_else(|| error!("List must a single child"))? - .try_into()?, - ) - .into(), - ), - T::LargeList => DataType::LargeList( - Box::::new( - value - .children - .first() - .ok_or_else(|| error!("List must a single child"))? - .try_into()?, - ) - .into(), - ), - T::FixedSizeList(n) => DataType::FixedSizeList( - Box::::new( - value - .children - .first() - .ok_or_else(|| error!("List must a single child"))? - .try_into()?, - ) - .into(), - *n, - ), - T::Binary => DataType::Binary, - T::LargeBinary => DataType::LargeBinary, - T::FixedSizeBinary(n) => DataType::FixedSizeBinary(*n), - T::Struct => DataType::Struct( - value - .children - .iter() - .map(Field::try_from) - .collect::>()?, - ), - T::Map => { - let element_field: Field = value - .children - .first() - .ok_or_else(|| error!("Map must a single child"))? - .try_into()?; - DataType::Map(Box::new(element_field).into(), false) - } - T::Union => { - let mut fields = Vec::new(); - for (idx, field) in value.children.iter().enumerate() { - fields.push((idx as i8, std::sync::Arc::new(Field::try_from(field)?))); - } - DataType::Union(fields.into_iter().collect(), UnionMode::Dense) - } - T::Dictionary => { - let Some(key_field) = value.children.first() else { - fail!("Dictionary must a two children"); - }; - let val_field: Field = value - .children - .get(1) - .ok_or_else(|| error!("Dictionary must a two children"))? - .try_into()?; - - let key_type = match &key_field.data_type { - GenericDataType::U8 => DataType::UInt8, - GenericDataType::U16 => DataType::UInt16, - GenericDataType::U32 => DataType::UInt32, - GenericDataType::U64 => DataType::UInt64, - GenericDataType::I8 => DataType::Int8, - GenericDataType::I16 => DataType::Int16, - GenericDataType::I32 => DataType::Int32, - GenericDataType::I64 => DataType::Int64, - _ => fail!("Invalid key type for dictionary"), - }; - - DataType::Dictionary(Box::new(key_type), Box::new(val_field.data_type().clone())) - } - T::Time32(U::Second) => DataType::Time32(ArrowTimeUnit::Second), - T::Time32(U::Millisecond) => DataType::Time32(ArrowTimeUnit::Millisecond), - T::Time32(unit) => fail!("invalid time unit {unit} for Time32"), - T::Time64(U::Microsecond) => DataType::Time64(ArrowTimeUnit::Microsecond), - T::Time64(U::Nanosecond) => DataType::Time64(ArrowTimeUnit::Nanosecond), - T::Time64(unit) => fail!("invalid time unit {unit} for Time64"), - T::Timestamp(unit, tz) => { - DataType::Timestamp((*unit).into(), tz.clone().map(|s| s.into())) - } - T::Duration(unit) => DataType::Duration((*unit).into()), - }; + fn try_from(value: &DataType) -> std::result::Result { + todo!() + } +} - let mut field = Field::new(&value.name, data_type, value.nullable); +impl TryFrom<&Field> for ArrowField { + type Error = Error; + + fn try_from(value: &Field) -> Result { + let mut field = ArrowField::new( + &value.name, + ArrowDataType::try_from(&value.data_type)?, + value.nullable, + ); field.set_metadata(value.metadata.clone()); Ok(field) @@ -378,3 +222,21 @@ impl From for ArrowTimeUnit { } } } + +impl From for UnionMode { + fn from(value: ArrowUnionMode) -> Self { + match value { + ArrowUnionMode::Dense => UnionMode::Dense, + ArrowUnionMode::Sparse => UnionMode::Sparse, + } + } +} + +impl From for ArrowUnionMode { + fn from(value: UnionMode) -> Self { + match value { + UnionMode::Dense => ArrowUnionMode::Dense, + UnionMode::Sparse => ArrowUnionMode::Sparse, + } + } +} diff --git a/serde_arrow/src/arrow_impl/type_support.rs b/serde_arrow/src/arrow_impl/type_support.rs index 8fdbddf0..7f34f4c3 100644 --- a/serde_arrow/src/arrow_impl/type_support.rs +++ b/serde_arrow/src/arrow_impl/type_support.rs @@ -1,11 +1,12 @@ use crate::_impl::arrow::{ - datatypes::{Field, FieldRef}, + datatypes::{Field as ArrowField, FieldRef}, error::ArrowError, }; use crate::internal::{ + arrow::Field, error::{Error, Result}, - schema::{extensions::FixedShapeTensorField, GenericField}, + schema::extensions::FixedShapeTensorField, }; impl From for Error { @@ -14,15 +15,15 @@ impl From for Error { } } -impl TryFrom<&FixedShapeTensorField> for Field { +impl TryFrom<&FixedShapeTensorField> for ArrowField { type Error = Error; fn try_from(value: &FixedShapeTensorField) -> Result { - Self::try_from(&GenericField::try_from(value)?) + Self::try_from(&Field::try_from(value)?) } } -impl TryFrom for Field { +impl TryFrom for ArrowField { type Error = Error; fn try_from(value: FixedShapeTensorField) -> Result { @@ -30,9 +31,9 @@ impl TryFrom for Field { } } -pub fn fields_from_field_refs(fields: &[FieldRef]) -> Result> { +pub fn fields_from_field_refs(fields: &[FieldRef]) -> Result> { fields .iter() - .map(|field| GenericField::try_from(field.as_ref())) + .map(|field| Field::try_from(field.as_ref())) .collect() } diff --git a/serde_arrow/src/internal/arrow/data_type.rs b/serde_arrow/src/internal/arrow/data_type.rs index 896b8106..a17405bd 100644 --- a/serde_arrow/src/internal/arrow/data_type.rs +++ b/serde_arrow/src/internal/arrow/data_type.rs @@ -1,18 +1,18 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use serde::{Deserialize, Serialize}; use crate::internal::error::{fail, Error, Result}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Field { pub name: String, pub data_type: DataType, + pub nullable: bool, pub metadata: HashMap, } -#[allow(unused)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[non_exhaustive] pub enum DataType { Null, @@ -32,15 +32,21 @@ pub enum DataType { LargeUtf8, Binary, LargeBinary, + FixedSizeBinary(i32), Date32, Date64, - Timestamp(TimeUnit, Option>), + Timestamp(TimeUnit, Option), Time32(TimeUnit), Time64(TimeUnit), - Decimal128, + Duration(TimeUnit), + Decimal128(u8, i8), Struct(Vec), List(Box), LargeList(Box), + FixedSizeList(Box, i32), + Map(Box, bool), + Dictionary(Box, Box, bool), + Union(Vec<(i8, Field)>, UnionMode), } #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] @@ -71,7 +77,34 @@ impl std::str::FromStr for TimeUnit { "Millisecond" => Ok(Self::Millisecond), "Microsecond" => Ok(Self::Microsecond), "Nanosecond" => Ok(Self::Nanosecond), - s => fail!("Invalid time unit {s}"), + s => fail!("Invalid TimeUnit: {s}"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum UnionMode { + Sparse, + Dense, +} + +impl std::fmt::Display for UnionMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UnionMode::Sparse => write!(f, "Sparse"), + UnionMode::Dense => write!(f, "Dense"), + } + } +} + +impl std::str::FromStr for UnionMode { + type Err = Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "Sparse" => Ok(UnionMode::Sparse), + "Dense" => Ok(UnionMode::Dense), + s => fail!("Invalid UnionMode: {s}"), } } } diff --git a/serde_arrow/src/internal/arrow/mod.rs b/serde_arrow/src/internal/arrow/mod.rs index 78c62c9a..8943c04d 100644 --- a/serde_arrow/src/internal/arrow/mod.rs +++ b/serde_arrow/src/internal/arrow/mod.rs @@ -11,5 +11,4 @@ pub use array::{ PrimitiveArray, PrimitiveArrayView, StructArray, StructArrayView, TimeArray, TimeArrayView, TimestampArray, TimestampArrayView, }; -#[allow(unused)] -pub use data_type::{DataType, Field, TimeUnit}; +pub use data_type::{DataType, Field, TimeUnit, UnionMode}; diff --git a/serde_arrow/src/internal/schema/deserialization.rs b/serde_arrow/src/internal/schema/deserialization.rs index 229624a5..476f58ae 100644 --- a/serde_arrow/src/internal/schema/deserialization.rs +++ b/serde_arrow/src/internal/schema/deserialization.rs @@ -1,6 +1,8 @@ //! Deserialization of SchemaLike objects with explicit support to deserialize //! from arrow-rs types +// TODO: delete me + use std::{collections::HashMap, str::FromStr}; use serde::{de::Visitor, Deserialize}; @@ -9,7 +11,7 @@ use crate::internal::{ arrow::TimeUnit, error::{fail, Error, Result}, schema::{ - merge_strategy_with_metadata, GenericDataType, GenericField, SerdeArrowSchema, Strategy, + merge_strategy_with_metadata, GenericDataType, SerdeArrowSchema, Strategy, }, }; @@ -315,55 +317,3 @@ impl<'de> Deserialize<'de> for GenericField { Ok(res) } } - -// A custom impl of untagged-enum repr with better error messages -impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { - fn deserialize>(deserializer: D) -> Result { - struct VisitorImpl; - - impl<'de> Visitor<'de> for VisitorImpl { - type Value = SerdeArrowSchema; - - fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "a sequence of fields or a struct with key 'fields' containing a sequence of fields") - } - - fn visit_seq>( - self, - mut seq: A, - ) -> Result { - let mut fields = Vec::new(); - while let Some(item) = seq.next_element::()? { - fields.push(item); - } - - Ok(SerdeArrowSchema { fields }) - } - - fn visit_map>( - self, - mut map: A, - ) -> Result { - use serde::de::Error; - - let mut fields = None; - - while let Some(key) = map.next_key::()? { - if key == "fields" { - fields = Some(map.next_value::>()?); - } else { - map.next_value::()?; - } - } - - let Some(fields) = fields else { - return Err(A::Error::custom("missing field `fields`")); - }; - - Ok(SerdeArrowSchema { fields }) - } - } - - deserializer.deserialize_any(VisitorImpl) - } -} diff --git a/serde_arrow/src/internal/schema/extensions.rs b/serde_arrow/src/internal/schema/extensions.rs index 7b825872..221f8d20 100644 --- a/serde_arrow/src/internal/schema/extensions.rs +++ b/serde_arrow/src/internal/schema/extensions.rs @@ -2,9 +2,10 @@ use std::collections::HashMap; use serde::Serialize; -use super::{GenericDataType, GenericField}; use crate::internal::{ + arrow::{DataType, Field}, error::{fail, Error, Result}, + schema::ArrowOrCustomField, utils::value, }; @@ -44,7 +45,7 @@ use crate::internal::{ pub struct FixedShapeTensorField { name: String, nullable: bool, - element: GenericField, + element: Field, shape: Vec, dim_names: Option>, permutation: Option>, @@ -57,7 +58,8 @@ impl FixedShapeTensorField { /// with the the name `"element"`. The field type can be any valid Arrow /// type. pub fn new(name: &str, element: impl Serialize, shape: Vec) -> Result { - let element: GenericField = value::transmute(&element)?; + let element: ArrowOrCustomField = value::transmute(&element)?; + let element = element.into_field()?; if element.name != "element" { fail!("The element field of FixedShapeTensorField must be named \"element\""); } @@ -118,7 +120,7 @@ impl FixedShapeTensorField { } } -impl TryFrom<&FixedShapeTensorField> for GenericField { +impl TryFrom<&FixedShapeTensorField> for Field { type Error = Error; fn try_from(value: &FixedShapeTensorField) -> Result { @@ -134,11 +136,10 @@ impl TryFrom<&FixedShapeTensorField> for GenericField { ); metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); - Ok(GenericField { - name: value.name.clone(), + Ok(Field { + name: value.name.to_owned(), nullable: value.nullable, - data_type: GenericDataType::FixedSizeList(n.try_into()?), - children: vec![value.element.clone()], + data_type: DataType::FixedSizeList(Box::new(value.element.clone()), n.try_into()?), metadata, }) } @@ -146,10 +147,11 @@ impl TryFrom<&FixedShapeTensorField> for GenericField { impl serde::ser::Serialize for FixedShapeTensorField { fn serialize(&self, serializer: S) -> Result { - use serde::ser::Error; - GenericField::try_from(self) - .map_err(S::Error::custom)? - .serialize(serializer) + // use serde::ser::Error; + // Field::try_from(self) + // .map_err(S::Error::custom)? + // .serialize(serializer) + todo!() } } @@ -162,7 +164,7 @@ impl serde::ser::Serialize for FixedShapeTensorField { /// https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor pub struct VariableShapeTensorField { name: String, - element: GenericField, + element: Field, ndim: usize, nullable: bool, dim_names: Option>, @@ -172,7 +174,8 @@ pub struct VariableShapeTensorField { impl VariableShapeTensorField { pub fn new(name: &str, element: impl serde::ser::Serialize, ndim: usize) -> Result { - let element: GenericField = value::transmute(&element)?; + let element: ArrowOrCustomField = value::transmute(&element)?; + let element = element.into_field()?; if element.name != "element" { fail!("The element field of FixedShapeTensorField must be named \"element\""); } @@ -268,7 +271,7 @@ impl VariableShapeTensorField { } } -impl TryFrom<&VariableShapeTensorField> for GenericField { +impl TryFrom<&VariableShapeTensorField> for Field { type Error = Error; fn try_from(value: &VariableShapeTensorField) -> Result { @@ -279,24 +282,32 @@ impl TryFrom<&VariableShapeTensorField> for GenericField { ); metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); - Ok(GenericField { + let mut fields = Vec::new(); + fields.push(Field { + name: String::from("data"), + data_type: DataType::List(Box::new(value.element.clone())), + nullable: false, + metadata: HashMap::new(), + }); + fields.push(Field { + name: String::from("shape"), + data_type: DataType::FixedSizeList( + Box::new(Field { + name: String::from("element"), + data_type: DataType::Int32, + nullable: false, + metadata: HashMap::new(), + }), + value.ndim.try_into()?, + ), + nullable: false, + metadata: HashMap::new(), + }); + + Ok(Field { name: value.name.clone(), nullable: value.nullable, - data_type: GenericDataType::Struct, - children: vec![ - GenericField::new("data", GenericDataType::List, false) - .with_child(value.element.clone()), - GenericField::new( - "shape", - GenericDataType::FixedSizeList(value.ndim.try_into()?), - false, - ) - .with_child(GenericField::new( - "element", - GenericDataType::I32, - false, - )), - ], + data_type: DataType::Struct(fields), metadata, }) } @@ -304,10 +315,11 @@ impl TryFrom<&VariableShapeTensorField> for GenericField { impl serde::ser::Serialize for VariableShapeTensorField { fn serialize(&self, serializer: S) -> Result { - use serde::ser::Error; - GenericField::try_from(self) - .map_err(S::Error::custom)? - .serialize(serializer) + // use serde::ser::Error; + // GenericField::try_from(self) + // .map_err(S::Error::custom)? + // .serialize(serializer) + todo!() } } diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index d87378db..1748d05d 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -8,14 +8,13 @@ use std::sync::Arc; use serde::{ser::Impossible, Serialize}; use crate::internal::{ + arrow::DataType, error::{fail, Error, Result}, - schema::{GenericDataType, Strategy}, + schema::{Strategy, TracingMode, TracingOptions}, }; -use super::tracing_options::TracingOptions; -use super::{ - tracer::{ListTracer, MapTracer, StructMode, StructTracer, Tracer, TupleTracer, UnionVariant}, - TracingMode, +use super::tracer::{ + ListTracer, MapTracer, StructMode, StructTracer, Tracer, TupleTracer, UnionVariant, }; impl Tracer { @@ -172,72 +171,72 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { type SerializeTupleVariant = TupleSerializer<'a>; fn serialize_bool(self, _: bool) -> Result { - self.0.ensure_primitive(GenericDataType::Bool) + self.0.ensure_primitive(DataType::Boolean) } fn serialize_i8(self, _: i8) -> Result { - self.0.ensure_number(GenericDataType::I8) + self.0.ensure_number(DataType::Int8) } fn serialize_i16(self, _: i16) -> Result { - self.0.ensure_number(GenericDataType::I16) + self.0.ensure_number(DataType::Int16) } fn serialize_i32(self, _: i32) -> Result { - self.0.ensure_number(GenericDataType::I32) + self.0.ensure_number(DataType::Int32) } fn serialize_i64(self, _: i64) -> Result { - self.0.ensure_number(GenericDataType::I64) + self.0.ensure_number(DataType::Int64) } fn serialize_u8(self, _: u8) -> Result { - self.0.ensure_number(GenericDataType::U8) + self.0.ensure_number(DataType::UInt8) } fn serialize_u16(self, _: u16) -> Result { - self.0.ensure_number(GenericDataType::U16) + self.0.ensure_number(DataType::UInt16) } fn serialize_u32(self, _: u32) -> Result { - self.0.ensure_number(GenericDataType::U32) + self.0.ensure_number(DataType::UInt32) } fn serialize_u64(self, _: u64) -> Result { - self.0.ensure_number(GenericDataType::U64) + self.0.ensure_number(DataType::UInt64) } fn serialize_f32(self, _: f32) -> Result { - self.0.ensure_number(GenericDataType::F32) + self.0.ensure_number(DataType::Float32) } fn serialize_f64(self, _: f64) -> Result { - self.0.ensure_number(GenericDataType::F64) + self.0.ensure_number(DataType::Float64) } fn serialize_char(self, _: char) -> Result { - self.0.ensure_primitive(GenericDataType::U32) + self.0.ensure_primitive(DataType::UInt32) } fn serialize_unit(self) -> Result { - self.0.ensure_primitive(GenericDataType::Null) + self.0.ensure_primitive(DataType::Null) } fn serialize_str(self, s: &str) -> Result { let guess_dates = self.0.get_options().guess_dates; if guess_dates && chrono::matches_naive_datetime(s) { self.0 - .ensure_utf8(GenericDataType::Date64, Some(Strategy::NaiveStrAsDate64)) + .ensure_utf8(DataType::Date64, Some(Strategy::NaiveStrAsDate64)) } else if guess_dates && chrono::matches_utc_datetime(s) { self.0 - .ensure_utf8(GenericDataType::Date64, Some(Strategy::UtcStrAsDate64)) + .ensure_utf8(DataType::Date64, Some(Strategy::UtcStrAsDate64)) } else { - self.0.ensure_utf8(GenericDataType::LargeUtf8, None) + self.0.ensure_utf8(DataType::LargeUtf8, None) } } fn serialize_bytes(self, _: &[u8]) -> Result { - self.0.ensure_primitive(GenericDataType::LargeBinary) + self.0.ensure_primitive(DataType::LargeBinary) } fn serialize_none(self) -> Result { @@ -321,7 +320,7 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { variant_name: &'static str, ) -> Result { let variant = self.ensure_union_variant(variant_name, variant_index)?; - variant.tracer.ensure_primitive(GenericDataType::Null) + variant.tracer.ensure_primitive(DataType::Null) } fn serialize_newtype_variant( @@ -591,14 +590,15 @@ mod test { use serde::Serialize; use serde_json::{json, Value}; - use crate::internal::schema::{GenericField, TracingOptions}; + use crate::internal::schema::{ArrowOrCustomField, TracingOptions}; use super::*; fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); - let expected = serde_json::from_value::(expected).unwrap(); + let expected = serde_json::from_value::(expected).unwrap(); + let expected = expected.into_field().unwrap(); assert_eq!(field, expected); } diff --git a/serde_arrow/src/internal/schema/from_type/mod.rs b/serde_arrow/src/internal/schema/from_type/mod.rs index 28155398..d127ca08 100644 --- a/serde_arrow/src/internal/schema/from_type/mod.rs +++ b/serde_arrow/src/internal/schema/from_type/mod.rs @@ -9,13 +9,14 @@ use serde::{ Deserialize, Deserializer, }; -use crate::internal::error::{fail, Error, Result}; - -use super::{ - tracer::{StructField, StructMode, Tracer}, - GenericDataType, TracingMode, TracingOptions, +use crate::internal::{ + arrow::DataType, + error::{fail, Error, Result}, + schema::{TracingMode, TracingOptions}, }; +use super::tracer::{StructField, StructMode, Tracer}; + impl Tracer { pub fn from_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> Result { let options = options.tracing_mode(TracingMode::FromType); @@ -58,82 +59,82 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { } fn deserialize_bool>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::Bool)?; + self.0.ensure_primitive(DataType::Boolean)?; visitor.visit_bool(Default::default()) } fn deserialize_i8>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::I8)?; + self.0.ensure_primitive(DataType::Int8)?; visitor.visit_i8(Default::default()) } fn deserialize_i16>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::I16)?; + self.0.ensure_primitive(DataType::Int16)?; visitor.visit_i16(Default::default()) } fn deserialize_i32>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::I32)?; + self.0.ensure_primitive(DataType::Int32)?; visitor.visit_i32(Default::default()) } fn deserialize_i64>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::I64)?; + self.0.ensure_primitive(DataType::Int64)?; visitor.visit_i64(Default::default()) } fn deserialize_u8>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::U8)?; + self.0.ensure_primitive(DataType::UInt8)?; visitor.visit_u8(Default::default()) } fn deserialize_u16>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::U16)?; + self.0.ensure_primitive(DataType::UInt16)?; visitor.visit_u16(Default::default()) } fn deserialize_u32>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::U32)?; + self.0.ensure_primitive(DataType::UInt32)?; visitor.visit_u32(Default::default()) } fn deserialize_u64>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::U64)?; + self.0.ensure_primitive(DataType::UInt64)?; visitor.visit_u64(Default::default()) } fn deserialize_f32>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::F32)?; + self.0.ensure_primitive(DataType::Float32)?; visitor.visit_f32(Default::default()) } fn deserialize_f64>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::F64)?; + self.0.ensure_primitive(DataType::Float64)?; visitor.visit_f64(Default::default()) } fn deserialize_char>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::U32)?; + self.0.ensure_primitive(DataType::UInt32)?; visitor.visit_char(Default::default()) } fn deserialize_str>(self, visitor: V) -> Result { - self.0.ensure_utf8(GenericDataType::LargeUtf8, None)?; + self.0.ensure_utf8(DataType::LargeUtf8, None)?; visitor.visit_borrowed_str("") } fn deserialize_string>(self, visitor: V) -> Result { - self.0.ensure_utf8(GenericDataType::LargeUtf8, None)?; + self.0.ensure_utf8(DataType::LargeUtf8, None)?; visitor.visit_string(Default::default()) } fn deserialize_bytes>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::LargeBinary)?; + self.0.ensure_primitive(DataType::LargeBinary)?; visitor.visit_borrowed_bytes(&[]) } fn deserialize_byte_buf>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::LargeBinary)?; + self.0.ensure_primitive(DataType::LargeBinary)?; visitor.visit_byte_buf(Default::default()) } @@ -143,7 +144,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { } fn deserialize_unit>(self, visitor: V) -> Result { - self.0.ensure_primitive(GenericDataType::Null)?; + self.0.ensure_primitive(DataType::Null)?; visitor.visit_unit() } @@ -152,7 +153,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { _name: &'static str, visitor: V, ) -> Result { - self.0.ensure_primitive(GenericDataType::Null)?; + self.0.ensure_primitive(DataType::Null)?; visitor.visit_unit() } diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index fb77d80e..c651e988 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -1,7 +1,6 @@ pub mod extensions; mod data_type; -mod deserialization; mod from_samples; mod from_type; mod serde; @@ -12,22 +11,22 @@ mod tracing_options; #[cfg(test)] mod test; -use std::collections::HashMap; - use crate::internal::{ - arrow::TimeUnit, + arrow::{TimeUnit, UnionMode}, error::{fail, Result}, utils::value, }; use ::serde::{Deserialize, Serialize}; -pub use data_type::GenericDataType; +pub use serde::deserialize::ArrowOrCustomField; pub use strategy::get_strategy_from_metadata; pub use strategy::{merge_strategy_with_metadata, Strategy, STRATEGY_KEY}; use tracer::Tracer; pub use tracing_options::{Overwrites, TracingMode, TracingOptions}; +use super::arrow::{DataType, Field}; + pub trait Sealed {} /// A sealed trait to add support for constructing schema-like objects @@ -287,9 +286,9 @@ pub trait SchemaLike: Sized + Sealed { /// /// It can be converted from / to arrow or arrow2 fields. /// -#[derive(Default, Debug, PartialEq, Clone, Serialize)] +#[derive(Default, Debug, PartialEq, Clone)] pub struct SerdeArrowSchema { - pub(crate) fields: Vec, + pub(crate) fields: Vec, } impl Sealed for SerdeArrowSchema {} @@ -308,331 +307,232 @@ impl SchemaLike for SerdeArrowSchema { } } -#[derive(Debug, Clone, PartialEq)] -pub struct GenericField { - pub name: String, - pub data_type: GenericDataType, - pub metadata: HashMap, - pub nullable: bool, - pub children: Vec, -} - -impl GenericField { - pub fn new(name: &str, data_type: GenericDataType, nullable: bool) -> Self { - Self { - name: name.to_string(), - data_type, - metadata: HashMap::new(), - nullable, - children: Vec::new(), +pub fn validate_field(field: &Field) -> Result<()> { + match &field.data_type { + DataType::Null => validate_null_field(field), + DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Decimal128(_, _) + | DataType::Date32 + | DataType::Binary + | DataType::LargeBinary + | DataType::Duration(_) => validate_primitive_field(field), + DataType::FixedSizeBinary(n) => validate_fixed_size_binary_field(field, *n), + DataType::Date64 => validate_date64_field(field), + DataType::Timestamp(unit, tz) => validate_timestamp_field(field, *unit, tz.as_deref()), + DataType::Time32(unit) => validate_time32_field(field, *unit), + DataType::Time64(unit) => validate_time64_field(field, *unit), + DataType::Struct(fields) => validate_struct_field(field, fields.as_slice()), + DataType::Map(entry, _) => validate_map_field(field, entry.as_ref()), + DataType::List(entry) => validate_list_field(field, entry.as_ref()), + DataType::LargeList(entry) => validate_list_field(field, entry.as_ref()), + DataType::FixedSizeList(entry, n) => { + validate_fixed_size_list_field(field, entry.as_ref(), *n) + } + DataType::Union(fields, mode) => validate_union_field(field, fields.as_slice(), *mode), + DataType::Dictionary(key, values, _) => { + validate_dictionary_field(field, key.as_ref(), values.as_ref()) } } +} - pub fn is_valid(&self) -> bool { - self.validate().is_ok() +fn validate_null_field(field: &Field) -> Result<()> { + match get_strategy_from_metadata(&field.metadata)? { + None | Some(Strategy::InconsistentTypes) | Some(Strategy::UnknownVariant) => Ok(()), + Some(strategy) => fail!("invalid strategy for Null field: {strategy}"), } +} - pub fn validate(&self) -> Result<()> { - match self.data_type { - GenericDataType::Null => self.validate_null(), - GenericDataType::Bool => self.validate_primitive(), - GenericDataType::U8 => self.validate_primitive(), - GenericDataType::U16 => self.validate_primitive(), - GenericDataType::U32 => self.validate_primitive(), - GenericDataType::U64 => self.validate_primitive(), - GenericDataType::I8 => self.validate_primitive(), - GenericDataType::I16 => self.validate_primitive(), - GenericDataType::I32 => self.validate_primitive(), - GenericDataType::I64 => self.validate_primitive(), - GenericDataType::F16 => self.validate_primitive(), - GenericDataType::F32 => self.validate_primitive(), - GenericDataType::F64 => self.validate_primitive(), - GenericDataType::Utf8 => self.validate_primitive(), - GenericDataType::LargeUtf8 => self.validate_primitive(), - GenericDataType::Date32 => self.validate_date32(), - GenericDataType::Date64 => self.validate_date64(), - GenericDataType::Struct => self.validate_struct(), - GenericDataType::Map => self.validate_map(), - GenericDataType::List => self.validate_list(), - GenericDataType::LargeList => self.validate_list(), - GenericDataType::FixedSizeList(n) => self.validate_fixed_size_list(n), - GenericDataType::Binary => self.validate_binary(), - GenericDataType::LargeBinary => self.validate_binary(), - GenericDataType::FixedSizeBinary(n) => self.validate_fixed_size_binary(n), - GenericDataType::Union => self.validate_union(), - GenericDataType::Dictionary => self.validate_dictionary(), - GenericDataType::Timestamp(_, _) => self.validate_timestamp(), - GenericDataType::Time32(_) => self.validate_time32(), - GenericDataType::Time64(_) => self.validate_time64(), - GenericDataType::Duration(_) => self.validate_duration(), - GenericDataType::Decimal128(_, _) => self.validate_primitive(), - } +fn validate_primitive_field(field: &Field) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!( + "invalid strategy for {data_type}: {strategy}", + data_type = DataTypeDisplay(&field.data_type), + ); } + Ok(()) +} - pub fn is_utc(&self) -> Result { - match &self.data_type { - GenericDataType::Date64 => match get_strategy_from_metadata(&self.metadata)? { - None | Some(Strategy::UtcStrAsDate64) => Ok(true), - Some(Strategy::NaiveStrAsDate64) => Ok(false), - Some(strategy) => fail!("invalid strategy for date64 deserializer: {strategy}"), - }, - GenericDataType::Timestamp(_, tz) => match tz { - Some(tz) => Ok(tz.to_lowercase() == "utc"), - None => Ok(false), - }, - _ => fail!("non date time type {}", self.data_type), - } +fn validate_fixed_size_binary_field(field: &Field, n: i32) -> Result<()> { + if n < 0 { + fail!("Invalid FixedSizedBinary with negative number of elements"); } + validate_primitive_field(field) +} - pub fn with_child(mut self, child: GenericField) -> Self { - self.children.push(child); - self +fn validate_fixed_size_list_field(field: &Field, child: &Field, n: i32) -> Result<()> { + if n < 0 { + fail!("Invalid FixedSizeList with negative number of elements"); } + validate_list_field(field, child) +} - pub fn with_metadata(mut self, key: String, value: String) -> Self { - self.metadata.insert(key, value); - self +fn validate_list_field(field: &Field, child: &Field) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!("invalid strategy for List field: {strategy}"); } + validate_field(child) } -impl GenericField { - pub(crate) fn validate_null(&self) -> Result<()> { - match get_strategy_from_metadata(&self.metadata)? { - None | Some(Strategy::InconsistentTypes) | Some(Strategy::UnknownVariant) => {} - Some(strategy) => fail!("invalid strategy for Null field: {strategy}"), - } - if !self.children.is_empty() { - fail!("Null field must not have children"); - } - Ok(()) +fn validate_dictionary_field(field: &Field, key: &DataType, value: &DataType) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!("invalid strategy for Dictionary field: {strategy}"); } - - pub(crate) fn validate_primitive(&self) -> Result<()> { - self.validate_no_strategy_no_children() + if !matches!( + key, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + ) { + fail!( + "invalid child for Dictionary. Expected integer keys, found: {key}", + key = DataTypeDisplay(key), + ); } - - pub(crate) fn validate_date32(&self) -> Result<()> { - self.validate_no_strategy_no_children() + if !matches!(value, DataType::Utf8 | DataType::LargeUtf8) { + fail!( + "invalid child for Dictionary. Expected string values, found: {value}", + value = DataTypeDisplay(value) + ); } + Ok(()) +} - pub(crate) fn validate_date64(&self) -> Result<()> { - match get_strategy_from_metadata(&self.metadata)? { - None | Some(Strategy::UtcStrAsDate64) | Some(Strategy::NaiveStrAsDate64) => {} - Some(strategy) => fail!("invalid strategy for Date64 field: {strategy}"), - } - if !self.children.is_empty() { - fail!("{} field must not have children", self.data_type); - } - Ok(()) +fn validate_date64_field(field: &Field) -> Result<()> { + match get_strategy_from_metadata(&field.metadata)? { + None | Some(Strategy::UtcStrAsDate64) | Some(Strategy::NaiveStrAsDate64) => Ok(()), + Some(strategy) => fail!("invalid strategy for Date64 field: {strategy}"), } +} - pub(crate) fn validate_timestamp(&self) -> Result<()> { - match get_strategy_from_metadata(&self.metadata)? { - None => Ok(()), - Some(strategy @ Strategy::UtcStrAsDate64) => { - if !matches!(&self.data_type, GenericDataType::Timestamp(_, Some(tz)) if tz.to_uppercase() == "UTC") - { - fail!( - "invalid strategy for timestamp field {}: {}", - self.data_type, - strategy, - ); - } - Ok(()) - } - Some(strategy @ Strategy::NaiveStrAsDate64) => { - if !matches!(&self.data_type, GenericDataType::Timestamp(_, None)) { - fail!( - "invalid strategy for timestamp field {}: {}", - self.data_type, - strategy, - ); - } - Ok(()) +fn validate_timestamp_field(field: &Field, unit: TimeUnit, tz: Option<&str>) -> Result<()> { + match get_strategy_from_metadata(&field.metadata)? { + None => {} + Some(strategy @ Strategy::UtcStrAsDate64) => { + if !matches!(tz, Some(tz) if tz.to_uppercase() == "UTC") { + fail!("invalid strategy for Timestamp({unit}, {tz:?}) field: {strategy}"); } - Some(strategy) => fail!( - "invalid strategy for timestamp field {}: {}", - self.data_type, - strategy - ), - } - } - - pub(crate) fn validate_time32(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!("invalid strategy for {dt}: {strategy}", dt = self.data_type); } - if !self.children.is_empty() { - fail!("{} field must not have children", self.data_type); - } - if !matches!( - self.data_type, - GenericDataType::Time32(TimeUnit::Second | TimeUnit::Millisecond) - ) { - fail!("Time32 field must have Second or Millisecond unit"); + Some(strategy @ Strategy::NaiveStrAsDate64) => { + if tz.is_some() { + fail!("invalid strategy for Timestamp({unit}, {tz:?}) field: {strategy}"); + } } - Ok(()) + Some(strategy) => fail!("invalid strategy for Timestamp({unit}, {tz:?}) field: {strategy}"), } + Ok(()) +} - pub(crate) fn validate_time64(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!( - "invalid strategy for {data_type}: {strategy}", - data_type = self.data_type, - ); - } - if !self.children.is_empty() { - fail!("{} field must not have children", self.data_type); - } - if !matches!( - self.data_type, - GenericDataType::Time64(TimeUnit::Microsecond | TimeUnit::Nanosecond) - ) { - fail!("Time64 field must have Microsecond or Nanosecond unit"); - } - Ok(()) +fn validate_time32_field(field: &Field, unit: TimeUnit) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!("invalid strategy for Time32({unit}) field: {strategy}"); } - - pub(crate) fn validate_duration(&self) -> Result<()> { - self.validate_no_strategy_no_children() + if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { + fail!("Time32 field must have Second or Millisecond unit"); } + Ok(()) +} - pub(crate) fn validate_struct(&self) -> Result<()> { - // NOTE: do not check number of children: arrow-rs can 0 children, arrow2 not - match get_strategy_from_metadata(&self.metadata)? { - None | Some(Strategy::MapAsStruct) | Some(Strategy::TupleAsStruct) => {} - Some(strategy) => fail!("invalid strategy for Struct field: {strategy}"), - } - for child in &self.children { - child.validate()?; - } - Ok(()) +fn validate_time64_field(field: &Field, unit: TimeUnit) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!("invalid strategy for Time64({unit}) field: {strategy}"); } - - pub(crate) fn validate_map(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!("invalid strategy for Map field: {strategy}"); - } - if self.children.len() != 1 { - fail!( - "invalid number of children for Map field: {}", - self.children.len() - ); - } - if self.children[0].data_type != GenericDataType::Struct { - fail!( - "invalid child for Map field, expected Struct, found: {}", - self.children[0].data_type - ); - } - if self.children[0].children.len() != 2 { - fail!("invalid child for Map field, expected Struct with two fields, found Struct wiht {} fields", self.children[0].children.len()); - } - - for child in &self.children { - child.validate()?; - } - - Ok(()) + if !matches!(unit, TimeUnit::Microsecond | TimeUnit::Nanosecond) { + fail!("Time64 field must have Microsecond or Nanosecond unit"); } + Ok(()) +} - pub(crate) fn validate_fixed_size_list(&self, n: i32) -> Result<()> { - if n < 0 { - fail!("Invalid FixedSizeList with negative number of elements"); - } - self.validate_list() +fn validate_struct_field(field: &Field, children: &[Field]) -> Result<()> { + // NOTE: do not check number of children: arrow-rs can 0 children, arrow2 not + match get_strategy_from_metadata(&field.metadata)? { + None | Some(Strategy::MapAsStruct) | Some(Strategy::TupleAsStruct) => {} + Some(strategy) => fail!("invalid strategy for Struct field: {strategy}"), } - - pub(crate) fn validate_list(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!("invalid strategy for List field: {strategy}"); - } - if self.children.len() != 1 { - fail!( - "invalid number of children for List field. Expected 1, found: {}", - self.children.len() - ); - } - self.children[0].validate()?; - - Ok(()) + for child in children { + validate_field(child)?; } + Ok(()) +} - pub(crate) fn validate_fixed_size_binary(&self, n: i32) -> Result<()> { - if n < 0 { - fail!("Invalid FixedSizedBinary with negative number of elements"); - } - self.validate_binary() +fn validate_map_field(field: &Field, entry: &Field) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!("invalid strategy for Map field: {strategy}"); } + // TODO: validate entry - pub(crate) fn validate_binary(&self) -> Result<()> { - self.validate_no_strategy_no_children() - } + Ok(()) +} - pub(crate) fn validate_union(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!("invalid strategy for Union field: {strategy}"); - } - if self.children.is_empty() { - fail!("Union field without children"); - } - for child in &self.children { - child.validate()?; - } - Ok(()) +fn validate_union_field(field: &Field, children: &[(i8, Field)], mode: UnionMode) -> Result<()> { + if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { + fail!("invalid strategy for Union field: {strategy}"); } - - pub(crate) fn validate_dictionary(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!("invalid strategy for Dictionary field: {strategy}"); - } - if self.children.len() != 2 { - fail!( - "invalid number of children for Dictionary field. Expected 2, found: {}", - self.children.len() - ); - } - if !matches!( - self.children[0].data_type, - GenericDataType::U8 - | GenericDataType::U16 - | GenericDataType::U32 - | GenericDataType::U64 - | GenericDataType::I8 - | GenericDataType::I16 - | GenericDataType::I32 - | GenericDataType::I64 - ) { - fail!( - "invalid child for Dictionary. Expected integer keys, found: {}", - self.children[0].data_type - ); - } - if !matches!( - self.children[1].data_type, - GenericDataType::Utf8 | GenericDataType::LargeUtf8 - ) { - fail!( - "invalid child for Dictionary. Expected string values, found: {}", - self.children[1].data_type - ); - } - for child in &self.children { - child.validate()?; - } - Ok(()) + for (_, child) in children { + validate_field(child)?; } + Ok(()) +} - pub(crate) fn validate_no_strategy_no_children(&self) -> Result<()> { - if let Some(strategy) = get_strategy_from_metadata(&self.metadata)? { - fail!( - "invalid strategy for {data_type}: {strategy}", - data_type = self.data_type, - ); - } - if !self.children.is_empty() { - fail!("{} field must not have children", self.data_type); +pub struct DataTypeDisplay<'a>(pub &'a DataType); + +impl<'a> std::fmt::Display for DataTypeDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + DataType::Null => write!(f, "Null"), + DataType::Boolean => write!(f, "Boolean"), + DataType::Int8 => write!(f, "Int8"), + DataType::Int16 => write!(f, "Int16"), + DataType::Int32 => write!(f, "Int32"), + DataType::Int64 => write!(f, "Int64"), + DataType::UInt8 => write!(f, "UInt8"), + DataType::UInt16 => write!(f, "UInt16"), + DataType::UInt32 => write!(f, "UInt32"), + DataType::UInt64 => write!(f, "UInt64"), + DataType::Float16 => write!(f, "Float16"), + DataType::Float32 => write!(f, "Float32"), + DataType::Float64 => write!(f, "Float64"), + DataType::Utf8 => write!(f, "Utf8"), + DataType::LargeUtf8 => write!(f, "LargeUtf8"), + DataType::Binary => write!(f, "Binary"), + DataType::LargeBinary => write!(f, "LargeBinary"), + DataType::FixedSizeBinary(n) => write!(f, "FixedSizeBinary({n})"), + DataType::Date32 => write!(f, "Date32"), + DataType::Date64 => write!(f, "Date64"), + DataType::Time32(unit) => write!(f, "Time32({unit})"), + DataType::Time64(unit) => write!(f, "Time64({unit})"), + DataType::Timestamp(unit, tz) => write!(f, "Timestamp({unit}, {tz:?})"), + DataType::Duration(unit) => write!(f, "Duration({unit})"), + DataType::List(_) => write!(f, "List"), + DataType::LargeList(_) => write!(f, "LargeList"), + DataType::FixedSizeList(_, n) => write!(f, "FixedSizeList({n})"), + DataType::Decimal128(precision, scale) => write!(f, "Decimal128({precision}, {scale}"), + DataType::Struct(_) => write!(f, "Struct"), + DataType::Map(_, sorted) => write!(f, "Map({sorted})"), + DataType::Dictionary(key, value, sorted) => write!( + f, + "Dictionary({key}, {value}, {sorted})", + key = DataTypeDisplay(key), + value = DataTypeDisplay(value), + ), + DataType::Union(_, mode) => write!(f, "Union({mode})"), } - Ok(()) } } diff --git a/serde_arrow/src/internal/schema/serde/deserialize.rs b/serde_arrow/src/internal/schema/serde/deserialize.rs new file mode 100644 index 00000000..c5159773 --- /dev/null +++ b/serde_arrow/src/internal/schema/serde/deserialize.rs @@ -0,0 +1,113 @@ +use std::collections::HashMap; + +use serde::de::Visitor; + +use crate::internal::{ + arrow::{DataType, Field}, + error::{fail, Error, Result}, + schema::{SerdeArrowSchema, Strategy, STRATEGY_KEY}, +}; + +// A custom impl of untagged-enum repr with better error messages +impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { + fn deserialize>(deserializer: D) -> Result { + struct VisitorImpl; + + impl<'de> Visitor<'de> for VisitorImpl { + type Value = SerdeArrowSchema; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "a sequence of fields or a struct with key 'fields' containing a sequence of fields") + } + + fn visit_seq>( + self, + mut seq: A, + ) -> Result { + use serde::de::Error; + + let mut fields = Vec::new(); + while let Some(item) = seq.next_element::()? { + fields.push(item.into_field().map_err(A::Error::custom)?); + } + + Ok(SerdeArrowSchema { fields }) + } + + fn visit_map>( + self, + mut map: A, + ) -> Result { + use serde::de::Error; + + let mut fields = None; + + while let Some(key) = map.next_key::()? { + if key == "fields" { + fields = Some(map.next_value::>()?); + } else { + map.next_value::()?; + } + } + + let Some(fields) = fields else { + return Err(A::Error::custom("missing field `fields`")); + }; + + let mut converted_fields = Vec::new(); + for field in fields { + converted_fields.push(field.into_field().map_err(A::Error::custom)?); + } + + Ok(SerdeArrowSchema { + fields: converted_fields, + }) + } + } + + deserializer.deserialize_any(VisitorImpl) + } +} + +pub enum ArrowOrCustomField { + Arrow(Field), + Custom(CustomField), +} + +impl ArrowOrCustomField { + pub fn into_field(self) -> Result { + let field = match self { + ArrowOrCustomField::Arrow(field) => return Ok(field), + ArrowOrCustomField::Custom(field) => field, + }; + + todo!() + } +} + +impl<'de> serde::Deserialize<'de> for ArrowOrCustomField { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + todo!() + } +} + +pub struct CustomField { + name: String, + data_type: ArrowOrCustomDataType, + strategy: Option, + children: Vec, + metadata: HashMap, +} + +pub enum ArrowOrCustomDataType { + Arrow(DataType), + Custom(String), +} + +impl ArrowOrCustomDataType { + pub fn into_data_type(self, children: Vec) -> Result { + todo!() + } +} diff --git a/serde_arrow/src/internal/schema/serde/mod.rs b/serde_arrow/src/internal/schema/serde/mod.rs index e08d9999..494ec4cf 100644 --- a/serde_arrow/src/internal/schema/serde/mod.rs +++ b/serde_arrow/src/internal/schema/serde/mod.rs @@ -1,6 +1,7 @@ //! Group all serialization / deserialization related functionality //! -mod serialize; +pub mod deserialize; +pub mod serialize; #[cfg(test)] mod test; diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index 0fd5cd0a..ae63f888 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -2,13 +2,40 @@ use std::collections::HashMap; -use serde::ser::SerializeStruct; +use serde::ser::{SerializeSeq, SerializeStruct}; -use crate::{internal::schema::GenericField, schema::STRATEGY_KEY}; +use crate::{ + internal::arrow::{DataType, Field}, + schema::{SerdeArrowSchema, STRATEGY_KEY}, +}; -impl serde::Serialize for GenericField { +impl serde::Serialize for SerdeArrowSchema { + fn serialize(&self, serializer: S) -> Result { + let mut s = serializer.serialize_struct("SerdeArrowSchema", 1)?; + s.serialize_field("fields", &SerializableFields(&self.fields))?; + s.end() + } +} + +pub struct SerializableFields<'a>(pub &'a [Field]); + +impl<'a> serde::Serialize for SerializableFields<'a> { + fn serialize(&self, serializer: S) -> Result { + let mut s = serializer.serialize_seq(Some(self.0.len()))?; + for field in self.0 { + s.serialize_element(&SerializableField(field))?; + } + + s.end() + } +} + +pub struct SerializableField<'a>(pub &'a Field); + +impl<'a> serde::Serialize for SerializableField<'a> { fn serialize(&self, serializer: S) -> Result { let non_strategy_metadata = self + .0 .metadata .iter() .filter(|(key, _)| *key != STRATEGY_KEY) @@ -18,32 +45,141 @@ impl serde::Serialize for GenericField { if !non_strategy_metadata.is_empty() { num_fields += 1; } - if self.metadata.contains_key(STRATEGY_KEY) { + if self.0.metadata.contains_key(STRATEGY_KEY) { num_fields += 1; } - if self.nullable { + if self.0.nullable { num_fields += 1; } - if !self.children.is_empty() { + if is_data_type_with_children(&self.0.data_type) { num_fields += 1; } let mut s = serializer.serialize_struct("Field", num_fields)?; - s.serialize_field("name", &self.name)?; - s.serialize_field("data_type", &self.data_type)?; + s.serialize_field("name", &self.0.name)?; + s.serialize_field("data_type", &SerializableDataType(&self.0.data_type))?; - if self.nullable { - s.serialize_field("nullable", &self.nullable)?; + if self.0.nullable { + s.serialize_field("nullable", &self.0.nullable)?; } if !non_strategy_metadata.is_empty() { s.serialize_field("metadata", &non_strategy_metadata)?; } - if let Some(strategy) = self.metadata.get(STRATEGY_KEY) { + if let Some(strategy) = self.0.metadata.get(STRATEGY_KEY) { s.serialize_field("strategy", strategy)?; } - if !self.children.is_empty() { - s.serialize_field("children", &self.children)?; + if is_data_type_with_children(&self.0.data_type) { + s.serialize_field("children", &SerializableDataTypeChildren(&self.0.data_type))?; } s.end() } } + +pub struct SerializableDataType<'a>(pub &'a DataType); + +impl<'a> serde::Serialize for SerializableDataType<'a> { + fn serialize(&self, serializer: S) -> Result { + use DataType as T; + match self.0 { + T::Null => "Null".serialize(serializer), + T::Boolean => "Boolean".serialize(serializer), + T::Int8 => "Int8".serialize(serializer), + T::Int16 => "Int16".serialize(serializer), + T::Int32 => "Int32".serialize(serializer), + T::Int64 => "Int64".serialize(serializer), + T::UInt8 => "UInt8".serialize(serializer), + T::UInt16 => "UInt16".serialize(serializer), + T::UInt32 => "UInt32".serialize(serializer), + T::UInt64 => "UInt64".serialize(serializer), + T::Float16 => "Float16".serialize(serializer), + T::Float32 => "Float32".serialize(serializer), + T::Float64 => "Float64".serialize(serializer), + T::Utf8 => "Utf8".serialize(serializer), + T::LargeUtf8 => "LargeUtf8".serialize(serializer), + T::Binary => "Binary".serialize(serializer), + T::LargeBinary => "LargeBinary".serialize(serializer), + T::Date32 => "Date32".serialize(serializer), + T::Date64 => "Date64".serialize(serializer), + T::Decimal128(precision, scale) => { + format!("Decimal128({precision}, {scale})").serialize(serializer) + } + T::Duration(unit) => format!("Duration({unit})").serialize(serializer), + T::Time32(unit) => format!("Time32({unit})").serialize(serializer), + T::Time64(unit) => format!("Time64({unit})").serialize(serializer), + T::Timestamp(unit, tz) => format!("Timestamp({unit}, {tz:?})").serialize(serializer), + T::FixedSizeBinary(n) => format!("FixedSizeBinary({n})").serialize(serializer), + T::FixedSizeList(_, n) => format!("FixedSizeList({n})").serialize(serializer), + T::Struct(_) => "Struct".serialize(serializer), + T::Map(_, _) => "Map".serialize(serializer), + T::Union(_, _) => "Union".serialize(serializer), + T::Dictionary(_, _, _) => "Dictionary".serialize(serializer), + T::LargeList(_) => "LargeList".serialize(serializer), + T::List(_) => "List".serialize(serializer), + } + } +} + +pub struct SerializableDataTypeChildren<'a>(pub &'a DataType); + +impl<'a> serde::Serialize for SerializableDataTypeChildren<'a> { + fn serialize(&self, serializer: S) -> Result { + use DataType as T; + + match self.0 { + T::FixedSizeList(entry, _) + | T::Map(entry, _) + | T::LargeList(entry) + | T::List(entry) => { + let mut s = serializer.serialize_seq(Some(1))?; + s.serialize_element(&SerializableField(entry.as_ref()))?; + s.end() + } + T::Struct(fields) => { + let mut s = serializer.serialize_seq(Some(fields.len()))?; + for field in fields { + s.serialize_element(&SerializableField(field))?; + } + s.end() + } + T::Union(fields, _) => { + let mut s = serializer.serialize_seq(Some(fields.len()))?; + for (_, field) in fields { + s.serialize_element(&SerializableField(field))?; + } + s.end() + } + T::Dictionary(key, value, _) => { + let mut s = serializer.serialize_seq(Some(2))?; + s.serialize_element(&DictionaryField("key", key))?; + s.serialize_element(&DictionaryField("value", value))?; + s.end() + } + _ => serializer.serialize_seq(Some(0))?.end(), + } + } +} + +struct DictionaryField<'a>(&'a str, &'a DataType); + +impl<'a> serde::Serialize for DictionaryField<'a> { + fn serialize(&self, serializer: S) -> Result { + let mut s = serializer.serialize_struct("Field", 2)?; + s.serialize_field("name", self.0)?; + s.serialize_field("data_type", &SerializableDataType(self.1))?; + s.end() + } +} + +fn is_data_type_with_children(data_type: &DataType) -> bool { + use DataType as T; + matches!( + data_type, + T::FixedSizeList(_, _) + | T::Struct(_) + | T::Map(_, _) + | T::Union(_, _) + | T::Dictionary(_, _, _) + | T::LargeList(_) + | T::List(_) + ) +} diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs index 19fa5c6a..585012fb 100644 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -1,29 +1,32 @@ use serde_json::json; use crate::internal::{ + arrow::{DataType, Field}, error::PanicOnError, - schema::{GenericDataType, GenericField, STRATEGY_KEY}, + schema::{ArrowOrCustomField, STRATEGY_KEY}, testing::hash_map, }; +use super::serialize::SerializableField; + #[test] fn i16_field_simple() -> PanicOnError<()> { - let field = GenericField { + let field = Field { name: String::from("my_field_name"), - data_type: GenericDataType::I16, + data_type: DataType::Int16, metadata: hash_map!(), nullable: false, - children: vec![], }; let expected = json!({ "name": "my_field_name", "data_type": "I16", }); - let actual = serde_json::to_value(&field)?; + let actual = serde_json::to_value(&SerializableField(&field))?; assert_eq!(actual, expected); - let roundtripped = serde_json::from_value::(actual)?; + let roundtripped = serde_json::from_value::(actual)?; + let roundtripped = roundtripped.into_field()?; assert_eq!(roundtripped, field); Ok(()) @@ -31,15 +34,14 @@ fn i16_field_simple() -> PanicOnError<()> { #[test] fn date64_field_complex() -> PanicOnError<()> { - let field = GenericField { + let field = Field { name: String::from("my_field_name"), - data_type: GenericDataType::Date64, + data_type: DataType::Date64, metadata: hash_map!( "foo" => "bar", STRATEGY_KEY => "NaiveStrAsDate64", ), nullable: true, - children: vec![], }; let expected = json!({ "name": "my_field_name", @@ -54,7 +56,8 @@ fn date64_field_complex() -> PanicOnError<()> { let actual = serde_json::to_value(&field)?; assert_eq!(actual, expected); - let roundtripped = serde_json::from_value::(actual)?; + let roundtripped = serde_json::from_value::(actual)?; + let roundtripped = roundtripped.into_field()?; assert_eq!(roundtripped, field); Ok(()) @@ -62,18 +65,16 @@ fn date64_field_complex() -> PanicOnError<()> { #[test] fn list_field_complex() -> PanicOnError<()> { - let field = GenericField { + let field = Field { name: String::from("my_field_name"), - data_type: GenericDataType::List, - metadata: hash_map!("foo" => "bar"), - nullable: true, - children: vec![GenericField { + data_type: DataType::List(Box::new(Field { name: String::from("element"), - data_type: GenericDataType::I64, + data_type: DataType::Int64, metadata: hash_map!(), nullable: false, - children: vec![], - }], + })), + metadata: hash_map!("foo" => "bar"), + nullable: true, }; let expected = json!({ "name": "my_field_name", @@ -88,7 +89,8 @@ fn list_field_complex() -> PanicOnError<()> { let actual = serde_json::to_value(&field)?; assert_eq!(actual, expected); - let roundtripped = serde_json::from_value::(actual)?; + let roundtripped = serde_json::from_value::(actual)?; + let roundtripped = roundtripped.into_field()?; assert_eq!(roundtripped, field); Ok(()) diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index a4719d82..8bd0bff9 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -1,23 +1,28 @@ +use std::collections::HashMap; + use serde_json::json; use crate::internal::{ - arrow::TimeUnit, - schema::{GenericDataType, GenericField, SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, + arrow::{DataType, Field, TimeUnit}, + schema::{SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, testing::{assert_error, hash_map}, }; -impl SerdeArrowSchema { - fn with_field(mut self, field: GenericField) -> Self { - self.fields.push(field); - self - } -} - #[test] fn example() { - let schema = SerdeArrowSchema::default() - .with_field(GenericField::new("foo", GenericDataType::U8, false)) - .with_field(GenericField::new("bar", GenericDataType::Utf8, false)); + let mut schema = SerdeArrowSchema::default(); + schema.fields.push(Field { + name: String::from("foo"), + data_type: DataType::UInt8, + nullable: false, + metadata: HashMap::new(), + }); + schema.fields.push(Field { + name: String::from("bar"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }); let actual = serde_json::to_string(&schema).unwrap(); assert_eq!( @@ -31,9 +36,19 @@ fn example() { #[test] fn example_without_wrapper() { - let expected = SerdeArrowSchema::default() - .with_field(GenericField::new("foo", GenericDataType::U8, false)) - .with_field(GenericField::new("bar", GenericDataType::Utf8, false)); + let mut expected = SerdeArrowSchema::default(); + expected.fields.push(Field { + name: String::from("foo"), + data_type: DataType::UInt8, + nullable: false, + metadata: HashMap::new(), + }); + expected.fields.push(Field { + name: String::from("bar"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }); let input = r#"[{"name":"foo","data_type":"U8"},{"name":"bar","data_type":"Utf8"}]"#; let actual: SerdeArrowSchema = serde_json::from_str(&input).unwrap(); @@ -42,13 +57,18 @@ fn example_without_wrapper() { #[test] fn list() { - let schema = SerdeArrowSchema::default().with_field( - GenericField::new("value", GenericDataType::List, false).with_child(GenericField::new( - "element", - GenericDataType::I32, - false, - )), - ); + let mut schema = SerdeArrowSchema::default(); + schema.fields.push(Field { + name: String::from("value"), + data_type: DataType::List(Box::new(Field { + name: String::from("element"), + data_type: DataType::Int32, + nullable: false, + metadata: Default::default(), + })), + nullable: false, + metadata: Default::default(), + }); let actual = serde_json::to_string(&schema).unwrap(); assert_eq!( @@ -70,21 +90,33 @@ fn doc_schema() { "#; let actual: SerdeArrowSchema = serde_json::from_str(&schema).unwrap(); - let expected = SerdeArrowSchema::default() - .with_field(GenericField::new("foo", GenericDataType::U8, false)) - .with_field(GenericField::new("bar", GenericDataType::Utf8, false)); + + let mut expected = SerdeArrowSchema::default(); + expected.fields.push(Field { + name: String::from("foo"), + data_type: DataType::UInt8, + nullable: false, + metadata: HashMap::new(), + }); + expected.fields.push(Field { + name: String::from("bar"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }); assert_eq!(actual, expected); } #[test] fn date64_with_strategy() { - let schema = SerdeArrowSchema::default().with_field( - GenericField::new("item", GenericDataType::Date64, false).with_metadata( - STRATEGY_KEY.to_string(), - Strategy::NaiveStrAsDate64.to_string(), - ), - ); + let mut schema = SerdeArrowSchema::default(); + schema.fields.push(Field { + name: String::from("item"), + data_type: DataType::Date64, + nullable: false, + metadata: hash_map!( STRATEGY_KEY => Strategy::NaiveStrAsDate64 ), + }); let actual = serde_json::to_string(&schema).unwrap(); assert_eq!( @@ -105,100 +137,101 @@ fn date64_with_strategy() { assert_eq!(from_json, schema); } -#[test] -fn timestamp_second_serialization() { - let dt = super::GenericDataType::Timestamp(TimeUnit::Second, None); - - let s = serde_json::to_string(&dt).unwrap(); - assert_eq!(s, r#""Timestamp(Second, None)""#); - - let rt = serde_json::from_str(&s).unwrap(); - assert_eq!(dt, rt); -} - -#[test] -fn timestamp_second_utc_serialization() { - let dt = super::GenericDataType::Timestamp(TimeUnit::Second, Some(String::from("Utc"))); - - let s = serde_json::to_string(&dt).unwrap(); - assert_eq!(s, r#""Timestamp(Second, Some(\"Utc\"))""#); - - let rt = serde_json::from_str(&s).unwrap(); - assert_eq!(dt, rt); -} - -#[test] -fn test_date32() { - use GenericDataType as DT; - - assert_eq!(DT::Date32.to_string(), "Date32"); - assert_eq!("Date32".parse::
().unwrap(), DT::Date32); -} - -#[test] -fn time64_data_type_format() { - use {GenericDataType as DT, TimeUnit as TU}; - - for (dt, s) in [ - (DT::Time64(TU::Microsecond), "Time64(Microsecond)"), - (DT::Time64(TU::Nanosecond), "Time64(Nanosecond)"), - ] { - assert_eq!(dt.to_string(), s); - assert_eq!(s.parse::
().unwrap(), dt); - } -} - -#[test] -fn test_long_form_types() { - use super::GenericDataType as DT; - use std::str::FromStr; - - assert_eq!(DT::from_str("Boolean").unwrap(), DT::Bool); - assert_eq!(DT::from_str("Int8").unwrap(), DT::I8); - assert_eq!(DT::from_str("Int16").unwrap(), DT::I16); - assert_eq!(DT::from_str("Int32").unwrap(), DT::I32); - assert_eq!(DT::from_str("Int64").unwrap(), DT::I64); - assert_eq!(DT::from_str("UInt8").unwrap(), DT::U8); - assert_eq!(DT::from_str("UInt16").unwrap(), DT::U16); - assert_eq!(DT::from_str("UInt32").unwrap(), DT::U32); - assert_eq!(DT::from_str("UInt64").unwrap(), DT::U64); - assert_eq!(DT::from_str("Float16").unwrap(), DT::F16); - assert_eq!(DT::from_str("Float32").unwrap(), DT::F32); - assert_eq!(DT::from_str("Float64").unwrap(), DT::F64); - assert_eq!( - DT::from_str("Decimal128(8,-2)").unwrap(), - DT::Decimal128(8, -2) - ); - assert_eq!( - DT::from_str("Decimal128( 8 , -2 )").unwrap(), - DT::Decimal128(8, -2) - ); -} - -macro_rules! test_data_type { - ($($variant:ident,)*) => { - mod test_data_type { - $( - #[allow(non_snake_case)] - #[test] - fn $variant() { - let ty = super::super::GenericDataType::$variant; - - let s = serde_json::to_string(&ty).unwrap(); - assert_eq!(s, concat!("\"", stringify!($variant), "\"")); - - let rt = serde_json::from_str(&s).unwrap(); - assert_eq!(ty, rt); - } - )* - } - }; -} - -test_data_type!( - Null, Bool, I8, I16, I32, I64, U8, U16, U32, U64, F16, F32, F64, Utf8, LargeUtf8, List, - LargeList, Struct, Dictionary, Union, Map, Date64, -); +// TODO: fix these tests (or move them somewhere else) +// #[test] +// fn timestamp_second_serialization() { +// let dt = super::GenericDataType::Timestamp(TimeUnit::Second, None); + +// let s = serde_json::to_string(&dt).unwrap(); +// assert_eq!(s, r#""Timestamp(Second, None)""#); + +// let rt = serde_json::from_str(&s).unwrap(); +// assert_eq!(dt, rt); +// } +// +// #[test] +// fn timestamp_second_utc_serialization() { +// let dt = super::GenericDataType::Timestamp(TimeUnit::Second, Some(String::from("Utc"))); + +// let s = serde_json::to_string(&dt).unwrap(); +// assert_eq!(s, r#""Timestamp(Second, Some(\"Utc\"))""#); + +// let rt = serde_json::from_str(&s).unwrap(); +// assert_eq!(dt, rt); +// } + +// #[test] +// fn test_date32() { +// use GenericDataType as DT; + +// assert_eq!(DT::Date32.to_string(), "Date32"); +// assert_eq!("Date32".parse::
().unwrap(), DT::Date32); +// } + +// #[test] +// fn time64_data_type_format() { +// use {GenericDataType as DT, TimeUnit as TU}; + +// for (dt, s) in [ +// (DT::Time64(TU::Microsecond), "Time64(Microsecond)"), +// (DT::Time64(TU::Nanosecond), "Time64(Nanosecond)"), +// ] { +// assert_eq!(dt.to_string(), s); +// assert_eq!(s.parse::
().unwrap(), dt); +// } +// } + +// #[test] +// fn test_long_form_types() { +// use super::GenericDataType as DT; +// use std::str::FromStr; + +// assert_eq!(DT::from_str("Boolean").unwrap(), DT::Bool); +// assert_eq!(DT::from_str("Int8").unwrap(), DT::I8); +// assert_eq!(DT::from_str("Int16").unwrap(), DT::I16); +// assert_eq!(DT::from_str("Int32").unwrap(), DT::I32); +// assert_eq!(DT::from_str("Int64").unwrap(), DT::I64); +// assert_eq!(DT::from_str("UInt8").unwrap(), DT::U8); +// assert_eq!(DT::from_str("UInt16").unwrap(), DT::U16); +// assert_eq!(DT::from_str("UInt32").unwrap(), DT::U32); +// assert_eq!(DT::from_str("UInt64").unwrap(), DT::U64); +// assert_eq!(DT::from_str("Float16").unwrap(), DT::F16); +// assert_eq!(DT::from_str("Float32").unwrap(), DT::F32); +// assert_eq!(DT::from_str("Float64").unwrap(), DT::F64); +// assert_eq!( +// DT::from_str("Decimal128(8,-2)").unwrap(), +// DT::Decimal128(8, -2) +// ); +// assert_eq!( +// DT::from_str("Decimal128( 8 , -2 )").unwrap(), +// DT::Decimal128(8, -2) +// ); +// } + +// macro_rules! test_data_type { +// ($($variant:ident,)*) => { +// mod test_data_type { +// $( +// #[allow(non_snake_case)] +// #[test] +// fn $variant() { +// let ty = super::super::GenericDataType::$variant; + +// let s = serde_json::to_string(&ty).unwrap(); +// assert_eq!(s, concat!("\"", stringify!($variant), "\"")); + +// let rt = serde_json::from_str(&s).unwrap(); +// assert_eq!(ty, rt); +// } +// )* +// } +// }; +// } + +// test_data_type!( +// Null, Bool, I8, I16, I32, I64, U8, U16, U32, U64, F16, F32, F64, Utf8, LargeUtf8, List, +// LargeList, Struct, Dictionary, Union, Map, Date64, +// ); #[test] fn test_metadata_strategy_from_explicit() { diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 1777a544..861b386c 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -4,13 +4,12 @@ use std::{ }; use crate::internal::{ + arrow::{DataType, Field, UnionMode}, error::{fail, Result}, - schema::{GenericDataType, GenericField, SerdeArrowSchema, Strategy}, -}; - -use super::{ - tracing_options::{TracingMode, TracingOptions}, - Overwrites, STRATEGY_KEY, + schema::{ + DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy, TracingMode, TracingOptions, + STRATEGY_KEY, + }, }; // TODO: allow to customize @@ -18,14 +17,28 @@ const MAX_TYPE_DEPTH: usize = 20; const RECURSIVE_TYPE_WARNING: &str = "too deeply nested type detected. Recursive types are not supported in schema tracing"; -fn default_dictionary_field(name: &str, nullable: bool) -> GenericField { - GenericField::new(name, GenericDataType::Dictionary, nullable) - .with_child(GenericField::new("key", GenericDataType::U32, nullable)) - .with_child(GenericField::new( - "value", - GenericDataType::LargeUtf8, +fn default_dictionary_field(name: &str, nullable: bool) -> Field { + Field { + name: name.to_owned(), + nullable: nullable, + metadata: HashMap::new(), + data_type: DataType::Dictionary( + Box::new(DataType::UInt32), + Box::new(DataType::LargeUtf8), false, - )) + ), + } +} + +fn unknown_variant_field() -> Field { + let mut metadata = HashMap::new(); + metadata.insert(STRATEGY_KEY.into(), Strategy::UnknownVariant.into()); + Field { + name: String::from(""), + nullable: true, + data_type: DataType::Null, + metadata, + } } struct NullFieldMessage<'a>(&'a str); @@ -103,15 +116,15 @@ impl Tracer { let tracing_mode = dispatch_tracer!(self, tracer => tracer.options.tracing_mode); let fields = match root.data_type { - GenericDataType::Struct => root.children, - GenericDataType::Null => fail!("No records found to determine schema"), + DataType::Struct(children) => children, + DataType::Null => fail!("No records found to determine schema"), dt => fail!( concat!( "Schema tracing is not directly supported for the root data type {dt}. ", "Only struct-like types are supported as root types in schema tracing. ", "{mitigation}", ), - dt = dt, + dt = DataTypeDisplay(&dt), mitigation = match tracing_mode { TracingMode::FromType => "Consider using the `Item` wrapper, i.e., `::from_type>()`.", TracingMode::FromSamples => "Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`.", @@ -133,11 +146,11 @@ impl Tracer { dispatch_tracer!(self, tracer => tracer.is_complete()) } - pub fn get_type(&self) -> Option<&GenericDataType> { + pub fn get_type(&self) -> Option<&str> { dispatch_tracer!(self, tracer => tracer.get_type()) } - pub fn to_field(&self) -> Result { + pub fn to_field(&self) -> Result { let path = dispatch_tracer!(self, tracer => &tracer.path); if let Some(overwrite) = dispatch_tracer!(self, tracer => tracer.options.get_overwrite(path)) @@ -437,11 +450,7 @@ impl Tracer { Ok(()) } - pub fn ensure_utf8( - &mut self, - item_type: GenericDataType, - strategy: Option, - ) -> Result<()> { + pub fn ensure_utf8(&mut self, item_type: DataType, strategy: Option) -> Result<()> { if self.is_unknown() { let tracer = dispatch_tracer!(self, tracer => PrimitiveTracer::new( tracer.name.clone(), @@ -454,7 +463,7 @@ impl Tracer { *self = Self::Primitive(tracer); } else if let Tracer::Primitive(tracer) = self { use { - GenericDataType::Date64, GenericDataType::LargeUtf8, Strategy::NaiveStrAsDate64, + DataType::Date64, DataType::LargeUtf8, Strategy::NaiveStrAsDate64, Strategy::UtcStrAsDate64, }; let (item_type, strategy) = match ((&tracer.item_type), (item_type)) { @@ -468,7 +477,11 @@ impl Tracer { }, (LargeUtf8, _) | (_, LargeUtf8) => (LargeUtf8, None), (prev_ty, new_ty) => { - fail!("mismatched types, previous {prev_ty}, current {new_ty}") + fail!( + "mismatched types, previous {prev_ty}, current {new_ty}", + prev_ty = DataTypeDisplay(prev_ty), + new_ty = DataTypeDisplay(&new_ty), + ) } }; tracer.item_type = item_type; @@ -477,12 +490,15 @@ impl Tracer { let Some(ty) = self.get_type() else { unreachable!("tracer cannot be unknown"); }; - fail!("mismatched types, previous {ty}, current {item_type}"); + fail!( + "mismatched types, previous {ty}, current {item_type}", + item_type = DataTypeDisplay(&item_type), + ); } Ok(()) } - pub fn ensure_primitive(&mut self, item_type: GenericDataType) -> Result<()> { + pub fn ensure_primitive(&mut self, item_type: DataType) -> Result<()> { match self { this @ Self::Unknown(_) => { let tracer = dispatch_tracer!(this, tracer => PrimitiveTracer::new( @@ -504,7 +520,7 @@ impl Tracer { Ok(()) } - pub fn ensure_number(&mut self, item_type: GenericDataType) -> Result<()> { + pub fn ensure_number(&mut self, item_type: DataType) -> Result<()> { match self { this @ Self::Unknown(_) => { let tracer = dispatch_tracer!(this, tracer => PrimitiveTracer::new( @@ -517,23 +533,35 @@ impl Tracer { *this = Self::Primitive(tracer); } Self::Primitive(tracer) if tracer.options.coerce_numbers => { - use GenericDataType::{F32, F64, I16, I32, I64, I8, U16, U32, U64, U8}; + use DataType::{ + Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, + }; let item_type = match (&tracer.item_type, item_type) { // unsigned x unsigned -> u64 - (U8 | U16 | U32 | U64, U8 | U16 | U32 | U64) => U64, + (UInt8 | UInt16 | UInt32 | UInt64, UInt8 | UInt16 | UInt32 | UInt64) => UInt64, // signed x signed -> i64 - (I8 | I16 | I32 | I64, I8 | I16 | I32 | I64) => I64, + (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64) => Int64, // signed x unsigned -> i64 - (I8 | I16 | I32 | I64, U8 | U16 | U32 | U64) => I64, + (Int8 | Int16 | Int32 | Int64, UInt8 | UInt16 | UInt32 | UInt64) => Int64, // unsigned x signed -> i64 - (U8 | U16 | U32 | U64, I8 | I16 | I32 | I64) => I64, + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) => Int64, // float x float -> f64 - (F32 | F64, F32 | F64) => F64, + (Float32 | Float64, Float32 | Float64) => Float64, // int x float -> f64 - (I8 | I16 | I32 | I64 | U8 | U16 | U32 | U64, F32 | F64) => F64, + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Float32 | Float64, + ) => Float64, // float x int -> f64 - (F32 | F64, I8 | I16 | I32 | I64 | U8 | U16 | U32 | U64) => F64, - (ty, ev) => fail!("Cannot accept event {ev} for tracer of primitive type {ty}"), + ( + Float32 | Float64, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Float64, + (ty, ev) => fail!( + "Cannot accept event {ev} for tracer of primitive type {ty}", + ev = DataTypeDisplay(&ev), + ty = DataTypeDisplay(&ty), + ), }; tracer.item_type = item_type; } @@ -566,15 +594,16 @@ impl UnknownTracer { } } - pub fn to_field(&self) -> Result { + pub fn to_field(&self) -> Result { if !self.options.allow_null_fields { fail!("{}", NullFieldMessage(&self.name)); } - Ok(GenericField::new( - &self.name, - GenericDataType::Null, - self.nullable, - )) + Ok(Field { + name: self.name.to_owned(), + data_type: DataType::Null, + nullable: self.nullable, + metadata: HashMap::new(), + }) } pub fn finish(&mut self) -> Result<()> { @@ -589,7 +618,7 @@ impl UnknownTracer { false } - pub fn get_type(&self) -> Option<&GenericDataType> { + pub fn get_type(&self) -> Option<&str> { None } } @@ -613,19 +642,27 @@ impl MapTracer { self.key_tracer.is_complete() && self.value_tracer.is_complete() } - pub fn get_type(&self) -> Option<&GenericDataType> { - Some(&GenericDataType::Map) + pub fn get_type(&self) -> Option<&str> { + Some("Map") } - pub fn to_field(&self) -> Result { - let mut entries = GenericField::new("entries", GenericDataType::Struct, false); - entries.children.push(self.key_tracer.to_field()?); - entries.children.push(self.value_tracer.to_field()?); - - let mut field = GenericField::new(&self.name, GenericDataType::Map, self.nullable); - field.children.push(entries); + pub fn to_field(&self) -> Result { + let entry = Field { + name: String::from("entries"), + nullable: false, + metadata: HashMap::new(), + data_type: DataType::Struct(vec![ + self.key_tracer.to_field()?, + self.value_tracer.to_field()?, + ]), + }; - Ok(field) + Ok(Field { + name: self.name.to_owned(), + data_type: DataType::Map(Box::new(entry), false), + nullable: self.nullable, + metadata: HashMap::new(), + }) } pub fn finish(&mut self) -> Result<()> { @@ -653,15 +690,17 @@ impl ListTracer { self.item_tracer.is_complete() } - pub fn get_type(&self) -> Option<&GenericDataType> { - Some(&GenericDataType::LargeList) + pub fn get_type(&self) -> Option<&str> { + Some("List") } - pub fn to_field(&self) -> Result { - let mut field = GenericField::new(&self.name, GenericDataType::LargeList, self.nullable); - field.children.push(self.item_tracer.to_field()?); - - Ok(field) + pub fn to_field(&self) -> Result { + Ok(Field { + name: self.name.to_owned(), + nullable: self.nullable, + metadata: HashMap::new(), + data_type: DataType::LargeList(Box::new(self.item_tracer.to_field()?)), + }) } pub fn finish(&mut self) -> Result<()> { @@ -687,20 +726,28 @@ impl TupleTracer { self.field_tracers.iter().all(|tracer| tracer.is_complete()) } - pub fn to_field(&self) -> Result { - let mut field = GenericField::new(&self.name, GenericDataType::Struct, self.nullable); - for tracer in &self.field_tracers { - field.children.push(tracer.to_field()?); - } - field.metadata.insert( + pub fn to_field(&self) -> Result { + let mut metadata = HashMap::new(); + metadata.insert( STRATEGY_KEY.to_string(), Strategy::TupleAsStruct.to_string(), ); - Ok(field) + + let mut fields = Vec::new(); + for tracer in &self.field_tracers { + fields.push(tracer.to_field()?); + } + + Ok(Field { + name: self.name.to_owned(), + data_type: DataType::Struct(fields), + nullable: self.nullable, + metadata, + }) } - pub fn get_type(&self) -> Option<&GenericDataType> { - Some(&GenericDataType::Struct) + pub fn get_type(&self) -> Option<&str> { + Some("Struct") } pub fn finish(&mut self) -> Result<()> { @@ -803,23 +850,28 @@ impl StructTracer { self.fields.iter().all(|field| field.tracer.is_complete()) } - pub fn to_field(&self) -> Result { - let mut res_field = GenericField::new(&self.name, GenericDataType::Struct, self.nullable); + pub fn to_field(&self) -> Result { + let mut fields = Vec::new(); for field in &self.fields { - res_field.children.push(field.tracer.to_field()?); + fields.push(field.tracer.to_field()?); } + let mut metadata = HashMap::new(); if let StructMode::Map = self.mode { - res_field.children.sort_by(|a, b| a.name.cmp(&b.name)); - res_field - .metadata - .insert(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()); + fields.sort_by(|a, b| a.name.cmp(&b.name)); + metadata.insert(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()); } - Ok(res_field) + + Ok(Field { + name: self.name.to_owned(), + data_type: DataType::Struct(fields), + nullable: self.nullable, + metadata, + }) } - pub fn get_type(&self) -> Option<&GenericDataType> { - Some(&GenericDataType::Struct) + pub fn get_type(&self) -> Option<&str> { + Some("Struct") } pub fn finish(&mut self) -> Result<()> { @@ -848,8 +900,11 @@ pub struct UnionVariant { impl UnionVariant { fn is_null_variant(&self) -> bool { - // Note: unknown tracers are treated as Null tracers - matches!(self.tracer.get_type(), None | Some(GenericDataType::Null)) + match &self.tracer { + Tracer::Unknown(_) => true, + Tracer::Primitive(tracer) if matches!(tracer.item_type, DataType::Null) => true, + _ => false, + } } } @@ -896,11 +951,11 @@ impl UnionTracer { .all(|variant| variant.tracer.is_complete()) } - pub fn get_type(&self) -> Option<&GenericDataType> { - Some(&GenericDataType::Union) + pub fn get_type(&self) -> Option<&str> { + Some("Union") } - pub fn to_field(&self) -> Result { + pub fn to_field(&self) -> Result { if self.is_without_data() { if self.options.enums_without_data_as_strings { return Ok(default_dictionary_field(&self.name, self.nullable)); @@ -910,21 +965,21 @@ impl UnionTracer { } } - let mut field = GenericField::new(&self.name, GenericDataType::Union, self.nullable); - for variant in &self.variants { + let mut fields = Vec::new(); + for (idx, variant) in self.variants.iter().enumerate() { if let Some(variant) = variant { - field.children.push(variant.tracer.to_field()?); + fields.push((i8::try_from(idx)?, variant.tracer.to_field()?)); } else { - field.children.push( - GenericField::new("", GenericDataType::Null, true).with_metadata( - STRATEGY_KEY.to_string(), - Strategy::UnknownVariant.to_string(), - ), - ); + fields.push((i8::try_from(idx)?, unknown_variant_field())); }; } - Ok(field) + Ok(Field { + name: self.name.to_owned(), + data_type: DataType::Union(fields, UnionMode::Dense), + nullable: self.nullable, + metadata: HashMap::new(), + }) } pub fn is_without_data(&self) -> bool { @@ -955,7 +1010,7 @@ pub struct PrimitiveTracer { pub options: Arc, pub nullable: bool, pub strategy: Option, - pub item_type: GenericDataType, + pub item_type: DataType, } impl PrimitiveTracer { @@ -963,7 +1018,7 @@ impl PrimitiveTracer { name: String, path: String, options: Arc, - item_type: GenericDataType, + item_type: DataType, nullable: bool, ) -> Self { Self { @@ -985,30 +1040,43 @@ impl PrimitiveTracer { Ok(()) } - pub fn to_field(&self) -> Result { - type D = GenericDataType; + pub fn to_field(&self) -> Result { + type D = DataType; if !self.options.allow_null_fields && matches!(self.item_type, D::Null) { fail!("{}", NullFieldMessage(&self.name)); } match &self.item_type { - D::Null => Ok(GenericField::new(&self.name, D::Null, true)), + D::Null => Ok(Field { + name: self.name.to_owned(), + data_type: DataType::Null, + nullable: true, + metadata: HashMap::new(), + }), dt @ (D::LargeUtf8 | D::Utf8) => { if !self.options.string_dictionary_encoding { - Ok(GenericField::new(&self.name, dt.clone(), self.nullable)) + Ok(Field { + name: self.name.to_owned(), + data_type: dt.clone(), + nullable: self.nullable, + metadata: HashMap::new(), + }) } else { Ok(default_dictionary_field(&self.name, self.nullable)) } } dt => { - let mut field = GenericField::new(&self.name, dt.clone(), self.nullable); + let mut metadata = HashMap::new(); if let Some(strategy) = self.strategy.as_ref() { - field - .metadata - .insert(STRATEGY_KEY.to_string(), strategy.to_string()); + metadata.insert(STRATEGY_KEY.to_string(), strategy.to_string()); } - Ok(field) + Ok(Field { + name: self.name.to_owned(), + data_type: dt.clone(), + nullable: self.nullable, + metadata, + }) } } } @@ -1023,7 +1091,7 @@ impl PrimitiveTracer { true } - pub fn get_type(&self) -> Option<&GenericDataType> { - Some(&self.item_type) + pub fn get_type(&self) -> Option<&str> { + Some("Primitive") } } diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index 1acad363..b20ac84e 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -2,9 +2,7 @@ use std::collections::HashMap; use serde::Serialize; -use crate::internal::{error::Result, utils::value}; - -use super::GenericField; +use crate::internal::{arrow::Field, error::Result, schema::ArrowOrCustomField, utils::value}; #[derive(Debug, Clone, Copy, PartialEq)] pub enum TracingMode { @@ -281,9 +279,9 @@ impl TracingOptions { pub fn overwrite, F: Serialize>(mut self, path: P, field: F) -> Result { let path = path.into(); let path = format!("$.{path}"); - let field: GenericField = value::transmute(&field)?; + let field: ArrowOrCustomField = value::transmute(&field)?; - self.overwrites.0.insert(path, field); + self.overwrites.0.insert(path, field.into_field()?); Ok(self) } @@ -292,11 +290,11 @@ impl TracingOptions { self } - pub(crate) fn get_overwrite(&self, path: &str) -> Option<&GenericField> { + pub(crate) fn get_overwrite(&self, path: &str) -> Option<&Field> { self.overwrites.0.get(path) } } /// An opaque mapping of field paths to field definitions #[derive(Debug, Clone, Default, PartialEq)] -pub struct Overwrites(pub(crate) HashMap); +pub struct Overwrites(pub(crate) HashMap); diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index d7cbb4d0..491c6976 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -5,7 +5,6 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, DictionaryArray}, error::{fail, Result}, - schema::GenericField, utils::Mut, }; @@ -13,16 +12,14 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct DictionaryUtf8Builder { - pub field: GenericField, pub indices: Box, pub values: Box, pub index: HashMap, } impl DictionaryUtf8Builder { - pub fn new(field: GenericField, indices: ArrayBuilder, values: ArrayBuilder) -> Self { + pub fn new(indices: ArrayBuilder, values: ArrayBuilder) -> Self { Self { - field, indices: Box::new(indices), values: Box::new(values), index: HashMap::new(), @@ -31,7 +28,6 @@ impl DictionaryUtf8Builder { pub fn take(&mut self) -> Self { Self { - field: self.field.clone(), indices: Box::new(self.indices.take()), values: Box::new(self.values.take()), index: std::mem::take(&mut self.index), diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index fdf557d1..8c0499d1 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::Result, + error::{fail, Result}, }; use super::{ @@ -19,12 +19,23 @@ pub struct MapBuilder { } impl MapBuilder { - pub fn new(meta: FieldMeta, entry: ArrayBuilder, is_nullable: bool) -> Self { - Self { + pub fn new(meta: FieldMeta, entry: ArrayBuilder, is_nullable: bool) -> Result { + Self::validate_entry(&entry)?; + Ok(Self { meta, offsets: OffsetsArray::new(is_nullable), entry: Box::new(entry), + }) + } + + fn validate_entry(entry: &ArrayBuilder) -> Result<()> { + let ArrayBuilder::Struct(entry) = entry else { + fail!("entry field of a map must be a struct field"); + }; + if entry.fields.len() != 2 { + fail!("entry field of a map must be a struct field with 2 fields"); } + Ok(()) } pub fn take(&mut self) -> Self { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index b2140657..1022db58 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -1,11 +1,11 @@ +use std::collections::HashMap; + use serde::Serialize; use crate::internal::{ - arrow::TimeUnit, + arrow::{DataType, Field, TimeUnit}, error::{fail, Result}, - schema::{ - get_strategy_from_metadata, GenericDataType, GenericField, SerdeArrowSchema, Strategy, - }, + schema::{get_strategy_from_metadata, SerdeArrowSchema, Strategy}, serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, @@ -30,7 +30,7 @@ impl OuterSequenceBuilder { pub fn new(schema: &SerdeArrowSchema) -> Result { return Ok(Self(build_struct(&schema.fields, false)?)); - fn build_struct(struct_fields: &[GenericField], nullable: bool) -> Result { + fn build_struct(struct_fields: &[Field], nullable: bool) -> Result { let mut fields = Vec::new(); for field in struct_fields { fields.push((build_builder(field)?, meta_from_field(field.clone())?)); @@ -38,26 +38,26 @@ impl OuterSequenceBuilder { StructBuilder::new(fields, nullable) } - fn build_builder(field: &GenericField) -> Result { - use {ArrayBuilder as A, GenericDataType as T}; + fn build_builder(field: &Field) -> Result { + use {ArrayBuilder as A, DataType as T}; let builder = match &field.data_type { T::Null => match get_strategy_from_metadata(&field.metadata)? { Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder), _ => A::Null(NullBuilder::new()), }, - T::Bool => A::Bool(BoolBuilder::new(field.nullable)), - T::I8 => A::I8(IntBuilder::new(field.nullable)), - T::I16 => A::I16(IntBuilder::new(field.nullable)), - T::I32 => A::I32(IntBuilder::new(field.nullable)), - T::I64 => A::I64(IntBuilder::new(field.nullable)), - T::U8 => A::U8(IntBuilder::new(field.nullable)), - T::U16 => A::U16(IntBuilder::new(field.nullable)), - T::U32 => A::U32(IntBuilder::new(field.nullable)), - T::U64 => A::U64(IntBuilder::new(field.nullable)), - T::F16 => A::F16(FloatBuilder::new(field.nullable)), - T::F32 => A::F32(FloatBuilder::new(field.nullable)), - T::F64 => A::F64(FloatBuilder::new(field.nullable)), + T::Boolean => A::Bool(BoolBuilder::new(field.nullable)), + T::Int8 => A::I8(IntBuilder::new(field.nullable)), + T::Int16 => A::I16(IntBuilder::new(field.nullable)), + T::Int32 => A::I32(IntBuilder::new(field.nullable)), + T::Int64 => A::I64(IntBuilder::new(field.nullable)), + T::UInt8 => A::U8(IntBuilder::new(field.nullable)), + T::UInt16 => A::U16(IntBuilder::new(field.nullable)), + T::UInt32 => A::U32(IntBuilder::new(field.nullable)), + T::UInt64 => A::U64(IntBuilder::new(field.nullable)), + T::Float16 => A::F16(FloatBuilder::new(field.nullable)), + T::Float32 => A::F32(FloatBuilder::new(field.nullable)), + T::Float64 => A::F64(FloatBuilder::new(field.nullable)), T::Date32 => A::Date32(Date32Builder::new(field.nullable)), T::Date64 => A::Date64(Date64Builder::new( None, @@ -87,80 +87,56 @@ impl OuterSequenceBuilder { } T::Utf8 => A::Utf8(Utf8Builder::new(field.nullable)), T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(field.nullable)), - T::List => { - let Some(child) = field.children.first() else { - fail!("cannot build a list without an element field"); - }; - A::List(ListBuilder::new( - meta_from_field(child.clone())?, - build_builder(child)?, - field.nullable, - )?) - } - T::LargeList => { - let Some(child) = field.children.first() else { - fail!("cannot build list without an element field"); - }; - A::LargeList(ListBuilder::new( - meta_from_field(child.clone())?, - build_builder(child)?, - field.nullable, - )?) - } - T::FixedSizeList(n) => { - let Some(child) = field.children.first() else { - fail!("cannot build list without an element field"); - }; - A::FixedSizedList(FixedSizeListBuilder::new( - meta_from_field(child.clone())?, - build_builder(child)?, - (*n).try_into()?, - field.nullable, - )) - } + T::List(child) => A::List(ListBuilder::new( + meta_from_field(*child.clone())?, + build_builder(child.as_ref())?, + field.nullable, + )?), + T::LargeList(child) => A::LargeList(ListBuilder::new( + meta_from_field(*child.clone())?, + build_builder(child.as_ref())?, + field.nullable, + )?), + T::FixedSizeList(child, n) => A::FixedSizedList(FixedSizeListBuilder::new( + meta_from_field(*child.clone())?, + build_builder(child.as_ref())?, + (*n).try_into()?, + field.nullable, + )), T::Binary => A::Binary(BinaryBuilder::new(field.nullable)), T::LargeBinary => A::LargeBinary(BinaryBuilder::new(field.nullable)), T::FixedSizeBinary(n) => A::FixedSizeBinary(FixedSizeBinaryBuilder::new( (*n).try_into()?, field.nullable, )), - T::Map => { - let Some(entry_field) = field.children.first() else { - fail!("Cannot build a map with an entry field"); - }; - if entry_field.data_type != T::Struct && entry_field.children.len() != 2 { - fail!("Invalid child field for map: {entry_field:?}") - } - A::Map(MapBuilder::new( - meta_from_field(entry_field.clone())?, - build_builder(entry_field)?, - field.nullable, - )) - } - T::Struct => A::Struct(build_struct(&field.children, field.nullable)?), - T::Dictionary => { - let Some(indices) = field.children.first() else { - fail!("Cannot build a dictionary without index field"); + T::Map(entry_field, _) => A::Map(MapBuilder::new( + meta_from_field(*entry_field.clone())?, + build_builder(entry_field.as_ref())?, + field.nullable, + )?), + T::Struct(children) => A::Struct(build_struct(&children, field.nullable)?), + T::Dictionary(key, value, _) => { + let key_field = Field { + name: "key".to_string(), + data_type: *key.clone(), + nullable: field.nullable, + metadata: HashMap::new(), }; - let Some(values) = field.children.get(1) else { - fail!("Cannot build a dictionary without values field"); + let value_field = Field { + name: "value".to_string(), + data_type: *value.clone(), + nullable: false, + metadata: HashMap::new(), }; - if !matches!(values.data_type, T::Utf8 | T::LargeUtf8) { - fail!("At the moment only string dictionaries are supported"); - } - // TODO: figure out how arrow encodes nullability and fix this - let mut indices = indices.clone(); - indices.nullable = field.nullable; A::DictionaryUtf8(DictionaryUtf8Builder::new( - field.clone(), - build_builder(&indices)?, - build_builder(values)?, + build_builder(&key_field)?, + build_builder(&value_field)?, )) } - T::Union => { + T::Union(union_fields, _) => { let mut fields = Vec::new(); - for field in &field.children { + for (_, field) in union_fields { fields.push((build_builder(field)?, meta_from_field(field.clone())?)); } diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 7c58941a..6995f332 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -9,7 +9,7 @@ use serde::{ser::SerializeSeq, Deserialize, Serialize}; use crate::internal::error::Result; -use super::{arrow::FieldMeta, schema::GenericField}; +use super::arrow::{Field, FieldMeta}; /// A wrapper around a sequence of items /// @@ -178,7 +178,7 @@ impl Offset for i64 { } } -pub fn meta_from_field(field: GenericField) -> Result { +pub fn meta_from_field(field: Field) -> Result { Ok(FieldMeta { name: field.name, nullable: field.nullable, diff --git a/serde_arrow/src/test/schema_like.rs b/serde_arrow/src/test/schema_like.rs index 4012fa34..2a8e1000 100644 --- a/serde_arrow/src/test/schema_like.rs +++ b/serde_arrow/src/test/schema_like.rs @@ -1,7 +1,9 @@ use serde_json::json; -use crate::internal::schema::{GenericDataType, GenericField}; -use crate::schema::{SchemaLike, SerdeArrowSchema}; +use crate::internal::{ + arrow::{DataType, Field}, + schema::{SchemaLike, SerdeArrowSchema}, +}; #[test] fn extra_attributes_trailing() { @@ -14,7 +16,12 @@ fn extra_attributes_trailing() { .unwrap(); assert_eq!( schema.fields, - vec![GenericField::new("foo", GenericDataType::F32, false)] + vec![Field { + name: String::from("foo"), + data_type: DataType::Float32, + nullable: false, + metadata: Default::default(), + }] ); } @@ -29,7 +36,12 @@ fn extra_attributes_leading() { .unwrap(); assert_eq!( schema.fields, - vec![GenericField::new("foo", GenericDataType::F32, false)] + vec![Field { + name: String::from("foo"), + data_type: DataType::Float32, + nullable: false, + metadata: Default::default(), + }] ); } diff --git a/serde_arrow/src/test_with_arrow/impls/examples.rs b/serde_arrow/src/test_with_arrow/impls/examples.rs index 677b7f2b..c2fd5900 100644 --- a/serde_arrow/src/test_with_arrow/impls/examples.rs +++ b/serde_arrow/src/test_with_arrow/impls/examples.rs @@ -1,9 +1,5 @@ use super::utils::Test; -use crate::{ - internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions, STRATEGY_KEY}, - utils::Item, -}; +use crate::internal::{schema::TracingOptions, utils::Item}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -296,14 +292,19 @@ fn fieldless_unions_in_a_struct() { ]; Test::new() - .with_schema(vec![ - GenericField::new("foo", GenericDataType::U32, false), - GenericField::new("bar", GenericDataType::Union, false) - .with_child(GenericField::new("A", GenericDataType::Null, true)) - .with_child(GenericField::new("B", GenericDataType::Null, true)) - .with_child(GenericField::new("C", GenericDataType::Null, true)), - GenericField::new("baz", GenericDataType::F32, false), - ]) + .with_schema(json!([ + {"name": "foo", "data_type": "U32"}, + { + "name": "bar", + "data_type": "Union", + "children": [ + {"name": "A", "data_type": "Null"}, + {"name": "B", "data_type": "Null"}, + {"name": "C", "data_type": "Null"}, + ], + }, + {"name": "baz", "data_type": "F32"}, + ])) .trace_schema_from_samples(&items, TracingOptions::default().allow_null_fields(true)) .trace_schema_from_type::(TracingOptions::default().allow_null_fields(true)) .serialize(&items) @@ -353,30 +354,25 @@ fn issue_57() { }]; Test::new() - .with_schema(vec![ - GenericField::new("filename", GenericDataType::LargeUtf8, false), - GenericField::new("game_type", GenericDataType::Union, false) - .with_child( - GenericField::new("", GenericDataType::Null, true).with_metadata( - STRATEGY_KEY.to_string(), - Strategy::UnknownVariant.to_string(), - ), - ) - .with_child(GenericField::new( - "RegularSeason", - GenericDataType::Null, - true, - )), - GenericField::new("account_type", GenericDataType::Union, false) - .with_child( - GenericField::new("", GenericDataType::Null, true).with_metadata( - STRATEGY_KEY.to_string(), - Strategy::UnknownVariant.to_string(), - ), - ) - .with_child(GenericField::new("Deduced", GenericDataType::Null, true)), - GenericField::new("file_index", GenericDataType::U64, false), - ]) + .with_schema(json!([ + {"name": "filename", "data_type": "LargeUtf8"}, + { + "name": "game_type", + "data_type": "Union", + "children": [ + {"name": "", "data_type": "Null", "strategy": "UnknownVariant"}, + {"name": "RegularSeason", "data_type": "Null"}, + ], + }, + { + "name": "account_type", + "data_type": "Union", + "children": [ + {"name": "", "data_type": "Null", "strategy": "UnknownVariant"}, + ], + }, + {"name": "file_index", "data_type": "U64"}, + ])) .trace_schema_from_samples(&items, TracingOptions::default().allow_null_fields(true)) // NOTE: trace_from_type discovers all variants // .trace_schema_from_type::(TracingOptions::default().allow_null_fields(true)) @@ -396,10 +392,10 @@ fn simple_example() { let items = &[S { a: 2.0, b: 4 }, S { a: -123.0, b: 9 }]; Test::new() - .with_schema(vec![ - GenericField::new("a", GenericDataType::F32, false), - GenericField::new("b", GenericDataType::U32, false), - ]) + .with_schema(json!([ + {"name": "a", "data_type": "F32", "nullable": false}, + {"name": "b", "data_type": "U32", "nullable": false}, + ])) .trace_schema_from_samples(items, TracingOptions::default().allow_null_fields(true)) .serialize(items) .deserialize(items) @@ -426,10 +422,10 @@ fn top_level_nullables() { ]; Test::new() - .with_schema(vec![ - GenericField::new("a", GenericDataType::F32, true), - GenericField::new("b", GenericDataType::U32, true), - ]) + .with_schema(json!([ + {"name": "a", "data_type": "F32", "nullable": true}, + {"name": "b", "data_type": "U32", "nullable": true}, + ])) .trace_schema_from_samples(items, TracingOptions::default().allow_null_fields(true)) .serialize(items) .deserialize(items) @@ -444,7 +440,7 @@ fn new_type_wrappers() { let items = [Item(U64(0)), Item(U64(1)), Item(U64(2))]; Test::new() - .with_schema(vec![GenericField::new("item", GenericDataType::U64, false)]) + .with_schema(json!([{"name": "item", "data_type": "U64"}])) .trace_schema_from_samples(&items, TracingOptions::default().allow_null_fields(true)) .trace_schema_from_type::>(TracingOptions::default().allow_null_fields(true)) .serialize(&items) diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index 8a74a386..4264c4e0 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -2,23 +2,28 @@ use std::collections::HashMap; use serde::Deserialize; -use crate::{ - internal::{ - schema::{ - tracer::Tracer, GenericDataType as T, GenericField as F, Strategy, TracingOptions, - }, - testing::assert_error, - utils::Item, - }, - schema::STRATEGY_KEY, +use crate::internal::{ + arrow::{DataType, Field, UnionMode}, + schema::{tracer::Tracer, Strategy, TracingOptions, STRATEGY_KEY}, + testing::assert_error, + utils::Item, }; -fn trace_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> F { +fn trace_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> Field { let tracer = Tracer::from_type::>(options).unwrap(); let schema = tracer.to_schema().unwrap(); schema.fields.into_iter().next().unwrap() } +fn new_field(name: &str, nullable: bool, data_type: DataType) -> Field { + Field { + name: name.to_owned(), + data_type, + nullable, + metadata: Default::default(), + } +} + #[test] fn issue_90() { #[allow(unused)] @@ -35,14 +40,17 @@ fn issue_90() { } let actual = trace_type::(TracingOptions::default()); - let expected = F::new("item", T::Struct, false).with_child( - F::new("distribution", T::Struct, true) - .with_child(F::new("samples", T::LargeList, false).with_child(F::new( - "element", - T::F64, - false, - ))) - .with_child(F::new("statistic", T::LargeUtf8, false)), + let expected = new_field( + "item", + false, + DataType::Struct(vec![ + new_field( + "distribution", + true, + DataType::Struct(vec![new_field("element", false, DataType::Float64)]), + ), + new_field("statistic", false, DataType::LargeUtf8), + ]), ); assert_eq!(actual, expected); @@ -52,49 +60,49 @@ fn issue_90() { fn trace_primitives() { assert_eq!( trace_type::<()>(TracingOptions::default().allow_null_fields(true)), - F::new("item", T::Null, true), + new_field("item", true, DataType::Null), ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::I8, false) + new_field("item", false, DataType::Int8) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::I16, false) + new_field("item", false, DataType::Int16) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::I32, false) + new_field("item", false, DataType::Int32) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::I64, false) + new_field("item", false, DataType::Int64) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::U8, false) + new_field("item", false, DataType::UInt8) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::U16, false) + new_field("item", false, DataType::UInt16) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::U32, false) + new_field("item", false, DataType::UInt32) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::U64, false) + new_field("item", false, DataType::UInt64) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::F32, false) + new_field("item", false, DataType::Float32) ); assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::F64, false) + new_field("item", false, DataType::Float64) ); } @@ -102,11 +110,11 @@ fn trace_primitives() { fn trace_option() { assert_eq!( trace_type::(TracingOptions::default()), - F::new("item", T::I8, false) + new_field("item", false, DataType::Int8) ); assert_eq!( trace_type::>(TracingOptions::default()), - F::new("item", T::I8, true) + new_field("item", true, DataType::Int8) ); } @@ -120,9 +128,14 @@ fn trace_struct() { } let actual = trace_type::(TracingOptions::default()); - let expected = F::new("item", T::Struct, false) - .with_child(F::new("a", T::Bool, false)) - .with_child(F::new("b", T::I8, true)); + let expected = new_field( + "item", + false, + DataType::Struct(vec![ + new_field("a", false, DataType::Boolean), + new_field("b", true, DataType::Int8), + ]), + ); assert_eq!(actual, expected); } @@ -130,13 +143,19 @@ fn trace_struct() { #[test] fn trace_tuple_as_struct() { let actual = trace_type::<(bool, Option)>(TracingOptions::default()); - let expected = F::new("item", T::Struct, false) - .with_child(F::new("0", T::Bool, false)) - .with_child(F::new("1", T::I8, true)) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ); + + let mut expected = new_field( + "item", + false, + DataType::Struct(vec![ + new_field("0", false, DataType::Boolean), + new_field("1", true, DataType::Int8), + ]), + ); + expected.metadata.insert( + STRATEGY_KEY.to_string(), + Strategy::TupleAsStruct.to_string(), + ); assert_eq!(actual, expected); } @@ -151,9 +170,17 @@ fn trace_union() { } let actual = trace_type::(TracingOptions::default()); - let expected = F::new("item", T::Union, false) - .with_child(F::new("A", T::I8, false)) - .with_child(F::new("B", T::F32, false)); + let expected = new_field( + "item", + false, + DataType::Union( + vec![ + (0, new_field("A", false, DataType::Int8)), + (1, new_field("B", false, DataType::Float32)), + ], + UnionMode::Dense, + ), + ); assert_eq!(actual, expected); } @@ -161,8 +188,11 @@ fn trace_union() { #[test] fn trace_list() { let actual = trace_type::>(TracingOptions::default()); - let expected = - F::new("item", T::LargeList, false).with_child(F::new("element", T::LargeUtf8, false)); + let expected = new_field( + "item", + false, + DataType::LargeList(Box::new(new_field("element", false, DataType::LargeUtf8))), + ); assert_eq!(actual, expected); } @@ -170,12 +200,21 @@ fn trace_list() { #[test] fn trace_map() { let actual = trace_type::>(TracingOptions::default().map_as_struct(false)); - let expected = F::new("item", T::Map, false).with_child( - F::new("entries", T::Struct, false) - .with_child(F::new("key", T::I8, false)) - .with_child(F::new("value", T::LargeUtf8, false)), + let expected = new_field( + "item", + false, + DataType::Map( + Box::new(new_field( + "entries", + false, + DataType::Struct(vec![ + new_field("key", false, DataType::Int8), + new_field("value", false, DataType::LargeUtf8), + ]), + )), + false, + ), ); - assert_eq!(actual, expected); } diff --git a/serde_arrow/src/test_with_arrow/impls/list.rs b/serde_arrow/src/test_with_arrow/impls/list.rs index 8a391ccb..4e6d48e5 100644 --- a/serde_arrow/src/test_with_arrow/impls/list.rs +++ b/serde_arrow/src/test_with_arrow/impls/list.rs @@ -1,8 +1,6 @@ -use crate::{ - internal::schema::{GenericDataType, GenericField}, - schema::TracingOptions, - utils::Item, -}; +use serde_json::json; + +use crate::internal::{schema::TracingOptions, utils::Item}; use super::utils::Test; @@ -11,12 +9,11 @@ fn large_list_u32() { let items = [Item(vec![0_u32, 1, 2]), Item(vec![3, 4]), Item(vec![])]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - false, - ) - .with_child(GenericField::new("element", GenericDataType::U32, false))]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "children": [{"name": "element", "data_type": "U32"}], + }])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -33,12 +30,11 @@ fn large_list_nullable_u64() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - false, - ) - .with_child(GenericField::new("element", GenericDataType::U64, true))]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "children": [{"name": "element", "data_type": "U32", "nullable": true}], + }])) .trace_schema_from_type::>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -55,12 +51,12 @@ fn nullable_large_list_u32() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - true, - ) - .with_child(GenericField::new("element", GenericDataType::U32, false))]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "nullable": true, + "children": [{"name": "element", "data_type": "U32"}], + }])) .trace_schema_from_type::>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -72,12 +68,11 @@ fn list_u32() { let items = [Item(vec![0_u32, 1, 2]), Item(vec![3, 4]), Item(vec![])]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::List, - false, - ) - .with_child(GenericField::new("element", GenericDataType::U32, false))]) + .with_schema(json!([{ + "name": "item", + "data_type": "List", + "children": [{"name": "element", "data_type": "U32"}], + }])) .serialize(&items) .deserialize(&items); } @@ -91,15 +86,18 @@ fn nested_large_list_u32() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - false, - ) - .with_child( - GenericField::new("element", GenericDataType::LargeList, false) - .with_child(GenericField::new("element", GenericDataType::U32, false)), - )]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "children": [{ + "name": "element", + "data_type": "List", + "children": [{ + "name": "element", + "data_type": "U32", + }], + }], + }])) .trace_schema_from_type::>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -115,12 +113,12 @@ fn nullable_vec_bool() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - true, - ) - .with_child(GenericField::new("element", GenericDataType::Bool, false))]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "nullable": true, + "children": [{"name": "element", "data_type": "Bool"}], + }])) .trace_schema_from_type::>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -136,15 +134,16 @@ fn nullable_vec_bool_nested() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - true, - ) - .with_child( - GenericField::new("element", GenericDataType::LargeList, false) - .with_child(GenericField::new("element", GenericDataType::Bool, false)), - )]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "nullable": true, + "children": [{ + "name": "element", + "data_type": "LargeList", + "children": [{"name": "element", "data_type": "Bool"}], + }], + }])) .trace_schema_from_type::>>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -160,12 +159,11 @@ fn vec_nullable_bool() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - false, - ) - .with_child(GenericField::new("element", GenericDataType::Bool, true))]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "children": [{"name": "element", "data_type": "Bool", "nullable": true}], + }])) .trace_schema_from_type::>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -177,12 +175,11 @@ fn byte_arrays() { let items = [Item(b"hello".to_vec()), Item(b"world!".to_vec())]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::LargeList, - false, - ) - .with_child(GenericField::new("element", GenericDataType::U8, false))]) + .with_schema(json!([{ + "name": "item", + "data_type": "LargeList", + "children": [{"name": "element", "data_type": "U8"}], + }])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) diff --git a/serde_arrow/src/test_with_arrow/impls/map.rs b/serde_arrow/src/test_with_arrow/impls/map.rs index 062d614b..bf9ee470 100644 --- a/serde_arrow/src/test_with_arrow/impls/map.rs +++ b/serde_arrow/src/test_with_arrow/impls/map.rs @@ -1,11 +1,10 @@ use std::collections::{BTreeMap, HashMap}; -use crate::{ - internal::{ - schema::{GenericDataType, GenericField}, - testing::{btree_map, hash_map}, - }, - schema::{Strategy, TracingOptions, STRATEGY_KEY}, +use serde_json::json; + +use crate::internal::{ + schema::TracingOptions, + testing::{btree_map, hash_map}, utils::Item, }; @@ -15,12 +14,7 @@ use super::utils::Test; #[test] fn map_as_struct() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::U32, false)); type Ty = BTreeMap; - let values: &[Item] = &[ Item(btree_map! { "a" => 1_u32, "b" => 2_u32 }), Item(btree_map! { "a" => 3_u32, "b" => 4_u32 }), @@ -28,7 +22,17 @@ fn map_as_struct() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "U32"}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values) .deserialize(values); @@ -36,10 +40,6 @@ fn map_as_struct() { #[test] fn hash_map_as_struct() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::U32, false)); type Ty = HashMap; let values: &[Item] = &[ Item(hash_map! { "a" => 1_u32, "b" => 2_u32 }), @@ -48,7 +48,17 @@ fn hash_map_as_struct() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "U32"}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values) .deserialize(values); @@ -56,10 +66,6 @@ fn hash_map_as_struct() { #[test] fn map_as_struct_nullable() { - let field = GenericField::new("item", GenericDataType::Struct, true) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::U32, false)); type Ty = Option>; let values: &[Item] = &[ Item(Some(btree_map! { "a" => 1_u32, "b" => 2_u32 })), @@ -69,7 +75,18 @@ fn map_as_struct_nullable() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "U32"}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values) .deserialize(values); @@ -77,10 +94,6 @@ fn map_as_struct_nullable() { #[test] fn map_as_struct_missing_fields() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! { "a" => 1_u32 }), @@ -89,17 +102,23 @@ fn map_as_struct_missing_fields() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "U32", "nullable": true}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values); } #[test] fn map_as_struct_missing_fields_2() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, true)) - .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! { "a" => 1_u32, "b" => 2_u32 }), @@ -110,17 +129,23 @@ fn map_as_struct_missing_fields_2() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32", "nullable": true}, + {"name": "b", "data_type": "U32", "nullable": true}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values); } #[test] fn map_as_struct_missing_fields_3() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, true)) - .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! {}), @@ -131,17 +156,23 @@ fn map_as_struct_missing_fields_3() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32", "nullable": true}, + {"name": "b", "data_type": "U32", "nullable": true}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values); } #[test] fn map_as_struct_nullable_fields() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::U32, true)) - .with_child(GenericField::new("b", GenericDataType::U32, true)); type Ty = BTreeMap>; let values: &[Item] = &[ Item(btree_map! { "a" => Some(1_u32), "b" => Some(4_u32) }), @@ -150,7 +181,17 @@ fn map_as_struct_nullable_fields() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "U32", "nullable": true}, + {"name": "b", "data_type": "U32", "nullable": true}, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .serialize(values) .deserialize(values); @@ -159,11 +200,6 @@ fn map_as_struct_nullable_fields() { #[test] fn map_as_map() { let tracing_options = TracingOptions::default().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::LargeUtf8, false)) - .with_child(GenericField::new("value", GenericDataType::U32, false)), - ); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! { "a" => 1_u32, "b" => 2_u32 }), @@ -171,7 +207,22 @@ fn map_as_map() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "LargeUtf8"}, + {"name": "value", "data_type": "U32"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -181,11 +232,6 @@ fn map_as_map() { #[test] fn map_as_map_empty() { let tracing_options = TracingOptions::default().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::LargeUtf8, false)) - .with_child(GenericField::new("value", GenericDataType::U32, false)), - ); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! {}), @@ -194,7 +240,22 @@ fn map_as_map_empty() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "LargeUtf8"}, + {"name": "value", "data_type": "U32"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -204,11 +265,6 @@ fn map_as_map_empty() { #[test] fn map_as_map_int_keys() { let tracing_options = TracingOptions::default().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::I32, false)) - .with_child(GenericField::new("value", GenericDataType::U32, false)), - ); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! { -1_i32 => 1_u32, -2_i32 => 2_u32 }), @@ -216,7 +272,22 @@ fn map_as_map_int_keys() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "I32"}, + {"name": "value", "data_type": "U32"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -226,11 +297,6 @@ fn map_as_map_int_keys() { #[test] fn hash_maps() { let tracing_options = TracingOptions::new().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::I64, false)) - .with_child(GenericField::new("value", GenericDataType::Bool, false)), - ); type Ty = HashMap; let values: &[Item] = &[ Item(hash_map! {0 => true, 1 => false, 2 => true}), @@ -239,7 +305,22 @@ fn hash_maps() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "I64"}, + {"name": "value", "data_type": "Bool"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -249,11 +330,6 @@ fn hash_maps() { #[test] fn hash_maps_nullable() { let tracing_options = TracingOptions::new().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, true).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::I64, false)) - .with_child(GenericField::new("value", GenericDataType::Bool, false)), - ); type Ty = Option>; let values: &[Item] = &[ Item(Some(hash_map! {0 => true, 1 => false, 2 => true})), @@ -262,7 +338,23 @@ fn hash_maps_nullable() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "nullable": true, + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "I64"}, + {"name": "value", "data_type": "Bool"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -272,11 +364,6 @@ fn hash_maps_nullable() { #[test] fn hash_maps_nullable_keys() { let tracing_options = TracingOptions::new().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::I64, true)) - .with_child(GenericField::new("value", GenericDataType::Bool, false)), - ); type Ty = HashMap, bool>; let values: &[Item] = &[ Item(hash_map! {Some(0) => true, Some(1) => false, Some(2) => true}), @@ -285,7 +372,22 @@ fn hash_maps_nullable_keys() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "I64", "nullable": true}, + {"name": "value", "data_type": "Bool"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -295,11 +397,6 @@ fn hash_maps_nullable_keys() { #[test] fn hash_maps_nullable_values() { let tracing_options = TracingOptions::new().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::I64, false)) - .with_child(GenericField::new("value", GenericDataType::Bool, true)), - ); type Ty = HashMap>; let values: &[Item] = &[ Item(hash_map! {0 => Some(true), 1 => Some(false), 2 => Some(true)}), @@ -308,7 +405,22 @@ fn hash_maps_nullable_values() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "I64"}, + {"name": "value", "data_type": "Bool", "nullable": true}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) @@ -318,11 +430,6 @@ fn hash_maps_nullable_values() { #[test] fn btree_maps() { let tracing_options = TracingOptions::new().map_as_struct(false); - let field = GenericField::new("item", GenericDataType::Map, false).with_child( - GenericField::new("entries", GenericDataType::Struct, false) - .with_child(GenericField::new("key", GenericDataType::I64, false)) - .with_child(GenericField::new("value", GenericDataType::Bool, false)), - ); type Ty = BTreeMap; let values: &[Item] = &[ Item(btree_map! {0 => true, 1 => false, 2 => true}), @@ -331,7 +438,22 @@ fn btree_maps() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Map", + "children": [ + { + "name": "entries", + "type": "Struct", + "children": [ + {"name": "key", "data_type": "I64"}, + {"name": "value", "data_type": "Bool"}, + ], + }, + ], + }, + ])) .trace_schema_from_samples(values, tracing_options.clone()) .trace_schema_from_type::>(tracing_options.clone()) .serialize(values) diff --git a/serde_arrow/src/test_with_arrow/impls/primitives.rs b/serde_arrow/src/test_with_arrow/impls/primitives.rs index 2407757c..7de4e922 100644 --- a/serde_arrow/src/test_with_arrow/impls/primitives.rs +++ b/serde_arrow/src/test_with_arrow/impls/primitives.rs @@ -1,14 +1,23 @@ use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::{ - internal::schema::{GenericDataType, GenericField}, +use crate::internal::{ + arrow::{DataType, Field}, schema::TracingOptions, utils::Item, }; use super::utils::Test; +fn new_field(name: &str, data_type: DataType, nullable: bool) -> Field { + Field { + name: name.to_owned(), + data_type, + nullable, + metadata: Default::default(), + } +} + #[test] fn null() { let items = &[Item(()), Item(()), Item(())]; @@ -300,7 +309,7 @@ fn nullable_f64() { #[test] fn f32_from_f64() { let values = [Item(-1.0_f64), Item(2.0), Item(-3.0), Item(4.0)]; - let field = GenericField::new("item", GenericDataType::F32, false); + let field = new_field("item", DataType::Float32, false); Test::new() .with_schema(vec![field]) @@ -310,7 +319,7 @@ fn f32_from_f64() { #[test] fn f64_from_f32() { - let field = GenericField::new("item", GenericDataType::F64, false); + let field = new_field("item", DataType::Float64, false); let values = [Item(-1.0_f32), Item(2.0), Item(-3.0), Item(4.0)]; Test::new() @@ -321,7 +330,7 @@ fn f64_from_f32() { #[test] fn f16_from_f32() { - let field = GenericField::new("item", GenericDataType::F16, false); + let field = new_field("item", DataType::Float16, false); let values = [Item(-1.0_f32), Item(2.0), Item(-3.0), Item(4.0)]; Test::new() @@ -332,7 +341,7 @@ fn f16_from_f32() { #[test] fn f16_from_f64() { - let field = GenericField::new("item", GenericDataType::F16, false); + let field = new_field("item", DataType::Float16, false); let values = [Item(-1.0_f64), Item(2.0), Item(-3.0), Item(4.0)]; Test::new() @@ -343,7 +352,7 @@ fn f16_from_f64() { #[test] fn str() { - let field = GenericField::new("item", GenericDataType::LargeUtf8, false); + let field = new_field("item", DataType::LargeUtf8, false); type Ty = String; let values = [ Item(String::from("a")), @@ -362,7 +371,7 @@ fn str() { #[test] fn nullable_str() { - let field = GenericField::new("item", GenericDataType::LargeUtf8, true); + let field = new_field("item", DataType::LargeUtf8, true); type Ty = Option; let values = [ Item(Some(String::from("a"))), @@ -381,7 +390,7 @@ fn nullable_str() { #[test] fn str_u32() { - let field = GenericField::new("item", GenericDataType::Utf8, false); + let field = new_field("item", DataType::Utf8, false); let values = [ Item(String::from("a")), Item(String::from("b")), @@ -397,7 +406,7 @@ fn str_u32() { #[test] fn nullable_str_u32() { - let field = GenericField::new("item", GenericDataType::Utf8, true); + let field = new_field("item", DataType::Utf8, true); let values = [ Item(Some(String::from("a"))), Item(None), @@ -413,7 +422,7 @@ fn nullable_str_u32() { #[test] fn borrowed_str() { - let field = GenericField::new("item", GenericDataType::LargeUtf8, false); + let field = new_field("item", DataType::LargeUtf8, false); type Ty<'a> = &'a str; @@ -429,7 +438,7 @@ fn borrowed_str() { #[test] fn nullabe_borrowed_str() { - let field = GenericField::new("item", GenericDataType::LargeUtf8, true); + let field = new_field("item", DataType::LargeUtf8, true); type Ty<'a> = Option<&'a str>; @@ -445,7 +454,7 @@ fn nullabe_borrowed_str() { #[test] fn borrowed_str_u32() { - let field = GenericField::new("item", GenericDataType::Utf8, false); + let field = new_field("item", DataType::Utf8, false); let values = [Item("a"), Item("b"), Item("c"), Item("d")]; @@ -457,7 +466,7 @@ fn borrowed_str_u32() { #[test] fn nullabe_borrowed_str_u32() { - let field = GenericField::new("item", GenericDataType::Utf8, true); + let field = new_field("item", DataType::Utf8, true); let values = [Item(Some("a")), Item(None), Item(None), Item(Some("d"))]; @@ -472,7 +481,7 @@ fn newtype_i64() { #[derive(Serialize, Deserialize, Debug, PartialEq)] struct I64(i64); - let field = GenericField::new("item", GenericDataType::I64, false); + let field = new_field("item", DataType::Int64, false); type Ty = I64; let values = [Item(I64(-1)), Item(I64(2)), Item(I64(3)), Item(I64(-4))]; @@ -487,7 +496,7 @@ fn newtype_i64() { #[test] fn u8_to_u16() { - let field = GenericField::new("item", GenericDataType::U16, false); + let field = new_field("item", DataType::UInt16, false); let values = [Item(1_u8), Item(2), Item(3), Item(4)]; Test::new() @@ -498,7 +507,7 @@ fn u8_to_u16() { #[test] fn u32_to_i64() { - let field = GenericField::new("item", GenericDataType::I64, false); + let field = new_field("item", DataType::Int64, false); let values = [Item(1_u32), Item(2), Item(3), Item(4)]; Test::new() @@ -509,7 +518,7 @@ fn u32_to_i64() { #[test] fn chars() { - let field = GenericField::new("item", GenericDataType::U32, false); + let field = new_field("item", DataType::UInt32, false); type Ty = char; let values = [Item('a'), Item('b'), Item('c')]; diff --git a/serde_arrow/src/test_with_arrow/impls/struct.rs b/serde_arrow/src/test_with_arrow/impls/struct.rs index ea0f8108..85c6d68d 100644 --- a/serde_arrow/src/test_with_arrow/impls/struct.rs +++ b/serde_arrow/src/test_with_arrow/impls/struct.rs @@ -1,10 +1,7 @@ use serde::{Deserialize, Serialize}; +use serde_json::json; -use crate::{ - internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions, STRATEGY_KEY}, - utils::Item, -}; +use crate::internal::{schema::TracingOptions, utils::Item}; use super::utils::Test; @@ -15,18 +12,21 @@ fn r#struct() { a: u32, b: bool, } - - type Ty = S; - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::Bool, false)); - let values = [Item(S { a: 1, b: true }), Item(S { a: 2, b: false })]; let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "Bool"}, + ], + } + ])) + .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -34,16 +34,6 @@ fn r#struct() { #[test] fn struct_nested() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::Bool, false)) - .with_child( - GenericField::new("c", GenericDataType::Struct, false) - .with_child(GenericField::new("d", GenericDataType::I32, false)) - .with_child(GenericField::new("e", GenericDataType::U16, false)), - ); - - type Ty = S; let values = [Item(S::default()), Item(S::default())]; #[derive(Default, Serialize, Deserialize, Debug, PartialEq)] @@ -60,8 +50,25 @@ fn struct_nested() { } let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "Bool"}, + { + "name": "c", + "data_type": "Struct", + "children": [ + {"name": "d", "data_type": "I32"}, + {"name": "e", "data_type": "U16"}, + ], + } + ], + } + ])) + .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -69,9 +76,6 @@ fn struct_nested() { #[test] fn struct_nullable_field() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::U32, true)) - .with_child(GenericField::new("b", GenericDataType::Bool, false)); type Ty = S; let values = [ Item(S { @@ -92,8 +96,17 @@ fn struct_nullable_field() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "U32", "nullable": true}, + {"name": "b", "data_type": "Bool"}, + ], + } + ])) + .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -101,10 +114,6 @@ fn struct_nullable_field() { #[test] fn nullable_struct() { - let field = GenericField::new("item", GenericDataType::Struct, true) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::Bool, false)); - type Ty = Option; let values = [ Item(Some(S { a: 1, b: true })), Item(None), @@ -118,8 +127,18 @@ fn nullable_struct() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "Bool"}, + ], + } + ])) + .trace_schema_from_type::>>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -127,15 +146,6 @@ fn nullable_struct() { #[test] fn nullable_nested_struct() { - let field = GenericField::new("item", GenericDataType::Struct, true) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child( - GenericField::new("b", GenericDataType::Struct, true) - .with_child(GenericField::new("c", GenericDataType::I16, false)) - .with_child(GenericField::new("d", GenericDataType::F64, false)), - ); - type Ty = Option; - let values = [ Item(Some(S1 { a: 1, b: None })), Item(None), @@ -159,8 +169,26 @@ fn nullable_nested_struct() { let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "children": [ + {"name": "a", "data_type": "U32"}, + { + "name": "b", + "data_type": "Struct", + "nullable": true, + "children": [ + {"name": "c", "data_type": "I16"}, + {"name": "d", "data_type": "F64"}, + ] + }, + ], + } + ])) + .trace_schema_from_type::>>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -168,10 +196,6 @@ fn nullable_nested_struct() { #[test] fn nullable_struct_nullable_fields() { - let field = GenericField::new("item", GenericDataType::Struct, true) - .with_child(GenericField::new("a", GenericDataType::U32, true)) - .with_child(GenericField::new("b", GenericDataType::Bool, true)); - type Ty = Option; let values = [ Item(Some(S { a: Some(1), @@ -195,8 +219,18 @@ fn nullable_struct_nullable_fields() { } let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "children": [ + {"name": "a", "data_type": "U32", "nullable": true}, + {"name": "b", "data_type": "Bool", "nullable": true}, + ], + } + ])) + .trace_schema_from_type::>>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -204,14 +238,6 @@ fn nullable_struct_nullable_fields() { #[test] fn nullable_struct_list_field() { - let field = - GenericField::new("item", GenericDataType::Struct, true) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child( - GenericField::new("b", GenericDataType::LargeList, true) - .with_child(GenericField::new("element", GenericDataType::Bool, false)), - ); - type Ty = Option; let values = [ Item(Some(S { a: 1, b: None })), Item(Some(S { @@ -232,8 +258,25 @@ fn nullable_struct_list_field() { } let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "children": [ + {"name": "a", "data_type": "U32"}, + { + "name": "b", + "data_type": "LargeList", + "nullable": true, + "children": [ + {"name": "element", "data_type": "Bool"}, + ], + }, + ], + } + ])) + .trace_schema_from_type::>>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -241,10 +284,6 @@ fn nullable_struct_list_field() { #[test] fn serde_flatten() { - let field = GenericField::new("item", GenericDataType::Struct, true) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()) - .with_child(GenericField::new("a", GenericDataType::I8, false)) - .with_child(GenericField::new("value", GenericDataType::Bool, false)); let values = [Item(Some(LocalItem { a: 0, b: Inner { value: true }, @@ -263,7 +302,18 @@ fn serde_flatten() { } let tracing_options = TracingOptions::default().map_as_struct(true); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "strategy": "MapAsStruct", + "children": [ + {"name": "a", "data_type": "I8"}, + {"name": "value", "data_type": "Bool"}, + ], + }, + ])) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -271,12 +321,6 @@ fn serde_flatten() { #[test] fn flattened_structures() { - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::I64, false)) - .with_child(GenericField::new("b", GenericDataType::F32, false)) - .with_child(GenericField::new("c", GenericDataType::F64, false)) - .with_metadata(STRATEGY_KEY.to_string(), Strategy::MapAsStruct.to_string()); - let values = [ Item(Outer { a: 0, @@ -305,7 +349,17 @@ fn flattened_structures() { } let tracing_options = TracingOptions::default(); Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "I64"}, + {"name": "b", "data_type": "F32"}, + {"name": "c", "data_type": "F64"}, + ], + } + ])) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -314,12 +368,6 @@ fn flattened_structures() { #[test] fn struct_nullable() { let tracing_options = TracingOptions::default().allow_null_fields(true); - let field = GenericField::new("item", GenericDataType::Struct, true) - .with_child(GenericField::new("a", GenericDataType::Bool, false)) - .with_child(GenericField::new("b", GenericDataType::I64, false)) - .with_child(GenericField::new("c", GenericDataType::Null, true)) - .with_child(GenericField::new("d", GenericDataType::LargeUtf8, false)); - type Ty = Option; let values = [ Item(Some(Struct { a: true, @@ -348,8 +396,20 @@ fn struct_nullable() { d: String, } Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "children": [ + {"name": "a", "data_type": "Bool"}, + {"name": "b", "data_type": "I64"}, + {"name": "c", "data_type": "Null"}, + {"name": "d", "data_type": "LargeUtf8"}, + ], + } + ])) + .trace_schema_from_type::>>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -358,14 +418,6 @@ fn struct_nullable() { #[test] fn struct_nullable_nested() { let tracing_options = TracingOptions::default().allow_null_fields(true); - let field = GenericField::new("item", GenericDataType::Struct, true).with_child( - GenericField::new("inner", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::Bool, false)) - .with_child(GenericField::new("b", GenericDataType::I64, false)) - .with_child(GenericField::new("c", GenericDataType::Null, true)) - .with_child(GenericField::new("d", GenericDataType::LargeUtf8, false)), - ); - type Ty = Option; let values = [ Item(Some(Outer { inner: Struct { @@ -399,8 +451,26 @@ fn struct_nullable_nested() { } Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "children": [ + { + "name": "inner", + "type": "Struct", + "children": [ + {"name": "a", "data_type": "Bool"}, + {"name": "b", "data_type": "I64"}, + {"name": "c", "data_type": "Null"}, + {"name": "d", "data_type": "LargeUtf8"}, + ] + }, + ], + } + ])) + .trace_schema_from_type::>>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -409,12 +479,6 @@ fn struct_nullable_nested() { #[test] fn struct_nullable_item() { let tracing_options = TracingOptions::default().allow_null_fields(true); - let field = GenericField::new("item", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::Bool, true)) - .with_child(GenericField::new("b", GenericDataType::I64, true)) - .with_child(GenericField::new("c", GenericDataType::Null, true)) - .with_child(GenericField::new("d", GenericDataType::LargeUtf8, true)); - type Ty = StructNullable; let values = [ Item(StructNullable { a: None, @@ -439,8 +503,19 @@ fn struct_nullable_item() { } Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "Bool", "nullable": true}, + {"name": "b", "data_type": "I64", "nullable": true}, + {"name": "c", "data_type": "Null", "nullable": true}, + {"name": "d", "data_type": "LargeUtf8", "nullable": true}, + ], + } + ])) + .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); diff --git a/serde_arrow/src/test_with_arrow/impls/tuple.rs b/serde_arrow/src/test_with_arrow/impls/tuple.rs index e90a78b9..adabae7f 100644 --- a/serde_arrow/src/test_with_arrow/impls/tuple.rs +++ b/serde_arrow/src/test_with_arrow/impls/tuple.rs @@ -1,10 +1,7 @@ use serde::{Deserialize, Serialize}; +use serde_json::json; -use crate::{ - internal::schema::{GenericDataType, GenericField}, - schema::{Strategy, TracingOptions, STRATEGY_KEY}, - utils::Item, -}; +use crate::internal::{schema::TracingOptions, utils::Item}; use super::utils::Test; @@ -13,17 +10,17 @@ fn tuple_u64_bool() { let items = [Item((1_u64, true)), Item((2_u64, false))]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - false, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U64, false)) - .with_child(GenericField::new("1", GenericDataType::Bool, false))]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U64"}, + {"name": "1", "data_type": "Bool"}, + ], + } + ])) .trace_schema_from_type::>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -39,17 +36,17 @@ fn tuple_struct_u64_bool() { let items = [Item(S(1_u64, true)), Item(S(2_u64, false))]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - false, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U64, false)) - .with_child(GenericField::new("1", GenericDataType::Bool, false))]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U64"}, + {"name": "1", "data_type": "Bool"}, + ], + } + ])) .trace_schema_from_type::>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -69,17 +66,18 @@ fn nullbale_tuple_u64_bool() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - true, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U64, false)) - .with_child(GenericField::new("1", GenericDataType::Bool, false))]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U64"}, + {"name": "1", "data_type": "Bool"}, + ], + } + ])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -92,16 +90,16 @@ fn tuple_nullable_u64() { let items = [Item((Some(1_u64),)), Item((Some(2_u64),)), Item((None,))]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - false, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U64, true))]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U64", "nullable": true}, + ], + } + ])) .trace_schema_from_type::,)>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -114,23 +112,23 @@ fn tuple_nested() { let items = [Item(((1_u64,),)), Item(((2_u64,),))]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - false, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child( - GenericField::new("0", GenericDataType::Struct, false) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U64, false)), - )]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + { + "name": "0", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U64"}, + ] + }, + ], + } + ])) .trace_schema_from_type::>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -147,17 +145,18 @@ fn tuple_nullable() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - true, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::Bool, false)) - .with_child(GenericField::new("1", GenericDataType::I64, false))]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "Bool"}, + {"name": "1", "data_type": "I64"}, + ], + } + ])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) @@ -174,25 +173,26 @@ fn tuple_nullable_nested() { ]; Test::new() - .with_schema(vec![GenericField::new( - "item", - GenericDataType::Struct, - true, - ) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child( - GenericField::new("0", GenericDataType::Struct, false) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::Bool, false)) - .with_child(GenericField::new("1", GenericDataType::I64, false)), - ) - .with_child(GenericField::new("1", GenericDataType::I64, false))]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Struct", + "nullable": true, + "strategy": "TupleAsStruct", + "children": [ + { + "name": "0", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "Bool"}, + {"name": "1", "data_type": "I64"}, + ], + }, + {"name": "1", "data_type": "Bool"}, + ], + } + ])) .trace_schema_from_type::>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) .serialize(&items) diff --git a/serde_arrow/src/test_with_arrow/impls/union.rs b/serde_arrow/src/test_with_arrow/impls/union.rs index 6c2b5c92..75f84690 100644 --- a/serde_arrow/src/test_with_arrow/impls/union.rs +++ b/serde_arrow/src/test_with_arrow/impls/union.rs @@ -1,9 +1,8 @@ use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::{ - internal::schema::{GenericDataType, GenericField}, - schema::{SchemaLike, Strategy, TracingOptions, STRATEGY_KEY}, +use crate::internal::{ + schema::{SchemaLike, TracingOptions}, utils::{Item, Items}, }; @@ -21,15 +20,20 @@ fn fieldless_unions() { type Ty = U; let tracing_options = TracingOptions::default().allow_null_fields(true); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child(GenericField::new("A", GenericDataType::Null, true)) - .with_child(GenericField::new("B", GenericDataType::Null, true)) - .with_child(GenericField::new("C", GenericDataType::Null, true)); - let values = [Item(U::A), Item(U::B), Item(U::C), Item(U::A)]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + {"name": "A", "data_type": "Null"}, + {"name": "B", "data_type": "Null"}, + {"name": "C", "data_type": "Null"}, + ], + }, + ])) .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) @@ -48,15 +52,20 @@ fn fieldless_union_out_of_order() { type Ty = U; let tracing_options = TracingOptions::default().allow_null_fields(true); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child(GenericField::new("A", GenericDataType::Null, true)) - .with_child(GenericField::new("B", GenericDataType::Null, true)) - .with_child(GenericField::new("C", GenericDataType::Null, true)); - let values = [Item(U::B), Item(U::A), Item(U::C)]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + {"name": "A", "data_type": "Null"}, + {"name": "B", "data_type": "Null"}, + {"name": "C", "data_type": "Null"}, + ], + }, + ])) .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) @@ -75,11 +84,6 @@ fn union_simple() { type Ty = U; let tracing_options = TracingOptions::default(); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child(GenericField::new("U32", GenericDataType::U32, false)) - .with_child(GenericField::new("Bool", GenericDataType::Bool, false)) - .with_child(GenericField::new("Str", GenericDataType::LargeUtf8, false)); - let values = [ Item(U::U32(32)), Item(U::Bool(true)), @@ -87,7 +91,17 @@ fn union_simple() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + {"name": "U32", "data_type": "U32"}, + {"name": "Bool", "data_type": "Bool"}, + {"name": "Str", "data_type": "LargeUtf8"}, + ], + }, + ])) .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) @@ -110,19 +124,6 @@ fn union_mixed() { type Ty = U; let tracing_options = TracingOptions::default(); - let field = - GenericField::new("item", GenericDataType::Union, false) - .with_child( - GenericField::new("V1", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::U32, false)) - .with_child(GenericField::new("b", GenericDataType::U64, false)), - ) - .with_child(GenericField::new("Bool", GenericDataType::Bool, false)) - .with_child( - GenericField::new("S", GenericDataType::Struct, false) - .with_child(GenericField::new("s", GenericDataType::LargeUtf8, false)), - ); - let values = [ Item(U::V1 { a: 32, b: 13 }), Item(U::Bool(true)), @@ -132,7 +133,30 @@ fn union_mixed() { ]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + { + "name": "V1", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "U32"}, + {"name": "b", "data_type": "U64"}, + ], + }, + {"name": "Bool", "data_type": "Bool"}, + { + "name": "S", + "data_type": "Struct", + "children": [ + {"name": "s", "data_type": "LargeUtf8"}, + ] + }, + ], + }, + ])) .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) @@ -153,17 +177,6 @@ fn union_nested() { Str(String), } - type Ty = U; - - let tracing_options = TracingOptions::default(); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child(GenericField::new("U32", GenericDataType::U32, false)) - .with_child( - GenericField::new("O", GenericDataType::Union, false) - .with_child(GenericField::new("Bool", GenericDataType::Bool, false)) - .with_child(GenericField::new("Str", GenericDataType::LargeUtf8, false)), - ); - let values = [ Item(U::U32(32)), Item(U::O(O::Bool(true))), @@ -172,9 +185,25 @@ fn union_nested() { ]; Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) - .trace_schema_from_samples(&values, tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + {"name": "U32", "data_type": "U32"}, + { + "name": "O", + "data_type": "Union", + "children": [ + {"name": "Bool", "data_type": "Bool"}, + {"name": "Str", "data_type": "LargeUtf8"}, + ], + }, + ], + }, + ])) + .trace_schema_from_type::>(TracingOptions::default()) + .trace_schema_from_samples(&values, TracingOptions::default()) .serialize(&values) .deserialize(&values); } @@ -188,15 +217,8 @@ fn enums() { U32(u32), U64(u64), } - type Ty = U; let tracing_options = TracingOptions::default(); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child(GenericField::new("U8", GenericDataType::U8, false)) - .with_child(GenericField::new("U16", GenericDataType::U16, false)) - .with_child(GenericField::new("U32", GenericDataType::U32, false)) - .with_child(GenericField::new("U64", GenericDataType::U64, false)); - let values = [ Item(U::U32(2)), Item(U::U64(3)), @@ -205,8 +227,19 @@ fn enums() { ]; Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + {"name": "U8", "data_type": "U8"}, + {"name": "U16", "data_type": "U16"}, + {"name": "U32", "data_type": "U32"}, + {"name": "U64", "data_type": "U64"} + ], + }, + ])) + .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) .deserialize(&values); @@ -219,35 +252,38 @@ fn enums_tuple() { A(u8, u32), B(u16, u64), } - type Ty = U; - - let tracing_options = TracingOptions::default(); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child( - GenericField::new("A", GenericDataType::Struct, false) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U8, false)) - .with_child(GenericField::new("1", GenericDataType::U32, false)), - ) - .with_child( - GenericField::new("B", GenericDataType::Struct, false) - .with_metadata( - STRATEGY_KEY.to_string(), - Strategy::TupleAsStruct.to_string(), - ) - .with_child(GenericField::new("0", GenericDataType::U16, false)) - .with_child(GenericField::new("1", GenericDataType::U64, false)), - ); let values = [Item(U::A(2, 3)), Item(U::B(0, 1))]; Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) - .trace_schema_from_samples(&values, tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + { + "name": "A", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U8"}, + {"name": "1", "data_type": "U32"}, + ], + }, + { + "name": "B", + "data_type": "Struct", + "strategy": "TupleAsStruct", + "children": [ + {"name": "0", "data_type": "U16"}, + {"name": "1", "data_type": "U64"}, + ], + }, + ], + }, + ])) + .trace_schema_from_type::>(TracingOptions::default()) + .trace_schema_from_samples(&values, TracingOptions::default()) .serialize(&values) .deserialize(&values); } @@ -259,27 +295,35 @@ fn enums_struct() { A { a: u8, b: u32 }, B { c: u16, d: u64 }, } - type Ty = U; - - let tracing_options = TracingOptions::default(); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child( - GenericField::new("A", GenericDataType::Struct, false) - .with_child(GenericField::new("a", GenericDataType::U8, false)) - .with_child(GenericField::new("b", GenericDataType::U32, false)), - ) - .with_child( - GenericField::new("B", GenericDataType::Struct, false) - .with_child(GenericField::new("c", GenericDataType::U16, false)) - .with_child(GenericField::new("d", GenericDataType::U64, false)), - ); - let values = [Item(U::A { a: 2, b: 3 }), Item(U::B { c: 0, d: 1 })]; Test::new() - .with_schema(vec![field]) - .trace_schema_from_type::>(tracing_options.clone()) - .trace_schema_from_samples(&values, tracing_options.clone()) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + { + "name": "A", + "data_type": "Struct", + "children": [ + {"name": "a", "data_type": "U8"}, + {"name": "b", "data_type": "U32"}, + ], + }, + { + "name": "B", + "data_type": "Struct", + "children": [ + {"name": "c", "data_type": "U16"}, + {"name": "d", "data_type": "U64"}, + ], + }, + ], + }, + ])) + .trace_schema_from_type::>(TracingOptions::default()) + .trace_schema_from_samples(&values, TracingOptions::default()) .serialize(&values) .deserialize(&values); } @@ -294,14 +338,19 @@ fn enums_union() { type Ty = U; let tracing_options = TracingOptions::default().allow_null_fields(true); - let field = GenericField::new("item", GenericDataType::Union, false) - .with_child(GenericField::new("A", GenericDataType::Null, true)) - .with_child(GenericField::new("B", GenericDataType::Null, true)); - let values = [Item(U::A), Item(U::B)]; Test::new() - .with_schema(vec![field]) + .with_schema(json!([ + { + "name": "item", + "data_type": "Union", + "children": [ + {"name": "A", "data_type": "Null"}, + {"name": "B", "data_type": "Null"}, + ], + }, + ])) .trace_schema_from_type::>(tracing_options.clone()) .trace_schema_from_samples(&values, tracing_options.clone()) .serialize(&values) diff --git a/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs b/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs index 22168490..13e711ba 100644 --- a/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs +++ b/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs @@ -5,8 +5,11 @@ use serde_json::json; use crate::{ _impl::{arrow, arrow2}, - internal::{schema::GenericField, testing::hash_map}, - schema::{SchemaLike, SerdeArrowSchema, STRATEGY_KEY}, + internal::{ + arrow::Field, + schema::{ArrowOrCustomField, SchemaLike, SerdeArrowSchema, STRATEGY_KEY}, + testing::hash_map, + }, }; fn example_field_desc() -> serde_json::Value { @@ -30,7 +33,8 @@ fn example_field_desc() -> serde_json::Value { #[test] fn arrow() { - let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); + let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); + let initial_field = initial_field.into_field().unwrap(); assert_eq!( initial_field.metadata, hash_map!("foo" => "bar", STRATEGY_KEY => "MapAsStruct") @@ -43,7 +47,7 @@ fn arrow() { ); // roundtrip via try_from - let generic_field = GenericField::try_from(&arrow_field).unwrap(); + let generic_field = Field::try_from(&arrow_field).unwrap(); assert_eq!(generic_field, initial_field); // roundtrip via serialize @@ -54,7 +58,8 @@ fn arrow() { #[test] fn arrow2() { - let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); + let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); + let initial_field = initial_field.into_field().unwrap(); assert_eq!( initial_field.metadata, hash_map!("foo" => "bar", STRATEGY_KEY => "MapAsStruct") @@ -71,7 +76,7 @@ fn arrow2() { ); // roundtrip via try_from - let generic_field = GenericField::try_from(&arrow_field).unwrap(); + let generic_field = Field::try_from(&arrow_field).unwrap(); assert_eq!(generic_field, initial_field); // note: arrow2 Field does not support serialize From e07641b79b5c06519548701dcd57c742d9b7dc83 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 5 Aug 2024 11:21:03 +0200 Subject: [PATCH 087/178] Remove ArrowOrCustomField --- serde_arrow/src/internal/schema/extensions.rs | 9 ++-- .../src/internal/schema/from_samples/mod.rs | 9 ++-- serde_arrow/src/internal/schema/mod.rs | 12 ++++- .../src/internal/schema/serde/deserialize.rs | 46 ++++++++----------- serde_arrow/src/internal/schema/serde/test.rs | 11 ++--- .../src/internal/schema/tracing_options.rs | 11 ++--- .../issue_35_preserve_metadata.rs | 8 ++-- 7 files changed, 51 insertions(+), 55 deletions(-) diff --git a/serde_arrow/src/internal/schema/extensions.rs b/serde_arrow/src/internal/schema/extensions.rs index 221f8d20..6aa1b6ee 100644 --- a/serde_arrow/src/internal/schema/extensions.rs +++ b/serde_arrow/src/internal/schema/extensions.rs @@ -5,8 +5,7 @@ use serde::Serialize; use crate::internal::{ arrow::{DataType, Field}, error::{fail, Error, Result}, - schema::ArrowOrCustomField, - utils::value, + schema::transmute_field, }; /// Easily construct a field for tensors with fixed shape @@ -58,8 +57,7 @@ impl FixedShapeTensorField { /// with the the name `"element"`. The field type can be any valid Arrow /// type. pub fn new(name: &str, element: impl Serialize, shape: Vec) -> Result { - let element: ArrowOrCustomField = value::transmute(&element)?; - let element = element.into_field()?; + let element = transmute_field(element)?; if element.name != "element" { fail!("The element field of FixedShapeTensorField must be named \"element\""); } @@ -174,8 +172,7 @@ pub struct VariableShapeTensorField { impl VariableShapeTensorField { pub fn new(name: &str, element: impl serde::ser::Serialize, ndim: usize) -> Result { - let element: ArrowOrCustomField = value::transmute(&element)?; - let element = element.into_field()?; + let element = transmute_field(element)?; if element.name != "element" { fail!("The element field of FixedShapeTensorField must be named \"element\""); } diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 1748d05d..a8428ae3 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -590,15 +590,18 @@ mod test { use serde::Serialize; use serde_json::{json, Value}; - use crate::internal::schema::{ArrowOrCustomField, TracingOptions}; + use crate::{ + internal::schema::{SerdeArrowSchema, TracingOptions}, + schema::SchemaLike, + }; use super::*; fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); - let expected = serde_json::from_value::(expected).unwrap(); - let expected = expected.into_field().unwrap(); + let expected = SerdeArrowSchema::from_value(&[expected]).unwrap(); + let expected = expected.fields.into_iter().next().unwrap(); assert_eq!(field, expected); } diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index c651e988..e33de0d9 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -19,7 +19,6 @@ use crate::internal::{ use ::serde::{Deserialize, Serialize}; -pub use serde::deserialize::ArrowOrCustomField; pub use strategy::get_strategy_from_metadata; pub use strategy::{merge_strategy_with_metadata, Strategy, STRATEGY_KEY}; use tracer::Tracer; @@ -307,6 +306,17 @@ impl SchemaLike for SerdeArrowSchema { } } +/// Wrapper around `SerdeArrowSchema::from_value` to convert a single field +/// +/// This function takes anything that serialized into a field and converts it into a field. +pub fn transmute_field(field: impl Serialize) -> Result { + let expected = SerdeArrowSchema::from_value(&[field])?; + let Some(field) = expected.fields.into_iter().next() else { + fail!("unexpected error in transmute_field: no field found"); + }; + Ok(field) +} + pub fn validate_field(field: &Field) -> Result<()> { match &field.data_type { DataType::Null => validate_null_field(field), diff --git a/serde_arrow/src/internal/schema/serde/deserialize.rs b/serde_arrow/src/internal/schema/serde/deserialize.rs index c5159773..389c6989 100644 --- a/serde_arrow/src/internal/schema/serde/deserialize.rs +++ b/serde_arrow/src/internal/schema/serde/deserialize.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use serde::de::Visitor; +use serde::{de::Visitor, Deserialize}; use crate::internal::{ arrow::{DataType, Field}, @@ -27,7 +27,7 @@ impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { use serde::de::Error; let mut fields = Vec::new(); - while let Some(item) = seq.next_element::()? { + while let Some(item) = seq.next_element::()? { fields.push(item.into_field().map_err(A::Error::custom)?); } @@ -44,7 +44,7 @@ impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { while let Some(key) = map.next_key::()? { if key == "fields" { - fields = Some(map.next_value::>()?); + fields = Some(map.next_value::>()?); } else { map.next_value::()?; } @@ -69,30 +69,7 @@ impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { } } -pub enum ArrowOrCustomField { - Arrow(Field), - Custom(CustomField), -} - -impl ArrowOrCustomField { - pub fn into_field(self) -> Result { - let field = match self { - ArrowOrCustomField::Arrow(field) => return Ok(field), - ArrowOrCustomField::Custom(field) => field, - }; - - todo!() - } -} - -impl<'de> serde::Deserialize<'de> for ArrowOrCustomField { - fn deserialize>( - deserializer: D, - ) -> std::result::Result { - todo!() - } -} - +#[derive(Debug, Clone, Deserialize)] pub struct CustomField { name: String, data_type: ArrowOrCustomDataType, @@ -101,6 +78,13 @@ pub struct CustomField { metadata: HashMap, } +impl CustomField { + pub fn into_field(self) -> Result { + todo!() + } +} + +#[derive(Debug, Clone)] pub enum ArrowOrCustomDataType { Arrow(DataType), Custom(String), @@ -111,3 +95,11 @@ impl ArrowOrCustomDataType { todo!() } } + +impl<'de> serde::Deserialize<'de> for ArrowOrCustomDataType { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + todo!() + } +} diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs index 585012fb..874857f5 100644 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -3,7 +3,7 @@ use serde_json::json; use crate::internal::{ arrow::{DataType, Field}, error::PanicOnError, - schema::{ArrowOrCustomField, STRATEGY_KEY}, + schema::{transmute_field, STRATEGY_KEY}, testing::hash_map, }; @@ -25,8 +25,7 @@ fn i16_field_simple() -> PanicOnError<()> { let actual = serde_json::to_value(&SerializableField(&field))?; assert_eq!(actual, expected); - let roundtripped = serde_json::from_value::(actual)?; - let roundtripped = roundtripped.into_field()?; + let roundtripped = transmute_field(&actual)?; assert_eq!(roundtripped, field); Ok(()) @@ -56,8 +55,7 @@ fn date64_field_complex() -> PanicOnError<()> { let actual = serde_json::to_value(&field)?; assert_eq!(actual, expected); - let roundtripped = serde_json::from_value::(actual)?; - let roundtripped = roundtripped.into_field()?; + let roundtripped = transmute_field(&actual)?; assert_eq!(roundtripped, field); Ok(()) @@ -89,8 +87,7 @@ fn list_field_complex() -> PanicOnError<()> { let actual = serde_json::to_value(&field)?; assert_eq!(actual, expected); - let roundtripped = serde_json::from_value::(actual)?; - let roundtripped = roundtripped.into_field()?; + let roundtripped = transmute_field(&actual)?; assert_eq!(roundtripped, field); Ok(()) diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index b20ac84e..025af322 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use serde::Serialize; -use crate::internal::{arrow::Field, error::Result, schema::ArrowOrCustomField, utils::value}; +use crate::internal::{arrow::Field, error::Result, schema::transmute_field, utils::value}; #[derive(Debug, Clone, Copy, PartialEq)] pub enum TracingMode { @@ -277,11 +277,10 @@ impl TracingOptions { /// Add an overwrite to [`overwrites`](#structfield.overwrites) pub fn overwrite, F: Serialize>(mut self, path: P, field: F) -> Result { - let path = path.into(); - let path = format!("$.{path}"); - let field: ArrowOrCustomField = value::transmute(&field)?; - - self.overwrites.0.insert(path, field.into_field()?); + self.overwrites.0.insert( + format!("$.{path}", path = path.into()), + transmute_field(field)?, + ); Ok(self) } diff --git a/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs b/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs index 13e711ba..33aa305d 100644 --- a/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs +++ b/serde_arrow/src/test_with_arrow/issue_35_preserve_metadata.rs @@ -7,7 +7,7 @@ use crate::{ _impl::{arrow, arrow2}, internal::{ arrow::Field, - schema::{ArrowOrCustomField, SchemaLike, SerdeArrowSchema, STRATEGY_KEY}, + schema::{transmute_field, SchemaLike, SerdeArrowSchema, STRATEGY_KEY}, testing::hash_map, }, }; @@ -33,8 +33,7 @@ fn example_field_desc() -> serde_json::Value { #[test] fn arrow() { - let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); - let initial_field = initial_field.into_field().unwrap(); + let initial_field = transmute_field(example_field_desc()).unwrap(); assert_eq!( initial_field.metadata, hash_map!("foo" => "bar", STRATEGY_KEY => "MapAsStruct") @@ -58,8 +57,7 @@ fn arrow() { #[test] fn arrow2() { - let initial_field = serde_json::from_value::(example_field_desc()).unwrap(); - let initial_field = initial_field.into_field().unwrap(); + let initial_field = transmute_field(example_field_desc()).unwrap(); assert_eq!( initial_field.metadata, hash_map!("foo" => "bar", STRATEGY_KEY => "MapAsStruct") From f6779f69946c2234b7b6d41fbca73ec38792c18d Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 5 Aug 2024 11:47:30 +0200 Subject: [PATCH 088/178] Implement Arrow -> Arrow field conversion --- serde_arrow/src/arrow_impl/schema.rs | 154 ++++++++++-------- serde_arrow/src/internal/schema/mod.rs | 3 +- .../src/internal/schema/serde/deserialize.rs | 124 +++++++++++++- serde_arrow/src/internal/schema/strategy.rs | 52 ------ .../src/test_with_arrow/impls/struct.rs | 1 - 5 files changed, 208 insertions(+), 126 deletions(-) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index b76f1ef3..f119d7e7 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -121,54 +121,36 @@ impl TryFrom<&ArrowDataType> for DataType { type Error = Error; fn try_from(value: &ArrowDataType) -> Result { - use {DataType as T, TimeUnit as U}; + use {ArrowDataType as AT, DataType as T}; match value { - ArrowDataType::Boolean => Ok(T::Boolean), - ArrowDataType::Null => Ok(T::Null), - ArrowDataType::Int8 => Ok(T::Int8), - ArrowDataType::Int16 => Ok(T::Int16), - ArrowDataType::Int32 => Ok(T::Int32), - ArrowDataType::Int64 => Ok(T::Int64), - ArrowDataType::UInt8 => Ok(T::UInt8), - ArrowDataType::UInt16 => Ok(T::UInt16), - ArrowDataType::UInt32 => Ok(T::UInt32), - ArrowDataType::UInt64 => Ok(T::UInt64), - ArrowDataType::Float16 => Ok(T::Float16), - ArrowDataType::Float32 => Ok(T::Float32), - ArrowDataType::Float64 => Ok(T::Float64), - ArrowDataType::Utf8 => Ok(T::Utf8), - ArrowDataType::LargeUtf8 => Ok(T::LargeUtf8), - ArrowDataType::Date32 => Ok(T::Date32), - ArrowDataType::Date64 => Ok(T::Date64), - ArrowDataType::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), - ArrowDataType::Time32(ArrowTimeUnit::Second) => Ok(T::Time32(U::Second)), - ArrowDataType::Time32(ArrowTimeUnit::Millisecond) => Ok(T::Time32(U::Millisecond)), - ArrowDataType::Time32(unit) => fail!("Invalid time unit {unit:?} for Time32"), - ArrowDataType::Time64(ArrowTimeUnit::Microsecond) => Ok(T::Time64(U::Microsecond)), - ArrowDataType::Time64(ArrowTimeUnit::Nanosecond) => Ok(T::Time64(U::Nanosecond)), - ArrowDataType::Time64(unit) => fail!("Invalid time unit {unit:?} for Time64"), - ArrowDataType::Timestamp(ArrowTimeUnit::Second, tz) => { - Ok(T::Timestamp(U::Second, tz.as_ref().map(|s| s.to_string()))) - } - ArrowDataType::Timestamp(ArrowTimeUnit::Millisecond, tz) => Ok(T::Timestamp( - U::Millisecond, - tz.as_ref().map(|s| s.to_string()), - )), - ArrowDataType::Timestamp(ArrowTimeUnit::Microsecond, tz) => Ok(T::Timestamp( - U::Microsecond, + AT::Boolean => Ok(T::Boolean), + AT::Null => Ok(T::Null), + AT::Int8 => Ok(T::Int8), + AT::Int16 => Ok(T::Int16), + AT::Int32 => Ok(T::Int32), + AT::Int64 => Ok(T::Int64), + AT::UInt8 => Ok(T::UInt8), + AT::UInt16 => Ok(T::UInt16), + AT::UInt32 => Ok(T::UInt32), + AT::UInt64 => Ok(T::UInt64), + AT::Float16 => Ok(T::Float16), + AT::Float32 => Ok(T::Float32), + AT::Float64 => Ok(T::Float64), + AT::Utf8 => Ok(T::Utf8), + AT::LargeUtf8 => Ok(T::LargeUtf8), + AT::Date32 => Ok(T::Date32), + AT::Date64 => Ok(T::Date64), + AT::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), + AT::Time32(unit) => Ok(T::Time32(unit.clone().into())), + AT::Time64(unit) => Ok(T::Time64(unit.clone().into())), + AT::Timestamp(unit, tz) => Ok(T::Timestamp( + unit.clone().into(), tz.as_ref().map(|s| s.to_string()), )), - ArrowDataType::Timestamp(ArrowTimeUnit::Nanosecond, tz) => Ok(T::Timestamp( - U::Nanosecond, - tz.as_ref().map(|s| s.to_string()), - )), - ArrowDataType::Duration(ArrowTimeUnit::Second) => Ok(T::Duration(U::Second)), - ArrowDataType::Duration(ArrowTimeUnit::Millisecond) => Ok(T::Duration(U::Millisecond)), - ArrowDataType::Duration(ArrowTimeUnit::Microsecond) => Ok(T::Duration(U::Microsecond)), - ArrowDataType::Duration(ArrowTimeUnit::Nanosecond) => Ok(T::Duration(U::Nanosecond)), - ArrowDataType::Binary => Ok(T::Binary), - ArrowDataType::LargeBinary => Ok(T::LargeBinary), - ArrowDataType::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), + AT::Duration(unit) => Ok(T::Duration(unit.clone().into())), + AT::Binary => Ok(T::Binary), + AT::LargeBinary => Ok(T::LargeBinary), + AT::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), _ => fail!("Only primitive data types can be converted to T"), } } @@ -193,7 +175,38 @@ impl TryFrom<&DataType> for ArrowDataType { type Error = Error; fn try_from(value: &DataType) -> std::result::Result { - todo!() + use {ArrowDataType as AT, DataType as T}; + match value { + T::Boolean => Ok(AT::Boolean), + T::Null => Ok(AT::Null), + T::Int8 => Ok(AT::Int8), + T::Int16 => Ok(AT::Int16), + T::Int32 => Ok(AT::Int32), + T::Int64 => Ok(AT::Int64), + T::UInt8 => Ok(AT::UInt8), + T::UInt16 => Ok(AT::UInt16), + T::UInt32 => Ok(AT::UInt32), + T::UInt64 => Ok(AT::UInt64), + T::Float16 => Ok(AT::Float16), + T::Float32 => Ok(AT::Float32), + T::Float64 => Ok(AT::Float64), + T::Utf8 => Ok(AT::Utf8), + T::LargeUtf8 => Ok(AT::LargeUtf8), + T::Date32 => Ok(AT::Date32), + T::Date64 => Ok(AT::Date64), + T::Decimal128(precision, scale) => Ok(AT::Decimal128(*precision, *scale)), + T::Time32(unit) => Ok(AT::Time32((*unit).into())), + T::Time64(unit) => Ok(AT::Time64((*unit).into())), + T::Timestamp(unit, tz) => Ok(AT::Timestamp( + (*unit).into(), + tz.as_ref().map(|s| s.to_string().into()), + )), + T::Duration(unit) => Ok(AT::Duration((*unit).into())), + T::Binary => Ok(AT::Binary), + T::LargeBinary => Ok(AT::LargeBinary), + T::FixedSizeBinary(n) => Ok(AT::FixedSizeBinary(*n)), + _ => fail!("Only primitive data types can be converted to T"), + } } } @@ -212,31 +225,34 @@ impl TryFrom<&Field> for ArrowField { } } -impl From for ArrowTimeUnit { - fn from(value: TimeUnit) -> Self { - match value { - TimeUnit::Second => Self::Second, - TimeUnit::Millisecond => Self::Millisecond, - TimeUnit::Microsecond => Self::Microsecond, - TimeUnit::Nanosecond => Self::Nanosecond, +macro_rules! impl_from_one_to_one { + ( + $src_ty:ty => $dst_ty:ty, + [ + $($src_variant:ident => $dst_variant:ident),* + ] + ) => { + impl From<$dst_ty> for $src_ty { + fn from(value: $dst_ty) -> Self { + match value { + $(<$dst_ty>::$dst_variant => <$src_ty>::$src_variant,)* + } + } } - } -} -impl From for UnionMode { - fn from(value: ArrowUnionMode) -> Self { - match value { - ArrowUnionMode::Dense => UnionMode::Dense, - ArrowUnionMode::Sparse => UnionMode::Sparse, + impl From<$src_ty> for $dst_ty { + fn from(value: $src_ty) -> Self { + match value { + $(<$src_ty>::$src_variant => <$dst_ty>::$dst_variant,)* + } + } } - } + }; } -impl From for ArrowUnionMode { - fn from(value: UnionMode) -> Self { - match value { - UnionMode::Dense => ArrowUnionMode::Dense, - UnionMode::Sparse => ArrowUnionMode::Sparse, - } - } -} +impl_from_one_to_one!( + TimeUnit => ArrowTimeUnit, + [Second => Second, Millisecond => Millisecond, Microsecond => Microsecond, Nanosecond => Nanosecond] +); + +impl_from_one_to_one!(UnionMode => ArrowUnionMode, [Sparse => Sparse, Dense => Dense]); diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index e33de0d9..f0eca260 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -19,8 +19,7 @@ use crate::internal::{ use ::serde::{Deserialize, Serialize}; -pub use strategy::get_strategy_from_metadata; -pub use strategy::{merge_strategy_with_metadata, Strategy, STRATEGY_KEY}; +pub use strategy::{get_strategy_from_metadata, Strategy, STRATEGY_KEY}; use tracer::Tracer; pub use tracing_options::{Overwrites, TracingMode, TracingOptions}; diff --git a/serde_arrow/src/internal/schema/serde/deserialize.rs b/serde_arrow/src/internal/schema/serde/deserialize.rs index 389c6989..b72c797f 100644 --- a/serde_arrow/src/internal/schema/serde/deserialize.rs +++ b/serde_arrow/src/internal/schema/serde/deserialize.rs @@ -73,14 +73,36 @@ impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { pub struct CustomField { name: String, data_type: ArrowOrCustomDataType, + #[serde(default)] + nullable: bool, + #[serde(default)] strategy: Option, + #[serde(default)] children: Vec, + #[serde(default)] metadata: HashMap, } impl CustomField { pub fn into_field(self) -> Result { - todo!() + match self.data_type { + ArrowOrCustomDataType::Arrow(data_type) => { + if !self.children.is_empty() { + fail!("Cannot use children with an arrow data type"); + } + + let metadata = merge_strategy_with_metadata(self.metadata, self.strategy)?; + Ok(Field { + name: self.name, + nullable: self.nullable, + data_type, + metadata, + }) + } + ArrowOrCustomDataType::Custom(data_type) => { + todo!() + } + } } } @@ -100,6 +122,104 @@ impl<'de> serde::Deserialize<'de> for ArrowOrCustomDataType { fn deserialize>( deserializer: D, ) -> std::result::Result { - todo!() + struct VisitorImpl; + + impl<'de> Visitor<'de> for VisitorImpl { + type Value = ArrowOrCustomDataType; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "string or DataType variant") + } + + fn visit_newtype_struct>( + self, + deserializer: D, + ) -> Result { + ArrowOrCustomDataType::deserialize(deserializer) + } + + fn visit_str(self, v: &str) -> Result { + Ok(ArrowOrCustomDataType::Custom(v.to_string())) + } + + fn visit_enum>( + self, + data: A, + ) -> Result { + let field = DataType::deserialize(EnumDeserializer(data))?; + Ok(ArrowOrCustomDataType::Arrow(field)) + } + } + + deserializer.deserialize_any(VisitorImpl) + } +} + +/// A helper to deserialize from an `EnumAccess` object directly +struct EnumDeserializer(A); + +impl<'de, A: serde::de::EnumAccess<'de>> serde::de::Deserializer<'de> for EnumDeserializer { + type Error = A::Error; + + fn deserialize_any>(self, visitor: V) -> Result { + visitor.visit_enum(self.0) + } + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum identifier ignored_any + } +} + +pub fn merge_strategy_with_metadata( + mut metadata: HashMap, + strategy: Option, +) -> Result> { + if metadata.contains_key(STRATEGY_KEY) && strategy.is_some() { + fail!("Duplicate strategy: metadata map contains {STRATEGY_KEY} and strategy given"); + } + if let Some(strategy) = strategy { + metadata.insert(STRATEGY_KEY.to_owned(), strategy.to_string()); } + Ok(metadata) +} + +#[test] +fn test_split_strategy_from_metadata_with_metadata() { + use crate::internal::testing::hash_map; + + let metadata: HashMap = hash_map!( + "key1" => "value1", + "key2" => "value2", + ); + let strategy: Option = Some(Strategy::TupleAsStruct); + + let expected: HashMap = hash_map!( + "SERDE_ARROW:strategy" => "TupleAsStruct", + "key1" => "value1", + "key2" => "value2", + ); + + let actual = merge_strategy_with_metadata(metadata, strategy).unwrap(); + assert_eq!(actual, expected); +} + +#[test] +fn test_split_strategy_from_metadata_without_metadata() { + use crate::internal::testing::hash_map; + + let metadata: HashMap = hash_map!( + "key1" => "value1", + "key2" => "value2", + ); + let strategy: Option = None; + + let expected: HashMap = hash_map!( + "key1" => "value1", + "key2" => "value2", + ); + + let actual = merge_strategy_with_metadata(metadata, strategy).unwrap(); + assert_eq!(actual, expected); } diff --git a/serde_arrow/src/internal/schema/strategy.rs b/serde_arrow/src/internal/schema/strategy.rs index 7622df25..29162d64 100644 --- a/serde_arrow/src/internal/schema/strategy.rs +++ b/serde_arrow/src/internal/schema/strategy.rs @@ -136,55 +136,3 @@ pub fn get_strategy_from_metadata(metadata: &HashMap) -> Result< }; Ok(Some(strategy.parse()?)) } - -pub fn merge_strategy_with_metadata( - mut metadata: HashMap, - strategy: Option, -) -> Result> { - if metadata.contains_key(STRATEGY_KEY) && strategy.is_some() { - fail!("Duplicate strategy: metadata map contains {STRATEGY_KEY} and strategy given"); - } - if let Some(strategy) = strategy { - metadata.insert(STRATEGY_KEY.to_owned(), strategy.to_string()); - } - Ok(metadata) -} - -#[test] -fn test_split_strategy_from_metadata_with_metadata() { - use crate::internal::testing::hash_map; - - let metadata: HashMap = hash_map!( - "key1" => "value1", - "key2" => "value2", - ); - let strategy: Option = Some(Strategy::TupleAsStruct); - - let expected: HashMap = hash_map!( - "SERDE_ARROW:strategy" => "TupleAsStruct", - "key1" => "value1", - "key2" => "value2", - ); - - let actual = merge_strategy_with_metadata(metadata, strategy).unwrap(); - assert_eq!(actual, expected); -} - -#[test] -fn test_split_strategy_from_metadata_without_metadata() { - use crate::internal::testing::hash_map; - - let metadata: HashMap = hash_map!( - "key1" => "value1", - "key2" => "value2", - ); - let strategy: Option = None; - - let expected: HashMap = hash_map!( - "key1" => "value1", - "key2" => "value2", - ); - - let actual = merge_strategy_with_metadata(metadata, strategy).unwrap(); - assert_eq!(actual, expected); -} diff --git a/serde_arrow/src/test_with_arrow/impls/struct.rs b/serde_arrow/src/test_with_arrow/impls/struct.rs index 85c6d68d..81b7c134 100644 --- a/serde_arrow/src/test_with_arrow/impls/struct.rs +++ b/serde_arrow/src/test_with_arrow/impls/struct.rs @@ -76,7 +76,6 @@ fn struct_nested() { #[test] fn struct_nullable_field() { - type Ty = S; let values = [ Item(S { a: Some(1), From 0646dc269968e4e88148999e9fecec664766a79f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 5 Aug 2024 12:06:49 +0200 Subject: [PATCH 089/178] Initial impl of deserialization logic --- serde_arrow/src/internal/schema/data_type.rs | 161 ------------------ serde_arrow/src/internal/schema/mod.rs | 6 +- .../src/internal/schema/serde/deserialize.rs | 136 ++++++++++++--- .../src/internal/schema/serde/serialize.rs | 1 + .../src/internal/schema/tracing_options.rs | 2 +- .../src/test_with_arrow/impls/playground.rs | 154 ----------------- 6 files changed, 115 insertions(+), 345 deletions(-) delete mode 100644 serde_arrow/src/internal/schema/data_type.rs delete mode 100644 serde_arrow/src/test_with_arrow/impls/playground.rs diff --git a/serde_arrow/src/internal/schema/data_type.rs b/serde_arrow/src/internal/schema/data_type.rs deleted file mode 100644 index 0e9e9aa8..00000000 --- a/serde_arrow/src/internal/schema/data_type.rs +++ /dev/null @@ -1,161 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::internal::{ - arrow::TimeUnit, - error::{fail, Error, Result}, - utils::dsl::Term, -}; - -#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Serialize, Deserialize)] -#[serde(try_from = "GenericDataTypeString", into = "GenericDataTypeString")] -pub enum GenericDataType { - Null, - Bool, - I8, - I16, - I32, - I64, - U8, - U16, - U32, - U64, - F16, - F32, - F64, - Utf8, - LargeUtf8, - Date32, - Date64, - Time32(TimeUnit), - Time64(TimeUnit), - Duration(TimeUnit), - Struct, - List, - LargeList, - FixedSizeList(i32), - Binary, - LargeBinary, - FixedSizeBinary(i32), - Union, - Map, - Dictionary, - Timestamp(TimeUnit, Option), - Decimal128(u8, i8), -} - -impl std::fmt::Display for GenericDataType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use GenericDataType::*; - match self { - Null => write!(f, "Null"), - Bool => write!(f, "Bool"), - Utf8 => write!(f, "Utf8"), - LargeUtf8 => write!(f, "LargeUtf8"), - I8 => write!(f, "I8"), - I16 => write!(f, "I16"), - I32 => write!(f, "I32"), - I64 => write!(f, "I64"), - U8 => write!(f, "U8"), - U16 => write!(f, "U16"), - U32 => write!(f, "U32"), - U64 => write!(f, "U64"), - F16 => write!(f, "F16"), - F32 => write!(f, "F32"), - F64 => write!(f, "F64"), - Date32 => write!(f, "Date32"), - Date64 => write!(f, "Date64"), - Struct => write!(f, "Struct"), - List => write!(f, "List"), - LargeList => write!(f, "LargeList"), - FixedSizeList(n) => write!(f, "FixedSizeList({n})"), - Binary => write!(f, "Binary"), - LargeBinary => write!(f, "LargeBinary"), - FixedSizeBinary(n) => write!(f, "FixedSizeBinary({n})"), - Union => write!(f, "Union"), - Map => write!(f, "Map"), - Dictionary => write!(f, "Dictionary"), - Timestamp(unit, timezone) => { - if let Some(timezone) = timezone { - write!(f, "Timestamp({unit}, Some(\"{timezone}\"))") - } else { - write!(f, "Timestamp({unit}, None)") - } - } - Time32(unit) => write!(f, "Time32({unit})"), - Time64(unit) => write!(f, "Time64({unit})"), - Duration(unit) => write!(f, "Duration({unit})"), - Decimal128(precision, scale) => write!(f, "Decimal128({precision}, {scale})"), - } - } -} - -impl std::str::FromStr for GenericDataType { - type Err = Error; - - fn from_str(s: &str) -> Result { - use GenericDataType as T; - - let res = match Term::from_str(s)?.as_call()? { - ("Null", []) => T::Null, - ("Bool" | "Boolean", []) => T::Bool, - ("Utf8", []) => T::Utf8, - ("LargeUtf8", []) => T::LargeUtf8, - ("U8" | "UInt8", []) => T::U8, - ("U16" | "UInt16", []) => T::U16, - ("U32" | "UInt32", []) => T::U32, - ("U64" | "UInt64", []) => T::U64, - ("I8" | "Int8", []) => T::I8, - ("I16" | "Int16", []) => T::I16, - ("I32" | "Int32", []) => T::I32, - ("I64" | "Int64", []) => T::I64, - ("F16" | "Float16", []) => T::F16, - ("F32" | "Float32", []) => T::F32, - ("F64" | "Float64", []) => T::F64, - ("Date32", []) => T::Date32, - ("Date64", []) => T::Date64, - ("Struct", []) => T::Struct, - ("List", []) => T::List, - ("LargeList", []) => T::LargeList, - ("FixedSizeList", [n]) => T::FixedSizeList(n.as_ident()?.parse()?), - ("Binary", []) => T::Binary, - ("LargeBinary", []) => T::LargeBinary, - ("FixedSizeBinary", [n]) => T::FixedSizeBinary(n.as_ident()?.parse()?), - ("Union", []) => T::Union, - ("Map", []) => T::Map, - ("Dictionary", []) => T::Dictionary, - ("Timestamp", [unit, timezone]) => { - let unit: TimeUnit = unit.as_ident()?.parse()?; - let timezone = timezone - .as_option()? - .map(|term| term.as_string()) - .transpose()?; - T::Timestamp(unit, timezone.map(|s| s.to_owned())) - } - ("Time32", [unit]) => T::Time32(unit.as_ident()?.parse()?), - ("Time64", [unit]) => T::Time64(unit.as_ident()?.parse()?), - ("Duration", [unit]) => T::Duration(unit.as_ident()?.parse()?), - ("Decimal128", [precision, scale]) => { - T::Decimal128(precision.as_ident()?.parse()?, scale.as_ident()?.parse()?) - } - _ => fail!("invalid data type {s}"), - }; - Ok(res) - } -} - -#[derive(Serialize, Deserialize)] -struct GenericDataTypeString(String); - -impl TryFrom for GenericDataType { - type Error = Error; - - fn try_from(value: GenericDataTypeString) -> std::result::Result { - value.0.parse() - } -} - -impl From for GenericDataTypeString { - fn from(value: GenericDataType) -> Self { - Self(value.to_string()) - } -} diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index f0eca260..e48e152d 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -1,6 +1,4 @@ pub mod extensions; - -mod data_type; mod from_samples; mod from_type; mod serde; @@ -481,7 +479,7 @@ fn validate_struct_field(field: &Field, children: &[Field]) -> Result<()> { Ok(()) } -fn validate_map_field(field: &Field, entry: &Field) -> Result<()> { +fn validate_map_field(field: &Field, _entry: &Field) -> Result<()> { if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { fail!("invalid strategy for Map field: {strategy}"); } @@ -490,7 +488,7 @@ fn validate_map_field(field: &Field, entry: &Field) -> Result<()> { Ok(()) } -fn validate_union_field(field: &Field, children: &[(i8, Field)], mode: UnionMode) -> Result<()> { +fn validate_union_field(field: &Field, children: &[(i8, Field)], _mode: UnionMode) -> Result<()> { if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { fail!("invalid strategy for Union field: {strategy}"); } diff --git a/serde_arrow/src/internal/schema/serde/deserialize.rs b/serde_arrow/src/internal/schema/serde/deserialize.rs index b72c797f..6d7d3740 100644 --- a/serde_arrow/src/internal/schema/serde/deserialize.rs +++ b/serde_arrow/src/internal/schema/serde/deserialize.rs @@ -1,11 +1,12 @@ -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; use serde::{de::Visitor, Deserialize}; use crate::internal::{ - arrow::{DataType, Field}, - error::{fail, Error, Result}, + arrow::{DataType, Field, TimeUnit, UnionMode}, + error::{fail, Result}, schema::{SerdeArrowSchema, Strategy, STRATEGY_KEY}, + utils::dsl::Term, }; // A custom impl of untagged-enum repr with better error messages @@ -70,7 +71,7 @@ impl<'de> serde::Deserialize<'de> for SerdeArrowSchema { } #[derive(Debug, Clone, Deserialize)] -pub struct CustomField { +struct CustomField { name: String, data_type: ArrowOrCustomDataType, #[serde(default)] @@ -84,40 +85,125 @@ pub struct CustomField { } impl CustomField { - pub fn into_field(self) -> Result { - match self.data_type { - ArrowOrCustomDataType::Arrow(data_type) => { - if !self.children.is_empty() { - fail!("Cannot use children with an arrow data type"); - } - - let metadata = merge_strategy_with_metadata(self.metadata, self.strategy)?; - Ok(Field { - name: self.name, - nullable: self.nullable, - data_type, - metadata, - }) - } - ArrowOrCustomDataType::Custom(data_type) => { - todo!() - } + fn into_field(self) -> Result { + let mut children = Vec::new(); + for child in self.children { + children.push(child.into_field()?); } + + let data_type = self.data_type.into_data_type(children)?; + let metadata = merge_strategy_with_metadata(self.metadata, self.strategy)?; + + Ok(Field { + name: self.name, + nullable: self.nullable, + data_type, + metadata, + }) } } #[derive(Debug, Clone)] -pub enum ArrowOrCustomDataType { +enum ArrowOrCustomDataType { Arrow(DataType), Custom(String), } impl ArrowOrCustomDataType { - pub fn into_data_type(self, children: Vec) -> Result { - todo!() + fn into_data_type(self, children: Vec) -> Result { + match self { + Self::Custom(data_type) => build_data_type(data_type, children), + Self::Arrow(data_type) => { + if !children.is_empty() { + fail!("Cannot use children with an arrow data type"); + } + Ok(data_type) + } + } } } +fn build_data_type(data_type: String, children: Vec) -> Result { + use DataType as T; + + let res = match Term::from_str(&data_type)?.as_call()? { + ("Null", []) => T::Null, + ("Bool" | "Boolean", []) => T::Boolean, + ("Utf8", []) => T::Utf8, + ("LargeUtf8", []) => T::LargeUtf8, + ("U8" | "UInt8", []) => T::UInt8, + ("U16" | "UInt16", []) => T::UInt16, + ("U32" | "UInt32", []) => T::UInt32, + ("U64" | "UInt64", []) => T::UInt64, + ("I8" | "Int8", []) => T::Int8, + ("I16" | "Int16", []) => T::Int16, + ("I32" | "Int32", []) => T::Int32, + ("I64" | "Int64", []) => T::Int64, + ("F16" | "Float16", []) => T::Float16, + ("F32" | "Float32", []) => T::Float32, + ("F64" | "Float64", []) => T::Float64, + ("Date32", []) => T::Date32, + ("Date64", []) => T::Date64, + ("Binary", []) => T::Binary, + ("LargeBinary", []) => T::LargeBinary, + ("FixedSizeBinary", [n]) => T::FixedSizeBinary(n.as_ident()?.parse()?), + ("Timestamp", [unit, timezone]) => { + let unit: TimeUnit = unit.as_ident()?.parse()?; + let timezone = timezone + .as_option()? + .map(|term| term.as_string()) + .transpose()?; + T::Timestamp(unit, timezone.map(|s| s.to_owned())) + } + ("Time32", [unit]) => T::Time32(unit.as_ident()?.parse()?), + ("Time64", [unit]) => T::Time64(unit.as_ident()?.parse()?), + ("Duration", [unit]) => T::Duration(unit.as_ident()?.parse()?), + ("Decimal128", [precision, scale]) => { + T::Decimal128(precision.as_ident()?.parse()?, scale.as_ident()?.parse()?) + } + ("Struct", []) => T::Struct(children), + ("List", []) => { + let Ok([child]) = <[_; 1]>::try_from(children) else { + fail!("Invalid children for List: expected one child"); + }; + T::List(Box::new(child)) + } + ("LargeList", []) => { + let Ok([child]) = <[_; 1]>::try_from(children) else { + fail!("Invalid children for List: expected one child"); + }; + T::LargeList(Box::new(child)) + } + ("FixedSizeList", [n]) => { + let Ok([child]) = <[_; 1]>::try_from(children) else { + fail!("Invalid children for LargeList: expected one child"); + }; + T::FixedSizeList(Box::new(child), n.as_ident()?.parse()?) + } + ("Dictionary", []) => { + let Ok([key, value]) = <[_; 2]>::try_from(children) else { + fail!("Invalid children for Dictionary: expected two children"); + }; + T::Dictionary(Box::new(key.data_type), Box::new(value.data_type), false) + } + ("Map", []) => { + let Ok([child]) = <[_; 1]>::try_from(children) else { + fail!("Invalid children for Map: expected one child"); + }; + T::Map(Box::new(child), false) + } + ("Union", []) => { + let mut children_with_type_ids = Vec::new(); + for (idx, child) in children.into_iter().enumerate() { + children_with_type_ids.push((idx.try_into()?, child)); + } + T::Union(children_with_type_ids, UnionMode::Dense) + } + _ => fail!("invalid data type {data_type}"), + }; + Ok(res) +} + impl<'de> serde::Deserialize<'de> for ArrowOrCustomDataType { fn deserialize>( deserializer: D, diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index ae63f888..dca618f1 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -149,6 +149,7 @@ impl<'a> serde::Serialize for SerializableDataTypeChildren<'a> { s.end() } T::Dictionary(key, value, _) => { + // TODO: this is incorrect, serialize as struct let mut s = serializer.serialize_seq(Some(2))?; s.serialize_element(&DictionaryField("key", key))?; s.serialize_element(&DictionaryField("value", value))?; diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index 025af322..754a25ae 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use serde::Serialize; -use crate::internal::{arrow::Field, error::Result, schema::transmute_field, utils::value}; +use crate::internal::{arrow::Field, error::Result, schema::transmute_field}; #[derive(Debug, Clone, Copy, PartialEq)] pub enum TracingMode { diff --git a/serde_arrow/src/test_with_arrow/impls/playground.rs b/serde_arrow/src/test_with_arrow/impls/playground.rs deleted file mode 100644 index 935edd98..00000000 --- a/serde_arrow/src/test_with_arrow/impls/playground.rs +++ /dev/null @@ -1,154 +0,0 @@ -macro_rules! test_roundtrip_arrays { - ( - $name:ident { - $($setup:tt)* - } - assert_round_trip( - $fields:expr, - $inputs:expr - $(, expected: $expected:expr)? - ); - ) => { - mod $name { - use serde::{Deserialize, Serialize}; - - use crate::{ - arrow, arrow2, - internal::{ - deserialize_from_arrays, - schema::{GenericDataType, GenericField}, - }, - Result, - }; - - #[test] - fn arrow2() { - use crate::_impl::arrow2::datatypes::Field; - $($setup)* - - let fields = $fields; - let inputs = $inputs; - - let expected = inputs; - $(let expected = $expected;)? - - - let fields = fields.iter().map(|f| Field::try_from(f)).collect::>>().unwrap(); - let arrays = arrow2::serialize_into_arrays(&fields, inputs).unwrap(); - - let reconstructed: Vec = deserialize_from_arrays(&fields, &arrays).unwrap(); - assert_eq!(reconstructed, expected); - } - - #[test] - fn arrow() { - use crate::_impl::arrow::datatypes::Field; - $($setup)* - - let fields = $fields; - let inputs = $inputs; - - let expected = inputs; - $(let expected = $expected;)? - - let fields = fields.iter().map(|f| Field::try_from(f)).collect::>>().unwrap(); - let arrays = arrow::serialize_into_arrays(&fields, inputs).unwrap(); - - let reconstructed: Vec = deserialize_from_arrays(&fields, &arrays).unwrap(); - assert_eq!(reconstructed, expected); - } - } - }; -} - -test_roundtrip_arrays!( - example { - #[derive(Debug, PartialEq, Deserialize, Serialize)] - struct S { - a: i32, - b: f32, - } - - let items = &[S { a: 0, b: 2.0 }, S { a: 1, b: 3.0 }, S { a: 2, b: 4.0 }]; - - let fields = vec![ - GenericField::new("a", GenericDataType::I32, false), - GenericField::new("b", GenericDataType::F16, false), - ]; - } - assert_round_trip(fields, items); -); - -test_roundtrip_arrays!( - primitives { - #[derive(Debug, Default, PartialEq, Deserialize, Serialize)] - struct S { - a: u8, - b: u16, - c: u32, - d: u64, - e: u8, - f: u16, - g: u32, - h: u64, - i: f32, - j: f32, - k: f64, - } - - let items = &[ - S::default(), - S::default(), - S::default(), - ]; - - let fields = vec![ - GenericField::new("a", GenericDataType::U8, false), - GenericField::new("b", GenericDataType::U16, false), - GenericField::new("c", GenericDataType::U32, false), - GenericField::new("d", GenericDataType::U64, false), - GenericField::new("e", GenericDataType::I8, false), - GenericField::new("f", GenericDataType::I16, false), - GenericField::new("g", GenericDataType::I32, false), - GenericField::new("h", GenericDataType::I64, false), - GenericField::new("i", GenericDataType::F16, false), - GenericField::new("j", GenericDataType::F32, false), - GenericField::new("k", GenericDataType::F64, false), - ]; - } - assert_round_trip(fields, items); -); - -test_roundtrip_arrays!( - example_field_order_different_from_struct { - #[derive(Debug, PartialEq, Deserialize, Serialize)] - struct S { - a: i32, - b: f32, - } - - let items = &[S { a: 0, b: 2.0 }, S { a: 1, b: 3.0 }, S { a: 2, b: 4.0 }]; - - let fields = vec![ - GenericField::new("b", GenericDataType::F16, false), - GenericField::new("a", GenericDataType::I32, false), - ]; - } - assert_round_trip(fields, items); -); - -test_roundtrip_arrays!( - example_optional_fields { - #[derive(Debug, PartialEq, Deserialize, Serialize)] - struct S { - a: Option, - } - - let items = &[S { a: Some(0) }, S { a: None }, S { a: Some(2) }]; - - let fields = vec![ - GenericField::new("a", GenericDataType::I32, true), - ]; - } - assert_round_trip(fields, items); -); From 26964a7477553492d8e77ae1938d0abb4cdc60f9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 09:51:25 +0200 Subject: [PATCH 090/178] Implement more type conversions --- serde_arrow/src/arrow2_impl/schema.rs | 68 ++++++++++++++++++- serde_arrow/src/arrow_impl/schema.rs | 30 +++++++- .../src/internal/schema/from_samples/mod.rs | 5 +- 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 6cb1cb73..b4ffb8c6 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -6,7 +6,7 @@ use crate::{ internal::{ arrow::{DataType, Field, TimeUnit, UnionMode}, error::{fail, Error, Result}, - schema::{validate_field, SchemaLike, Sealed, SerdeArrowSchema}, + schema::{validate_field, DataTypeDisplay, SchemaLike, Sealed, SerdeArrowSchema}, }, }; @@ -151,7 +151,7 @@ impl TryFrom<&DataType> for ArrowDataType { type Error = Error; fn try_from(value: &DataType) -> std::result::Result { - use {ArrowDataType as AT, DataType as T}; + use {ArrowDataType as AT, ArrowField as AF, DataType as T, IntegerType as I}; match value { T::Null => Ok(AT::Null), T::Boolean => Ok(AT::Boolean), @@ -178,6 +178,70 @@ impl TryFrom<&DataType> for ArrowDataType { } Ok(AT::Decimal((*precision).try_into()?, (*scale).try_into()?)) } + T::Binary => Ok(AT::Binary), + T::LargeBinary => Ok(AT::LargeBinary), + T::Utf8 => Ok(AT::Utf8), + T::LargeUtf8 => Ok(AT::LargeUtf8), + T::Dictionary(key, value, sorted) => match key.as_ref() { + T::Int8 => Ok(AT::Dictionary( + I::Int8, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::Int16 => Ok(AT::Dictionary( + I::Int16, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::Int32 => Ok(AT::Dictionary( + I::Int32, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::Int64 => Ok(AT::Dictionary( + I::Int64, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::UInt8 => Ok(AT::Dictionary( + I::UInt8, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::UInt16 => Ok(AT::Dictionary( + I::UInt16, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::UInt32 => Ok(AT::Dictionary( + I::UInt32, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + T::UInt64 => Ok(AT::Dictionary( + I::UInt64, + AT::try_from(value.as_ref())?.into(), + *sorted, + )), + dt => fail!( + "unsupported dictionary key type {dt}", + dt = DataTypeDisplay(dt) + ), + }, + T::List(field) => Ok(AT::List(AF::try_from(field.as_ref())?.into())), + T::LargeList(field) => Ok(AT::LargeList(AF::try_from(field.as_ref())?.into())), + T::FixedSizeList(field, n) => Ok(AT::FixedSizeList( + AF::try_from(field.as_ref())?.into(), + (*n).try_into()?, + )), + T::Map(field, sorted) => Ok(AT::Map(AF::try_from(field.as_ref())?.into(), *sorted)), + T::Struct(in_fields) => { + let mut fields = Vec::new(); + for field in in_fields { + fields.push(AF::try_from(field)?); + } + Ok(AT::Struct(fields)) + } _ => todo!(), } } diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index f119d7e7..a8a1424b 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -121,7 +121,7 @@ impl TryFrom<&ArrowDataType> for DataType { type Error = Error; fn try_from(value: &ArrowDataType) -> Result { - use {ArrowDataType as AT, DataType as T}; + use {ArrowDataType as AT, DataType as T, Field as F}; match value { AT::Boolean => Ok(T::Boolean), AT::Null => Ok(T::Null), @@ -151,6 +151,19 @@ impl TryFrom<&ArrowDataType> for DataType { AT::Binary => Ok(T::Binary), AT::LargeBinary => Ok(T::LargeBinary), AT::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), + AT::List(field) => Ok(T::List(F::try_from(field.as_ref())?.into())), + AT::LargeList(field) => Ok(T::LargeList(F::try_from(field.as_ref())?.into())), + AT::FixedSizeList(field, n) => { + Ok(T::FixedSizeList(F::try_from(field.as_ref())?.into(), *n)) + } + AT::Map(field, sorted) => Ok(T::Map(F::try_from(field.as_ref())?.into(), *sorted)), + AT::Struct(in_fields) => { + let mut fields = Vec::new(); + for field in in_fields { + fields.push(field.as_ref().try_into()?); + } + Ok(T::Struct(fields)) + } _ => fail!("Only primitive data types can be converted to T"), } } @@ -175,7 +188,7 @@ impl TryFrom<&DataType> for ArrowDataType { type Error = Error; fn try_from(value: &DataType) -> std::result::Result { - use {ArrowDataType as AT, DataType as T}; + use {ArrowDataType as AT, ArrowField as AF, DataType as T}; match value { T::Boolean => Ok(AT::Boolean), T::Null => Ok(AT::Null), @@ -205,6 +218,19 @@ impl TryFrom<&DataType> for ArrowDataType { T::Binary => Ok(AT::Binary), T::LargeBinary => Ok(AT::LargeBinary), T::FixedSizeBinary(n) => Ok(AT::FixedSizeBinary(*n)), + T::List(field) => Ok(AT::List(AF::try_from(field.as_ref())?.into())), + T::LargeList(field) => Ok(AT::LargeList(AF::try_from(field.as_ref())?.into())), + T::FixedSizeList(field, n) => { + Ok(AT::FixedSizeList(AF::try_from(field.as_ref())?.into(), *n)) + } + T::Map(field, sorted) => Ok(AT::Map(AF::try_from(field.as_ref())?.into(), *sorted)), + T::Struct(in_fields) => { + let mut fields: Vec = Vec::new(); + for field in in_fields { + fields.push(AF::try_from(field)?.into()); + } + Ok(AT::Struct(fields.into())) + } _ => fail!("Only primitive data types can be converted to T"), } } diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index a8428ae3..0ac7f286 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -591,7 +591,7 @@ mod test { use serde_json::{json, Value}; use crate::{ - internal::schema::{SerdeArrowSchema, TracingOptions}, + internal::schema::{transmute_field, TracingOptions}, schema::SchemaLike, }; @@ -600,8 +600,7 @@ mod test { fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); - let expected = SerdeArrowSchema::from_value(&[expected]).unwrap(); - let expected = expected.fields.into_iter().next().unwrap(); + let expected = transmute_field(expected).unwrap(); assert_eq!(field, expected); } From 904164d07040a1bf99fae5c2c7164f8c25a9491a Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 09:54:15 +0200 Subject: [PATCH 091/178] Fully implement Dictionary --- serde_arrow/src/arrow_impl/schema.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index a8a1424b..b3476a70 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -164,6 +164,11 @@ impl TryFrom<&ArrowDataType> for DataType { } Ok(T::Struct(fields)) } + AT::Dictionary(key, value) => Ok(T::Dictionary( + T::try_from(key.as_ref())?.into(), + T::try_from(value.as_ref())?.into(), + false, + )), _ => fail!("Only primitive data types can be converted to T"), } } @@ -231,6 +236,10 @@ impl TryFrom<&DataType> for ArrowDataType { } Ok(AT::Struct(fields.into())) } + T::Dictionary(key, value, _sorted) => Ok(AT::Dictionary( + AT::try_from(key.as_ref())?.into(), + AT::try_from(value.as_ref())?.into(), + )), _ => fail!("Only primitive data types can be converted to T"), } } From ad38848abba2a1e18ac381e43632d18b92630c45 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 10:02:01 +0200 Subject: [PATCH 092/178] Fix schema serialization tests --- .../src/internal/schema/serde/serialize.rs | 22 +-- serde_arrow/src/internal/schema/serde/test.rs | 126 ++++++++++-------- 2 files changed, 81 insertions(+), 67 deletions(-) diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index dca618f1..3ef9a643 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -83,17 +83,17 @@ impl<'a> serde::Serialize for SerializableDataType<'a> { match self.0 { T::Null => "Null".serialize(serializer), T::Boolean => "Boolean".serialize(serializer), - T::Int8 => "Int8".serialize(serializer), - T::Int16 => "Int16".serialize(serializer), - T::Int32 => "Int32".serialize(serializer), - T::Int64 => "Int64".serialize(serializer), - T::UInt8 => "UInt8".serialize(serializer), - T::UInt16 => "UInt16".serialize(serializer), - T::UInt32 => "UInt32".serialize(serializer), - T::UInt64 => "UInt64".serialize(serializer), - T::Float16 => "Float16".serialize(serializer), - T::Float32 => "Float32".serialize(serializer), - T::Float64 => "Float64".serialize(serializer), + T::Int8 => "I8".serialize(serializer), + T::Int16 => "I16".serialize(serializer), + T::Int32 => "I32".serialize(serializer), + T::Int64 => "I64".serialize(serializer), + T::UInt8 => "U8".serialize(serializer), + T::UInt16 => "U16".serialize(serializer), + T::UInt32 => "U32".serialize(serializer), + T::UInt64 => "U64".serialize(serializer), + T::Float16 => "F16".serialize(serializer), + T::Float32 => "F32".serialize(serializer), + T::Float64 => "F64".serialize(serializer), T::Utf8 => "Utf8".serialize(serializer), T::LargeUtf8 => "LargeUtf8".serialize(serializer), T::Binary => "Binary".serialize(serializer), diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs index 874857f5..a88a7582 100644 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -1,94 +1,108 @@ use serde_json::json; -use crate::internal::{ +use crate::{internal::{ arrow::{DataType, Field}, error::PanicOnError, - schema::{transmute_field, STRATEGY_KEY}, + schema::STRATEGY_KEY, testing::hash_map, -}; - -use super::serialize::SerializableField; +}, schema::{SchemaLike, SerdeArrowSchema}}; #[test] fn i16_field_simple() -> PanicOnError<()> { - let field = Field { - name: String::from("my_field_name"), - data_type: DataType::Int16, - metadata: hash_map!(), - nullable: false, - }; + let schema = SerdeArrowSchema { fields: vec![ + Field { + name: String::from("my_field_name"), + data_type: DataType::Int16, + metadata: hash_map!(), + nullable: false, + }, + ]}; let expected = json!({ - "name": "my_field_name", - "data_type": "I16", + "fields": [ + { + "name": "my_field_name", + "data_type": "I16", + } + ], }); - let actual = serde_json::to_value(&SerializableField(&field))?; + + + let actual = serde_json::to_value(&schema)?; assert_eq!(actual, expected); - let roundtripped = transmute_field(&actual)?; - assert_eq!(roundtripped, field); + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); Ok(()) } #[test] fn date64_field_complex() -> PanicOnError<()> { - let field = Field { - name: String::from("my_field_name"), - data_type: DataType::Date64, - metadata: hash_map!( - "foo" => "bar", - STRATEGY_KEY => "NaiveStrAsDate64", - ), - nullable: true, - }; - let expected = json!({ - "name": "my_field_name", - "data_type": "Date64", - "metadata": { - "foo": "bar", + let schema = SerdeArrowSchema {fields: vec![ + Field { + name: String::from("my_field_name"), + data_type: DataType::Date64, + metadata: hash_map!( + "foo" => "bar", + STRATEGY_KEY => "NaiveStrAsDate64", + ), + nullable: true, }, - "strategy": "NaiveStrAsDate64", - "nullable": true, + ]}; + let expected = json!({ + "fields": [{ + "name": "my_field_name", + "data_type": "Date64", + "metadata": { + "foo": "bar", + }, + "strategy": "NaiveStrAsDate64", + "nullable": true, + }], }); - let actual = serde_json::to_value(&field)?; + let actual = serde_json::to_value(&schema)?; assert_eq!(actual, expected); - let roundtripped = transmute_field(&actual)?; - assert_eq!(roundtripped, field); + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); Ok(()) } #[test] fn list_field_complex() -> PanicOnError<()> { - let field = Field { - name: String::from("my_field_name"), - data_type: DataType::List(Box::new(Field { - name: String::from("element"), - data_type: DataType::Int64, - metadata: hash_map!(), - nullable: false, - })), - metadata: hash_map!("foo" => "bar"), - nullable: true, - }; + let schema = SerdeArrowSchema {fields: vec![ + Field { + name: String::from("my_field_name"), + data_type: DataType::List(Box::new(Field { + name: String::from("element"), + data_type: DataType::Int64, + metadata: hash_map!(), + nullable: false, + })), + metadata: hash_map!("foo" => "bar"), + nullable: true, + }, + ]}; let expected = json!({ - "name": "my_field_name", - "data_type": "List", - "metadata": {"foo": "bar"}, - "nullable": true, - "children": [ - {"name": "element", "data_type": "I64"}, - ] + "fields": [{ + "name": "my_field_name", + "data_type": "List", + "metadata": {"foo": "bar"}, + "nullable": true, + "children": [ + {"name": "element", "data_type": "I64"}, + ], + }], }); - let actual = serde_json::to_value(&field)?; + let actual = serde_json::to_value(&schema)?; assert_eq!(actual, expected); - let roundtripped = transmute_field(&actual)?; - assert_eq!(roundtripped, field); + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); Ok(()) } From d6f6df8795805ccc2e0e001aae2ca192c599a101 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 10:21:43 +0200 Subject: [PATCH 093/178] Fix various copy and paste errors --- .../src/internal/schema/serde/deserialize.rs | 8 +- serde_arrow/src/internal/schema/serde/test.rs | 41 +++--- serde_arrow/src/internal/schema/test.rs | 123 ++++++++++-------- .../impls/issue_90_type_tracing.rs | 36 +++-- serde_arrow/src/test_with_arrow/impls/list.rs | 2 +- serde_arrow/src/test_with_arrow/impls/map.rs | 16 +-- .../src/test_with_arrow/impls/struct.rs | 3 +- 7 files changed, 126 insertions(+), 103 deletions(-) diff --git a/serde_arrow/src/internal/schema/serde/deserialize.rs b/serde_arrow/src/internal/schema/serde/deserialize.rs index 6d7d3740..dd70089e 100644 --- a/serde_arrow/src/internal/schema/serde/deserialize.rs +++ b/serde_arrow/src/internal/schema/serde/deserialize.rs @@ -5,7 +5,7 @@ use serde::{de::Visitor, Deserialize}; use crate::internal::{ arrow::{DataType, Field, TimeUnit, UnionMode}, error::{fail, Result}, - schema::{SerdeArrowSchema, Strategy, STRATEGY_KEY}, + schema::{validate_field, SerdeArrowSchema, Strategy, STRATEGY_KEY}, utils::dsl::Term, }; @@ -94,12 +94,14 @@ impl CustomField { let data_type = self.data_type.into_data_type(children)?; let metadata = merge_strategy_with_metadata(self.metadata, self.strategy)?; - Ok(Field { + let field = Field { name: self.name, nullable: self.nullable, data_type, metadata, - }) + }; + validate_field(&field)?; + Ok(field) } } diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs index a88a7582..e207a011 100644 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -1,22 +1,25 @@ use serde_json::json; -use crate::{internal::{ - arrow::{DataType, Field}, - error::PanicOnError, - schema::STRATEGY_KEY, - testing::hash_map, -}, schema::{SchemaLike, SerdeArrowSchema}}; +use crate::{ + internal::{ + arrow::{DataType, Field}, + error::PanicOnError, + schema::STRATEGY_KEY, + testing::hash_map, + }, + schema::{SchemaLike, SerdeArrowSchema}, +}; #[test] fn i16_field_simple() -> PanicOnError<()> { - let schema = SerdeArrowSchema { fields: vec![ - Field { + let schema = SerdeArrowSchema { + fields: vec![Field { name: String::from("my_field_name"), data_type: DataType::Int16, metadata: hash_map!(), nullable: false, - }, - ]}; + }], + }; let expected = json!({ "fields": [ { @@ -26,8 +29,6 @@ fn i16_field_simple() -> PanicOnError<()> { ], }); - - let actual = serde_json::to_value(&schema)?; assert_eq!(actual, expected); @@ -39,8 +40,8 @@ fn i16_field_simple() -> PanicOnError<()> { #[test] fn date64_field_complex() -> PanicOnError<()> { - let schema = SerdeArrowSchema {fields: vec![ - Field { + let schema = SerdeArrowSchema { + fields: vec![Field { name: String::from("my_field_name"), data_type: DataType::Date64, metadata: hash_map!( @@ -48,8 +49,8 @@ fn date64_field_complex() -> PanicOnError<()> { STRATEGY_KEY => "NaiveStrAsDate64", ), nullable: true, - }, - ]}; + }], + }; let expected = json!({ "fields": [{ "name": "my_field_name", @@ -73,8 +74,8 @@ fn date64_field_complex() -> PanicOnError<()> { #[test] fn list_field_complex() -> PanicOnError<()> { - let schema = SerdeArrowSchema {fields: vec![ - Field { + let schema = SerdeArrowSchema { + fields: vec![Field { name: String::from("my_field_name"), data_type: DataType::List(Box::new(Field { name: String::from("element"), @@ -84,8 +85,8 @@ fn list_field_complex() -> PanicOnError<()> { })), metadata: hash_map!("foo" => "bar"), nullable: true, - }, - ]}; + }], + }; let expected = json!({ "fields": [{ "name": "my_field_name", diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index 8bd0bff9..58051204 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -10,19 +10,22 @@ use crate::internal::{ #[test] fn example() { - let mut schema = SerdeArrowSchema::default(); - schema.fields.push(Field { - name: String::from("foo"), - data_type: DataType::UInt8, - nullable: false, - metadata: HashMap::new(), - }); - schema.fields.push(Field { - name: String::from("bar"), - data_type: DataType::Utf8, - nullable: false, - metadata: Default::default(), - }); + let schema = SerdeArrowSchema { + fields: vec![ + Field { + name: String::from("foo"), + data_type: DataType::UInt8, + nullable: false, + metadata: HashMap::new(), + }, + Field { + name: String::from("bar"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }, + ], + }; let actual = serde_json::to_string(&schema).unwrap(); assert_eq!( @@ -36,19 +39,22 @@ fn example() { #[test] fn example_without_wrapper() { - let mut expected = SerdeArrowSchema::default(); - expected.fields.push(Field { - name: String::from("foo"), - data_type: DataType::UInt8, - nullable: false, - metadata: HashMap::new(), - }); - expected.fields.push(Field { - name: String::from("bar"), - data_type: DataType::Utf8, - nullable: false, - metadata: Default::default(), - }); + let expected = SerdeArrowSchema { + fields: vec![ + Field { + name: String::from("foo"), + data_type: DataType::UInt8, + nullable: false, + metadata: HashMap::new(), + }, + Field { + name: String::from("bar"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }, + ], + }; let input = r#"[{"name":"foo","data_type":"U8"},{"name":"bar","data_type":"Utf8"}]"#; let actual: SerdeArrowSchema = serde_json::from_str(&input).unwrap(); @@ -57,18 +63,19 @@ fn example_without_wrapper() { #[test] fn list() { - let mut schema = SerdeArrowSchema::default(); - schema.fields.push(Field { - name: String::from("value"), - data_type: DataType::List(Box::new(Field { - name: String::from("element"), - data_type: DataType::Int32, + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("value"), + data_type: DataType::List(Box::new(Field { + name: String::from("element"), + data_type: DataType::Int32, + nullable: false, + metadata: Default::default(), + })), nullable: false, metadata: Default::default(), - })), - nullable: false, - metadata: Default::default(), - }); + }], + }; let actual = serde_json::to_string(&schema).unwrap(); assert_eq!( @@ -91,32 +98,36 @@ fn doc_schema() { let actual: SerdeArrowSchema = serde_json::from_str(&schema).unwrap(); - let mut expected = SerdeArrowSchema::default(); - expected.fields.push(Field { - name: String::from("foo"), - data_type: DataType::UInt8, - nullable: false, - metadata: HashMap::new(), - }); - expected.fields.push(Field { - name: String::from("bar"), - data_type: DataType::Utf8, - nullable: false, - metadata: Default::default(), - }); + let expected = SerdeArrowSchema { + fields: vec![ + Field { + name: String::from("foo"), + data_type: DataType::UInt8, + nullable: false, + metadata: HashMap::new(), + }, + Field { + name: String::from("bar"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }, + ], + }; assert_eq!(actual, expected); } #[test] fn date64_with_strategy() { - let mut schema = SerdeArrowSchema::default(); - schema.fields.push(Field { - name: String::from("item"), - data_type: DataType::Date64, - nullable: false, - metadata: hash_map!( STRATEGY_KEY => Strategy::NaiveStrAsDate64 ), - }); + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("item"), + data_type: DataType::Date64, + nullable: false, + metadata: hash_map!( STRATEGY_KEY => Strategy::NaiveStrAsDate64 ), + }], + }; let actual = serde_json::to_string(&schema).unwrap(); assert_eq!( diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index 4264c4e0..e9d2b88a 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; use serde::Deserialize; +use serde_json::json; use crate::internal::{ arrow::{DataType, Field, UnionMode}, - schema::{tracer::Tracer, Strategy, TracingOptions, STRATEGY_KEY}, + schema::{tracer::Tracer, transmute_field, Strategy, TracingOptions, STRATEGY_KEY}, testing::assert_error, utils::Item, }; @@ -40,19 +41,26 @@ fn issue_90() { } let actual = trace_type::(TracingOptions::default()); - let expected = new_field( - "item", - false, - DataType::Struct(vec![ - new_field( - "distribution", - true, - DataType::Struct(vec![new_field("element", false, DataType::Float64)]), - ), - new_field("statistic", false, DataType::LargeUtf8), - ]), - ); - + let expected = transmute_field(json!({ + "name": "item", + "data_type": "Struct", + "children": [ + { + "name": "distribution", + "nullable": true, + "data_type": "Struct", + "children": [ + { + "name": "samples", + "data_type": "LargeList", + "children": [{"name": "element", "data_type": "F64"}], + }, + {"name": "statistic", "data_type": "LargeUtf8"}, + ], + }, + ], + })) + .unwrap(); assert_eq!(actual, expected); } diff --git a/serde_arrow/src/test_with_arrow/impls/list.rs b/serde_arrow/src/test_with_arrow/impls/list.rs index 4e6d48e5..a5b31b5c 100644 --- a/serde_arrow/src/test_with_arrow/impls/list.rs +++ b/serde_arrow/src/test_with_arrow/impls/list.rs @@ -33,7 +33,7 @@ fn large_list_nullable_u64() { .with_schema(json!([{ "name": "item", "data_type": "LargeList", - "children": [{"name": "element", "data_type": "U32", "nullable": true}], + "children": [{"name": "element", "data_type": "U64", "nullable": true}], }])) .trace_schema_from_type::>>>(TracingOptions::default()) .trace_schema_from_samples(&items, TracingOptions::default()) diff --git a/serde_arrow/src/test_with_arrow/impls/map.rs b/serde_arrow/src/test_with_arrow/impls/map.rs index bf9ee470..0a5905c3 100644 --- a/serde_arrow/src/test_with_arrow/impls/map.rs +++ b/serde_arrow/src/test_with_arrow/impls/map.rs @@ -214,7 +214,7 @@ fn map_as_map() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "LargeUtf8"}, {"name": "value", "data_type": "U32"}, @@ -247,7 +247,7 @@ fn map_as_map_empty() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "LargeUtf8"}, {"name": "value", "data_type": "U32"}, @@ -279,7 +279,7 @@ fn map_as_map_int_keys() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "I32"}, {"name": "value", "data_type": "U32"}, @@ -312,7 +312,7 @@ fn hash_maps() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "I64"}, {"name": "value", "data_type": "Bool"}, @@ -346,7 +346,7 @@ fn hash_maps_nullable() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "I64"}, {"name": "value", "data_type": "Bool"}, @@ -379,7 +379,7 @@ fn hash_maps_nullable_keys() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "I64", "nullable": true}, {"name": "value", "data_type": "Bool"}, @@ -412,7 +412,7 @@ fn hash_maps_nullable_values() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "I64"}, {"name": "value", "data_type": "Bool", "nullable": true}, @@ -445,7 +445,7 @@ fn btree_maps() { "children": [ { "name": "entries", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "key", "data_type": "I64"}, {"name": "value", "data_type": "Bool"}, diff --git a/serde_arrow/src/test_with_arrow/impls/struct.rs b/serde_arrow/src/test_with_arrow/impls/struct.rs index 81b7c134..7c4df27d 100644 --- a/serde_arrow/src/test_with_arrow/impls/struct.rs +++ b/serde_arrow/src/test_with_arrow/impls/struct.rs @@ -352,6 +352,7 @@ fn flattened_structures() { { "name": "item", "data_type": "Struct", + "strategy": "MapAsStruct", "children": [ {"name": "a", "data_type": "I64"}, {"name": "b", "data_type": "F32"}, @@ -458,7 +459,7 @@ fn struct_nullable_nested() { "children": [ { "name": "inner", - "type": "Struct", + "data_type": "Struct", "children": [ {"name": "a", "data_type": "Bool"}, {"name": "b", "data_type": "I64"}, From dde9ed1e1a070dd2469b582214bbe882869c2b6e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:00:33 +0200 Subject: [PATCH 094/178] Fix copy and paste error --- serde_arrow/src/test_with_arrow/impls/list.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/test_with_arrow/impls/list.rs b/serde_arrow/src/test_with_arrow/impls/list.rs index a5b31b5c..c19e7e29 100644 --- a/serde_arrow/src/test_with_arrow/impls/list.rs +++ b/serde_arrow/src/test_with_arrow/impls/list.rs @@ -91,7 +91,7 @@ fn nested_large_list_u32() { "data_type": "LargeList", "children": [{ "name": "element", - "data_type": "List", + "data_type": "LargeList", "children": [{ "name": "element", "data_type": "U32", From 2ac156f4f0ffd835472bcb85434ac4b06d3730e0 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:00:49 +0200 Subject: [PATCH 095/178] Fix type id handling for UnionArrays --- serde_arrow/src/arrow2_impl/array.rs | 17 +++++-- serde_arrow/src/arrow2_impl/schema.rs | 49 ++++++++++++++----- serde_arrow/src/arrow_impl/schema.rs | 17 ++++++- .../serialization/outer_sequence_builder.rs | 5 +- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index b133a818..51a57a0b 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -131,8 +131,12 @@ impl TryFrom for ArrayRef { } A::DenseUnion(arr) => { let (values, fields) = array_with_meta_to_array_and_fields(arr.fields)?; + let mut type_ids = Vec::new(); + for type_id in 0..fields.len() { + type_ids.push(type_id.try_into()?); + } Ok(Box::new(UnionArray::try_new( - T::Union(fields, None, UnionMode::Dense), + T::Union(fields, Some(type_ids), UnionMode::Dense), arr.types.into(), values, Some(arr.offsets.into()), @@ -329,11 +333,18 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { offsets: array.offsets().as_slice(), })) } else if let Some(array) = any.downcast_ref::() { - // TODO: check type ids - let T::Union(union_fields, _, UnionMode::Dense) = array.data_type() else { + let T::Union(union_fields, type_ids, UnionMode::Dense) = array.data_type() else { fail!("Invalid data type: only dense unions are supported"); }; + if let Some(type_ids) = type_ids.as_ref() { + for (idx, type_id) in type_ids.iter().enumerate() { + if usize::try_from(*type_id) != Ok(idx) { + fail!("Only consecutive type ids are supported"); + } + } + } + let types = array.types().as_slice(); let Some(offsets) = array.offsets() else { fail!("DenseUnion array without offsets are not supported"); diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index b4ffb8c6..c08e2cc0 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -63,7 +63,7 @@ impl TryFrom<&ArrowDataType> for DataType { type Error = Error; fn try_from(value: &ArrowDataType) -> Result { - use {ArrowDataType as AT, DataType as T}; + use {ArrowDataType as AT, DataType as T, Field as F, IntegerType as I}; match value { AT::Null => Ok(T::Null), AT::Boolean => Ok(T::Boolean), @@ -111,14 +111,14 @@ impl TryFrom<&ArrowDataType> for DataType { } AT::Dictionary(key, value, sorted) => { let key = match key { - IntegerType::Int8 => T::Int8, - IntegerType::Int16 => T::Int16, - IntegerType::Int32 => T::Int32, - IntegerType::Int64 => T::Int64, - IntegerType::UInt8 => T::UInt8, - IntegerType::UInt16 => T::UInt16, - IntegerType::UInt32 => T::UInt32, - IntegerType::UInt64 => T::UInt64, + I::Int8 => T::Int8, + I::Int16 => T::Int16, + I::Int32 => T::Int32, + I::Int64 => T::Int64, + I::UInt8 => T::UInt8, + I::UInt16 => T::UInt16, + I::UInt32 => T::UInt32, + I::UInt64 => T::UInt64, }; Ok(T::Dictionary( Box::new(key), @@ -126,7 +126,24 @@ impl TryFrom<&ArrowDataType> for DataType { *sorted, )) } - AT::Union(fields, type_ids, mode) => todo!(), + AT::Union(in_fields, in_type_ids, mode) => { + let in_type_ids = match in_type_ids { + Some(in_type_ids) => in_type_ids.clone(), + None => { + let mut type_ids = Vec::new(); + for id in 0..in_fields.len() { + type_ids.push(id.try_into()?); + } + type_ids + } + }; + + let mut fields = Vec::new(); + for (type_id, field) in in_type_ids.iter().zip(in_fields) { + fields.push(((*type_id).try_into()?, F::try_from(field)?)); + } + Ok(T::Union(fields, (*mode).into())) + } dt => fail!("Cannot convert data type {dt:?} to internal data type"), } } @@ -180,6 +197,7 @@ impl TryFrom<&DataType> for ArrowDataType { } T::Binary => Ok(AT::Binary), T::LargeBinary => Ok(AT::LargeBinary), + T::FixedSizeBinary(n) => Ok(AT::FixedSizeBinary((*n).try_into()?)), T::Utf8 => Ok(AT::Utf8), T::LargeUtf8 => Ok(AT::LargeUtf8), T::Dictionary(key, value, sorted) => match key.as_ref() { @@ -242,7 +260,16 @@ impl TryFrom<&DataType> for ArrowDataType { } Ok(AT::Struct(fields)) } - _ => todo!(), + T::Union(in_fields, mode) => { + let mut fields = Vec::new(); + let mut type_ids = Vec::new(); + + for (type_id, field) in in_fields { + fields.push(AF::try_from(field)?); + type_ids.push((*type_id).try_into()?); + } + Ok(AT::Union(fields, Some(type_ids), (*mode).into())) + } } } } diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index b3476a70..338e51dc 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -169,7 +169,14 @@ impl TryFrom<&ArrowDataType> for DataType { T::try_from(value.as_ref())?.into(), false, )), - _ => fail!("Only primitive data types can be converted to T"), + AT::Union(in_fields, mode) => { + let mut fields = Vec::new(); + for (type_id, field) in in_fields.iter() { + fields.push((type_id, F::try_from(field.as_ref())?)); + } + Ok(T::Union(fields, (*mode).into())) + } + data_type => fail!("Unsupported arrow data type {data_type}"), } } } @@ -240,7 +247,13 @@ impl TryFrom<&DataType> for ArrowDataType { AT::try_from(key.as_ref())?.into(), AT::try_from(value.as_ref())?.into(), )), - _ => fail!("Only primitive data types can be converted to T"), + T::Union(in_fields, mode) => { + let mut fields = Vec::new(); + for (type_id, field) in in_fields { + fields.push((*type_id, Arc::new(AF::try_from(field)?))); + } + Ok(AT::Union(fields.into_iter().collect(), (*mode).into())) + } } } } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 1022db58..a76a6851 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -136,7 +136,10 @@ impl OuterSequenceBuilder { } T::Union(union_fields, _) => { let mut fields = Vec::new(); - for (_, field) in union_fields { + for (idx, (type_id, field)) in union_fields.iter().enumerate() { + if usize::try_from(*type_id) != Ok(idx) { + fail!("non consecutive type ids are not supported"); + } fields.push((build_builder(field)?, meta_from_field(field.clone())?)); } From f9e90aba6d5dc465dc53f37ad7aecdaecca1a66c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:06:11 +0200 Subject: [PATCH 096/178] Ensure nullable fields are nullable, even without explicit nullable --- .../src/internal/schema/serde/deserialize.rs | 7 +++++- serde_arrow/src/internal/schema/serde/test.rs | 25 +++++++++++++++++++ .../src/test_with_arrow/impls/union.rs | 10 ++++---- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/serde_arrow/src/internal/schema/serde/deserialize.rs b/serde_arrow/src/internal/schema/serde/deserialize.rs index dd70089e..8fc2e3fe 100644 --- a/serde_arrow/src/internal/schema/serde/deserialize.rs +++ b/serde_arrow/src/internal/schema/serde/deserialize.rs @@ -94,9 +94,14 @@ impl CustomField { let data_type = self.data_type.into_data_type(children)?; let metadata = merge_strategy_with_metadata(self.metadata, self.strategy)?; + let nullable = match &data_type { + DataType::Null => true, + _ => self.nullable, + }; + let field = Field { name: self.name, - nullable: self.nullable, + nullable, data_type, metadata, }; diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs index e207a011..234b11fe 100644 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ b/serde_arrow/src/internal/schema/serde/test.rs @@ -107,3 +107,28 @@ fn list_field_complex() -> PanicOnError<()> { Ok(()) } + +#[test] +fn null_fields_are_nullable_implicitly() -> PanicOnError<()> { + let expected = SerdeArrowSchema { + fields: vec![Field { + name: String::from("item"), + data_type: DataType::Null, + metadata: hash_map!(), + nullable: true, + }], + }; + let schema = json!({ + "fields": [ + { + "name": "item", + "data_type": "Null", + } + ], + }); + + let actual = SerdeArrowSchema::from_value(&schema)?; + assert_eq!(actual, expected); + + Ok(()) +} diff --git a/serde_arrow/src/test_with_arrow/impls/union.rs b/serde_arrow/src/test_with_arrow/impls/union.rs index 75f84690..6b63c0cd 100644 --- a/serde_arrow/src/test_with_arrow/impls/union.rs +++ b/serde_arrow/src/test_with_arrow/impls/union.rs @@ -60,9 +60,9 @@ fn fieldless_union_out_of_order() { "name": "item", "data_type": "Union", "children": [ - {"name": "A", "data_type": "Null"}, - {"name": "B", "data_type": "Null"}, - {"name": "C", "data_type": "Null"}, + {"name": "A", "data_type": "Null", "nullable": true}, + {"name": "B", "data_type": "Null", "nullable": true}, + {"name": "C", "data_type": "Null", "nullable": true}, ], }, ])) @@ -346,8 +346,8 @@ fn enums_union() { "name": "item", "data_type": "Union", "children": [ - {"name": "A", "data_type": "Null"}, - {"name": "B", "data_type": "Null"}, + {"name": "A", "data_type": "Null", "nullable": true}, + {"name": "B", "data_type": "Null", "nullable": true}, ], }, ])) From 894ab0c2b1a61b2e4ed259c65763b4559d6855bf Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:17:32 +0200 Subject: [PATCH 097/178] Fix more copy and paste errors --- serde_arrow/src/test_with_arrow/impls/examples.rs | 1 + serde_arrow/src/test_with_arrow/impls/tuple.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/serde_arrow/src/test_with_arrow/impls/examples.rs b/serde_arrow/src/test_with_arrow/impls/examples.rs index c2fd5900..f478cefa 100644 --- a/serde_arrow/src/test_with_arrow/impls/examples.rs +++ b/serde_arrow/src/test_with_arrow/impls/examples.rs @@ -369,6 +369,7 @@ fn issue_57() { "data_type": "Union", "children": [ {"name": "", "data_type": "Null", "strategy": "UnknownVariant"}, + {"name": "Deduced", "data_type": "Null"}, ], }, {"name": "file_index", "data_type": "U64"}, diff --git a/serde_arrow/src/test_with_arrow/impls/tuple.rs b/serde_arrow/src/test_with_arrow/impls/tuple.rs index adabae7f..30bc79ae 100644 --- a/serde_arrow/src/test_with_arrow/impls/tuple.rs +++ b/serde_arrow/src/test_with_arrow/impls/tuple.rs @@ -189,7 +189,7 @@ fn tuple_nullable_nested() { {"name": "1", "data_type": "I64"}, ], }, - {"name": "1", "data_type": "Bool"}, + {"name": "1", "data_type": "I64"}, ], } ])) From 6602a14f646ed045481cb465517a7bd2acb8bdb6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:23:36 +0200 Subject: [PATCH 098/178] Fix tensor field extension types --- serde_arrow/src/internal/schema/extensions.rs | 18 ++++------ .../src/internal/schema/from_samples/mod.rs | 5 +-- serde_arrow/src/internal/schema/mod.rs | 1 + .../src/internal/schema/serde/serialize.rs | 34 ++++++++++--------- 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/serde_arrow/src/internal/schema/extensions.rs b/serde_arrow/src/internal/schema/extensions.rs index 6aa1b6ee..db558960 100644 --- a/serde_arrow/src/internal/schema/extensions.rs +++ b/serde_arrow/src/internal/schema/extensions.rs @@ -5,7 +5,7 @@ use serde::Serialize; use crate::internal::{ arrow::{DataType, Field}, error::{fail, Error, Result}, - schema::transmute_field, + schema::{transmute_field, PrettyField}, }; /// Easily construct a field for tensors with fixed shape @@ -145,11 +145,9 @@ impl TryFrom<&FixedShapeTensorField> for Field { impl serde::ser::Serialize for FixedShapeTensorField { fn serialize(&self, serializer: S) -> Result { - // use serde::ser::Error; - // Field::try_from(self) - // .map_err(S::Error::custom)? - // .serialize(serializer) - todo!() + use serde::ser::Error; + let field = Field::try_from(self).map_err(S::Error::custom)?; + PrettyField(&field).serialize(serializer) } } @@ -312,11 +310,9 @@ impl TryFrom<&VariableShapeTensorField> for Field { impl serde::ser::Serialize for VariableShapeTensorField { fn serialize(&self, serializer: S) -> Result { - // use serde::ser::Error; - // GenericField::try_from(self) - // .map_err(S::Error::custom)? - // .serialize(serializer) - todo!() + use serde::ser::Error; + let field = Field::try_from(self).map_err(S::Error::custom)?; + PrettyField(&field).serialize(serializer) } } diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 0ac7f286..52ec026c 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -590,10 +590,7 @@ mod test { use serde::Serialize; use serde_json::{json, Value}; - use crate::{ - internal::schema::{transmute_field, TracingOptions}, - schema::SchemaLike, - }; + use crate::internal::schema::{transmute_field, TracingOptions}; use super::*; diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index e48e152d..cb7ebcec 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -17,6 +17,7 @@ use crate::internal::{ use ::serde::{Deserialize, Serialize}; +pub use serde::serialize::PrettyField; pub use strategy::{get_strategy_from_metadata, Strategy, STRATEGY_KEY}; use tracer::Tracer; pub use tracing_options::{Overwrites, TracingMode, TracingOptions}; diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index 3ef9a643..892531f0 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -12,27 +12,29 @@ use crate::{ impl serde::Serialize for SerdeArrowSchema { fn serialize(&self, serializer: S) -> Result { let mut s = serializer.serialize_struct("SerdeArrowSchema", 1)?; - s.serialize_field("fields", &SerializableFields(&self.fields))?; + s.serialize_field("fields", &PrettyFields(&self.fields))?; s.end() } } -pub struct SerializableFields<'a>(pub &'a [Field]); +/// A wrapper around fields to serialize into a more compact format +pub struct PrettyFields<'a>(pub &'a [Field]); -impl<'a> serde::Serialize for SerializableFields<'a> { +impl<'a> serde::Serialize for PrettyFields<'a> { fn serialize(&self, serializer: S) -> Result { let mut s = serializer.serialize_seq(Some(self.0.len()))?; for field in self.0 { - s.serialize_element(&SerializableField(field))?; + s.serialize_element(&PrettyField(field))?; } s.end() } } -pub struct SerializableField<'a>(pub &'a Field); +/// A wrapper around a single field to serialize into a more compact format +pub struct PrettyField<'a>(pub &'a Field); -impl<'a> serde::Serialize for SerializableField<'a> { +impl<'a> serde::Serialize for PrettyField<'a> { fn serialize(&self, serializer: S) -> Result { let non_strategy_metadata = self .0 @@ -57,7 +59,7 @@ impl<'a> serde::Serialize for SerializableField<'a> { let mut s = serializer.serialize_struct("Field", num_fields)?; s.serialize_field("name", &self.0.name)?; - s.serialize_field("data_type", &SerializableDataType(&self.0.data_type))?; + s.serialize_field("data_type", &PrettyFieldDataType(&self.0.data_type))?; if self.0.nullable { s.serialize_field("nullable", &self.0.nullable)?; @@ -69,15 +71,15 @@ impl<'a> serde::Serialize for SerializableField<'a> { s.serialize_field("strategy", strategy)?; } if is_data_type_with_children(&self.0.data_type) { - s.serialize_field("children", &SerializableDataTypeChildren(&self.0.data_type))?; + s.serialize_field("children", &PrettyFieldChildren(&self.0.data_type))?; } s.end() } } -pub struct SerializableDataType<'a>(pub &'a DataType); +struct PrettyFieldDataType<'a>(pub &'a DataType); -impl<'a> serde::Serialize for SerializableDataType<'a> { +impl<'a> serde::Serialize for PrettyFieldDataType<'a> { fn serialize(&self, serializer: S) -> Result { use DataType as T; match self.0 { @@ -119,9 +121,9 @@ impl<'a> serde::Serialize for SerializableDataType<'a> { } } -pub struct SerializableDataTypeChildren<'a>(pub &'a DataType); +struct PrettyFieldChildren<'a>(pub &'a DataType); -impl<'a> serde::Serialize for SerializableDataTypeChildren<'a> { +impl<'a> serde::Serialize for PrettyFieldChildren<'a> { fn serialize(&self, serializer: S) -> Result { use DataType as T; @@ -131,20 +133,20 @@ impl<'a> serde::Serialize for SerializableDataTypeChildren<'a> { | T::LargeList(entry) | T::List(entry) => { let mut s = serializer.serialize_seq(Some(1))?; - s.serialize_element(&SerializableField(entry.as_ref()))?; + s.serialize_element(&PrettyField(entry.as_ref()))?; s.end() } T::Struct(fields) => { let mut s = serializer.serialize_seq(Some(fields.len()))?; for field in fields { - s.serialize_element(&SerializableField(field))?; + s.serialize_element(&PrettyField(field))?; } s.end() } T::Union(fields, _) => { let mut s = serializer.serialize_seq(Some(fields.len()))?; for (_, field) in fields { - s.serialize_element(&SerializableField(field))?; + s.serialize_element(&PrettyField(field))?; } s.end() } @@ -166,7 +168,7 @@ impl<'a> serde::Serialize for DictionaryField<'a> { fn serialize(&self, serializer: S) -> Result { let mut s = serializer.serialize_struct("Field", 2)?; s.serialize_field("name", self.0)?; - s.serialize_field("data_type", &SerializableDataType(self.1))?; + s.serialize_field("data_type", &PrettyFieldDataType(self.1))?; s.end() } } From 8c6991bbd6ebf65ff18a0ab11b454140cc50f04e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:45:39 +0200 Subject: [PATCH 099/178] Fix schema tests (+ Bug in serialize) --- .../src/internal/schema/serde/serialize.rs | 2 +- serde_arrow/src/internal/schema/test.rs | 240 +++++++++++------- 2 files changed, 144 insertions(+), 98 deletions(-) diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index 892531f0..8256a246 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -84,7 +84,7 @@ impl<'a> serde::Serialize for PrettyFieldDataType<'a> { use DataType as T; match self.0 { T::Null => "Null".serialize(serializer), - T::Boolean => "Boolean".serialize(serializer), + T::Boolean => "Bool".serialize(serializer), T::Int8 => "I8".serialize(serializer), T::Int16 => "I16".serialize(serializer), T::Int32 => "I32".serialize(serializer), diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index 58051204..bef4bd2b 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use serde_json::json; +use serde_json::{json, Value}; use crate::internal::{ arrow::{DataType, Field, TimeUnit}, @@ -8,6 +8,35 @@ use crate::internal::{ testing::{assert_error, hash_map}, }; +fn type_from_str(s: &str) -> DataType { + let schema = SerdeArrowSchema::from_value(&json!([{"name": "item", "data_type": s}])).unwrap(); + schema.fields[0].data_type.clone() +} + +fn pretty_str_from_type(data_type: &DataType) -> String { + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("item"), + data_type: data_type.clone(), + nullable: false, + metadata: Default::default(), + }], + }; + let json = serde_json::to_value(schema).unwrap(); + + let Value::String(data_type) = json + .get("fields") + .unwrap() + .get(0) + .unwrap() + .get("data_type") + .unwrap() + else { + panic!("data type must be string"); + }; + data_type.clone() +} + #[test] fn example() { let schema = SerdeArrowSchema { @@ -148,102 +177,6 @@ fn date64_with_strategy() { assert_eq!(from_json, schema); } -// TODO: fix these tests (or move them somewhere else) -// #[test] -// fn timestamp_second_serialization() { -// let dt = super::GenericDataType::Timestamp(TimeUnit::Second, None); - -// let s = serde_json::to_string(&dt).unwrap(); -// assert_eq!(s, r#""Timestamp(Second, None)""#); - -// let rt = serde_json::from_str(&s).unwrap(); -// assert_eq!(dt, rt); -// } -// -// #[test] -// fn timestamp_second_utc_serialization() { -// let dt = super::GenericDataType::Timestamp(TimeUnit::Second, Some(String::from("Utc"))); - -// let s = serde_json::to_string(&dt).unwrap(); -// assert_eq!(s, r#""Timestamp(Second, Some(\"Utc\"))""#); - -// let rt = serde_json::from_str(&s).unwrap(); -// assert_eq!(dt, rt); -// } - -// #[test] -// fn test_date32() { -// use GenericDataType as DT; - -// assert_eq!(DT::Date32.to_string(), "Date32"); -// assert_eq!("Date32".parse::
().unwrap(), DT::Date32); -// } - -// #[test] -// fn time64_data_type_format() { -// use {GenericDataType as DT, TimeUnit as TU}; - -// for (dt, s) in [ -// (DT::Time64(TU::Microsecond), "Time64(Microsecond)"), -// (DT::Time64(TU::Nanosecond), "Time64(Nanosecond)"), -// ] { -// assert_eq!(dt.to_string(), s); -// assert_eq!(s.parse::
().unwrap(), dt); -// } -// } - -// #[test] -// fn test_long_form_types() { -// use super::GenericDataType as DT; -// use std::str::FromStr; - -// assert_eq!(DT::from_str("Boolean").unwrap(), DT::Bool); -// assert_eq!(DT::from_str("Int8").unwrap(), DT::I8); -// assert_eq!(DT::from_str("Int16").unwrap(), DT::I16); -// assert_eq!(DT::from_str("Int32").unwrap(), DT::I32); -// assert_eq!(DT::from_str("Int64").unwrap(), DT::I64); -// assert_eq!(DT::from_str("UInt8").unwrap(), DT::U8); -// assert_eq!(DT::from_str("UInt16").unwrap(), DT::U16); -// assert_eq!(DT::from_str("UInt32").unwrap(), DT::U32); -// assert_eq!(DT::from_str("UInt64").unwrap(), DT::U64); -// assert_eq!(DT::from_str("Float16").unwrap(), DT::F16); -// assert_eq!(DT::from_str("Float32").unwrap(), DT::F32); -// assert_eq!(DT::from_str("Float64").unwrap(), DT::F64); -// assert_eq!( -// DT::from_str("Decimal128(8,-2)").unwrap(), -// DT::Decimal128(8, -2) -// ); -// assert_eq!( -// DT::from_str("Decimal128( 8 , -2 )").unwrap(), -// DT::Decimal128(8, -2) -// ); -// } - -// macro_rules! test_data_type { -// ($($variant:ident,)*) => { -// mod test_data_type { -// $( -// #[allow(non_snake_case)] -// #[test] -// fn $variant() { -// let ty = super::super::GenericDataType::$variant; - -// let s = serde_json::to_string(&ty).unwrap(); -// assert_eq!(s, concat!("\"", stringify!($variant), "\"")); - -// let rt = serde_json::from_str(&s).unwrap(); -// assert_eq!(ty, rt); -// } -// )* -// } -// }; -// } - -// test_data_type!( -// Null, Bool, I8, I16, I32, I64, U8, U16, U32, U64, F16, F32, F64, Utf8, LargeUtf8, List, -// LargeList, Struct, Dictionary, Union, Map, Date64, -// ); - #[test] fn test_metadata_strategy_from_explicit() { let schema = SerdeArrowSchema::from_value(&json!([ @@ -337,3 +270,116 @@ fn test_invalid_metadata() { assert_error(&res, "Duplicate strategy"); } + +#[test] +fn test_long_form_types() { + assert_eq!(type_from_str("Boolean"), DataType::Boolean); + assert_eq!(type_from_str("Int8"), DataType::Int8); + assert_eq!(type_from_str("Int16"), DataType::Int16); + assert_eq!(type_from_str("Int32"), DataType::Int32); + assert_eq!(type_from_str("Int64"), DataType::Int64); + assert_eq!(type_from_str("UInt8"), DataType::UInt8); + assert_eq!(type_from_str("UInt16"), DataType::UInt16); + assert_eq!(type_from_str("UInt32"), DataType::UInt32); + assert_eq!(type_from_str("UInt64"), DataType::UInt64); + assert_eq!(type_from_str("Float16"), DataType::Float16); + assert_eq!(type_from_str("Float32"), DataType::Float32); + assert_eq!(type_from_str("Float64"), DataType::Float64); + assert_eq!( + type_from_str("Decimal128(8,-2)"), + DataType::Decimal128(8, -2) + ); + assert_eq!( + type_from_str("Decimal128( 8 , -2 )"), + DataType::Decimal128(8, -2) + ); +} + +macro_rules! test_short_form_type { + ($name:ident, $data_type:expr, $s:expr) => { + #[test] + fn $name() { + let data_type: DataType = $data_type; + let s: &str = $s; + assert_eq!(pretty_str_from_type(&data_type), s); + assert_eq!(type_from_str(s), data_type); + } + }; +} + +test_short_form_type!(test_null, DataType::Null, "Null"); +test_short_form_type!(test_boolean, DataType::Boolean, "Bool"); +test_short_form_type!(test_int8, DataType::Int8, "I8"); +test_short_form_type!(test_int16, DataType::Int16, "I16"); +test_short_form_type!(test_int32, DataType::Int32, "I32"); +test_short_form_type!(test_int64, DataType::Int64, "I64"); +test_short_form_type!(test_uint8, DataType::UInt8, "U8"); +test_short_form_type!(test_uint16, DataType::UInt16, "U16"); +test_short_form_type!(test_uint32, DataType::UInt32, "U32"); +test_short_form_type!(test_uint64, DataType::UInt64, "U64"); +test_short_form_type!(test_float16, DataType::Float16, "F16"); +test_short_form_type!(test_float32, DataType::Float32, "F32"); +test_short_form_type!(test_float64, DataType::Float64, "F64"); +test_short_form_type!(test_date_32, DataType::Date32, "Date32"); +test_short_form_type!(test_date_64, DataType::Date64, "Date64"); + +test_short_form_type!(test_utf8, DataType::Utf8, "Utf8"); +test_short_form_type!(test_large_utf8, DataType::LargeUtf8, "LargeUtf8"); + +test_short_form_type!(test_binary, DataType::Binary, "Binary"); +test_short_form_type!(test_large_binary, DataType::LargeBinary, "LargeBinary"); + +test_short_form_type!( + test_fixed_size_binary, + DataType::FixedSizeBinary(32), + "FixedSizeBinary(32)" +); +test_short_form_type!( + test_decimal_128, + DataType::Decimal128(2, -2), + "Decimal128(2, -2)" +); + +test_short_form_type!( + test_timestamp_no_tz, + DataType::Timestamp(TimeUnit::Second, None), + "Timestamp(Second, None)" +); +test_short_form_type!( + test_timestamp_utc, + DataType::Timestamp(TimeUnit::Millisecond, Some(String::from("Utc"))), + "Timestamp(Millisecond, Some(\"Utc\"))" +); + +test_short_form_type!( + test_time32_second, + DataType::Time32(TimeUnit::Second), + "Time32(Second)" +); +test_short_form_type!( + test_time32_millisecond, + DataType::Time32(TimeUnit::Millisecond), + "Time32(Millisecond)" +); + +test_short_form_type!( + test_time64_microsecond, + DataType::Time64(TimeUnit::Microsecond), + "Time64(Microsecond)" +); +test_short_form_type!( + test_time64_nanosecond, + DataType::Time64(TimeUnit::Nanosecond), + "Time64(Nanosecond)" +); + +test_short_form_type!( + test_duration_second, + DataType::Duration(TimeUnit::Second), + "Duration(Second)" +); +test_short_form_type!( + test_duration_nanosecond, + DataType::Duration(TimeUnit::Nanosecond), + "Duration(Nanosecond)" +); From 86173720ac2c8108380b0b1abcee6c32fe43783c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 11:46:41 +0200 Subject: [PATCH 100/178] Delete unused code --- .../src/internal/schema/deserialization.rs | 319 ------------------ 1 file changed, 319 deletions(-) delete mode 100644 serde_arrow/src/internal/schema/deserialization.rs diff --git a/serde_arrow/src/internal/schema/deserialization.rs b/serde_arrow/src/internal/schema/deserialization.rs deleted file mode 100644 index 476f58ae..00000000 --- a/serde_arrow/src/internal/schema/deserialization.rs +++ /dev/null @@ -1,319 +0,0 @@ -//! Deserialization of SchemaLike objects with explicit support to deserialize -//! from arrow-rs types - -// TODO: delete me - -use std::{collections::HashMap, str::FromStr}; - -use serde::{de::Visitor, Deserialize}; - -use crate::internal::{ - arrow::TimeUnit, - error::{fail, Error, Result}, - schema::{ - merge_strategy_with_metadata, GenericDataType, SerdeArrowSchema, Strategy, - }, -}; - -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct ArrowField { - name: String, - data_type: ArrowDataType, - nullable: bool, - metadata: HashMap, -} - -impl ArrowField { - pub fn new(name: &str, data_type: ArrowDataType, nullable: bool) -> Self { - Self { - name: name.to_string(), - data_type, - nullable, - metadata: HashMap::new(), - } - } -} - -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub enum ArrowTimeUnit { - Second, - Millisecond, - Microsecond, - Nanosecond, -} - -impl From for TimeUnit { - fn from(value: ArrowTimeUnit) -> Self { - match value { - ArrowTimeUnit::Second => Self::Second, - ArrowTimeUnit::Millisecond => Self::Millisecond, - ArrowTimeUnit::Microsecond => Self::Microsecond, - ArrowTimeUnit::Nanosecond => Self::Nanosecond, - } - } -} - -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub enum ArrowUnionMode { - Sparse, - Dense, -} - -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub enum ArrowDataType { - Null, - Boolean, - Int8, - Int16, - Int32, - Int64, - UInt8, - UInt16, - UInt32, - UInt64, - Float16, - Float32, - Float64, - Utf8, - LargeUtf8, - Date32, - Date64, - Time64(ArrowTimeUnit), - Struct(Vec), - List(Box), - LargeList(Box), - FixedSizeList(Box, i32), - Map(Box), - // TODO: - // Union, - Dictionary(Box, Box), - Decimal128(u8, i8), - Timestamp(ArrowTimeUnit, Option), - Union(Vec<(i8, ArrowField)>, ArrowUnionMode), -} - -impl ArrowDataType { - pub fn into_generic(self) -> Result<(GenericDataType, Vec)> { - use GenericDataType as T; - - let (data_type, children) = match self { - Self::Null => (T::Null, vec![]), - Self::Boolean => (T::Bool, vec![]), - Self::Int8 => (T::I8, vec![]), - Self::Int16 => (T::I16, vec![]), - Self::Int32 => (T::I32, vec![]), - Self::Int64 => (T::I64, vec![]), - Self::UInt8 => (T::U8, vec![]), - Self::UInt16 => (T::U16, vec![]), - Self::UInt32 => (T::U32, vec![]), - Self::UInt64 => (T::U64, vec![]), - Self::Float16 => (T::F16, vec![]), - Self::Float32 => (T::F32, vec![]), - Self::Float64 => (T::F64, vec![]), - Self::Utf8 => (T::Utf8, vec![]), - Self::LargeUtf8 => (T::LargeUtf8, vec![]), - Self::Date32 => (T::Date32, vec![]), - Self::Date64 => (T::Date64, vec![]), - Self::Time64(unit) => (T::Time64(unit.into()), vec![]), - Self::Decimal128(precision, scale) => (T::Decimal128(precision, scale), vec![]), - Self::Struct(fields) => (T::Struct, fields), - Self::List(field) => (T::List, vec![*field]), - Self::LargeList(field) => (T::LargeList, vec![*field]), - Self::FixedSizeList(field, n) => (T::FixedSizeList(n), vec![*field]), - Self::Map(field) => (T::Map, vec![*field]), - Self::Dictionary(key, value) => ( - T::Map, - vec![ - ArrowField::new("", *key, false), - ArrowField::new("", *value, false), - ], - ), - Self::Timestamp(unit, timezone) => (T::Timestamp(unit.into(), timezone), vec![]), - Self::Union(variants, mode) => { - let mut children = Vec::new(); - - if !matches!(mode, ArrowUnionMode::Dense) { - fail!("Only dense unions are supported at the moment"); - } - - for (pos, (idx, variant)) in variants.into_iter().enumerate() { - if pos as i8 != idx { - fail!("Union types with explicit field indices are not supported"); - } - children.push(variant); - } - - (T::Union, children) - } - }; - let children = children - .into_iter() - .map(GenericField::try_from) - .collect::>>()?; - Ok((data_type, children)) - } -} - -impl TryFrom for GenericField { - type Error = Error; - - fn try_from(value: ArrowField) -> Result { - let (data_type, children) = value.data_type.into_generic()?; - Ok(GenericField { - name: value.name, - nullable: value.nullable, - metadata: value.metadata, - data_type, - children, - }) - } -} - -#[derive(Debug)] -enum GenericOrArrowDataType { - Generic(GenericDataType), - Arrow(ArrowDataType), -} - -impl<'de> Deserialize<'de> for GenericOrArrowDataType { - fn deserialize>(deserializer: D) -> Result { - struct VisitorImpl; - - impl<'de> Visitor<'de> for VisitorImpl { - type Value = GenericOrArrowDataType; - - fn visit_newtype_struct>( - self, - deserializer: D, - ) -> Result { - GenericOrArrowDataType::deserialize(deserializer) - } - - fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "string or DataType variant") - } - - fn visit_str(self, v: &str) -> Result { - match GenericDataType::from_str(v) { - Ok(res) => Ok(GenericOrArrowDataType::Generic(res)), - Err(err) => Err(E::custom(err.to_string())), - } - } - - fn visit_enum>( - self, - data: A, - ) -> Result { - let field = ArrowDataType::deserialize(EnumDeserializer(data))?; - Ok(GenericOrArrowDataType::Arrow(field)) - } - } - - deserializer.deserialize_any(VisitorImpl) - } -} - -struct EnumDeserializer(A); - -impl<'de, A: serde::de::EnumAccess<'de>> serde::de::Deserializer<'de> for EnumDeserializer { - type Error = A::Error; - - fn deserialize_any>(self, visitor: V) -> Result { - visitor.visit_enum(self.0) - } - - serde::forward_to_deserialize_any! { - bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option unit unit_struct newtype_struct seq tuple - tuple_struct map struct enum identifier ignored_any - } -} - -impl<'de> Deserialize<'de> for GenericField { - fn deserialize>(deserializer: D) -> Result { - use serde::de::Error; - - struct VisitorImpl; - - impl<'de> Visitor<'de> for VisitorImpl { - type Value = GenericField; - - fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "a struct with keys 'name', 'data_type', ...") - } - - fn visit_map>( - self, - mut map: A, - ) -> Result { - let mut name = None; - let mut nullable = None; - let mut strategy = None; - let mut metadata = None; - let mut data_type = None; - let mut children = None; - - while let Some(key) = map.next_key::()? { - match key.as_str() { - "name" => { - name = Some(map.next_value::()?); - } - "nullable" => { - nullable = Some(map.next_value::()?); - } - "metadata" => { - metadata = Some(map.next_value::>()?); - } - "strategy" => { - strategy = Some(map.next_value::>()?); - } - "data_type" => { - data_type = Some(map.next_value::()?); - } - "children" => { - children = Some(map.next_value::>()?); - } - _ => { - map.next_value::()?; - } - } - } - - let Some(data_type) = data_type else { - return Err(A::Error::custom("missing field `data_type`")); - }; - let (data_type, children) = match data_type { - GenericOrArrowDataType::Generic(data_type) => { - (data_type, children.unwrap_or_default()) - } - GenericOrArrowDataType::Arrow(data_type) => { - if children.is_some() { - return Err(A::Error::custom( - "cannot mix `children` with arrow-rs-style data types", - )); - } - data_type - .into_generic() - .map_err(|err| A::Error::custom(err.to_string()))? - } - }; - - let metadata = - merge_strategy_with_metadata(metadata.unwrap_or_default(), strategy.flatten()) - .map_err(A::Error::custom)?; - - Ok(GenericField { - name: name.ok_or_else(|| A::Error::custom("missing field `name`"))?, - data_type, - children, - nullable: nullable.unwrap_or_default(), - metadata, - }) - } - } - - let res = deserializer.deserialize_map(VisitorImpl)?; - res.validate().map_err(D::Error::custom)?; - Ok(res) - } -} From 5659f7466eb9193c7f84ceaa18f2b2cf74c4e32f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 20:04:18 +0200 Subject: [PATCH 101/178] Move all tests into a single file --- serde_arrow/src/internal/schema/serde/mod.rs | 3 - serde_arrow/src/internal/schema/serde/test.rs | 134 ------------- serde_arrow/src/internal/schema/test.rs | 181 ++++++++++++++++++ 3 files changed, 181 insertions(+), 137 deletions(-) delete mode 100644 serde_arrow/src/internal/schema/serde/test.rs diff --git a/serde_arrow/src/internal/schema/serde/mod.rs b/serde_arrow/src/internal/schema/serde/mod.rs index 494ec4cf..4d5f0642 100644 --- a/serde_arrow/src/internal/schema/serde/mod.rs +++ b/serde_arrow/src/internal/schema/serde/mod.rs @@ -2,6 +2,3 @@ //! pub mod deserialize; pub mod serialize; - -#[cfg(test)] -mod test; diff --git a/serde_arrow/src/internal/schema/serde/test.rs b/serde_arrow/src/internal/schema/serde/test.rs deleted file mode 100644 index 234b11fe..00000000 --- a/serde_arrow/src/internal/schema/serde/test.rs +++ /dev/null @@ -1,134 +0,0 @@ -use serde_json::json; - -use crate::{ - internal::{ - arrow::{DataType, Field}, - error::PanicOnError, - schema::STRATEGY_KEY, - testing::hash_map, - }, - schema::{SchemaLike, SerdeArrowSchema}, -}; - -#[test] -fn i16_field_simple() -> PanicOnError<()> { - let schema = SerdeArrowSchema { - fields: vec![Field { - name: String::from("my_field_name"), - data_type: DataType::Int16, - metadata: hash_map!(), - nullable: false, - }], - }; - let expected = json!({ - "fields": [ - { - "name": "my_field_name", - "data_type": "I16", - } - ], - }); - - let actual = serde_json::to_value(&schema)?; - assert_eq!(actual, expected); - - let roundtripped = SerdeArrowSchema::from_value(&actual)?; - assert_eq!(roundtripped, schema); - - Ok(()) -} - -#[test] -fn date64_field_complex() -> PanicOnError<()> { - let schema = SerdeArrowSchema { - fields: vec![Field { - name: String::from("my_field_name"), - data_type: DataType::Date64, - metadata: hash_map!( - "foo" => "bar", - STRATEGY_KEY => "NaiveStrAsDate64", - ), - nullable: true, - }], - }; - let expected = json!({ - "fields": [{ - "name": "my_field_name", - "data_type": "Date64", - "metadata": { - "foo": "bar", - }, - "strategy": "NaiveStrAsDate64", - "nullable": true, - }], - }); - - let actual = serde_json::to_value(&schema)?; - assert_eq!(actual, expected); - - let roundtripped = SerdeArrowSchema::from_value(&actual)?; - assert_eq!(roundtripped, schema); - - Ok(()) -} - -#[test] -fn list_field_complex() -> PanicOnError<()> { - let schema = SerdeArrowSchema { - fields: vec![Field { - name: String::from("my_field_name"), - data_type: DataType::List(Box::new(Field { - name: String::from("element"), - data_type: DataType::Int64, - metadata: hash_map!(), - nullable: false, - })), - metadata: hash_map!("foo" => "bar"), - nullable: true, - }], - }; - let expected = json!({ - "fields": [{ - "name": "my_field_name", - "data_type": "List", - "metadata": {"foo": "bar"}, - "nullable": true, - "children": [ - {"name": "element", "data_type": "I64"}, - ], - }], - }); - - let actual = serde_json::to_value(&schema)?; - assert_eq!(actual, expected); - - let roundtripped = SerdeArrowSchema::from_value(&actual)?; - assert_eq!(roundtripped, schema); - - Ok(()) -} - -#[test] -fn null_fields_are_nullable_implicitly() -> PanicOnError<()> { - let expected = SerdeArrowSchema { - fields: vec![Field { - name: String::from("item"), - data_type: DataType::Null, - metadata: hash_map!(), - nullable: true, - }], - }; - let schema = json!({ - "fields": [ - { - "name": "item", - "data_type": "Null", - } - ], - }); - - let actual = SerdeArrowSchema::from_value(&schema)?; - assert_eq!(actual, expected); - - Ok(()) -} diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index bef4bd2b..2d651543 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -4,6 +4,7 @@ use serde_json::{json, Value}; use crate::internal::{ arrow::{DataType, Field, TimeUnit}, + error::PanicOnError, schema::{SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, testing::{assert_error, hash_map}, }; @@ -37,6 +38,186 @@ fn pretty_str_from_type(data_type: &DataType) -> String { data_type.clone() } +#[test] +fn i16_field_simple() -> PanicOnError<()> { + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("my_field_name"), + data_type: DataType::Int16, + metadata: hash_map!(), + nullable: false, + }], + }; + let expected = json!({ + "fields": [ + { + "name": "my_field_name", + "data_type": "I16", + } + ], + }); + + let actual = serde_json::to_value(&schema)?; + assert_eq!(actual, expected); + + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); + + Ok(()) +} + +#[test] +fn date64_field_complex() -> PanicOnError<()> { + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("my_field_name"), + data_type: DataType::Date64, + metadata: hash_map!( + "foo" => "bar", + STRATEGY_KEY => "NaiveStrAsDate64", + ), + nullable: true, + }], + }; + let expected = json!({ + "fields": [{ + "name": "my_field_name", + "data_type": "Date64", + "metadata": { + "foo": "bar", + }, + "strategy": "NaiveStrAsDate64", + "nullable": true, + }], + }); + + let actual = serde_json::to_value(&schema)?; + assert_eq!(actual, expected); + + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); + + Ok(()) +} + +#[test] +fn list_field_complex() -> PanicOnError<()> { + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("my_field_name"), + data_type: DataType::List(Box::new(Field { + name: String::from("element"), + data_type: DataType::Int64, + metadata: hash_map!(), + nullable: false, + })), + metadata: hash_map!("foo" => "bar"), + nullable: true, + }], + }; + let expected = json!({ + "fields": [{ + "name": "my_field_name", + "data_type": "List", + "metadata": {"foo": "bar"}, + "nullable": true, + "children": [ + {"name": "element", "data_type": "I64"}, + ], + }], + }); + + let actual = serde_json::to_value(&schema)?; + assert_eq!(actual, expected); + + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); + + Ok(()) +} + +#[test] +fn map_field_complex() -> PanicOnError<()> { + let schema = SerdeArrowSchema { + fields: vec![Field { + name: String::from("my_field_name"), + metadata: Default::default(), + nullable: false, + data_type: DataType::Map( + Box::new(Field { + name: String::from("entry"), + data_type: DataType::Struct(vec![ + Field { + name: String::from("key"), + data_type: DataType::Utf8, + nullable: false, + metadata: Default::default(), + }, + Field { + name: String::from("value"), + data_type: DataType::Int32, + nullable: false, + metadata: Default::default(), + }, + ]), + metadata: Default::default(), + nullable: false, + }), + false, + ), + }], + }; + let expected = json!({ + "fields": [{ + "name": "my_field_name", + "data_type": "Map", + "children": [ + { + "name": "entry", + "data_type": "Struct", + "children": [ + {"name": "key", "data_type": "Utf8"}, + {"name": "value", "data_type": "I32"}, + ] + }, + ], + }], + }); + + let actual = serde_json::to_value(&schema)?; + assert_eq!(actual, expected); + + let roundtripped = SerdeArrowSchema::from_value(&actual)?; + assert_eq!(roundtripped, schema); + + Ok(()) +} + +#[test] +fn null_fields_are_nullable_implicitly() -> PanicOnError<()> { + let expected = SerdeArrowSchema { + fields: vec![Field { + name: String::from("item"), + data_type: DataType::Null, + metadata: hash_map!(), + nullable: true, + }], + }; + let schema = json!({ + "fields": [ + { + "name": "item", + "data_type": "Null", + } + ], + }); + + let actual = SerdeArrowSchema::from_value(&schema)?; + assert_eq!(actual, expected); + + Ok(()) +} + #[test] fn example() { let schema = SerdeArrowSchema { From bb74c189361bbbaf12e0dc1496f7aa971383fd89 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 20:16:18 +0200 Subject: [PATCH 102/178] Include type ids in union arrays, similar to arrow --- serde_arrow/src/arrow2_impl/array.rs | 61 +++++++++++-------- serde_arrow/src/arrow_impl/array.rs | 12 ++-- serde_arrow/src/internal/arrow/array.rs | 4 +- .../deserialization/array_deserializer.rs | 6 +- .../internal/serialization/union_builder.rs | 4 +- 5 files changed, 47 insertions(+), 40 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 51a57a0b..1150a739 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::{ _impl::arrow2::{ array::{ @@ -109,7 +111,16 @@ impl TryFrom for ArrayRef { arr.validity, ), A::Struct(arr) => { - let (values, fields) = array_with_meta_to_array_and_fields(arr.fields)?; + let mut values = Vec::new(); + let mut fields = Vec::new(); + + for (child, meta) in arr.fields { + let child: ArrayRef = child.try_into()?; + let field = field_from_array_and_meta(child.as_ref(), meta); + + values.push(child); + fields.push(field); + } Ok(Box::new(StructArray::new( T::Struct(fields), values, @@ -130,11 +141,19 @@ impl TryFrom for ArrayRef { ))) } A::DenseUnion(arr) => { - let (values, fields) = array_with_meta_to_array_and_fields(arr.fields)?; + let mut values = Vec::new(); + let mut fields = Vec::new(); let mut type_ids = Vec::new(); - for type_id in 0..fields.len() { + + for (type_id, child, meta) in arr.fields { + let child: ArrayRef = child.try_into()?; + let field = field_from_array_and_meta(child.as_ref(), meta); + type_ids.push(type_id.try_into()?); + values.push(child); + fields.push(field); } + Ok(Box::new(UnionArray::try_new( T::Union(fields, Some(type_ids), UnionMode::Dense), arr.types.into(), @@ -337,13 +356,15 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { fail!("Invalid data type: only dense unions are supported"); }; - if let Some(type_ids) = type_ids.as_ref() { - for (idx, type_id) in type_ids.iter().enumerate() { - if usize::try_from(*type_id) != Ok(idx) { - fail!("Only consecutive type ids are supported"); - } + let type_ids = if let Some(type_ids) = type_ids.as_ref() { + Cow::Borrowed(type_ids) + } else { + let mut type_ids = Vec::new(); + for idx in 0..union_fields.len() { + type_ids.push(idx.try_into()?); } - } + Cow::Owned(type_ids) + }; let types = array.types().as_slice(); let Some(offsets) = array.offsets() else { @@ -351,8 +372,11 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { }; let mut fields = Vec::new(); - for (child, child_field) in array.fields().iter().zip(union_fields) { + for ((type_id, child), child_field) in + type_ids.iter().zip(array.fields().iter()).zip(union_fields) + { fields.push(( + (*type_id).try_into()?, child.as_ref().try_into()?, meta_from_field(child_field.try_into()?)?, )); @@ -435,23 +459,6 @@ fn field_from_array_and_meta(arr: &dyn A2Array, meta: FieldMeta) -> Field { .with_metadata(meta.metadata.into_iter().collect()) } -fn array_with_meta_to_array_and_fields( - arrays: Vec<(Array, FieldMeta)>, -) -> Result<(Vec, Vec)> { - let mut res_fields = Vec::new(); - let mut res_arrays = Vec::new(); - - for (child, meta) in arrays { - let child: ArrayRef = child.try_into()?; - let field = field_from_array_and_meta(child.as_ref(), meta); - - res_arrays.push(child); - res_fields.push(field); - } - - Ok((res_arrays, res_fields)) -} - fn build_dictionary_array( indices_type: IntegerType, indices: InternalPrimitiveArray, diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index 8c7a2d6b..8340ff57 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -206,11 +206,11 @@ impl TryFrom for ArrayData { let mut fields = Vec::new(); let mut child_data = Vec::new(); - for (idx, (array, meta)) in arr.fields.into_iter().enumerate() { + for (type_id, array, meta) in arr.fields { let child: ArrayData = array.try_into()?; let field = field_from_data_and_meta(&child, meta); - fields.push((idx as i8, Arc::new(field))); + fields.push((type_id, Arc::new(field))); child_data.push(child); } @@ -508,14 +508,10 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { }; let mut fields = Vec::new(); - for (type_idx, (type_id, field)) in union_fields.iter().enumerate() { - if type_id < 0 || usize::try_from(type_id)? != type_idx { - fail!("invalid union, only unions with consecutive variants are supported"); - } - + for (type_id, field) in union_fields.iter() { let meta = meta_from_field(Field::try_from(field.as_ref())?)?; let view: ArrayView = array.child(type_id).as_ref().try_into()?; - fields.push((view, meta)); + fields.push((type_id, view, meta)); } let Some(offsets) = array.offsets() else { fail!("Dense unions must have an offset array"); diff --git a/serde_arrow/src/internal/arrow/array.rs b/serde_arrow/src/internal/arrow/array.rs index 2521fd8e..f1727d6f 100644 --- a/serde_arrow/src/internal/arrow/array.rs +++ b/serde_arrow/src/internal/arrow/array.rs @@ -273,12 +273,12 @@ pub struct DictionaryArrayView<'a> { pub struct DenseUnionArray { pub types: Vec, pub offsets: Vec, - pub fields: Vec<(Array, FieldMeta)>, + pub fields: Vec<(i8, Array, FieldMeta)>, } #[derive(Clone, Debug)] pub struct DenseUnionArrayView<'a> { pub types: &'a [i8], pub offsets: &'a [i32], - pub fields: Vec<(ArrayView<'a>, FieldMeta)>, + pub fields: Vec<(i8, ArrayView<'a>, FieldMeta)>, } diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 6c8b86c6..9ba76024 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -231,7 +231,11 @@ impl<'a> ArrayDeserializer<'a> { }, ArrayView::DenseUnion(view) => { let mut fields = Vec::new(); - for (field_view, field_meta) in view.fields { + for (idx, (type_id, field_view, field_meta)) in view.fields.into_iter().enumerate() + { + if usize::try_from(type_id) != Ok(idx) { + fail!("Only unions with consecutive type ids are currently supported in arrow2"); + } let field_deserializer = ArrayDeserializer::new(get_strategy(&field_meta)?.as_ref(), field_view)?; fields.push((field_meta.name, field_deserializer)) diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index ade767f4..ea3f66b6 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -43,8 +43,8 @@ impl UnionBuilder { pub fn into_array(self) -> Result { let mut fields = Vec::new(); - for (builder, meta) in self.fields { - fields.push((builder.into_array()?, meta)); + for (idx, (builder, meta)) in self.fields.into_iter().enumerate() { + fields.push((idx.try_into()?, builder.into_array()?, meta)); } Ok(Array::DenseUnion(DenseUnionArray { From 3102fcfd17b4ae5dffd54c09ae58863e2f5503a2 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 20:34:51 +0200 Subject: [PATCH 103/178] Move ext types into own modules, add serialization test for FixedShapeTensor --- serde_arrow/src/internal/schema/extensions.rs | 371 ------------------ .../extensions/fixed_shape_tensor_field.rs | 180 +++++++++ .../src/internal/schema/extensions/mod.rs | 6 + .../src/internal/schema/extensions/utils.rs | 58 +++ .../extensions/variable_shape_tensor_field.rs | 174 ++++++++ 5 files changed, 418 insertions(+), 371 deletions(-) delete mode 100644 serde_arrow/src/internal/schema/extensions.rs create mode 100644 serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs create mode 100644 serde_arrow/src/internal/schema/extensions/mod.rs create mode 100644 serde_arrow/src/internal/schema/extensions/utils.rs create mode 100644 serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs diff --git a/serde_arrow/src/internal/schema/extensions.rs b/serde_arrow/src/internal/schema/extensions.rs deleted file mode 100644 index db558960..00000000 --- a/serde_arrow/src/internal/schema/extensions.rs +++ /dev/null @@ -1,371 +0,0 @@ -use std::collections::HashMap; - -use serde::Serialize; - -use crate::internal::{ - arrow::{DataType, Field}, - error::{fail, Error, Result}, - schema::{transmute_field, PrettyField}, -}; - -/// Easily construct a field for tensors with fixed shape -/// -/// See the [arrow docs][fixed-shape-tensor-docs] for details on the different -/// fields. -/// -/// The Rust value must serialize to a fixed size list that contains the -/// flattened tensor elements in C order. To support different orders, set the -/// [`permutation`][FixedShapeTensorField::permutation]. -/// -/// This struct is designed to be used with -/// [`TracingOptions::overwrite`][crate::schema::TracingOptions::overwrite]: -/// -/// ```rust -/// # use serde_json::json; -/// # use serde_arrow::{Result, schema::{TracingOptions, ext::FixedShapeTensorField}}; -/// # fn main() -> Result<()> { -/// TracingOptions::default().overwrite( -/// "tensor", -/// FixedShapeTensorField::new( -/// "tensor", -/// json!({"name": "element", "data_type": "I32"}), -/// vec![2, 2], -/// )?, -/// )? -/// # ; -/// # Ok(()) -/// # } -/// ``` -/// -/// [fixed-shape-tensor-docs]: -/// https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor -/// -#[derive(Clone, Debug, PartialEq)] -pub struct FixedShapeTensorField { - name: String, - nullable: bool, - element: Field, - shape: Vec, - dim_names: Option>, - permutation: Option>, -} - -impl FixedShapeTensorField { - /// Construct a new instance - /// - /// Note the element parameter must serialize into a valid field definition - /// with the the name `"element"`. The field type can be any valid Arrow - /// type. - pub fn new(name: &str, element: impl Serialize, shape: Vec) -> Result { - let element = transmute_field(element)?; - if element.name != "element" { - fail!("The element field of FixedShapeTensorField must be named \"element\""); - } - - Ok(Self { - name: name.to_owned(), - shape, - element, - nullable: false, - dim_names: None, - permutation: None, - }) - } - - /// Set the nullability of the field - pub fn nullable(mut self, value: bool) -> Self { - self.nullable = value; - self - } - - /// Set the permutation of the dimension - pub fn permutation(mut self, value: Vec) -> Result { - check_permutation(self.shape.len(), &value)?; - self.permutation = Some(value); - Ok(self) - } - - /// Set the dimension names - pub fn dim_names(mut self, value: Vec) -> Result { - check_dim_names(self.shape.len(), &value)?; - self.dim_names = Some(value); - Ok(self) - } -} - -impl FixedShapeTensorField { - fn get_ext_metadata(&self) -> Result { - use std::fmt::Write; - - let mut ext_metadata = String::new(); - write!(&mut ext_metadata, "{{")?; - - write!(&mut ext_metadata, "\"shape\":")?; - write_list(&mut ext_metadata, self.shape.iter())?; - - if let Some(permutation) = self.permutation.as_ref() { - write!(&mut ext_metadata, ",\"permutation\":")?; - write_list(&mut ext_metadata, permutation.iter())?; - } - - if let Some(dim_names) = self.dim_names.as_ref() { - write!(&mut ext_metadata, ",\"dim_names\":")?; - write_list(&mut ext_metadata, dim_names.iter().map(DebugRepr))?; - } - - write!(&mut ext_metadata, "}}")?; - Ok(ext_metadata) - } -} - -impl TryFrom<&FixedShapeTensorField> for Field { - type Error = Error; - - fn try_from(value: &FixedShapeTensorField) -> Result { - let mut n = 1; - for s in &value.shape { - n *= *s; - } - - let mut metadata = HashMap::new(); - metadata.insert( - "ARROW:extension:name".into(), - "arrow.fixed_shape_tensor".into(), - ); - metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); - - Ok(Field { - name: value.name.to_owned(), - nullable: value.nullable, - data_type: DataType::FixedSizeList(Box::new(value.element.clone()), n.try_into()?), - metadata, - }) - } -} - -impl serde::ser::Serialize for FixedShapeTensorField { - fn serialize(&self, serializer: S) -> Result { - use serde::ser::Error; - let field = Field::try_from(self).map_err(S::Error::custom)?; - PrettyField(&field).serialize(serializer) - } -} - -/// Helper to build fields for tensors with variable shape -/// -/// See the [arrow docs][variable-shape-tensor-field-docs] for details on the -/// different fields. -/// -/// [variable-shape-tensor-field-docs]: -/// https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor -pub struct VariableShapeTensorField { - name: String, - element: Field, - ndim: usize, - nullable: bool, - dim_names: Option>, - permutation: Option>, - uniform_shape: Option>>, -} - -impl VariableShapeTensorField { - pub fn new(name: &str, element: impl serde::ser::Serialize, ndim: usize) -> Result { - let element = transmute_field(element)?; - if element.name != "element" { - fail!("The element field of FixedShapeTensorField must be named \"element\""); - } - - Ok(Self { - name: name.to_owned(), - element, - ndim, - nullable: false, - dim_names: None, - permutation: None, - uniform_shape: None, - }) - } - - /// Set the nullability of the field - pub fn nullable(mut self, value: bool) -> Self { - self.nullable = value; - self - } - - /// Set the permutation of the dimension - pub fn permutation(mut self, value: Vec) -> Result { - check_permutation(self.ndim, &value)?; - self.permutation = Some(value); - Ok(self) - } - - /// Set the dimension names - pub fn dim_names(mut self, value: Vec) -> Result { - check_dim_names(self.ndim, &value)?; - self.dim_names = Some(value); - Ok(self) - } - - /// Set the uniform shape - pub fn uniform_shape(mut self, value: Vec>) -> Result { - if value.len() != self.ndim { - fail!("Invalid uniform_shape value"); - } - self.uniform_shape = Some(value); - Ok(self) - } -} - -impl VariableShapeTensorField { - fn get_ext_metadata(&self) -> Result { - use std::fmt::Write; - - let mut first_field = true; - - let mut ext_metadata = String::new(); - write!(&mut ext_metadata, "{{")?; - - if let Some(permutation) = self.permutation.as_ref() { - if first_field { - first_field = false; - write!(&mut ext_metadata, ",")?; - } - write!(&mut ext_metadata, "\"permutation\":")?; - write_list(&mut ext_metadata, permutation.iter())?; - } - - if let Some(dim_names) = self.dim_names.as_ref() { - if first_field { - first_field = false; - write!(&mut ext_metadata, ",")?; - } - write!(&mut ext_metadata, "\"dim_names\":")?; - write_list(&mut ext_metadata, dim_names.iter().map(DebugRepr))?; - } - - if let Some(uniform_shape) = self.uniform_shape.as_ref() { - if first_field { - first_field = false; - write!(&mut ext_metadata, ",")?; - } - write!(&mut ext_metadata, "\"uniform_shape\":")?; - write_list( - &mut ext_metadata, - uniform_shape.iter().map(|val| match val { - Some(val) => format!("{val}"), - None => String::from("null"), - }), - )?; - } - - // silence "value not read" warning - let _ = first_field; - - write!(&mut ext_metadata, "}}")?; - Ok(ext_metadata) - } -} - -impl TryFrom<&VariableShapeTensorField> for Field { - type Error = Error; - - fn try_from(value: &VariableShapeTensorField) -> Result { - let mut metadata = HashMap::new(); - metadata.insert( - "ARROW:extension:name".into(), - "arrow.variable_shape_tensor".into(), - ); - metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); - - let mut fields = Vec::new(); - fields.push(Field { - name: String::from("data"), - data_type: DataType::List(Box::new(value.element.clone())), - nullable: false, - metadata: HashMap::new(), - }); - fields.push(Field { - name: String::from("shape"), - data_type: DataType::FixedSizeList( - Box::new(Field { - name: String::from("element"), - data_type: DataType::Int32, - nullable: false, - metadata: HashMap::new(), - }), - value.ndim.try_into()?, - ), - nullable: false, - metadata: HashMap::new(), - }); - - Ok(Field { - name: value.name.clone(), - nullable: value.nullable, - data_type: DataType::Struct(fields), - metadata, - }) - } -} - -impl serde::ser::Serialize for VariableShapeTensorField { - fn serialize(&self, serializer: S) -> Result { - use serde::ser::Error; - let field = Field::try_from(self).map_err(S::Error::custom)?; - PrettyField(&field).serialize(serializer) - } -} - -fn check_dim_names(ndim: usize, dim_names: &[String]) -> Result<()> { - if dim_names.len() != ndim { - fail!("Number of dim names must be equal to the number of dimensions"); - } - Ok(()) -} - -fn check_permutation(ndim: usize, permutation: &[usize]) -> Result<()> { - if permutation.len() != ndim { - fail!("Number of permutation entries must be equal to the number of dimensions"); - } - let seen = vec![false; permutation.len()]; - for &i in permutation { - if i >= seen.len() { - fail!( - "Invalid permutation: index {i} is not in range 0..{len}", - len = seen.len() - ); - } - if seen[i] { - fail!("Invalid permutation: index {i} found multiple times"); - } - } - for (i, seen) in seen.into_iter().enumerate() { - if !seen { - fail!("Invalid permutation: index {i} is not present"); - } - } - Ok(()) -} - -fn write_list(s: &mut String, items: impl Iterator) -> Result<()> { - use std::fmt::Write; - - write!(s, "[")?; - for (idx, val) in items.enumerate() { - if idx != 0 { - write!(s, ",{val}")?; - } else { - write!(s, "{val}")?; - } - } - write!(s, "]")?; - Ok(()) -} - -struct DebugRepr(T); - -impl std::fmt::Display for DebugRepr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } -} diff --git a/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs new file mode 100644 index 00000000..c5239a55 --- /dev/null +++ b/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs @@ -0,0 +1,180 @@ +use std::collections::HashMap; + +use crate::internal::{ + arrow::{DataType, Field}, + error::{fail, Error, Result}, + schema::{transmute_field, PrettyField}, +}; + +use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr}; + +/// Easily construct a field for tensors with fixed shape +/// +/// See the [arrow docs][fixed-shape-tensor-docs] for details on the different +/// fields. +/// +/// The Rust value must serialize to a fixed size list that contains the +/// flattened tensor elements in C order. To support different orders, set the +/// [`permutation`][FixedShapeTensorField::permutation]. +/// +/// This struct is designed to be used with +/// [`TracingOptions::overwrite`][crate::schema::TracingOptions::overwrite]: +/// +/// ```rust +/// # use serde_json::json; +/// # use serde_arrow::{Result, schema::{TracingOptions, ext::FixedShapeTensorField}}; +/// # fn main() -> Result<()> { +/// TracingOptions::default().overwrite( +/// "tensor", +/// FixedShapeTensorField::new( +/// "tensor", +/// json!({"name": "element", "data_type": "I32"}), +/// vec![2, 2], +/// )?, +/// )? +/// # ; +/// # Ok(()) +/// # } +/// ``` +/// +/// [fixed-shape-tensor-docs]: +/// https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor +/// +#[derive(Clone, Debug, PartialEq)] +pub struct FixedShapeTensorField { + name: String, + nullable: bool, + element: Field, + shape: Vec, + dim_names: Option>, + permutation: Option>, +} + +impl FixedShapeTensorField { + /// Construct a new instance + /// + /// Note the element parameter must serialize into a valid field definition + /// with the the name `"element"`. The field type can be any valid Arrow + /// type. + pub fn new(name: &str, element: impl serde::ser::Serialize, shape: Vec) -> Result { + let element = transmute_field(element)?; + if element.name != "element" { + fail!("The element field of FixedShapeTensorField must be named \"element\""); + } + + Ok(Self { + name: name.to_owned(), + shape, + element, + nullable: false, + dim_names: None, + permutation: None, + }) + } + + /// Set the nullability of the field + pub fn nullable(mut self, value: bool) -> Self { + self.nullable = value; + self + } + + /// Set the permutation of the dimension + pub fn permutation(mut self, value: Vec) -> Result { + check_permutation(self.shape.len(), &value)?; + self.permutation = Some(value); + Ok(self) + } + + /// Set the dimension names + pub fn dim_names(mut self, value: Vec) -> Result { + check_dim_names(self.shape.len(), &value)?; + self.dim_names = Some(value); + Ok(self) + } +} + +impl FixedShapeTensorField { + fn get_ext_metadata(&self) -> Result { + use std::fmt::Write; + + let mut ext_metadata = String::new(); + write!(&mut ext_metadata, "{{")?; + + write!(&mut ext_metadata, "\"shape\":")?; + write_list(&mut ext_metadata, self.shape.iter())?; + + if let Some(permutation) = self.permutation.as_ref() { + write!(&mut ext_metadata, ",\"permutation\":")?; + write_list(&mut ext_metadata, permutation.iter())?; + } + + if let Some(dim_names) = self.dim_names.as_ref() { + write!(&mut ext_metadata, ",\"dim_names\":")?; + write_list(&mut ext_metadata, dim_names.iter().map(DebugRepr))?; + } + + write!(&mut ext_metadata, "}}")?; + Ok(ext_metadata) + } +} + +impl TryFrom<&FixedShapeTensorField> for Field { + type Error = Error; + + fn try_from(value: &FixedShapeTensorField) -> Result { + let mut n = 1; + for s in &value.shape { + n *= *s; + } + + let mut metadata = HashMap::new(); + metadata.insert( + "ARROW:extension:name".into(), + "arrow.fixed_shape_tensor".into(), + ); + metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); + + Ok(Field { + name: value.name.to_owned(), + nullable: value.nullable, + data_type: DataType::FixedSizeList(Box::new(value.element.clone()), n.try_into()?), + metadata, + }) + } +} + +impl serde::ser::Serialize for FixedShapeTensorField { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error; + let field = Field::try_from(self).map_err(S::Error::custom)?; + PrettyField(&field).serialize(serializer) + } +} + +#[test] +fn fixed_shape_tensor_field_repr() -> crate::internal::error::PanicOnError<()> { + use serde_json::json; + + let field = FixedShapeTensorField::new( + "hello", + json!({"name": "element", "data_type": "F32"}), + vec![2, 3], + )?; + let field = Field::try_from(&field)?; + let actual = serde_json::to_value(&PrettyField(&field))?; + let expected = json!({ + "name": "hello", + "data_type": "FixedSizeList(6)", + "children": [{ + "name": "element", + "data_type": "F32", + }], + "metadata": { + "ARROW:extension:metadata": "{\"shape\":[2,3]}", + "ARROW:extension:name": "arrow.fixed_shape_tensor", + }, + }); + + assert_eq!(actual, expected); + Ok(()) +} diff --git a/serde_arrow/src/internal/schema/extensions/mod.rs b/serde_arrow/src/internal/schema/extensions/mod.rs new file mode 100644 index 00000000..abf2db9d --- /dev/null +++ b/serde_arrow/src/internal/schema/extensions/mod.rs @@ -0,0 +1,6 @@ +mod fixed_shape_tensor_field; +mod utils; +mod variable_shape_tensor_field; + +pub use fixed_shape_tensor_field::FixedShapeTensorField; +pub use variable_shape_tensor_field::VariableShapeTensorField; diff --git a/serde_arrow/src/internal/schema/extensions/utils.rs b/serde_arrow/src/internal/schema/extensions/utils.rs new file mode 100644 index 00000000..aae4d1c0 --- /dev/null +++ b/serde_arrow/src/internal/schema/extensions/utils.rs @@ -0,0 +1,58 @@ +use crate::internal::error::{fail, Result}; + +pub fn check_dim_names(ndim: usize, dim_names: &[String]) -> Result<()> { + if dim_names.len() != ndim { + fail!("Number of dim names must be equal to the number of dimensions"); + } + Ok(()) +} + +pub fn check_permutation(ndim: usize, permutation: &[usize]) -> Result<()> { + if permutation.len() != ndim { + fail!("Number of permutation entries must be equal to the number of dimensions"); + } + let seen = vec![false; permutation.len()]; + for &i in permutation { + if i >= seen.len() { + fail!( + "Invalid permutation: index {i} is not in range 0..{len}", + len = seen.len() + ); + } + if seen[i] { + fail!("Invalid permutation: index {i} found multiple times"); + } + } + for (i, seen) in seen.into_iter().enumerate() { + if !seen { + fail!("Invalid permutation: index {i} is not present"); + } + } + Ok(()) +} + +pub fn write_list( + s: &mut String, + items: impl Iterator, +) -> Result<()> { + use std::fmt::Write; + + write!(s, "[")?; + for (idx, val) in items.enumerate() { + if idx != 0 { + write!(s, ",{val}")?; + } else { + write!(s, "{val}")?; + } + } + write!(s, "]")?; + Ok(()) +} + +pub struct DebugRepr(pub T); + +impl std::fmt::Display for DebugRepr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} diff --git a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs new file mode 100644 index 00000000..a45d9f1c --- /dev/null +++ b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs @@ -0,0 +1,174 @@ +use std::collections::HashMap; + +use crate::internal::{ + arrow::{DataType, Field}, + error::{fail, Error, Result}, + schema::{transmute_field, PrettyField}, +}; + +use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr}; + +/// Helper to build fields for tensors with variable shape +/// +/// See the [arrow docs][variable-shape-tensor-field-docs] for details on the +/// different fields. +/// +/// [variable-shape-tensor-field-docs]: +/// https://arrow.apache.org/docs/format/CanonicalExtensions.html#variable-shape-tensor +pub struct VariableShapeTensorField { + name: String, + element: Field, + ndim: usize, + nullable: bool, + dim_names: Option>, + permutation: Option>, + uniform_shape: Option>>, +} + +impl VariableShapeTensorField { + pub fn new(name: &str, element: impl serde::ser::Serialize, ndim: usize) -> Result { + let element = transmute_field(element)?; + if element.name != "element" { + fail!("The element field of FixedShapeTensorField must be named \"element\""); + } + + Ok(Self { + name: name.to_owned(), + element, + ndim, + nullable: false, + dim_names: None, + permutation: None, + uniform_shape: None, + }) + } + + /// Set the nullability of the field + pub fn nullable(mut self, value: bool) -> Self { + self.nullable = value; + self + } + + /// Set the permutation of the dimension + pub fn permutation(mut self, value: Vec) -> Result { + check_permutation(self.ndim, &value)?; + self.permutation = Some(value); + Ok(self) + } + + /// Set the dimension names + pub fn dim_names(mut self, value: Vec) -> Result { + check_dim_names(self.ndim, &value)?; + self.dim_names = Some(value); + Ok(self) + } + + /// Set the uniform shape + pub fn uniform_shape(mut self, value: Vec>) -> Result { + if value.len() != self.ndim { + fail!("Invalid uniform_shape value"); + } + self.uniform_shape = Some(value); + Ok(self) + } +} + +impl VariableShapeTensorField { + fn get_ext_metadata(&self) -> Result { + use std::fmt::Write; + + let mut first_field = true; + + let mut ext_metadata = String::new(); + write!(&mut ext_metadata, "{{")?; + + if let Some(permutation) = self.permutation.as_ref() { + if first_field { + first_field = false; + write!(&mut ext_metadata, ",")?; + } + write!(&mut ext_metadata, "\"permutation\":")?; + write_list(&mut ext_metadata, permutation.iter())?; + } + + if let Some(dim_names) = self.dim_names.as_ref() { + if first_field { + first_field = false; + write!(&mut ext_metadata, ",")?; + } + write!(&mut ext_metadata, "\"dim_names\":")?; + write_list(&mut ext_metadata, dim_names.iter().map(DebugRepr))?; + } + + if let Some(uniform_shape) = self.uniform_shape.as_ref() { + if first_field { + first_field = false; + write!(&mut ext_metadata, ",")?; + } + write!(&mut ext_metadata, "\"uniform_shape\":")?; + write_list( + &mut ext_metadata, + uniform_shape.iter().map(|val| match val { + Some(val) => format!("{val}"), + None => String::from("null"), + }), + )?; + } + + // silence "value not read" warning + let _ = first_field; + + write!(&mut ext_metadata, "}}")?; + Ok(ext_metadata) + } +} + +impl TryFrom<&VariableShapeTensorField> for Field { + type Error = Error; + + fn try_from(value: &VariableShapeTensorField) -> Result { + let mut metadata = HashMap::new(); + metadata.insert( + "ARROW:extension:name".into(), + "arrow.variable_shape_tensor".into(), + ); + metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); + + let mut fields = Vec::new(); + fields.push(Field { + name: String::from("data"), + data_type: DataType::List(Box::new(value.element.clone())), + nullable: false, + metadata: HashMap::new(), + }); + fields.push(Field { + name: String::from("shape"), + data_type: DataType::FixedSizeList( + Box::new(Field { + name: String::from("element"), + data_type: DataType::Int32, + nullable: false, + metadata: HashMap::new(), + }), + value.ndim.try_into()?, + ), + nullable: false, + metadata: HashMap::new(), + }); + + Ok(Field { + name: value.name.clone(), + nullable: value.nullable, + data_type: DataType::Struct(fields), + metadata, + }) + } +} + +impl serde::ser::Serialize for VariableShapeTensorField { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error; + let field = Field::try_from(self).map_err(S::Error::custom)?; + PrettyField(&field).serialize(serializer) + } +} From deaf588e5ab53969854ef975f28caa2981034784 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 20:45:00 +0200 Subject: [PATCH 104/178] Add serialization test for VariableShapeTensorField --- .../extensions/variable_shape_tensor_field.rs | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs index a45d9f1c..6f7849ab 100644 --- a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs +++ b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs @@ -172,3 +172,36 @@ impl serde::ser::Serialize for VariableShapeTensorField { PrettyField(&field).serialize(serializer) } } + +#[test] +fn test_serialization() -> crate::internal::error::PanicOnError<()> { + use serde_json::json; + + let field = VariableShapeTensorField::new( + "foo bar", + json!({"name": "element", "data_type": "Bool"}), + 2, + )?; + let field = Field::try_from(&field)?; + let actual = serde_json::to_value(PrettyField(&field))?; + + let expected = json!({ + "name": "foo bar", + "data_type": "Struct", + "children": [ + { + "name": "data", + "data_type": "List", + "children": [{"name": "element", "data_type": "Bool"}], + }, + {"name": "shape", "data_type": "FixedSizeList(2)", "children": [{"name": "element", "data_type": "I32"}]} + ], + "metadata": { + "ARROW:extension:metadata": "{}", + "ARROW:extension:name": "arrow.variable_shape_tensor", + }, + }); + + assert_eq!(actual, expected); + Ok(()) +} From 31d544c2402b9e1f4091dcd2afe17122759f4822 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 20:47:27 +0200 Subject: [PATCH 105/178] Update changelog --- Changes.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/Changes.md b/Changes.md index df9cd12b..25f93094 100644 --- a/Changes.md +++ b/Changes.md @@ -2,11 +2,21 @@ ## 0.12 +Refactor the underlying implementation to prepare for further development + +New features + - Add `Binary` / `LargeBinary` support for `arrow2` -- Remove `serde_arrow::schema::Schema` -- Remove `serde_arrow::ArrowBuilder` and `serde_arrow::Arrow2Builder` + +API changes + - Use `impl serde::Serialize` instead of `&(impl serde::Serialize + ?Sized)` - Use `&[FieldRef]` instead of `&[Field]` in arrow APIs + +Removed deprecated API + +- Remove `serde_arrow::schema::Schema` +- Remove `serde_arrow::ArrowBuilder` and `serde_arrow::Arrow2Builder` - Remove `from_arrow_fields` / `to_arrow_fields` for `SerdeArrowSchema`, use the `TryFrom` conversions to convert between fields and `SerdeArrowSchema` - Remove `SerdeArrowSchema::new()`, `Overwrites::new()` From c95d5f06887a9435751ff69e707119f8da376a3e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 20:54:01 +0200 Subject: [PATCH 106/178] Validate map entries --- serde_arrow/src/internal/schema/mod.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index cb7ebcec..78ebba3c 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -484,8 +484,15 @@ fn validate_map_field(field: &Field, _entry: &Field) -> Result<()> { if let Some(strategy) = get_strategy_from_metadata(&field.metadata)? { fail!("invalid strategy for Map field: {strategy}"); } - // TODO: validate entry - + let DataType::Map(entry, _) = &field.data_type else { + fail!("Invalid data type for map child, expected a map"); + }; + let DataType::Struct(entry_fields) = &entry.data_type else { + fail!("Invalid child data type for map, expected struct with 2 fields"); + }; + if entry_fields.len() != 2 { + fail!("Invalid child data type for map, expected struct with 2 fields"); + } Ok(()) } From cca0efe350c80962ba394d9cacb59c00c9a98f25 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 21:13:55 +0200 Subject: [PATCH 107/178] Fix todos, remove unused code --- .../src/internal/schema/serde/serialize.rs | 1 - .../src/test/decimal_representations.rs | 91 ------------------- serde_arrow/src/test/mod.rs | 1 - .../src/test_with_arrow/impls/utils.rs | 15 ++- 4 files changed, 11 insertions(+), 97 deletions(-) delete mode 100644 serde_arrow/src/test/decimal_representations.rs diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index 8256a246..f31cfeed 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -151,7 +151,6 @@ impl<'a> serde::Serialize for PrettyFieldChildren<'a> { s.end() } T::Dictionary(key, value, _) => { - // TODO: this is incorrect, serialize as struct let mut s = serializer.serialize_seq(Some(2))?; s.serialize_element(&DictionaryField("key", key))?; s.serialize_element(&DictionaryField("value", value))?; diff --git a/serde_arrow/src/test/decimal_representations.rs b/serde_arrow/src/test/decimal_representations.rs deleted file mode 100644 index f88c2f54..00000000 --- a/serde_arrow/src/test/decimal_representations.rs +++ /dev/null @@ -1,91 +0,0 @@ -// TODO: use to_value -/* -use std::str::FromStr; - -use rust_decimal::Decimal; -use serde::Serialize; - -use crate::internal::{event::Event, sink::serialize_into_sink}; - -#[test] -fn example_str() { - #[derive(Serialize)] - struct Wrapper { - #[serde(with = "rust_decimal::serde::str")] - value: Decimal, - } - - let mut events = Vec::>::new(); - serialize_into_sink( - &mut events, - &Wrapper { - value: Decimal::from_str("0.20").unwrap(), - }, - ) - .unwrap(); - - assert_eq!( - events, - vec![ - Event::StartStruct, - Event::Str("value").to_owned(), - Event::Str("0.20").to_owned(), - Event::EndStruct, - ], - ); -} - -#[test] -fn example_float() { - #[derive(Serialize)] - struct Wrapper { - #[serde(with = "rust_decimal::serde::float")] - value: Decimal, - } - - let mut events = Vec::>::new(); - serialize_into_sink( - &mut events, - &Wrapper { - value: Decimal::from_str("0.20").unwrap(), - }, - ) - .unwrap(); - - assert_eq!( - events, - vec![ - Event::StartStruct, - Event::Str("value").to_owned(), - Event::F64(0.2), - Event::EndStruct, - ], - ); -} - */ -/* -#[test] -fn example_arbitrary_precision() { - #[derive(Serialize)] - struct Wrapper { - #[serde(with = "rust_decimal::serde::arbitrary_precision")] - value: Decimal, - } - - let mut events = Vec::>::new(); - serialize_into_sink(&mut events, &Wrapper { value: Decimal::from_str("0.20").unwrap() }).unwrap(); - - assert_eq!( - events, - vec![ - Event::StartStruct, - Event::Str("value").to_owned(), - Event::StartStruct, - Event::Str("$serde_json::private::Number").to_owned(), - Event::Str("0.20"), - Event::EndStruct, - Event::EndStruct, - ], - ); -} -*/ diff --git a/serde_arrow/src/test/mod.rs b/serde_arrow/src/test/mod.rs index 93168059..24690804 100644 --- a/serde_arrow/src/test/mod.rs +++ b/serde_arrow/src/test/mod.rs @@ -1,4 +1,3 @@ mod api_chrono; -mod decimal_representations; mod error; mod schema_like; diff --git a/serde_arrow/src/test_with_arrow/impls/utils.rs b/serde_arrow/src/test_with_arrow/impls/utils.rs index 4721f8f2..ca1cdb7b 100644 --- a/serde_arrow/src/test_with_arrow/impls/utils.rs +++ b/serde_arrow/src/test_with_arrow/impls/utils.rs @@ -165,7 +165,12 @@ impl Test { let mut builder = crate::ArrayBuilder::from_arrow(&fields)?; builder.extend(items)?; let arrays = builder.to_arrow()?; - assert_eq!(self.arrays.arrow, Some(arrays)); + assert_eq!(self.arrays.arrow.as_ref(), Some(&arrays)); + + assert_eq!(fields.len(), arrays.len()); + for (field, array) in std::iter::zip(&fields, &arrays) { + assert_eq!(field.data_type(), array.data_type()); + } Ok(()) } @@ -184,10 +189,12 @@ impl Test { let mut builder = crate::ArrayBuilder::from_arrow2(&fields)?; builder.extend(items)?; let arrays = builder.to_arrow2()?; - assert_eq!(self.arrays.arrow2, Some(arrays)); - - // TODO: test that the result arrays has the fields as the schema + assert_eq!(self.arrays.arrow2.as_ref(), Some(&arrays)); + assert_eq!(fields.len(), arrays.len()); + for (field, array) in std::iter::zip(&fields, &arrays) { + assert_eq!(field.data_type(), array.data_type()); + } Ok(()) } From 0f25d6e3a751b46f7091179b27060478ba2591a0 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 21:17:25 +0200 Subject: [PATCH 108/178] fix format --- serde_arrow/src/test_with_arrow/impls/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/test_with_arrow/impls/utils.rs b/serde_arrow/src/test_with_arrow/impls/utils.rs index ca1cdb7b..a0349924 100644 --- a/serde_arrow/src/test_with_arrow/impls/utils.rs +++ b/serde_arrow/src/test_with_arrow/impls/utils.rs @@ -170,7 +170,7 @@ impl Test { assert_eq!(fields.len(), arrays.len()); for (field, array) in std::iter::zip(&fields, &arrays) { assert_eq!(field.data_type(), array.data_type()); - } + } Ok(()) } From e2231b4033fb8dbd2347096fdff6c62287481536 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 21:36:25 +0200 Subject: [PATCH 109/178] Ensure I8 can be used to serialize and deserialize bool --- .../src/internal/serialization/int_builder.rs | 5 ++++ .../src/test_with_arrow/impls/bool8.rs | 25 +++++++++++++++++++ serde_arrow/src/test_with_arrow/impls/mod.rs | 1 + 3 files changed, 31 insertions(+) create mode 100644 serde_arrow/src/test_with_arrow/impls/bool8.rs diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index acea49ff..f7857632 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -77,6 +77,11 @@ where self.0.push_scalar_none() } + fn serialize_bool(&mut self, v: bool) -> Result<()> { + let v: u8 = if v { 1 } else { 0 }; + self.0.push_scalar_value(I::try_from(v)?) + } + fn serialize_i8(&mut self, v: i8) -> Result<()> { self.0.push_scalar_value(I::try_from(v)?) } diff --git a/serde_arrow/src/test_with_arrow/impls/bool8.rs b/serde_arrow/src/test_with_arrow/impls/bool8.rs new file mode 100644 index 00000000..d055e67d --- /dev/null +++ b/serde_arrow/src/test_with_arrow/impls/bool8.rs @@ -0,0 +1,25 @@ +use serde_json::json; + +use crate::internal::utils::Item; + +use super::utils::Test; + +#[test] +fn bool_as_int8() { + let items = &[Item(true), Item(false)]; + Test::new() + .with_schema(json!([{"name": "item", "data_type": "U8"}])) + .serialize(items) + .deserialize(items) + .check_nulls(&[&[false, false]]); +} + +#[test] +fn nullable_bool_as_int8() { + let items = &[Item(Some(true)), Item(None), Item(Some(false))]; + Test::new() + .with_schema(json!([{"name": "item", "data_type": "U8", "nullable": true}])) + .serialize(items) + .deserialize(items) + .check_nulls(&[&[false, true, false]]); +} diff --git a/serde_arrow/src/test_with_arrow/impls/mod.rs b/serde_arrow/src/test_with_arrow/impls/mod.rs index 90695d09..72db9255 100644 --- a/serde_arrow/src/test_with_arrow/impls/mod.rs +++ b/serde_arrow/src/test_with_arrow/impls/mod.rs @@ -1,5 +1,6 @@ mod utils; +mod bool8; mod bytes; mod chrono; mod dictionary; From 03a7949b9651e30320e3da497ef05655b44f2442 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 22:17:03 +0200 Subject: [PATCH 110/178] Move array-ext into utils --- .../src/internal/serialization/binary_builder.rs | 10 +++++----- serde_arrow/src/internal/serialization/bool_builder.rs | 6 ++---- .../src/internal/serialization/date32_builder.rs | 6 ++---- .../src/internal/serialization/date64_builder.rs | 6 ++---- .../src/internal/serialization/decimal_builder.rs | 6 ++---- .../src/internal/serialization/duration_builder.rs | 6 ++---- .../serialization/fixed_size_binary_builder.rs | 6 ++---- .../internal/serialization/fixed_size_list_builder.rs | 7 ++----- .../src/internal/serialization/float_builder.rs | 6 ++---- serde_arrow/src/internal/serialization/int_builder.rs | 6 ++---- serde_arrow/src/internal/serialization/list_builder.rs | 7 ++----- serde_arrow/src/internal/serialization/map_builder.rs | 7 ++----- serde_arrow/src/internal/serialization/mod.rs | 1 - .../src/internal/serialization/struct_builder.rs | 7 ++----- serde_arrow/src/internal/serialization/time_builder.rs | 6 ++---- serde_arrow/src/internal/serialization/utf8_builder.rs | 6 ++---- .../src/internal/{serialization => utils}/array_ext.rs | 0 serde_arrow/src/internal/utils/mod.rs | 1 + 18 files changed, 34 insertions(+), 66 deletions(-) rename serde_arrow/src/internal/{serialization => utils}/array_ext.rs (100%) diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 2bc1ea5a..8bb2adc5 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -3,13 +3,13 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, BytesArray}, error::Result, - utils::{Mut, Offset}, + utils::{ + array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, + Mut, Offset, + }, }; -use super::{ - array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 1614bcfe..c1d23683 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, BooleanArray}, error::Result, + utils::array_ext::{set_bit_buffer, set_validity, set_validity_default}, }; -use super::{ - array_ext::{set_bit_buffer, set_validity, set_validity_default}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct BoolBuilder(BooleanArray); diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 1aaa3079..d4feb160 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -3,12 +3,10 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Date32Builder(PrimitiveArray); diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 576229c3..7001b558 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, error::{fail, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Date64Builder { diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 1caf8802..2f851d47 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,13 +1,11 @@ use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, utils::decimal::{self, DecimalParser}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct DecimalBuilder { diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 17ecc99a..1b835a69 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct DurationBuilder { diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 2c82b46e..329d4fc2 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -3,13 +3,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; -use super::{ - array_ext::{ArrayExt, CountArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 89007756..1c222304 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -3,14 +3,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, CountArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index a9be17b2..f54ed275 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -3,13 +3,11 @@ use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, error::Result, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, utils::Mut, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct FloatBuilder(PrimitiveArray); diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index f7857632..f385667f 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -1,12 +1,10 @@ use crate::internal::{ arrow::{Array, PrimitiveArray}, error::{Error, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct IntBuilder(PrimitiveArray); diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 8a2bdad1..7578433e 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -3,14 +3,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, error::Result, + utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, utils::{Mut, Offset}, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 8c0499d1..737d1df6 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -3,13 +3,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct MapBuilder { diff --git a/serde_arrow/src/internal/serialization/mod.rs b/serde_arrow/src/internal/serialization/mod.rs index ac137b17..f6af48eb 100644 --- a/serde_arrow/src/internal/serialization/mod.rs +++ b/serde_arrow/src/internal/serialization/mod.rs @@ -1,7 +1,6 @@ //! A serialization implementation without the event model pub mod array_builder; -pub mod array_ext; pub mod binary_builder; pub mod bool_builder; pub mod date32_builder; diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 264a7b86..68753f13 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -5,14 +5,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, error::{fail, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; -use super::{ - array_builder::ArrayBuilder, - array_ext::{ArrayExt, CountArray, SeqArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; const UNKNOWN_KEY: usize = usize::MAX; diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index ed6dae67..550cc17a 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -3,12 +3,10 @@ use chrono::Timelike; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, error::{Error, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; -use super::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct TimeBuilder { diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 788123cf..b811b8aa 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,13 +1,11 @@ use crate::internal::{ arrow::{Array, BytesArray}, error::{fail, Result}, + utils::array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, utils::Offset, }; -use super::{ - array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - simple_serializer::SimpleSerializer, -}; +use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Utf8Builder(BytesArray); diff --git a/serde_arrow/src/internal/serialization/array_ext.rs b/serde_arrow/src/internal/utils/array_ext.rs similarity index 100% rename from serde_arrow/src/internal/serialization/array_ext.rs rename to serde_arrow/src/internal/utils/array_ext.rs diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 6995f332..36b3a8cd 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod array_ext; pub mod decimal; pub mod dsl; pub mod value; From 635b176e64326bd6b165329b8675b289101de5b2 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 22:17:27 +0200 Subject: [PATCH 111/178] Allow to construct a Deserializer directly from internal arrays --- serde_arrow/src/arrow2_impl/api.rs | 38 +++++---------- serde_arrow/src/arrow_impl/api.rs | 30 +++--------- serde_arrow/src/internal/deserializer.rs | 32 ++++++++++++- .../src/internal/utils/array_view_ext.rs | 48 +++++++++++++++++++ serde_arrow/src/internal/utils/mod.rs | 1 + 5 files changed, 98 insertions(+), 51 deletions(-) create mode 100644 serde_arrow/src/internal/utils/array_view_ext.rs diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index 10a75f1f..6127d956 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -10,13 +10,9 @@ use crate::{ internal::{ array_builder::ArrayBuilder, arrow::Field, - deserialization::{ - array_deserializer::ArrayDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - }, deserializer::Deserializer, error::{fail, Result}, - schema::{get_strategy_from_metadata, SerdeArrowSchema}, + schema::SerdeArrowSchema, serializer::Serializer, }, }; @@ -153,14 +149,7 @@ impl<'de> Deserializer<'de> { where A: AsRef, { - let fields = fields - .iter() - .map(Field::try_from) - .collect::>>()?; - let arrays = arrays - .iter() - .map(|array| array.as_ref()) - .collect::>(); + use crate::internal::arrow::ArrayView; if fields.len() != arrays.len() { fail!( @@ -169,21 +158,16 @@ impl<'de> Deserializer<'de> { arrays.len() ); } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - - let mut deserializers = Vec::new(); - for (field, array) in std::iter::zip(fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } - let strategy = get_strategy_from_metadata(&field.metadata)?; - let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?; - deserializers.push((field.name.clone(), deserializer)); - } - let deserializer = OuterSequenceDeserializer::new(deserializers, len); - let deserializer = Deserializer(deserializer); + let fields = fields + .iter() + .map(Field::try_from) + .collect::>>()?; + let views = arrays + .iter() + .map(|array| ArrayView::try_from(array.as_ref())) + .collect::>>()?; - Ok(deserializer) + Deserializer::new(&fields, views) } } diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index 852a0198..92644fa7 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -10,13 +10,9 @@ use crate::{ }, internal::{ array_builder::ArrayBuilder, - deserialization::{ - array_deserializer::ArrayDeserializer, - outer_sequence_deserializer::OuterSequenceDeserializer, - }, deserializer::Deserializer, error::{fail, Result}, - schema::{get_strategy_from_metadata, SerdeArrowSchema}, + schema::SerdeArrowSchema, serializer::Serializer, }, }; @@ -241,11 +237,7 @@ impl<'de> Deserializer<'de> { where A: AsRef, { - let fields = fields_from_field_refs(fields)?; - let arrays = arrays - .iter() - .map(|array| array.as_ref()) - .collect::>(); + use crate::internal::arrow::ArrayView; if fields.len() != arrays.len() { fail!( @@ -254,23 +246,15 @@ impl<'de> Deserializer<'de> { arrays.len() ); } - let len = arrays.first().map(|array| array.len()).unwrap_or_default(); - let mut deserializers = Vec::new(); - for (field, array) in std::iter::zip(&fields, arrays) { - if array.len() != len { - fail!("arrays of different lengths are not supported"); - } + let fields = fields_from_field_refs(fields)?; - let strategy = get_strategy_from_metadata(&field.metadata)?; - let deserializer = ArrayDeserializer::new(strategy.as_ref(), array.try_into()?)?; - deserializers.push((field.name.clone(), deserializer)); + let mut views = Vec::new(); + for array in arrays { + views.push(ArrayView::try_from(array.as_ref())?); } - let deserializer = OuterSequenceDeserializer::new(deserializers, len); - let deserializer = Deserializer(deserializer); - - Ok(deserializer) + Deserializer::new(&fields, views) } /// Construct a new deserializer from a record batch (*requires one of the diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index b716c286..723615b3 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -1,8 +1,14 @@ use serde::de::Visitor; use crate::internal::{ - deserialization::outer_sequence_deserializer::OuterSequenceDeserializer, + arrow::{ArrayView, Field}, + deserialization::{ + array_deserializer::ArrayDeserializer, + outer_sequence_deserializer::OuterSequenceDeserializer, + }, error::{fail, Error, Result}, + schema::get_strategy_from_metadata, + utils::array_view_ext::ArrayViewExt, }; /// A structure to deserialize Arrow arrays into Rust objects @@ -14,6 +20,30 @@ use crate::internal::{ #[cfg_attr(has_arrow2, doc = r"- [`Deserializer::from_arrow2`]")] pub struct Deserializer<'de>(pub(crate) OuterSequenceDeserializer<'de>); +impl<'de> Deserializer<'de> { + pub(crate) fn new(fields: &[Field], views: Vec>) -> Result { + let len = match views.first() { + Some(view) => view.len(), + None => 0, + }; + + let mut deserializers = Vec::new(); + for (field, view) in std::iter::zip(fields, views) { + if view.len() != len { + fail!("Cannot deserialize from arrays with different lengths"); + } + let strategy = get_strategy_from_metadata(&field.metadata)?; + let deserializer = ArrayDeserializer::new(strategy.as_ref(), view)?; + deserializers.push((field.name.clone(), deserializer)); + } + + let deserializer = OuterSequenceDeserializer::new(deserializers, len); + let deserializer = Deserializer(deserializer); + + Ok(deserializer) + } +} + impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { type Error = Error; diff --git a/serde_arrow/src/internal/utils/array_view_ext.rs b/serde_arrow/src/internal/utils/array_view_ext.rs new file mode 100644 index 00000000..df26ae13 --- /dev/null +++ b/serde_arrow/src/internal/utils/array_view_ext.rs @@ -0,0 +1,48 @@ +use crate::internal::arrow::ArrayView; + +pub trait ArrayViewExt { + fn len(&self) -> usize; +} + +impl<'a> ArrayViewExt for ArrayView<'a> { + fn len(&self) -> usize { + use ArrayView as V; + match self { + V::Null(view) => view.len, + V::Boolean(view) => view.len, + V::Int8(view) => view.values.len(), + V::Int16(view) => view.values.len(), + V::Int32(view) => view.values.len(), + V::Int64(view) => view.values.len(), + V::UInt8(view) => view.values.len(), + V::UInt16(view) => view.values.len(), + V::UInt32(view) => view.values.len(), + V::UInt64(view) => view.values.len(), + V::Float16(view) => view.values.len(), + V::Float32(view) => view.values.len(), + V::Float64(view) => view.values.len(), + V::Date32(view) => view.values.len(), + V::Date64(view) => view.values.len(), + V::Time32(view) => view.values.len(), + V::Time64(view) => view.values.len(), + V::Timestamp(view) => view.values.len(), + V::Duration(view) => view.values.len(), + V::Decimal128(view) => view.values.len(), + V::Utf8(view) => view.offsets.len().saturating_sub(1), + V::LargeUtf8(view) => view.offsets.len().saturating_sub(1), + V::Binary(view) => view.offsets.len().saturating_sub(1), + V::LargeBinary(view) => view.offsets.len().saturating_sub(1), + V::FixedSizeBinary(view) => match usize::try_from(view.n) { + Ok(n) if n > 0 => view.data.len() / n, + _ => 0, + }, + V::FixedSizeList(view) => view.len, + V::List(view) => view.offsets.len().saturating_sub(1), + V::LargeList(view) => view.offsets.len().saturating_sub(1), + V::DenseUnion(view) => view.types.len(), + V::Map(view) => view.offsets.len().saturating_sub(1), + V::Struct(view) => view.len, + V::Dictionary(view) => view.indices.len(), + } + } +} diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 36b3a8cd..cee187fe 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -1,4 +1,5 @@ pub mod array_ext; +pub mod array_view_ext; pub mod decimal; pub mod dsl; pub mod value; From ed138a25aa96b6e90e28953bc41ed268ff95a819 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 22:22:41 +0200 Subject: [PATCH 112/178] Test that bools serialized from ints are true for any non-zero value --- .../src/test_with_arrow/impls/bool8.rs | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/serde_arrow/src/test_with_arrow/impls/bool8.rs b/serde_arrow/src/test_with_arrow/impls/bool8.rs index d055e67d..11f97247 100644 --- a/serde_arrow/src/test_with_arrow/impls/bool8.rs +++ b/serde_arrow/src/test_with_arrow/impls/bool8.rs @@ -1,6 +1,11 @@ +use serde::Deserialize; use serde_json::json; -use crate::internal::utils::Item; +use crate::internal::{ + arrow::{ArrayView, DataType, Field, PrimitiveArrayView}, + deserializer::Deserializer, + utils::{Item, Items}, +}; use super::utils::Test; @@ -23,3 +28,26 @@ fn nullable_bool_as_int8() { .deserialize(items) .check_nulls(&[&[false, true, false]]); } + +// from the bool8 specs: false is denoted by the value 0. true can be specified using any non-zero +// value. Preferably 1. +#[test] +fn deserialize_from_not_01_ints() -> crate::internal::error::PanicOnError<()> { + let field = Field { + name: String::from("item"), + data_type: DataType::Int8, + nullable: false, + metadata: Default::default(), + }; + let view = ArrayView::Int8(PrimitiveArrayView { + validity: None, + values: &[0, -1, 2, 3, -31, 100, 0, 0], + }); + let deserializer = Deserializer::new(&[field], vec![view])?; + + let Items(actual) = Items::>::deserialize(deserializer)?; + let expected = vec![false, true, true, true, true, true, false, false]; + assert_eq!(actual, expected); + + Ok(()) +} From 079184f319f520367412414f5da1d655dfba3658 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 22:33:52 +0200 Subject: [PATCH 113/178] Add Bool8 extension type --- .../internal/schema/extensions/bool8_field.rs | 76 +++++++++++++++++++ .../src/internal/schema/extensions/mod.rs | 2 + serde_arrow/src/lib.rs | 2 +- .../src/test_with_arrow/impls/bool8.rs | 28 ++++++- 4 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 serde_arrow/src/internal/schema/extensions/bool8_field.rs diff --git a/serde_arrow/src/internal/schema/extensions/bool8_field.rs b/serde_arrow/src/internal/schema/extensions/bool8_field.rs new file mode 100644 index 00000000..1bd8dd60 --- /dev/null +++ b/serde_arrow/src/internal/schema/extensions/bool8_field.rs @@ -0,0 +1,76 @@ +use std::collections::HashMap; + +use crate::internal::{ + arrow::{DataType, Field}, + error::{Error, Result}, + schema::PrettyField, +}; + +/// A helper to construct fields with the Bool8 extension type +pub struct Bool8Field { + name: String, + nullable: bool, +} + +impl Bool8Field { + /// Construct a new `Bool8Field`` + pub fn new(name: &str) -> Self { + Self { + name: name.into(), + nullable: false, + } + } + + /// Set the nullability of the field + pub fn nullable(mut self, value: bool) -> Self { + self.nullable = value; + self + } +} + +impl TryFrom<&Bool8Field> for Field { + type Error = Error; + + fn try_from(value: &Bool8Field) -> Result { + let mut metadata = HashMap::new(); + metadata.insert("ARROW:extension:name".into(), "arrow.bool8".into()); + metadata.insert("ARROW:extension:metadata".into(), String::new()); + + Ok(Field { + name: value.name.to_owned(), + nullable: value.nullable, + data_type: DataType::Int8, + metadata, + }) + } +} + +impl serde::ser::Serialize for Bool8Field { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error; + let field = Field::try_from(self).map_err(S::Error::custom)?; + PrettyField(&field).serialize(serializer) + } +} + +#[test] +fn bool8_repr() -> crate::internal::error::PanicOnError<()> { + use serde_json::json; + + let field = Bool8Field::new("hello"); + + let field = Field::try_from(&field)?; + let actual = serde_json::to_value(&PrettyField(&field))?; + + let expected = json!({ + "name": "hello", + "data_type": "I8", + "metadata": { + "ARROW:extension:name": "arrow.bool8", + "ARROW:extension:metadata": "", + }, + }); + + assert_eq!(actual, expected); + Ok(()) +} diff --git a/serde_arrow/src/internal/schema/extensions/mod.rs b/serde_arrow/src/internal/schema/extensions/mod.rs index abf2db9d..fa879d7c 100644 --- a/serde_arrow/src/internal/schema/extensions/mod.rs +++ b/serde_arrow/src/internal/schema/extensions/mod.rs @@ -1,6 +1,8 @@ +mod bool8_field; mod fixed_shape_tensor_field; mod utils; mod variable_shape_tensor_field; +pub use bool8_field::Bool8Field; pub use fixed_shape_tensor_field::FixedShapeTensorField; pub use variable_shape_tensor_field::VariableShapeTensorField; diff --git a/serde_arrow/src/lib.rs b/serde_arrow/src/lib.rs index b3a85ce5..1ce4d392 100644 --- a/serde_arrow/src/lib.rs +++ b/serde_arrow/src/lib.rs @@ -391,7 +391,7 @@ pub mod schema { /// [ext-docs]: https://arrow.apache.org/docs/format/CanonicalExtensions.html pub mod ext { pub use crate::internal::schema::extensions::{ - FixedShapeTensorField, VariableShapeTensorField, + Bool8Field, FixedShapeTensorField, VariableShapeTensorField, }; } } diff --git a/serde_arrow/src/test_with_arrow/impls/bool8.rs b/serde_arrow/src/test_with_arrow/impls/bool8.rs index 11f97247..14fb52c2 100644 --- a/serde_arrow/src/test_with_arrow/impls/bool8.rs +++ b/serde_arrow/src/test_with_arrow/impls/bool8.rs @@ -4,6 +4,7 @@ use serde_json::json; use crate::internal::{ arrow::{ArrayView, DataType, Field, PrimitiveArrayView}, deserializer::Deserializer, + schema::{extensions::Bool8Field, TracingOptions}, utils::{Item, Items}, }; @@ -13,7 +14,7 @@ use super::utils::Test; fn bool_as_int8() { let items = &[Item(true), Item(false)]; Test::new() - .with_schema(json!([{"name": "item", "data_type": "U8"}])) + .with_schema(json!([{"name": "item", "data_type": "I8"}])) .serialize(items) .deserialize(items) .check_nulls(&[&[false, false]]); @@ -23,7 +24,7 @@ fn bool_as_int8() { fn nullable_bool_as_int8() { let items = &[Item(Some(true)), Item(None), Item(Some(false))]; Test::new() - .with_schema(json!([{"name": "item", "data_type": "U8", "nullable": true}])) + .with_schema(json!([{"name": "item", "data_type": "I8", "nullable": true}])) .serialize(items) .deserialize(items) .check_nulls(&[&[false, true, false]]); @@ -51,3 +52,26 @@ fn deserialize_from_not_01_ints() -> crate::internal::error::PanicOnError<()> { Ok(()) } + +#[test] +fn overwrites() -> crate::internal::error::PanicOnError<()> { + let tracing_options = TracingOptions::new().overwrite("item", Bool8Field::new("item"))?; + + let items = &[Item(true), Item(false)]; + Test::new() + .with_schema(json!([{ + "name": "item", + "data_type": "I8", + "metadata": { + "ARROW:extension:name": "arrow.bool8", + "ARROW:extension:metadata": "", + }, + }])) + .trace_schema_from_samples(&items, tracing_options.clone()) + .trace_schema_from_type::>(tracing_options) + .serialize(items) + .deserialize(items) + .check_nulls(&[&[false, false]]); + + Ok(()) +} From 949e5bc5b06b193fef21950662b9e4da7aff2ddb Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 7 Aug 2024 22:39:51 +0200 Subject: [PATCH 114/178] Fix cross references --- serde_arrow/src/internal/schema/mod.rs | 33 +++++++++++--------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 78ebba3c..8ff90182 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -135,22 +135,20 @@ pub trait SchemaLike: Sized + Sealed { /// fn from_value(value: &T) -> Result; - /// Determine the schema from the given record type. See [`TracingOptions`] - /// for customization options. + /// Determine the schema from the given record type. See [`TracingOptions`] for customization + /// options. /// - /// This approach requires the type `T` to implement - /// [`Deserialize`][serde::Deserialize]. As only type information is used, - /// it is not possible to detect data dependent properties. Examples of - /// unsupported features: + /// This approach requires the type `T` to implement [`Deserialize`][::serde::Deserialize]. As + /// only type information is used, it is not possible to detect data dependent properties. + /// Examples of unsupported features: /// /// - auto detection of date time strings /// - non self-describing types such as `serde_json::Value` /// - flattened structure (`#[serde(flatten)]`) - /// - types that require specific data to be deserialized, such as the - /// `DateTime` type of `chrono` or the `Uuid` type of the `uuid` package + /// - types that require specific data to be deserialized, such as the `DateTime` type of + /// `chrono` or the `Uuid` type of the `uuid` package /// - /// Consider using [`from_samples`][SchemaLike::from_samples] in these - /// cases. + /// Consider using [`from_samples`][SchemaLike::from_samples] in these cases. /// /// ```rust /// # #[cfg(has_arrow)] @@ -199,20 +197,17 @@ pub trait SchemaLike: Sized + Sealed { /// ``` fn from_type<'de, T: Deserialize<'de> + ?Sized>(options: TracingOptions) -> Result; - /// Determine the schema from samples. See [`TracingOptions`] for - /// customization options. + /// Determine the schema from samples. See [`TracingOptions`] for customization options. /// - /// This approach requires the type `T` to implement - /// [`Serialize`][serde::Serialize] and the samples to include all relevant - /// values. It uses only the information encoded in the samples to generate - /// the schema. Therefore, the following requirements must be met: + /// This approach requires the type `T` to implement [`Serialize`][::serde::Serialize] and the + /// samples to include all relevant values. It uses only the information encoded in the samples + /// to generate the schema. Therefore, the following requirements must be met: /// /// - at least one `Some` value for `Option<..>` fields /// - all variants of enum fields /// - at least one element for sequence fields (e.g., `Vec<..>`) - /// - at least one example for map types (e.g., `HashMap<.., ..>`). All - /// possible keys must be given, if [`options.map_as_struct == - /// true`][TracingOptions::map_as_struct]). + /// - at least one example for map types (e.g., `HashMap<.., ..>`). All possible keys must be + /// given, if [`options.map_as_struct == true`][TracingOptions::map_as_struct]). /// /// ```rust /// # #[cfg(has_arrow)] From 580c539bcb63077505aacdcfb54ab8766bd5202a Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 09:58:52 +0200 Subject: [PATCH 115/178] Add doc example, update extension type docs --- serde_arrow/src/arrow_impl/type_support.rs | 32 ++++++++++------- .../internal/schema/extensions/bool8_field.rs | 34 +++++++++++++++++-- .../extensions/fixed_shape_tensor_field.rs | 4 +-- .../extensions/variable_shape_tensor_field.rs | 3 +- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/serde_arrow/src/arrow_impl/type_support.rs b/serde_arrow/src/arrow_impl/type_support.rs index 7f34f4c3..8330bb80 100644 --- a/serde_arrow/src/arrow_impl/type_support.rs +++ b/serde_arrow/src/arrow_impl/type_support.rs @@ -6,7 +6,7 @@ use crate::_impl::arrow::{ use crate::internal::{ arrow::Field, error::{Error, Result}, - schema::extensions::FixedShapeTensorField, + schema::extensions::{Bool8Field, FixedShapeTensorField, VariableShapeTensorField}, }; impl From for Error { @@ -15,22 +15,30 @@ impl From for Error { } } -impl TryFrom<&FixedShapeTensorField> for ArrowField { - type Error = Error; +macro_rules! impl_try_from_ext_type { + ($ty:ty) => { + impl TryFrom<&$ty> for ArrowField { + type Error = Error; - fn try_from(value: &FixedShapeTensorField) -> Result { - Self::try_from(&Field::try_from(value)?) - } -} + fn try_from(value: &$ty) -> Result { + Self::try_from(&Field::try_from(value)?) + } + } -impl TryFrom for ArrowField { - type Error = Error; + impl TryFrom<$ty> for ArrowField { + type Error = Error; - fn try_from(value: FixedShapeTensorField) -> Result { - Self::try_from(&value) - } + fn try_from(value: $ty) -> Result { + Self::try_from(&value) + } + } + }; } +impl_try_from_ext_type!(Bool8Field); +impl_try_from_ext_type!(FixedShapeTensorField); +impl_try_from_ext_type!(VariableShapeTensorField); + pub fn fields_from_field_refs(fields: &[FieldRef]) -> Result> { fields .iter() diff --git a/serde_arrow/src/internal/schema/extensions/bool8_field.rs b/serde_arrow/src/internal/schema/extensions/bool8_field.rs index 1bd8dd60..56ba81af 100644 --- a/serde_arrow/src/internal/schema/extensions/bool8_field.rs +++ b/serde_arrow/src/internal/schema/extensions/bool8_field.rs @@ -6,14 +6,44 @@ use crate::internal::{ schema::PrettyField, }; -/// A helper to construct fields with the Bool8 extension type +/// A helper to construct new `Bool8` fields (`arrow.bool8`) +/// +/// This extension type can be used with `overwrites` in schema tracing: +/// +/// ```rust +/// # use serde_json::json; +/// # use serde_arrow::{Result, schema::{SerdeArrowSchema, SchemaLike, TracingOptions, ext::Bool8Field}}; +/// # use serde::Deserialize; +/// # fn main() -> Result<()> { +/// ##[derive(Deserialize)] +/// struct Record { +/// int_field: i32, +/// nested: Nested, +/// } +/// +/// ##[derive(Deserialize)] +/// struct Nested { +/// bool_field: bool, +/// } +/// +/// let tracing_options = TracingOptions::default() +/// .overwrite("nested.bool_field", Bool8Field::new("bool_field"))?; +/// +/// let schema = SerdeArrowSchema::from_type::(tracing_options)?; +/// # std::mem::drop(schema); +/// # Ok(()) +/// # } +/// ``` +/// +/// It can also be converted to a `arrow` `Field` for manual schema manipulation. +/// pub struct Bool8Field { name: String, nullable: bool, } impl Bool8Field { - /// Construct a new `Bool8Field`` + /// Construct a new non-nullable `Bool8Field` pub fn new(name: &str) -> Self { Self { name: name.into(), diff --git a/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs index c5239a55..b7f0ca22 100644 --- a/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs +++ b/serde_arrow/src/internal/schema/extensions/fixed_shape_tensor_field.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr}; -/// Easily construct a field for tensors with fixed shape +/// Easily construct a fixed shape tensor fields (`arrow.fixed_shape_tensor`) /// /// See the [arrow docs][fixed-shape-tensor-docs] for details on the different /// fields. @@ -51,7 +51,7 @@ pub struct FixedShapeTensorField { } impl FixedShapeTensorField { - /// Construct a new instance + /// Construct a new non-nullable `FixedShapeTensorField` /// /// Note the element parameter must serialize into a valid field definition /// with the the name `"element"`. The field type can be any valid Arrow diff --git a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs index 6f7849ab..1dfa3c83 100644 --- a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs +++ b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs @@ -8,7 +8,7 @@ use crate::internal::{ use super::utils::{check_dim_names, check_permutation, write_list, DebugRepr}; -/// Helper to build fields for tensors with variable shape +/// Helper to build variable shape tensor fields (`arrow.variable_shape_tensor`) /// /// See the [arrow docs][variable-shape-tensor-field-docs] for details on the /// different fields. @@ -26,6 +26,7 @@ pub struct VariableShapeTensorField { } impl VariableShapeTensorField { + /// Create a new non-nullable `VariableShapeTensorField` pub fn new(name: &str, element: impl serde::ser::Serialize, ndim: usize) -> Result { let element = transmute_field(element)?; if element.name != "element" { From e89b0b4185b540e3f263ffe5a0c5a1da5b2569db Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 10:02:30 +0200 Subject: [PATCH 116/178] Fix clippy --- serde_arrow/src/arrow2_impl/array.rs | 2 +- serde_arrow/src/arrow2_impl/schema.rs | 4 +- .../extensions/variable_shape_tensor_field.rs | 43 ++++++++++--------- serde_arrow/src/internal/schema/tracer.rs | 4 +- .../serialization/outer_sequence_builder.rs | 2 +- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 1150a739..7df8f857 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -149,7 +149,7 @@ impl TryFrom for ArrayRef { let child: ArrayRef = child.try_into()?; let field = field_from_array_and_meta(child.as_ref(), meta); - type_ids.push(type_id.try_into()?); + type_ids.push(type_id.into()); values.push(child); fields.push(field); } diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index c08e2cc0..08aa8930 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -193,7 +193,7 @@ impl TryFrom<&DataType> for ArrowDataType { if *scale < 0 { fail!("arrow2 does not support decimals with negative scale"); } - Ok(AT::Decimal((*precision).try_into()?, (*scale).try_into()?)) + Ok(AT::Decimal((*precision).into(), (*scale).try_into()?)) } T::Binary => Ok(AT::Binary), T::LargeBinary => Ok(AT::LargeBinary), @@ -266,7 +266,7 @@ impl TryFrom<&DataType> for ArrowDataType { for (type_id, field) in in_fields { fields.push(AF::try_from(field)?); - type_ids.push((*type_id).try_into()?); + type_ids.push((*type_id).into()); } Ok(AT::Union(fields, Some(type_ids), (*mode).into())) } diff --git a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs index 1dfa3c83..3cc1936c 100644 --- a/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs +++ b/serde_arrow/src/internal/schema/extensions/variable_shape_tensor_field.rs @@ -135,27 +135,28 @@ impl TryFrom<&VariableShapeTensorField> for Field { ); metadata.insert("ARROW:extension:metadata".into(), value.get_ext_metadata()?); - let mut fields = Vec::new(); - fields.push(Field { - name: String::from("data"), - data_type: DataType::List(Box::new(value.element.clone())), - nullable: false, - metadata: HashMap::new(), - }); - fields.push(Field { - name: String::from("shape"), - data_type: DataType::FixedSizeList( - Box::new(Field { - name: String::from("element"), - data_type: DataType::Int32, - nullable: false, - metadata: HashMap::new(), - }), - value.ndim.try_into()?, - ), - nullable: false, - metadata: HashMap::new(), - }); + let fields = vec![ + Field { + name: String::from("data"), + data_type: DataType::List(Box::new(value.element.clone())), + nullable: false, + metadata: HashMap::new(), + }, + Field { + name: String::from("shape"), + data_type: DataType::FixedSizeList( + Box::new(Field { + name: String::from("element"), + data_type: DataType::Int32, + nullable: false, + metadata: HashMap::new(), + }), + value.ndim.try_into()?, + ), + nullable: false, + metadata: HashMap::new(), + }, + ]; Ok(Field { name: value.name.clone(), diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 861b386c..3f2a5e1a 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -20,7 +20,7 @@ const RECURSIVE_TYPE_WARNING: &str = fn default_dictionary_field(name: &str, nullable: bool) -> Field { Field { name: name.to_owned(), - nullable: nullable, + nullable, metadata: HashMap::new(), data_type: DataType::Dictionary( Box::new(DataType::UInt32), @@ -560,7 +560,7 @@ impl Tracer { (ty, ev) => fail!( "Cannot accept event {ev} for tracer of primitive type {ty}", ev = DataTypeDisplay(&ev), - ty = DataTypeDisplay(&ty), + ty = DataTypeDisplay(ty), ), }; tracer.item_type = item_type; diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index a76a6851..c9e837ae 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -114,7 +114,7 @@ impl OuterSequenceBuilder { build_builder(entry_field.as_ref())?, field.nullable, )?), - T::Struct(children) => A::Struct(build_struct(&children, field.nullable)?), + T::Struct(children) => A::Struct(build_struct(children, field.nullable)?), T::Dictionary(key, value, _) => { let key_field = Field { name: "key".to_string(), From 292482408bf3b295f84ed86fa9e00671a925ab1d Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 10:05:56 +0200 Subject: [PATCH 117/178] Rework crate-internal imports --- .../src/internal/schema/from_type/test_error_messages.rs | 4 ++-- serde_arrow/src/internal/schema/serde/serialize.rs | 4 ++-- .../internal/serialization/unknown_variant_builder.rs | 9 +++------ serde_arrow/src/internal/utils/value.rs | 2 +- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs index 5c9b0fef..b0e8c007 100644 --- a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs @@ -5,9 +5,9 @@ use std::collections::HashMap; use serde::Deserialize; use serde_json::json; -use crate::{ - internal::testing::assert_error, +use crate::internal::{ schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, + testing::assert_error, }; #[test] diff --git a/serde_arrow/src/internal/schema/serde/serialize.rs b/serde_arrow/src/internal/schema/serde/serialize.rs index f31cfeed..4d3937b6 100644 --- a/serde_arrow/src/internal/schema/serde/serialize.rs +++ b/serde_arrow/src/internal/schema/serde/serialize.rs @@ -4,8 +4,8 @@ use std::collections::HashMap; use serde::ser::{SerializeSeq, SerializeStruct}; -use crate::{ - internal::arrow::{DataType, Field}, +use crate::internal::{ + arrow::{DataType, Field}, schema::{SerdeArrowSchema, STRATEGY_KEY}, }; diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index b6b6c0f3..c1c4b4bb 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -1,11 +1,8 @@ use serde::Serialize; -use crate::{ - internal::{ - arrow::{Array, NullArray}, - error::fail, - }, - Result, +use crate::internal::{ + arrow::{Array, NullArray}, + error::{fail, Result}, }; use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; diff --git a/serde_arrow/src/internal/utils/value.rs b/serde_arrow/src/internal/utils/value.rs index 91bdb1e2..9f7eea00 100644 --- a/serde_arrow/src/internal/utils/value.rs +++ b/serde_arrow/src/internal/utils/value.rs @@ -1,7 +1,7 @@ //! Serialize values into a in-memory representation use serde::{de::DeserializeOwned, forward_to_deserialize_any, Serialize}; -use crate::{internal::error::fail, Error, Result}; +use crate::internal::error::{fail, Error, Result}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct Variant(u32, &'static str); From aa147bbc21aa3e8f7f24cb20d50dcb8bfafc8207 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 10:10:44 +0200 Subject: [PATCH 118/178] Add missed change --- .../src/internal/schema/from_samples/test_error_messages.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs b/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs index f044c098..8c7d67fe 100644 --- a/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs @@ -1,8 +1,8 @@ use serde::Serialize; -use crate::{ - internal::testing::assert_error, +use crate::internal::{ schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, + testing::assert_error, }; #[test] From 1eda9f84d15376cddb8a1dc84e81ac91fa6b9100 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 10:10:48 +0200 Subject: [PATCH 119/178] Update changelog --- Changes.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Changes.md b/Changes.md index 25f93094..1d0a4ba9 100644 --- a/Changes.md +++ b/Changes.md @@ -7,6 +7,8 @@ Refactor the underlying implementation to prepare for further development New features - Add `Binary` / `LargeBinary` support for `arrow2` +- Add support to serialize / deserialize `bool` from integer arrays +- Add a helper to construct `Bool8` arrays API changes From e84f901b304c04a4c339a4181f4cf138dac5a78d Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 14:59:28 +0200 Subject: [PATCH 120/178] Add test that checks for the inclusion of the field path --- serde_arrow/src/internal/error.rs | 9 +++++ serde_arrow/src/test/error.rs | 9 ----- serde_arrow/src/test/error_messages/mod.rs | 1 + .../src/test/error_messages/push_validity.rs | 37 +++++++++++++++++++ serde_arrow/src/test/mod.rs | 2 +- 5 files changed, 48 insertions(+), 10 deletions(-) delete mode 100644 serde_arrow/src/test/error.rs create mode 100644 serde_arrow/src/test/error_messages/mod.rs create mode 100644 serde_arrow/src/test/error_messages/push_validity.rs diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 610cd1ab..060fde71 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -211,3 +211,12 @@ impl From for PanicOnErrorError { panic!("{value}"); } } + +#[test] +fn error_can_be_converted_to_anyhow() { + fn func() -> anyhow::Result<()> { + Err(error!("dummy"))?; + Ok(()) + } + assert!(func().is_err()); +} diff --git a/serde_arrow/src/test/error.rs b/serde_arrow/src/test/error.rs deleted file mode 100644 index 8cc2d52a..00000000 --- a/serde_arrow/src/test/error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::internal::error::error; - -#[test] -fn error_can_be_converted_to_anyhow() { - fn _func() -> anyhow::Result<()> { - Err(error!("dummy"))?; - Ok(()) - } -} diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs new file mode 100644 index 00000000..e1c999c8 --- /dev/null +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -0,0 +1 @@ +mod push_validity; \ No newline at end of file diff --git a/serde_arrow/src/test/error_messages/push_validity.rs b/serde_arrow/src/test/error_messages/push_validity.rs new file mode 100644 index 00000000..f7ca8863 --- /dev/null +++ b/serde_arrow/src/test/error_messages/push_validity.rs @@ -0,0 +1,37 @@ +use serde::Serialize; +use serde_json::json; + +use crate::internal::{array_builder::ArrayBuilder, error::PanicOnError, schema::{SchemaLike, SerdeArrowSchema}, testing::assert_error}; + + +#[test] +fn push_validity_issue_202() -> PanicOnError<()> { + let schema = SerdeArrowSchema::from_value(&json!([ + { + "name": "nested", + "data_type": "Struct", + "children": [ + {"name": "field", "data_type": "U32"}, + ], + }, + ]))?; + + #[derive(Serialize)] + struct Record { + nested: Nested, + } + + #[derive(Serialize)] + struct Nested { + field: Option, + } + + let mut array_builder = ArrayBuilder::new(schema)?; + let res = array_builder.push(&Record { nested: Nested { field: Some(5) }}); + assert_eq!(res, Ok(())); + + let res = array_builder.push(&Record { nested: Nested { field: None }}); + assert_error(&res, "field: \"nested.field\""); + + Ok(()) +} \ No newline at end of file diff --git a/serde_arrow/src/test/mod.rs b/serde_arrow/src/test/mod.rs index 24690804..60e0d320 100644 --- a/serde_arrow/src/test/mod.rs +++ b/serde_arrow/src/test/mod.rs @@ -1,3 +1,3 @@ mod api_chrono; -mod error; +mod error_messages; mod schema_like; From 564c266b249ab8f29e987656dceff47a9ea71efd Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 14:59:37 +0200 Subject: [PATCH 121/178] Start to document the Error format --- Development.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Development.md b/Development.md index a9b6f275..9551a2f3 100644 --- a/Development.md +++ b/Development.md @@ -28,3 +28,7 @@ modules can can be run without installing further packages. 1. `python x.py add-arrow-version {VERSION}` 2. `python x.py precommit` + +## Error format + +- Include the path to the field where sensible From 0df5d56c646cd9ce6b0368361dbcf076e3124fdc Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Thu, 8 Aug 2024 22:19:35 +0200 Subject: [PATCH 122/178] Add machinery to annotate errors with extra information --- serde_arrow/src/internal/error.rs | 83 ++++++- .../serialization/simple_serializer.rs | 209 ++++++++++++++---- .../internal/serialization/utf8_builder.rs | 6 +- serde_arrow/src/test/error_messages/mod.rs | 2 +- .../src/test/error_messages/push_validity.rs | 26 ++- 5 files changed, 259 insertions(+), 67 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 060fde71..24b3db5f 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -1,5 +1,6 @@ use std::{ backtrace::{Backtrace, BacktraceStatus}, + collections::BTreeMap, convert::Infallible, }; @@ -21,6 +22,7 @@ pub type Result = std::result::Result; #[non_exhaustive] pub enum Error { Custom(CustomError), + Annotated(AnnotatedError), } impl Error { @@ -45,17 +47,52 @@ impl Error { } impl Error { + pub(crate) fn empty() -> Self { + Self::Custom(CustomError { + message: String::new(), + backtrace: Backtrace::disabled(), + cause: None, + }) + } + pub fn message(&self) -> &str { match self { Self::Custom(err) => &err.message, + Self::Annotated(err) => err.error.message(), } } pub fn backtrace(&self) -> &Backtrace { match self { Self::Custom(err) => &err.backtrace, + Self::Annotated(err) => &err.error.backtrace(), + } + } + + pub(crate) fn annotations(&self) -> Option<&BTreeMap> { + match self { + Self::Custom(_) => None, + Self::Annotated(err) => Some(&err.annotations), } } + + /// Ensure the error is annotated and return a mutable reference to the annotations + pub(crate) fn annotations_mut(&mut self) -> &mut BTreeMap { + if !matches!(self, Self::Annotated(_)) { + let mut this = Error::empty(); + std::mem::swap(self, &mut this); + + *self = Self::Annotated(AnnotatedError { + error: Box::new(this), + annotations: BTreeMap::new(), + }); + } + + let Self::Annotated(err) = self else { + unreachable!(); + }; + &mut err.annotations + } } pub struct CustomError { @@ -70,6 +107,17 @@ impl std::cmp::PartialEq for CustomError { } } +pub struct AnnotatedError { + pub(crate) error: Box, + pub(crate) annotations: BTreeMap, +} + +impl std::cmp::PartialEq for AnnotatedError { + fn eq(&self, other: &Self) -> bool { + self.error.eq(&other.error) && self.annotations == other.annotations + } +} + impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "<{self}>") @@ -78,14 +126,35 @@ impl std::fmt::Debug for Error { impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Error::Custom(e) => write!( - f, - "Error: {msg}\n{bt}", - msg = e.message, - bt = BacktraceDisplay(&e.backtrace), - ), + write!( + f, + "Error: {msg}{annotations}\n{bt}", + msg = self.message(), + annotations = AnnotationsDisplay(self.annotations()), + bt = BacktraceDisplay(self.backtrace()), + ) + } +} + +struct AnnotationsDisplay<'a>(Option<&'a BTreeMap>); + +impl<'a> std::fmt::Display for AnnotationsDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Some(annotations) = self.0 else { + return Ok(()); + }; + if annotations.is_empty() { + return Ok(()); + } + + write!(f, "(")?; + for (idx, (key, value)) in annotations.iter().enumerate() { + if idx != 0 { + write!(f, ", ")?; + } + write!(f, "{key}: {value:?}")?; } + write!(f, ")") } } diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 52438114..416ae157 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -25,6 +25,11 @@ use super::ArrayBuilder; pub trait SimpleSerializer: Sized { fn name(&self) -> &str; + // TODO: remove default + fn annotate_error(&self, err: Error) -> Error { + err + } + fn serialize_default(&mut self) -> Result<()> { fail!("serialize_default is not supported for {}", self.name()); } @@ -275,71 +280,105 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { type SerializeTupleVariant = Mut<'a, ArrayBuilder>; fn serialize_unit(self) -> Result<()> { - self.0.serialize_unit() + self.0 + .serialize_unit() + .map_err(|err| self.0.annotate_error(err)) } fn serialize_none(self) -> Result<()> { - self.0.serialize_none() + self.0 + .serialize_none() + .map_err(|err| self.0.annotate_error(err)) } fn serialize_some(self, value: &V) -> Result<()> { - self.0.serialize_some(value) + self.0 + .serialize_some(value) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_bool(self, v: bool) -> Result<()> { - self.0.serialize_bool(v) + self.0 + .serialize_bool(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_char(self, v: char) -> Result<()> { - self.0.serialize_char(v) + self.0 + .serialize_char(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u8(self, v: u8) -> Result<()> { - self.0.serialize_u8(v) + self.0 + .serialize_u8(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u16(self, v: u16) -> Result<()> { - self.0.serialize_u16(v) + self.0 + .serialize_u16(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u32(self, v: u32) -> Result<()> { - self.0.serialize_u32(v) + self.0 + .serialize_u32(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_u64(self, v: u64) -> Result<()> { - self.0.serialize_u64(v) + self.0 + .serialize_u64(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i8(self, v: i8) -> Result<()> { - self.0.serialize_i8(v) + self.0 + .serialize_i8(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i16(self, v: i16) -> Result<()> { - self.0.serialize_i16(v) + self.0 + .serialize_i16(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i32(self, v: i32) -> Result<()> { - self.0.serialize_i32(v) + self.0 + .serialize_i32(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_i64(self, v: i64) -> Result<()> { - self.0.serialize_i64(v) + self.0 + .serialize_i64(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_f32(self, v: f32) -> Result<()> { - self.0.serialize_f32(v) + self.0 + .serialize_f32(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_f64(self, v: f64) -> Result<()> { - self.0.serialize_f64(v) + self.0 + .serialize_f64(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_bytes(self, v: &[u8]) -> Result<()> { - self.0.serialize_bytes(v) + self.0 + .serialize_bytes(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_str(self, v: &str) -> Result<()> { - self.0.serialize_str(v) + self.0 + .serialize_str(v) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_newtype_struct( @@ -347,7 +386,9 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { name: &'static str, value: &V, ) -> Result<()> { - self.0.serialize_newtype_struct(name, value) + self.0 + .serialize_newtype_struct(name, value) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_newtype_variant( @@ -359,10 +400,13 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { ) -> Result<()> { self.0 .serialize_newtype_variant(name, variant_index, variant, value) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_unit_struct(self, name: &'static str) -> Result<()> { - self.0.serialize_unit_struct(name) + self.0 + .serialize_unit_struct(name) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_unit_variant( @@ -371,26 +415,36 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant_index: u32, variant: &'static str, ) -> Result<()> { - self.0.serialize_unit_variant(name, variant_index, variant) + self.0 + .serialize_unit_variant(name, variant_index, variant) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_map(self, len: Option) -> Result { - self.0.serialize_map_start(len)?; + self.0 + .serialize_map_start(len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } fn serialize_seq(self, len: Option) -> Result { - self.0.serialize_seq_start(len)?; + self.0 + .serialize_seq_start(len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } fn serialize_struct(self, name: &'static str, len: usize) -> Result { - self.0.serialize_struct_start(name, len)?; + self.0 + .serialize_struct_start(name, len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } fn serialize_tuple(self, len: usize) -> Result { - self.0.serialize_tuple_start(len)?; + self.0 + .serialize_tuple_start(len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } @@ -399,7 +453,9 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { name: &'static str, len: usize, ) -> Result { - self.0.serialize_tuple_struct_start(name, len)?; + self.0 + .serialize_tuple_struct_start(name, len) + .map_err(|err| self.0.annotate_error(err))?; Ok(Mut(&mut *self.0)) } @@ -410,10 +466,16 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant: &'static str, len: usize, ) -> Result { - let variant_builder = - self.0 - .serialize_struct_variant_start(name, variant_index, variant, len)?; - Ok(Mut(variant_builder)) + // cannot borrow self immutably, as the result will keep self.0 borrowed mutably + // TODO: figure out how to remove this hack + let annotations_error = self.0.annotate_error(Error::empty()); + match self + .0 + .serialize_struct_variant_start(name, variant_index, variant, len) + { + Ok(variant_builder) => Ok(Mut(variant_builder)), + Err(err) => Err(merge_annotations(err, annotations_error)), + } } fn serialize_tuple_variant( @@ -423,11 +485,32 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant: &'static str, len: usize, ) -> Result { - let variant_builder = - self.0 - .serialize_tuple_variant_start(name, variant_index, variant, len)?; - Ok(Mut(variant_builder)) + // cannot borrow self immutably, as the result will keep self.0 borrowed mutably + // TODO: figure out how to remove this hack + let annotations_error = self.0.annotate_error(Error::empty()); + match self + .0 + .serialize_tuple_variant_start(name, variant_index, variant, len) + { + Ok(variant_builder) => Ok(Mut(variant_builder)), + Err(err) => Err(merge_annotations(err, annotations_error)), + } + } +} + +fn merge_annotations(mut err: Error, mut annotations_err: Error) -> Error { + let extra_annotations = std::mem::take(annotations_err.annotations_mut()); + if extra_annotations.is_empty() { + return err; } + + let result_annotations = err.annotations_mut(); + for (key, value) in extra_annotations { + if !result_annotations.contains_key(&key) { + result_annotations.insert(key, value); + } + } + err } impl<'a, T: SimpleSerializer> SerializeMap for Mut<'a, T> { @@ -435,15 +518,21 @@ impl<'a, T: SimpleSerializer> SerializeMap for Mut<'a, T> { type Error = Error; fn serialize_key(&mut self, key: &V) -> Result<()> { - self.0.serialize_map_key(key) + self.0 + .serialize_map_key(key) + .map_err(|err| self.0.annotate_error(err)) } fn serialize_value(&mut self, value: &V) -> Result<()> { - self.0.serialize_map_value(value) + self.0 + .serialize_map_value(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_map_end() + self.0 + .serialize_map_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -452,11 +541,15 @@ impl<'a, T: SimpleSerializer> SerializeSeq for Mut<'a, T> { type Error = Error; fn serialize_element(&mut self, value: &V) -> Result<()> { - self.0.serialize_seq_element(value) + self.0 + .serialize_seq_element(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_seq_end() + self.0 + .serialize_seq_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -469,11 +562,15 @@ impl<'a, T: SimpleSerializer> SerializeStruct for Mut<'a, T> { key: &'static str, value: &V, ) -> Result<()> { - self.0.serialize_struct_field(key, value) + self.0 + .serialize_struct_field(key, value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_struct_end() + self.0 + .serialize_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -482,11 +579,15 @@ impl<'a, T: SimpleSerializer> SerializeTuple for Mut<'a, T> { type Error = Error; fn serialize_element(&mut self, value: &V) -> Result<()> { - self.0.serialize_tuple_element(value) + self.0 + .serialize_tuple_element(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_tuple_end() + self.0 + .serialize_tuple_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -495,11 +596,15 @@ impl<'a, T: SimpleSerializer> SerializeTupleStruct for Mut<'a, T> { type Error = Error; fn serialize_field(&mut self, value: &V) -> Result<()> { - self.0.serialize_tuple_struct_field(value) + self.0 + .serialize_tuple_struct_field(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_tuple_struct_end() + self.0 + .serialize_tuple_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -512,11 +617,15 @@ impl<'a, T: SimpleSerializer> SerializeStructVariant for Mut<'a, T> { key: &'static str, value: &V, ) -> Result<()> { - self.0.serialize_struct_field(key, value) + self.0 + .serialize_struct_field(key, value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_struct_end() + self.0 + .serialize_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } @@ -525,10 +634,14 @@ impl<'a, T: SimpleSerializer> SerializeTupleVariant for Mut<'a, T> { type Error = Error; fn serialize_field(&mut self, value: &V) -> Result<()> { - self.0.serialize_tuple_struct_field(value) + self.0 + .serialize_tuple_struct_field(value) + .map_err(|err| self.0.annotate_error(err)) } fn end(self) -> Result<()> { - self.0.serialize_tuple_struct_end() + self.0 + .serialize_tuple_struct_end() + .map_err(|err| self.0.annotate_error(err)) } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index b811b8aa..69f78efd 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,8 +1,10 @@ use crate::internal::{ arrow::{Array, BytesArray}, error::{fail, Result}, - utils::array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - utils::Offset, + utils::{ + array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, + Offset, + }, }; use super::simple_serializer::SimpleSerializer; diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index e1c999c8..7730e5aa 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1 +1 @@ -mod push_validity; \ No newline at end of file +mod push_validity; diff --git a/serde_arrow/src/test/error_messages/push_validity.rs b/serde_arrow/src/test/error_messages/push_validity.rs index f7ca8863..77097aeb 100644 --- a/serde_arrow/src/test/error_messages/push_validity.rs +++ b/serde_arrow/src/test/error_messages/push_validity.rs @@ -1,21 +1,25 @@ use serde::Serialize; use serde_json::json; -use crate::internal::{array_builder::ArrayBuilder, error::PanicOnError, schema::{SchemaLike, SerdeArrowSchema}, testing::assert_error}; - +use crate::internal::{ + array_builder::ArrayBuilder, + error::PanicOnError, + schema::{SchemaLike, SerdeArrowSchema}, + testing::assert_error, +}; #[test] fn push_validity_issue_202() -> PanicOnError<()> { let schema = SerdeArrowSchema::from_value(&json!([ { - "name": "nested", - "data_type": "Struct", + "name": "nested", + "data_type": "Struct", "children": [ {"name": "field", "data_type": "U32"}, ], }, ]))?; - + #[derive(Serialize)] struct Record { nested: Nested, @@ -27,11 +31,15 @@ fn push_validity_issue_202() -> PanicOnError<()> { } let mut array_builder = ArrayBuilder::new(schema)?; - let res = array_builder.push(&Record { nested: Nested { field: Some(5) }}); + let res = array_builder.push(&Record { + nested: Nested { field: Some(5) }, + }); assert_eq!(res, Ok(())); - let res = array_builder.push(&Record { nested: Nested { field: None }}); + let res = array_builder.push(&Record { + nested: Nested { field: None }, + }); assert_error(&res, "field: \"nested.field\""); - + Ok(()) -} \ No newline at end of file +} From fbc28f4a08843736bf83e43acbbce77cacc7c3f2 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 17:26:36 +0200 Subject: [PATCH 123/178] Reformat code --- serde_arrow/src/internal/serialization/simple_serializer.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 416ae157..bea84fb5 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -473,7 +473,7 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { .0 .serialize_struct_variant_start(name, variant_index, variant, len) { - Ok(variant_builder) => Ok(Mut(variant_builder)), + Ok(variant_builder) => Ok(Mut(variant_builder)), Err(err) => Err(merge_annotations(err, annotations_error)), } } @@ -504,8 +504,8 @@ fn merge_annotations(mut err: Error, mut annotations_err: Error) -> Error { return err; } - let result_annotations = err.annotations_mut(); - for (key, value) in extra_annotations { + let result_annotations = err.annotations_mut(); + for (key, value) in extra_annotations { if !result_annotations.contains_key(&key) { result_annotations.insert(key, value); } From c9209870d67dc2bb304fc437da6be84c346a3d71 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 17:29:44 +0200 Subject: [PATCH 124/178] Rename assert_error -> assert_error_contains --- .../from_samples/test_error_messages.rs | 12 ++++----- .../schema/from_type/test_error_messages.rs | 26 +++++++++---------- serde_arrow/src/internal/schema/test.rs | 4 +-- serde_arrow/src/internal/testing.rs | 2 +- .../src/test/error_messages/push_validity.rs | 4 +-- .../src/test_with_arrow/impls/chrono.rs | 12 ++++----- .../test_with_arrow/impls/fixed_size_list.rs | 4 +-- .../test_with_arrow/impls/issue_203_uuid.rs | 4 +-- .../impls/issue_90_type_tracing.rs | 4 +-- .../src/test_with_arrow/impls/union.rs | 4 +-- 10 files changed, 38 insertions(+), 38 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs b/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs index 8c7d67fe..f9411eba 100644 --- a/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs @@ -2,17 +2,17 @@ use serde::Serialize; use crate::internal::{ schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, - testing::assert_error, + testing::assert_error_contains, }; #[test] fn outer_struct() { let res = SerdeArrowSchema::from_samples(&[1_u32, 2_u32, 3_u32], TracingOptions::default()); - assert_error( + assert_error_contains( &res, "Only struct-like types are supported as root types in schema tracing.", ); - assert_error(&res, "Consider using the `Items` wrapper,"); + assert_error_contains(&res, "Consider using the `Items` wrapper,"); } /// See: https://github.com/chmp/serde_arrow/issues/97 @@ -31,8 +31,8 @@ fn outer_sequence_issue_97() { }; let res = SerdeArrowSchema::from_samples(&b, TracingOptions::default()); - assert_error(&res, "Cannot trace non-sequences with `from_samples`."); - assert_error(&res, "Consider wrapping the argument in an array."); + assert_error_contains(&res, "Cannot trace non-sequences with `from_samples`."); + assert_error_contains(&res, "Consider wrapping the argument in an array."); } #[test] @@ -44,5 +44,5 @@ fn enums_without_data() { } let res = SerdeArrowSchema::from_samples(&[E::A, E::B], TracingOptions::default()); - assert_error(&res, "by setting `enums_without_data_as_strings` to `true`"); + assert_error_contains(&res, "by setting `enums_without_data_as_strings` to `true`"); } diff --git a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs index b0e8c007..f2fc5ac6 100644 --- a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs @@ -7,17 +7,17 @@ use serde_json::json; use crate::internal::{ schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, - testing::assert_error, + testing::assert_error_contains, }; #[test] fn from_type_budget() { let res = SerdeArrowSchema::from_type::(TracingOptions::default().from_type_budget(0)); - assert_error( + assert_error_contains( &res, "Could not determine schema from the type after 0 iterations.", ); - assert_error( + assert_error_contains( &res, "Consider increasing the budget option or using `from_samples`.", ); @@ -26,11 +26,11 @@ fn from_type_budget() { #[test] fn non_self_describing_types() { let res = SerdeArrowSchema::from_type::(TracingOptions::default()); - assert_error( + assert_error_contains( &res, "Non self describing types cannot be traced with `from_type`.", ); - assert_error(&res, "Consider using `from_samples`."); + assert_error_contains(&res, "Consider using `from_samples`."); } #[test] @@ -38,18 +38,18 @@ fn map_as_struct() { let res = SerdeArrowSchema::from_type::>( TracingOptions::default().map_as_struct(true), ); - assert_error(&res, "Cannot trace maps as structs with `from_type`"); - assert_error(&res, "Consider using `from_samples`."); + assert_error_contains(&res, "Cannot trace maps as structs with `from_type`"); + assert_error_contains(&res, "Consider using `from_samples`."); } #[test] fn outer_struct() { let res = SerdeArrowSchema::from_type::(TracingOptions::default()); - assert_error( + assert_error_contains( &res, "Only struct-like types are supported as root types in schema tracing.", ); - assert_error(&res, "Consider using the `Item` wrapper,"); + assert_error_contains(&res, "Consider using the `Item` wrapper,"); } #[test] @@ -61,7 +61,7 @@ fn enums_without_data() { } let res = SerdeArrowSchema::from_type::(TracingOptions::default()); - assert_error(&res, "by setting `enums_without_data_as_strings` to `true`"); + assert_error_contains(&res, "by setting `enums_without_data_as_strings` to `true`"); } #[test] @@ -77,7 +77,7 @@ fn missing_overwrites() { .overwrite("b", json!({"name": "b", "data_type": "I64"})) .unwrap(), ); - assert_error(&res, "Overwritten fields could not be found:"); + assert_error_contains(&res, "Overwritten fields could not be found:"); } #[test] @@ -93,7 +93,7 @@ fn mismatched_overwrite_name() { .overwrite("a", json!({"name": "b", "data_type": "I64"})) .unwrap(), ); - assert_error(&res, "Invalid name for overwritten field"); + assert_error_contains(&res, "Invalid name for overwritten field"); } #[test] @@ -109,7 +109,7 @@ fn overwrite_invalid_name() { .overwrite("a", json!({"name": "b", "data_type": "I64"})) .unwrap(), ); - assert_error( + assert_error_contains( &res, "Invalid name for overwritten field \"a\": found \"b\", expected \"a\"", ); diff --git a/serde_arrow/src/internal/schema/test.rs b/serde_arrow/src/internal/schema/test.rs index 2d651543..eaf3170e 100644 --- a/serde_arrow/src/internal/schema/test.rs +++ b/serde_arrow/src/internal/schema/test.rs @@ -6,7 +6,7 @@ use crate::internal::{ arrow::{DataType, Field, TimeUnit}, error::PanicOnError, schema::{SchemaLike, SerdeArrowSchema, Strategy, STRATEGY_KEY}, - testing::{assert_error, hash_map}, + testing::{assert_error_contains, hash_map}, }; fn type_from_str(s: &str) -> DataType { @@ -449,7 +449,7 @@ fn test_invalid_metadata() { }, ])); - assert_error(&res, "Duplicate strategy"); + assert_error_contains(&res, "Duplicate strategy"); } #[test] diff --git a/serde_arrow/src/internal/testing.rs b/serde_arrow/src/internal/testing.rs index b0140ab7..06179ec7 100644 --- a/serde_arrow/src/internal/testing.rs +++ b/serde_arrow/src/internal/testing.rs @@ -1,5 +1,5 @@ //! Support for tests -pub fn assert_error(actual: &Result, expected: &str) { +pub fn assert_error_contains(actual: &Result, expected: &str) { let Err(actual) = actual else { panic!("expected an error, but no error was raised"); }; diff --git a/serde_arrow/src/test/error_messages/push_validity.rs b/serde_arrow/src/test/error_messages/push_validity.rs index 77097aeb..5af0584d 100644 --- a/serde_arrow/src/test/error_messages/push_validity.rs +++ b/serde_arrow/src/test/error_messages/push_validity.rs @@ -5,7 +5,7 @@ use crate::internal::{ array_builder::ArrayBuilder, error::PanicOnError, schema::{SchemaLike, SerdeArrowSchema}, - testing::assert_error, + testing::assert_error_contains, }; #[test] @@ -39,7 +39,7 @@ fn push_validity_issue_202() -> PanicOnError<()> { let res = array_builder.push(&Record { nested: Nested { field: None }, }); - assert_error(&res, "field: \"nested.field\""); + assert_error_contains(&res, "field: \"nested.field\""); Ok(()) } diff --git a/serde_arrow/src/test_with_arrow/impls/chrono.rs b/serde_arrow/src/test_with_arrow/impls/chrono.rs index dcfe1334..082b422a 100644 --- a/serde_arrow/src/test_with_arrow/impls/chrono.rs +++ b/serde_arrow/src/test_with_arrow/impls/chrono.rs @@ -1,6 +1,6 @@ use super::utils::Test; use crate::{ - internal::testing::assert_error, + internal::testing::assert_error_contains, schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, utils::Item, }; @@ -12,7 +12,7 @@ use serde_json::json; #[test] fn trace_from_type_does_not_work() { let res = SerdeArrowSchema::from_type::>>(TracingOptions::default()); - assert_error(&res, "premature end of input"); + assert_error_contains(&res, "premature end of input"); } #[test] @@ -300,14 +300,14 @@ fn time64_type_invalid_units() { // Note: the arrow docs state: that the time unit "[m]ust be either // microseconds or nanoseconds." - assert_error( + assert_error_contains( &SerdeArrowSchema::from_value(&json!([{ "name": "item", "data_type": "Time64(Millisecond)", }])), "Error: Time64 field must have Microsecond or Nanosecond unit", ); - assert_error( + assert_error_contains( &SerdeArrowSchema::from_value(&json!([{ "name": "item", "data_type": "Time64(Second)", @@ -315,14 +315,14 @@ fn time64_type_invalid_units() { "Error: Time64 field must have Microsecond or Nanosecond unit", ); - assert_error( + assert_error_contains( &SerdeArrowSchema::from_value(&json!([{ "name": "item", "data_type": "Time32(Microsecond)", }])), "Error: Time32 field must have Second or Millisecond unit", ); - assert_error( + assert_error_contains( &SerdeArrowSchema::from_value(&json!([{ "name": "item", "data_type": "Time32(Nanosecond)", diff --git a/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs b/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs index 47e4bc86..a714e8ee 100644 --- a/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs +++ b/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs @@ -3,7 +3,7 @@ use serde_json::json; use super::utils::Test; use crate::_impl::arrow::datatypes::FieldRef; -use crate::internal::testing::assert_error; +use crate::internal::testing::assert_error_contains; use crate::internal::utils::Item; use crate::schema::SchemaLike; @@ -72,7 +72,7 @@ fn incorrect_number_of_elements() { .unwrap(); let res = crate::to_record_batch(&fields, &items); - assert_error(&res, "Invalid number of elements for FixedSizedList(2)."); + assert_error_contains(&res, "Invalid number of elements for FixedSizedList(2)."); } #[test] diff --git a/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs b/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs index ed4d4a81..945f0801 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_203_uuid.rs @@ -2,7 +2,7 @@ use serde_json::json; use uuid::Uuid; use crate::{ - internal::testing::assert_error, + internal::testing::assert_error_contains, schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, utils::Item, }; @@ -30,5 +30,5 @@ fn example_as_list() { #[test] fn trace_from_type_does_not_work() { let res = SerdeArrowSchema::from_type::>(TracingOptions::default()); - assert_error(&res, "UUID parsing failed"); + assert_error_contains(&res, "UUID parsing failed"); } diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index e9d2b88a..69c0bf6a 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -6,7 +6,7 @@ use serde_json::json; use crate::internal::{ arrow::{DataType, Field, UnionMode}, schema::{tracer::Tracer, transmute_field, Strategy, TracingOptions, STRATEGY_KEY}, - testing::assert_error, + testing::assert_error_contains, utils::Item, }; @@ -236,5 +236,5 @@ fn unsupported_recursive_types() { } let res = Tracer::from_type::(TracingOptions::default()); - assert_error(&res, "too deeply nested type detected"); + assert_error_contains(&res, "too deeply nested type detected"); } diff --git a/serde_arrow/src/test_with_arrow/impls/union.rs b/serde_arrow/src/test_with_arrow/impls/union.rs index 6b63c0cd..80894af2 100644 --- a/serde_arrow/src/test_with_arrow/impls/union.rs +++ b/serde_arrow/src/test_with_arrow/impls/union.rs @@ -361,7 +361,7 @@ fn enums_union() { fn missing_union_variants() { use crate::_impl::arrow::datatypes::FieldRef; - use crate::internal::testing::assert_error; + use crate::internal::testing::assert_error_contains; use crate::schema::TracingOptions; use serde::{Deserialize, Serialize}; @@ -377,7 +377,7 @@ fn missing_union_variants() { // NOTE: variant B was never encountered during tracing let res = crate::to_arrow(&fields, &Items(&[U::A, U::B, U::C])); - assert_error(&res, "Serialization failed: an unknown variant"); + assert_error_contains(&res, "Serialization failed: an unknown variant"); } #[test] From 04414f97d97ddec33d3e48acaea079c6dfdb4a53 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 17:29:48 +0200 Subject: [PATCH 125/178] Reformat code --- serde_arrow/src/internal/schema/tracer.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index f0e97ec7..37189b91 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -460,11 +460,7 @@ impl Tracer { Ok(()) } - pub fn ensure_utf8( - &mut self, - item_type: DataType, - strategy: Option, - ) -> Result<()> { + pub fn ensure_utf8(&mut self, item_type: DataType, strategy: Option) -> Result<()> { self.ensure_primitive_with_strategy(item_type, strategy) } @@ -502,7 +498,10 @@ impl Tracer { if matches!(item_type, DataType::Null) { dispatch_tracer!(this, tracer => { tracer.nullable = true }); } else { - fail!("Cannot merge {ty:?} with {item_type:?}", ty = this.get_type()); + fail!( + "Cannot merge {ty:?} with {item_type:?}", + ty = this.get_type() + ); } } Self::Primitive(tracer) => { @@ -527,7 +526,8 @@ fn coerce_primitive_type( options: &TracingOptions, ) -> Result<(DataType, bool, Option)> { use DataType::{ - Date64, LargeUtf8, Null, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, + Date64, Float32, Float64, Int16, Int32, Int64, Int8, LargeUtf8, Null, UInt16, UInt32, + UInt64, UInt8, }; let res = match (prev, curr) { From 78b96895e2c7e855dac2bf2091360b82501543da Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 17:30:22 +0200 Subject: [PATCH 126/178] Format code --- serde_arrow/src/internal/schema/tracer.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index f0e97ec7..37189b91 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -460,11 +460,7 @@ impl Tracer { Ok(()) } - pub fn ensure_utf8( - &mut self, - item_type: DataType, - strategy: Option, - ) -> Result<()> { + pub fn ensure_utf8(&mut self, item_type: DataType, strategy: Option) -> Result<()> { self.ensure_primitive_with_strategy(item_type, strategy) } @@ -502,7 +498,10 @@ impl Tracer { if matches!(item_type, DataType::Null) { dispatch_tracer!(this, tracer => { tracer.nullable = true }); } else { - fail!("Cannot merge {ty:?} with {item_type:?}", ty = this.get_type()); + fail!( + "Cannot merge {ty:?} with {item_type:?}", + ty = this.get_type() + ); } } Self::Primitive(tracer) => { @@ -527,7 +526,8 @@ fn coerce_primitive_type( options: &TracingOptions, ) -> Result<(DataType, bool, Option)> { use DataType::{ - Date64, LargeUtf8, Null, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, + Date64, Float32, Float64, Int16, Int32, Int64, Int8, LargeUtf8, Null, UInt16, UInt32, + UInt64, UInt8, }; let res = match (prev, curr) { From 50731d5052f70a437ab7ef54eea92ea5145633eb Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 20:01:46 +0200 Subject: [PATCH 127/178] Implement annotated error message for int / struct builders, include path --- serde_arrow/src/internal/error.rs | 16 +- .../internal/serialization/array_builder.rs | 131 ++++----- .../src/internal/serialization/int_builder.rs | 52 ++-- .../serialization/outer_sequence_builder.rs | 264 ++++++++++-------- .../serialization/simple_serializer.rs | 2 +- .../internal/serialization/struct_builder.rs | 17 +- .../src/test/error_messages/push_validity.rs | 21 +- 7 files changed, 262 insertions(+), 241 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 24b3db5f..432cf02d 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -69,6 +69,20 @@ impl Error { } } + /// Turn the error into an annotated error and call the provided function with a mutable + /// reference to the annotations + pub(crate) fn annotate_unannotated)>( + mut self, + func: F, + ) -> Self { + if matches!(self, Self::Annotated(_)) { + self + } else { + func(self.annotations_mut()); + self + } + } + pub(crate) fn annotations(&self) -> Option<&BTreeMap> { match self { Self::Custom(_) => None, @@ -147,7 +161,7 @@ impl<'a> std::fmt::Display for AnnotationsDisplay<'a> { return Ok(()); } - write!(f, "(")?; + write!(f, " (")?; for (idx, (key, value)) in annotations.iter().enumerate() { if idx != 0 { write!(f, ", ")?; diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 368261c0..eb7f0e5c 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -1,7 +1,10 @@ use half::f16; use serde::Serialize; -use crate::internal::{arrow::Array, error::Result}; +use crate::internal::{ + arrow::Array, + error::{Error, Result}, +}; use super::{ binary_builder::BinaryBuilder, bool_builder::BoolBuilder, date32_builder::Date32Builder, @@ -10,9 +13,10 @@ use super::{ fixed_size_binary_builder::FixedSizeBinaryBuilder, fixed_size_list_builder::FixedSizeListBuilder, float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder, - null_builder::NullBuilder, simple_serializer::SimpleSerializer, struct_builder::StructBuilder, - time_builder::TimeBuilder, union_builder::UnionBuilder, - unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, + null_builder::NullBuilder, simple_serializer::merge_annotations, + simple_serializer::SimpleSerializer, struct_builder::StructBuilder, time_builder::TimeBuilder, + union_builder::UnionBuilder, unknown_variant_builder::UnknownVariantBuilder, + utf8_builder::Utf8Builder, }; #[derive(Debug, Clone)] @@ -91,43 +95,6 @@ macro_rules! dispatch { } impl ArrayBuilder { - pub fn name(&self) -> &'static str { - match self { - Self::Null(_) => "Null", - Self::Bool(_) => "Bool", - Self::I8(_) => "I8", - Self::I16(_) => "I16", - Self::I32(_) => "I32", - Self::I64(_) => "I64", - Self::U8(_) => "U8", - Self::U16(_) => "U16", - Self::U32(_) => "U32", - Self::U64(_) => "U64", - Self::F16(_) => "F16", - Self::F32(_) => "F32", - Self::F64(_) => "F64", - Self::Date32(_) => "Date32", - Self::Date64(_) => "Date64", - Self::Time32(_) => "Time32", - Self::Time64(_) => "Time64", - Self::Duration(_) => "Duration", - Self::Decimal128(_) => "Decimal128", - Self::Utf8(_) => "Utf8", - Self::LargeUtf8(_) => "LargeUtf8", - Self::List(_) => "List", - Self::LargeList(_) => "LargeList", - Self::FixedSizedList(_) => "FixedSizeList", - Self::Binary(_) => "Binary", - Self::LargeBinary(_) => "LargeBinary", - Self::FixedSizeBinary(_) => "FixedSizeBinary", - Self::Struct(_) => "Struct", - Self::Map(_) => "Map", - Self::DictionaryUtf8(_) => "DictionaryUtf8", - Self::Union(_) => "Union", - Self::UnknownVariant(_) => "UnknownVariant", - } - } - pub fn is_nullable(&self) -> bool { dispatch!(self, Self(builder) => builder.is_nullable()) } @@ -185,162 +152,164 @@ impl SimpleSerializer for ArrayBuilder { } fn serialize_default(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_default()) + dispatch!(self, Self(builder) => builder.serialize_default().map_err(|err| builder.annotate_error(err))) } fn serialize_unit_struct(&mut self, name: &'static str) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_unit_struct(name)) + dispatch!(self, Self(builder) => builder.serialize_unit_struct(name).map_err(|err| builder.annotate_error(err))) } fn serialize_none(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_none()) + dispatch!(self, Self(builder) => builder.serialize_none().map_err(|err| builder.annotate_error(err))) } fn serialize_some(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_some(value)) + dispatch!(self, Self(builder) => builder.serialize_some(value).map_err(|err| builder.annotate_error(err))) } fn serialize_unit(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_unit()) + dispatch!(self, Self(builder) => builder.serialize_unit().map_err(|err| builder.annotate_error(err))) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_bool(v)) + dispatch!(self, Self(builder) => builder.serialize_bool(v).map_err(|err| builder.annotate_error(err))) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i8(v)) + dispatch!(self, Self(builder) => builder.serialize_i8(v).map_err(|err| builder.annotate_error(err))) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i16(v)) + dispatch!(self, Self(builder) => builder.serialize_i16(v).map_err(|err| builder.annotate_error(err))) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i32(v)) + dispatch!(self, Self(builder) => builder.serialize_i32(v).map_err(|err| builder.annotate_error(err))) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i64(v)) + dispatch!(self, Self(builder) => builder.serialize_i64(v).map_err(|err| builder.annotate_error(err))) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u8(v)) + dispatch!(self, Self(builder) => builder.serialize_u8(v).map_err(|err| builder.annotate_error(err))) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u16(v)) + dispatch!(self, Self(builder) => builder.serialize_u16(v).map_err(|err| builder.annotate_error(err))) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u32(v)) + dispatch!(self, Self(builder) => builder.serialize_u32(v).map_err(|err| builder.annotate_error(err))) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u64(v)) + dispatch!(self, Self(builder) => builder.serialize_u64(v).map_err(|err| builder.annotate_error(err))) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_f32(v)) + dispatch!(self, Self(builder) => builder.serialize_f32(v).map_err(|err| builder.annotate_error(err))) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_f64(v)) + dispatch!(self, Self(builder) => builder.serialize_f64(v).map_err(|err| builder.annotate_error(err))) } fn serialize_char(&mut self, v: char) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_char(v)) + dispatch!(self, Self(builder) => builder.serialize_char(v).map_err(|err| builder.annotate_error(err))) } fn serialize_str(&mut self, v: &str) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_str(v)) + dispatch!(self, Self(builder) => builder.serialize_str(v).map_err(|err| builder.annotate_error(err))) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_bytes(v)) + dispatch!(self, Self(builder) => builder.serialize_bytes(v).map_err(|err| builder.annotate_error(err))) } fn serialize_seq_start(&mut self, len: Option) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_seq_start(len)) + dispatch!(self, Self(builder) => builder.serialize_seq_start(len).map_err(|err| builder.annotate_error(err))) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_seq_element(value)) + dispatch!(self, Self(builder) => builder.serialize_seq_element(value).map_err(|err| builder.annotate_error(err))) } fn serialize_seq_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_seq_end()) + dispatch!(self, Self(builder) => builder.serialize_seq_end().map_err(|err| builder.annotate_error(err))) } fn serialize_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_struct_start(name, len)) + dispatch!(self, Self(builder) => builder.serialize_struct_start(name, len).map_err(|err| builder.annotate_error(err))) } fn serialize_struct_field(&mut self, key: &'static str, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_struct_field(key, value)) + dispatch!(self, Self(builder) => builder.serialize_struct_field(key, value).map_err(|err| builder.annotate_error(err))) } fn serialize_struct_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_struct_end()) + dispatch!(self, Self(builder) => builder.serialize_struct_end().map_err(|err| builder.annotate_error(err))) } fn serialize_map_start(&mut self, len: Option) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_start(len)) + dispatch!(self, Self(builder) => builder.serialize_map_start(len).map_err(|err| builder.annotate_error(err))) } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_key(key)) + dispatch!(self, Self(builder) => builder.serialize_map_key(key).map_err(|err| builder.annotate_error(err))) } fn serialize_map_value(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_value(value)) + dispatch!(self, Self(builder) => builder.serialize_map_value(value).map_err(|err| builder.annotate_error(err))) } fn serialize_map_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_end()) + dispatch!(self, Self(builder) => builder.serialize_map_end().map_err(|err| builder.annotate_error(err))) } fn serialize_tuple_start(&mut self, len: usize) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_start(len)) + dispatch!(self, Self(builder) => builder.serialize_tuple_start(len).map_err(|err| builder.annotate_error(err))) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_element(value)) + dispatch!(self, Self(builder) => builder.serialize_tuple_element(value).map_err(|err| builder.annotate_error(err))) } fn serialize_tuple_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_end()) + dispatch!(self, Self(builder) => builder.serialize_tuple_end().map_err(|err| builder.annotate_error(err))) } fn serialize_tuple_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_struct_start(name, len)) + dispatch!(self, Self(builder) => builder.serialize_tuple_struct_start(name, len).map_err(|err| builder.annotate_error(err))) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_struct_field(value)) + dispatch!(self, Self(builder) => builder.serialize_tuple_struct_field(value).map_err(|err| builder.annotate_error(err))) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_struct_end()) + dispatch!(self, Self(builder) => builder.serialize_tuple_struct_end().map_err(|err| builder.annotate_error(err))) } fn serialize_newtype_struct(&mut self, name: &'static str, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_newtype_struct(name, value)) + dispatch!(self, Self(builder) => builder.serialize_newtype_struct(name, value).map_err(|err| builder.annotate_error(err))) } fn serialize_newtype_variant(&mut self, name: &'static str, variant_index: u32, variant: &'static str, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_newtype_variant(name, variant_index, variant, value)) + dispatch!(self, Self(builder) => builder.serialize_newtype_variant(name, variant_index, variant, value).map_err(|err| builder.annotate_error(err))) } fn serialize_unit_variant(&mut self, name: &'static str, variant_index: u32, variant: &'static str) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_unit_variant(name, variant_index, variant)) + dispatch!(self, Self(builder) => builder.serialize_unit_variant(name, variant_index, variant).map_err(|err| builder.annotate_error(err))) } fn serialize_struct_variant_start<'this>(&'this mut self, name: &'static str, variant_index: u32, variant: &'static str, len: usize) -> Result<&'this mut ArrayBuilder> { - dispatch!(self, Self(builder) => builder.serialize_struct_variant_start(name, variant_index, variant, len)) + let annotations_err = dispatch!(self, Self(builder) => builder.annotate_error(Error::empty())); + dispatch!(self, Self(builder) => builder.serialize_struct_variant_start(name, variant_index, variant, len).map_err(|err| merge_annotations(err, annotations_err))) } fn serialize_tuple_variant_start<'this> (&'this mut self, name: &'static str, variant_index: u32, variant: &'static str, len: usize) -> Result<&'this mut ArrayBuilder> { - dispatch!(self, Self(builder) => builder.serialize_tuple_variant_start(name, variant_index, variant, len)) + let annotations_err = dispatch!(self, Self(builder) => builder.annotate_error(Error::empty())); + dispatch!(self, Self(builder) => builder.serialize_tuple_variant_start(name, variant_index, variant, len).map_err(|err| merge_annotations(err, annotations_err))) } } diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index f385667f..204a74ab 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -7,19 +7,29 @@ use crate::internal::{ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] -pub struct IntBuilder(PrimitiveArray); +pub struct IntBuilder { + path: String, + array: PrimitiveArray, +} impl IntBuilder { - pub fn new(is_nullable: bool) -> Self { - Self(new_primitive_array(is_nullable)) + pub fn new(path: String, is_nullable: bool) -> Self { + println!("new IntBuilder ({path}"); + Self { + path, + array: new_primitive_array(is_nullable), + } } pub fn take(&mut self) -> Self { - Self(self.0.take()) + Self { + path: self.path.clone(), + array: self.array.take(), + } } pub fn is_nullable(&self) -> bool { - self.0.validity.is_some() + self.array.validity.is_some() } } @@ -27,7 +37,7 @@ macro_rules! impl_into_array { ($ty:ty, $var:ident) => { impl IntBuilder<$ty> { pub fn into_array(self) -> Result { - Ok(Array::$var(self.0)) + Ok(Array::$var(self.array)) } } }; @@ -67,52 +77,58 @@ where "IntBuilder<()>" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_bool(&mut self, v: bool) -> Result<()> { let v: u8 = if v { 1 } else { 0 }; - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.0.push_scalar_value(I::try_from(v)?) + self.array.push_scalar_value(I::try_from(v)?) } fn serialize_char(&mut self, v: char) -> Result<()> { - self.0.push_scalar_value(I::try_from(u32::from(v))?) + self.array.push_scalar_value(I::try_from(u32::from(v))?) } } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index c9e837ae..e0c48a28 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -28,126 +28,11 @@ pub struct OuterSequenceBuilder(StructBuilder); impl OuterSequenceBuilder { pub fn new(schema: &SerdeArrowSchema) -> Result { - return Ok(Self(build_struct(&schema.fields, false)?)); - - fn build_struct(struct_fields: &[Field], nullable: bool) -> Result { - let mut fields = Vec::new(); - for field in struct_fields { - fields.push((build_builder(field)?, meta_from_field(field.clone())?)); - } - StructBuilder::new(fields, nullable) - } - - fn build_builder(field: &Field) -> Result { - use {ArrayBuilder as A, DataType as T}; - - let builder = match &field.data_type { - T::Null => match get_strategy_from_metadata(&field.metadata)? { - Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder), - _ => A::Null(NullBuilder::new()), - }, - T::Boolean => A::Bool(BoolBuilder::new(field.nullable)), - T::Int8 => A::I8(IntBuilder::new(field.nullable)), - T::Int16 => A::I16(IntBuilder::new(field.nullable)), - T::Int32 => A::I32(IntBuilder::new(field.nullable)), - T::Int64 => A::I64(IntBuilder::new(field.nullable)), - T::UInt8 => A::U8(IntBuilder::new(field.nullable)), - T::UInt16 => A::U16(IntBuilder::new(field.nullable)), - T::UInt32 => A::U32(IntBuilder::new(field.nullable)), - T::UInt64 => A::U64(IntBuilder::new(field.nullable)), - T::Float16 => A::F16(FloatBuilder::new(field.nullable)), - T::Float32 => A::F32(FloatBuilder::new(field.nullable)), - T::Float64 => A::F64(FloatBuilder::new(field.nullable)), - T::Date32 => A::Date32(Date32Builder::new(field.nullable)), - T::Date64 => A::Date64(Date64Builder::new( - None, - is_utc_strategy(get_strategy_from_metadata(&field.metadata)?.as_ref())?, - field.nullable, - )), - T::Timestamp(unit, tz) => A::Date64(Date64Builder::new( - Some((*unit, tz.clone())), - is_utc_tz(tz.as_deref())?, - field.nullable, - )), - T::Time32(unit) => { - if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { - fail!("Only timestamps with second or millisecond unit are supported"); - } - A::Time32(TimeBuilder::new(*unit, field.nullable)) - } - T::Time64(unit) => { - if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { - fail!("Only timestamps with nanosecond or microsecond unit are supported"); - } - A::Time64(TimeBuilder::new(*unit, field.nullable)) - } - T::Duration(unit) => A::Duration(DurationBuilder::new(*unit, field.nullable)), - T::Decimal128(precision, scale) => { - A::Decimal128(DecimalBuilder::new(*precision, *scale, field.nullable)) - } - T::Utf8 => A::Utf8(Utf8Builder::new(field.nullable)), - T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(field.nullable)), - T::List(child) => A::List(ListBuilder::new( - meta_from_field(*child.clone())?, - build_builder(child.as_ref())?, - field.nullable, - )?), - T::LargeList(child) => A::LargeList(ListBuilder::new( - meta_from_field(*child.clone())?, - build_builder(child.as_ref())?, - field.nullable, - )?), - T::FixedSizeList(child, n) => A::FixedSizedList(FixedSizeListBuilder::new( - meta_from_field(*child.clone())?, - build_builder(child.as_ref())?, - (*n).try_into()?, - field.nullable, - )), - T::Binary => A::Binary(BinaryBuilder::new(field.nullable)), - T::LargeBinary => A::LargeBinary(BinaryBuilder::new(field.nullable)), - T::FixedSizeBinary(n) => A::FixedSizeBinary(FixedSizeBinaryBuilder::new( - (*n).try_into()?, - field.nullable, - )), - T::Map(entry_field, _) => A::Map(MapBuilder::new( - meta_from_field(*entry_field.clone())?, - build_builder(entry_field.as_ref())?, - field.nullable, - )?), - T::Struct(children) => A::Struct(build_struct(children, field.nullable)?), - T::Dictionary(key, value, _) => { - let key_field = Field { - name: "key".to_string(), - data_type: *key.clone(), - nullable: field.nullable, - metadata: HashMap::new(), - }; - let value_field = Field { - name: "value".to_string(), - data_type: *value.clone(), - nullable: false, - metadata: HashMap::new(), - }; - - A::DictionaryUtf8(DictionaryUtf8Builder::new( - build_builder(&key_field)?, - build_builder(&value_field)?, - )) - } - T::Union(union_fields, _) => { - let mut fields = Vec::new(); - for (idx, (type_id, field)) in union_fields.iter().enumerate() { - if usize::try_from(*type_id) != Ok(idx) { - fail!("non consecutive type ids are not supported"); - } - fields.push((build_builder(field)?, meta_from_field(field.clone())?)); - } - - A::Union(UnionBuilder::new(fields)) - } - }; - Ok(builder) - } + Ok(Self(build_struct( + String::from("$"), + &schema.fields, + false, + )?)) } /// Extract the contained struct fields @@ -222,6 +107,145 @@ impl SimpleSerializer for OuterSequenceBuilder { } } +fn build_struct(path: String, struct_fields: &[Field], nullable: bool) -> Result { + let mut fields = Vec::new(); + for field in struct_fields { + let field_path = format!("{path}.{field_name}", field_name = field.name); + fields.push(( + build_builder(field_path, field)?, + meta_from_field(field.clone())?, + )); + } + StructBuilder::new(path, fields, nullable) +} + +fn build_builder(path: String, field: &Field) -> Result { + use {ArrayBuilder as A, DataType as T}; + + let builder = match &field.data_type { + T::Null => match get_strategy_from_metadata(&field.metadata)? { + Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder), + _ => A::Null(NullBuilder::new()), + }, + T::Boolean => A::Bool(BoolBuilder::new(field.nullable)), + T::Int8 => A::I8(IntBuilder::new(path, field.nullable)), + T::Int16 => A::I16(IntBuilder::new(path, field.nullable)), + T::Int32 => A::I32(IntBuilder::new(path, field.nullable)), + T::Int64 => A::I64(IntBuilder::new(path, field.nullable)), + T::UInt8 => A::U8(IntBuilder::new(path, field.nullable)), + T::UInt16 => A::U16(IntBuilder::new(path, field.nullable)), + T::UInt32 => A::U32(IntBuilder::new(path, field.nullable)), + T::UInt64 => A::U64(IntBuilder::new(path, field.nullable)), + T::Float16 => A::F16(FloatBuilder::new(field.nullable)), + T::Float32 => A::F32(FloatBuilder::new(field.nullable)), + T::Float64 => A::F64(FloatBuilder::new(field.nullable)), + T::Date32 => A::Date32(Date32Builder::new(field.nullable)), + T::Date64 => A::Date64(Date64Builder::new( + None, + is_utc_strategy(get_strategy_from_metadata(&field.metadata)?.as_ref())?, + field.nullable, + )), + T::Timestamp(unit, tz) => A::Date64(Date64Builder::new( + Some((*unit, tz.clone())), + is_utc_tz(tz.as_deref())?, + field.nullable, + )), + T::Time32(unit) => { + if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { + fail!("Only timestamps with second or millisecond unit are supported"); + } + A::Time32(TimeBuilder::new(*unit, field.nullable)) + } + T::Time64(unit) => { + if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { + fail!("Only timestamps with nanosecond or microsecond unit are supported"); + } + A::Time64(TimeBuilder::new(*unit, field.nullable)) + } + T::Duration(unit) => A::Duration(DurationBuilder::new(*unit, field.nullable)), + T::Decimal128(precision, scale) => { + A::Decimal128(DecimalBuilder::new(*precision, *scale, field.nullable)) + } + T::Utf8 => A::Utf8(Utf8Builder::new(field.nullable)), + T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(field.nullable)), + T::List(child) => A::List(ListBuilder::new( + meta_from_field(*child.clone())?, + build_builder( + format!("{path}.{child_name}", child_name = child.name), + child.as_ref(), + )?, + field.nullable, + )?), + T::LargeList(child) => A::LargeList(ListBuilder::new( + meta_from_field(*child.clone())?, + build_builder( + format!("{path}.{child_name}", child_name = child.name), + child.as_ref(), + )?, + field.nullable, + )?), + T::FixedSizeList(child, n) => A::FixedSizedList(FixedSizeListBuilder::new( + meta_from_field(*child.clone())?, + build_builder( + format!("{path}.{child_name}", child_name = child.name), + child.as_ref(), + )?, + (*n).try_into()?, + field.nullable, + )), + T::Binary => A::Binary(BinaryBuilder::new(field.nullable)), + T::LargeBinary => A::LargeBinary(BinaryBuilder::new(field.nullable)), + T::FixedSizeBinary(n) => A::FixedSizeBinary(FixedSizeBinaryBuilder::new( + (*n).try_into()?, + field.nullable, + )), + T::Map(entry_field, _) => A::Map(MapBuilder::new( + meta_from_field(*entry_field.clone())?, + build_builder( + format!("{path}.{child_name}", child_name = entry_field.name), + entry_field.as_ref(), + )?, + field.nullable, + )?), + T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?), + T::Dictionary(key, value, _) => { + let key_field = Field { + name: "key".to_string(), + data_type: *key.clone(), + nullable: field.nullable, + metadata: HashMap::new(), + }; + let value_field = Field { + name: "value".to_string(), + data_type: *value.clone(), + nullable: false, + metadata: HashMap::new(), + }; + + A::DictionaryUtf8(DictionaryUtf8Builder::new( + build_builder(format!("{path}.key"), &key_field)?, + build_builder(format!("{path}.value"), &value_field)?, + )) + } + T::Union(union_fields, _) => { + let mut fields = Vec::new(); + for (idx, (type_id, field)) in union_fields.iter().enumerate() { + if usize::try_from(*type_id) != Ok(idx) { + fail!("non consecutive type ids are not supported"); + } + let field_path = format!("{path}.{field_name}", field_name = field.name); + fields.push(( + build_builder(field_path, field)?, + meta_from_field(field.clone())?, + )); + } + + A::Union(UnionBuilder::new(fields)) + } + }; + Ok(builder) +} + fn is_utc_tz(tz: Option<&str>) -> Result { match tz { None => Ok(false), diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index bea84fb5..e4dcb0bb 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -498,7 +498,7 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { } } -fn merge_annotations(mut err: Error, mut annotations_err: Error) -> Error { +pub fn merge_annotations(mut err: Error, mut annotations_err: Error) -> Error { let extra_annotations = std::mem::take(annotations_err.annotations_mut()); if extra_annotations.is_empty() { return err; diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 68753f13..4914ea1e 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; @@ -15,6 +15,7 @@ const UNKNOWN_KEY: usize = usize::MAX; #[derive(Debug, Clone)] pub struct StructBuilder { + pub path: String, pub fields: Vec<(ArrayBuilder, FieldMeta)>, pub lookup: FieldLookup, pub next: usize, @@ -23,10 +24,15 @@ pub struct StructBuilder { } impl StructBuilder { - pub fn new(fields: Vec<(ArrayBuilder, FieldMeta)>, is_nullable: bool) -> Result { + pub fn new( + path: String, + fields: Vec<(ArrayBuilder, FieldMeta)>, + is_nullable: bool, + ) -> Result { let lookup = FieldLookup::new(fields.iter().map(|(_, meta)| meta.name.clone()).collect())?; Ok(Self { + path, seq: CountArray::new(is_nullable), seen: vec![false; fields.len()], next: 0, @@ -37,6 +43,7 @@ impl StructBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), fields: self .fields .iter_mut() @@ -114,6 +121,12 @@ impl SimpleSerializer for StructBuilder { "StructBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.seq.push_seq_default()?; for (builder, _) in &mut self.fields { diff --git a/serde_arrow/src/test/error_messages/push_validity.rs b/serde_arrow/src/test/error_messages/push_validity.rs index 5af0584d..5c9bf256 100644 --- a/serde_arrow/src/test/error_messages/push_validity.rs +++ b/serde_arrow/src/test/error_messages/push_validity.rs @@ -1,4 +1,3 @@ -use serde::Serialize; use serde_json::json; use crate::internal::{ @@ -20,26 +19,12 @@ fn push_validity_issue_202() -> PanicOnError<()> { }, ]))?; - #[derive(Serialize)] - struct Record { - nested: Nested, - } - - #[derive(Serialize)] - struct Nested { - field: Option, - } - let mut array_builder = ArrayBuilder::new(schema)?; - let res = array_builder.push(&Record { - nested: Nested { field: Some(5) }, - }); + let res = array_builder.push(&json!({"nested": {"field": 32}})); assert_eq!(res, Ok(())); - let res = array_builder.push(&Record { - nested: Nested { field: None }, - }); - assert_error_contains(&res, "field: \"nested.field\""); + let res = array_builder.push(&json!({"nested": {"field": null}})); + assert_error_contains(&res, "field: \"$.nested.field\""); Ok(()) } From e84a8132132793a65b985db085cf259db44b1093 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 20:09:43 +0200 Subject: [PATCH 128/178] Add tests for top-level and struct fields --- .../src/test/error_messages/push_validity.rs | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/serde_arrow/src/test/error_messages/push_validity.rs b/serde_arrow/src/test/error_messages/push_validity.rs index 5c9bf256..e3b5c2eb 100644 --- a/serde_arrow/src/test/error_messages/push_validity.rs +++ b/serde_arrow/src/test/error_messages/push_validity.rs @@ -8,7 +8,7 @@ use crate::internal::{ }; #[test] -fn push_validity_issue_202() -> PanicOnError<()> { +fn int_nested() -> PanicOnError<()> { let schema = SerdeArrowSchema::from_value(&json!([ { "name": "nested", @@ -28,3 +28,57 @@ fn push_validity_issue_202() -> PanicOnError<()> { Ok(()) } + +#[test] +fn int_top_level() -> PanicOnError<()> { + let schema = SerdeArrowSchema::from_value(&json!([ + {"name": "field", "data_type": "U32"}, + ]))?; + + let mut array_builder = ArrayBuilder::new(schema)?; + let res = array_builder.push(&json!({"field": 32})); + assert_eq!(res, Ok(())); + + let res = array_builder.push(&json!({"field": null})); + assert_error_contains(&res, "field: \"$.field\""); + + Ok(()) +} + +#[test] +fn struct_nested() -> PanicOnError<()> { + let schema = SerdeArrowSchema::from_value(&json!([ + { + "name": "nested", + "data_type": "Struct", + "children": [ + {"name": "field", "data_type": "Struct", "children": []}, + ], + }, + ]))?; + + let mut array_builder = ArrayBuilder::new(schema)?; + let res = array_builder.push(&json!({"nested": {"field": {}}})); + assert_eq!(res, Ok(())); + + let res = array_builder.push(&json!({"nested": {"field": null}})); + assert_error_contains(&res, "field: \"$.nested.field\""); + + Ok(()) +} + +#[test] +fn struct_top_level() -> PanicOnError<()> { + let schema = SerdeArrowSchema::from_value(&json!([ + {"name": "field", "data_type": "Struct", "children": []}, + ]))?; + + let mut array_builder = ArrayBuilder::new(schema)?; + let res = array_builder.push(&json!({"field": {}})); + assert_eq!(res, Ok(())); + + let res = array_builder.push(&json!({"field": null})); + assert_error_contains(&res, "field: \"$.field\""); + + Ok(()) +} From d5fba86b1ab4c9be0f158b2fee5f7242d9a3f47b Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 20:18:01 +0200 Subject: [PATCH 129/178] Simplify annotated error impl --- serde_arrow/src/internal/error.rs | 35 ++++++------------- .../serialization/simple_serializer.rs | 18 ++++------ 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 432cf02d..18950749 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -72,14 +72,19 @@ impl Error { /// Turn the error into an annotated error and call the provided function with a mutable /// reference to the annotations pub(crate) fn annotate_unannotated)>( - mut self, + self, func: F, ) -> Self { - if matches!(self, Self::Annotated(_)) { - self - } else { - func(self.annotations_mut()); - self + match self { + Self::Annotated(err) => Self::Annotated(err), + non_annotated_err => { + let mut annotations = BTreeMap::new(); + func(&mut annotations); + Self::Annotated(AnnotatedError { + error: Box::new(non_annotated_err), + annotations, + }) + } } } @@ -89,24 +94,6 @@ impl Error { Self::Annotated(err) => Some(&err.annotations), } } - - /// Ensure the error is annotated and return a mutable reference to the annotations - pub(crate) fn annotations_mut(&mut self) -> &mut BTreeMap { - if !matches!(self, Self::Annotated(_)) { - let mut this = Error::empty(); - std::mem::swap(self, &mut this); - - *self = Self::Annotated(AnnotatedError { - error: Box::new(this), - annotations: BTreeMap::new(), - }); - } - - let Self::Annotated(err) = self else { - unreachable!(); - }; - &mut err.annotations - } } pub struct CustomError { diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index e4dcb0bb..57d956f7 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -498,19 +498,13 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { } } -pub fn merge_annotations(mut err: Error, mut annotations_err: Error) -> Error { - let extra_annotations = std::mem::take(annotations_err.annotations_mut()); - if extra_annotations.is_empty() { +pub fn merge_annotations(err: Error, annotations_err: Error) -> Error { + let Error::Annotated(annotations_err) = annotations_err else { return err; - } - - let result_annotations = err.annotations_mut(); - for (key, value) in extra_annotations { - if !result_annotations.contains_key(&key) { - result_annotations.insert(key, value); - } - } - err + }; + err.annotate_unannotated(|annotations| { + *annotations = annotations_err.annotations; + }) } impl<'a, T: SimpleSerializer> SerializeMap for Mut<'a, T> { From 4a5ca144244d170e28a138fdab359b5fac431dee Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sat, 31 Aug 2024 20:31:35 +0200 Subject: [PATCH 130/178] Add annotations to binary, bool, date32, date64, decimal, dictionary, duration, float, time builders --- .../internal/serialization/array_builder.rs | 4 + .../internal/serialization/binary_builder.rs | 47 ++++++--- .../internal/serialization/bool_builder.rs | 63 +++++++----- .../internal/serialization/date32_builder.rs | 37 ++++--- .../internal/serialization/date64_builder.rs | 18 +++- .../internal/serialization/decimal_builder.rs | 13 ++- .../serialization/dictionary_utf8_builder.rs | 13 ++- .../serialization/duration_builder.rs | 13 ++- .../internal/serialization/float_builder.rs | 97 ++++++++++++------- .../serialization/outer_sequence_builder.rs | 39 +++++--- .../internal/serialization/struct_builder.rs | 4 + .../internal/serialization/time_builder.rs | 11 ++- 12 files changed, 251 insertions(+), 108 deletions(-) diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index eb7f0e5c..60df05fa 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -151,6 +151,10 @@ impl SimpleSerializer for ArrayBuilder { "ArrayBuilder" } + fn annotate_error(&self, err: Error) -> Error { + dispatch!(self, Self(builder) => builder.annotate_error(err)) + } + fn serialize_default(&mut self) -> Result<()> { dispatch!(self, Self(builder) => builder.serialize_default().map_err(|err| builder.annotate_error(err))) } diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 8bb2adc5..ed4b0624 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, BytesArray}, - error::Result, + error::{Error, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, Mut, Offset, @@ -13,45 +13,54 @@ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] -pub struct BinaryBuilder(BytesArray); +pub struct BinaryBuilder { + path: String, + array: BytesArray, +} impl BinaryBuilder { - pub fn new(is_nullable: bool) -> Self { - Self(new_bytes_array(is_nullable)) + pub fn new(path: String, is_nullable: bool) -> Self { + Self { + path, + array: new_bytes_array(is_nullable), + } } pub fn take(&mut self) -> Self { - Self(self.0.take()) + Self { + path: self.path.clone(), + array: self.array.take(), + } } pub fn is_nullable(&self) -> bool { - self.0.validity.is_some() + self.array.validity.is_some() } } impl BinaryBuilder { pub fn into_array(self) -> Result { - Ok(Array::Binary(self.0)) + Ok(Array::Binary(self.array)) } } impl BinaryBuilder { pub fn into_array(self) -> Result { - Ok(Array::LargeBinary(self.0)) + Ok(Array::LargeBinary(self.array)) } } impl BinaryBuilder { fn start(&mut self) -> Result<()> { - self.0.start_seq() + self.array.start_seq() } fn element(&mut self, value: &V) -> Result<()> { let mut u8_serializer = U8Serializer(0); value.serialize(Mut(&mut u8_serializer))?; - self.0.data.push(u8_serializer.0); - self.0.push_seq_elements(1) + self.array.data.push(u8_serializer.0); + self.array.push_seq_elements(1) } fn end(&mut self) -> Result<()> { @@ -64,12 +73,18 @@ impl SimpleSerializer for BinaryBuilder { "BinaryBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { @@ -109,7 +124,7 @@ impl SimpleSerializer for BinaryBuilder { } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - self.0.push_scalar_value(v) + self.array.push_scalar_value(v) } } @@ -120,6 +135,10 @@ impl SimpleSerializer for U8Serializer { "SerializeU8" } + fn annotate_error(&self, err: Error) -> Error { + err + } + fn serialize_u8(&mut self, v: u8) -> Result<()> { self.0 = v; Ok(()) diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index c1d23683..b6d7c983 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -1,37 +1,46 @@ use crate::internal::{ arrow::{Array, BooleanArray}, - error::Result, + error::{Error, Result}, utils::array_ext::{set_bit_buffer, set_validity, set_validity_default}, }; use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] -pub struct BoolBuilder(BooleanArray); +pub struct BoolBuilder { + path: String, + array: BooleanArray, +} impl BoolBuilder { - pub fn new(is_nullable: bool) -> Self { - Self(BooleanArray { - len: 0, - validity: is_nullable.then(Vec::new), - values: Vec::new(), - }) + pub fn new(path: String, is_nullable: bool) -> Self { + Self { + path: path, + array: BooleanArray { + len: 0, + validity: is_nullable.then(Vec::new), + values: Vec::new(), + }, + } } pub fn take(&mut self) -> Self { - Self(BooleanArray { - len: std::mem::take(&mut self.0.len), - validity: self.0.validity.as_mut().map(std::mem::take), - values: std::mem::take(&mut self.0.values), - }) + Self { + path: self.path.clone(), + array: BooleanArray { + len: std::mem::take(&mut self.array.len), + validity: self.array.validity.as_mut().map(std::mem::take), + values: std::mem::take(&mut self.array.values), + }, + } } pub fn is_nullable(&self) -> bool { - self.0.validity.is_some() + self.array.validity.is_some() } pub fn into_array(self) -> Result { - Ok(Array::Boolean(self.0)) + Ok(Array::Boolean(self.array)) } } @@ -40,24 +49,30 @@ impl SimpleSerializer for BoolBuilder { "BoolBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - set_validity_default(self.0.validity.as_mut(), self.0.len); - set_bit_buffer(&mut self.0.values, self.0.len, false); - self.0.len += 1; + set_validity_default(self.array.validity.as_mut(), self.array.len); + set_bit_buffer(&mut self.array.values, self.array.len, false); + self.array.len += 1; Ok(()) } fn serialize_none(&mut self) -> Result<()> { - set_validity(self.0.validity.as_mut(), self.0.len, false)?; - set_bit_buffer(&mut self.0.values, self.0.len, false); - self.0.len += 1; + set_validity(self.array.validity.as_mut(), self.array.len, false)?; + set_bit_buffer(&mut self.array.values, self.array.len, false); + self.array.len += 1; Ok(()) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - set_validity(self.0.validity.as_mut(), self.0.len, true)?; - set_bit_buffer(&mut self.0.values, self.0.len, v); - self.0.len += 1; + set_validity(self.array.validity.as_mut(), self.array.len, true)?; + set_bit_buffer(&mut self.array.values, self.array.len, v); + self.array.len += 1; Ok(()) } } diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index d4feb160..3c564d04 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -2,30 +2,39 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::Result, + error::{Error, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] -pub struct Date32Builder(PrimitiveArray); +pub struct Date32Builder { + path: String, + array: PrimitiveArray, +} impl Date32Builder { - pub fn new(is_nullable: bool) -> Self { - Self(new_primitive_array(is_nullable)) + pub fn new(path: String, is_nullable: bool) -> Self { + Self { + path, + array: new_primitive_array(is_nullable), + } } pub fn take(&mut self) -> Self { - Self(self.0.take()) + Self { + path: self.path.clone(), + array: self.array.take(), + } } pub fn is_nullable(&self) -> bool { - self.0.validity.is_some() + self.array.validity.is_some() } pub fn into_array(self) -> Result { - Ok(Array::Date32(self.0)) + Ok(Array::Date32(self.array)) } } @@ -34,12 +43,18 @@ impl SimpleSerializer for Date32Builder { "Date32Builder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_str(&mut self, v: &str) -> Result<()> { @@ -49,10 +64,10 @@ impl SimpleSerializer for Date32Builder { let duration_since_epoch = date.signed_duration_since(UNIX_EPOCH); let days_since_epoch = duration_since_epoch.num_days().try_into()?; - self.0.push_scalar_value(days_since_epoch) + self.array.push_scalar_value(days_since_epoch) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.0.push_scalar_value(v) + self.array.push_scalar_value(v) } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 7001b558..fb2df64b 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; @@ -8,14 +8,21 @@ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct Date64Builder { + path: String, pub meta: Option<(TimeUnit, Option)>, pub utc: bool, pub array: PrimitiveArray, } impl Date64Builder { - pub fn new(meta: Option<(TimeUnit, Option)>, utc: bool, is_nullable: bool) -> Self { + pub fn new( + path: String, + meta: Option<(TimeUnit, Option)>, + utc: bool, + is_nullable: bool, + ) -> Self { Self { + path, meta, utc, array: new_primitive_array(is_nullable), @@ -24,6 +31,7 @@ impl Date64Builder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), meta: self.meta.clone(), utc: self.utc, array: self.array.take(), @@ -56,6 +64,12 @@ impl SimpleSerializer for Date64Builder { "Date64Builder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.array.push_scalar_default() } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 2f851d47..2f35093d 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, - error::Result, + error::{Error, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, utils::decimal::{self, DecimalParser}, }; @@ -9,6 +9,7 @@ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct DecimalBuilder { + path: String, pub precision: u8, pub scale: i8, pub f32_factor: f32, @@ -18,8 +19,9 @@ pub struct DecimalBuilder { } impl DecimalBuilder { - pub fn new(precision: u8, scale: i8, is_nullable: bool) -> Self { + pub fn new(path: String, precision: u8, scale: i8, is_nullable: bool) -> Self { Self { + path, precision, scale, f32_factor: (10.0_f32).powi(scale as i32), @@ -31,6 +33,7 @@ impl DecimalBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), precision: self.precision, scale: self.scale, f32_factor: self.f32_factor, @@ -59,6 +62,12 @@ impl SimpleSerializer for DecimalBuilder { "DecimalBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.array.push_scalar_default() } diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index 491c6976..a6ce94e5 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, DictionaryArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::Mut, }; @@ -12,14 +12,16 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct DictionaryUtf8Builder { + path: String, pub indices: Box, pub values: Box, pub index: HashMap, } impl DictionaryUtf8Builder { - pub fn new(indices: ArrayBuilder, values: ArrayBuilder) -> Self { + pub fn new(path: String, indices: ArrayBuilder, values: ArrayBuilder) -> Self { Self { + path, indices: Box::new(indices), values: Box::new(values), index: HashMap::new(), @@ -28,6 +30,7 @@ impl DictionaryUtf8Builder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), indices: Box::new(self.indices.take()), values: Box::new(self.values.take()), index: std::mem::take(&mut self.index), @@ -51,6 +54,12 @@ impl SimpleSerializer for DictionaryUtf8Builder { "DictionaryUtf8" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.indices.serialize_none() } diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 1b835a69..c18053c7 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::Result, + error::{Error, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; @@ -8,13 +8,15 @@ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct DurationBuilder { + path: String, pub unit: TimeUnit, pub array: PrimitiveArray, } impl DurationBuilder { - pub fn new(unit: TimeUnit, is_nullable: bool) -> Self { + pub fn new(path: String, unit: TimeUnit, is_nullable: bool) -> Self { Self { + path, unit, array: new_primitive_array(is_nullable), } @@ -22,6 +24,7 @@ impl DurationBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), unit: self.unit, array: self.array.take(), } @@ -45,6 +48,12 @@ impl SimpleSerializer for DurationBuilder { "DurationBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.array.push_scalar_default() } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index f54ed275..d747c6ec 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -2,7 +2,7 @@ use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::Result, + error::{Error, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, utils::Mut, }; @@ -10,19 +10,28 @@ use crate::internal::{ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] -pub struct FloatBuilder(PrimitiveArray); +pub struct FloatBuilder { + path: String, + array: PrimitiveArray, +} impl FloatBuilder { - pub fn new(is_nullable: bool) -> Self { - Self(new_primitive_array(is_nullable)) + pub fn new(path: String, is_nullable: bool) -> Self { + Self { + path, + array: new_primitive_array(is_nullable), + } } pub fn take(&mut self) -> Self { - Self(self.0.take()) + Self { + path: self.path.clone(), + array: self.array.take(), + } } pub fn is_nullable(&self) -> bool { - self.0.validity.is_some() + self.array.validity.is_some() } } @@ -30,7 +39,7 @@ macro_rules! impl_into_array { ($ty:ty, $var:ident) => { impl FloatBuilder<$ty> { pub fn into_array(self) -> Result { - Ok(Array::$var(self.0)) + Ok(Array::$var(self.array)) } } }; @@ -45,12 +54,18 @@ impl SimpleSerializer for FloatBuilder { "FloatBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_some(&mut self, value: &V) -> Result<()> { @@ -58,43 +73,43 @@ impl SimpleSerializer for FloatBuilder { } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.0.push_scalar_value(v) + self.array.push_scalar_value(v) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.0.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32) } } @@ -103,52 +118,58 @@ impl SimpleSerializer for FloatBuilder { "FloatBuilder<64>" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.0.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.0.push_scalar_value(v) + self.array.push_scalar_value(v) } } @@ -157,19 +178,25 @@ impl SimpleSerializer for FloatBuilder { "FloatBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.0.push_scalar_value(f16::from_f32(v)) + self.array.push_scalar_value(f16::from_f32(v)) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.0.push_scalar_value(f16::from_f64(v)) + self.array.push_scalar_value(f16::from_f64(v)) } } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index e0c48a28..33d058f4 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -127,7 +127,7 @@ fn build_builder(path: String, field: &Field) -> Result { Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder), _ => A::Null(NullBuilder::new()), }, - T::Boolean => A::Bool(BoolBuilder::new(field.nullable)), + T::Boolean => A::Bool(BoolBuilder::new(path, field.nullable)), T::Int8 => A::I8(IntBuilder::new(path, field.nullable)), T::Int16 => A::I16(IntBuilder::new(path, field.nullable)), T::Int32 => A::I32(IntBuilder::new(path, field.nullable)), @@ -136,16 +136,18 @@ fn build_builder(path: String, field: &Field) -> Result { T::UInt16 => A::U16(IntBuilder::new(path, field.nullable)), T::UInt32 => A::U32(IntBuilder::new(path, field.nullable)), T::UInt64 => A::U64(IntBuilder::new(path, field.nullable)), - T::Float16 => A::F16(FloatBuilder::new(field.nullable)), - T::Float32 => A::F32(FloatBuilder::new(field.nullable)), - T::Float64 => A::F64(FloatBuilder::new(field.nullable)), - T::Date32 => A::Date32(Date32Builder::new(field.nullable)), + T::Float16 => A::F16(FloatBuilder::new(path, field.nullable)), + T::Float32 => A::F32(FloatBuilder::new(path, field.nullable)), + T::Float64 => A::F64(FloatBuilder::new(path, field.nullable)), + T::Date32 => A::Date32(Date32Builder::new(path, field.nullable)), T::Date64 => A::Date64(Date64Builder::new( + path, None, is_utc_strategy(get_strategy_from_metadata(&field.metadata)?.as_ref())?, field.nullable, )), T::Timestamp(unit, tz) => A::Date64(Date64Builder::new( + path, Some((*unit, tz.clone())), is_utc_tz(tz.as_deref())?, field.nullable, @@ -154,18 +156,21 @@ fn build_builder(path: String, field: &Field) -> Result { if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { fail!("Only timestamps with second or millisecond unit are supported"); } - A::Time32(TimeBuilder::new(*unit, field.nullable)) + A::Time32(TimeBuilder::new(path, *unit, field.nullable)) } T::Time64(unit) => { if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { fail!("Only timestamps with nanosecond or microsecond unit are supported"); } - A::Time64(TimeBuilder::new(*unit, field.nullable)) - } - T::Duration(unit) => A::Duration(DurationBuilder::new(*unit, field.nullable)), - T::Decimal128(precision, scale) => { - A::Decimal128(DecimalBuilder::new(*precision, *scale, field.nullable)) + A::Time64(TimeBuilder::new(path, *unit, field.nullable)) } + T::Duration(unit) => A::Duration(DurationBuilder::new(path, *unit, field.nullable)), + T::Decimal128(precision, scale) => A::Decimal128(DecimalBuilder::new( + path, + *precision, + *scale, + field.nullable, + )), T::Utf8 => A::Utf8(Utf8Builder::new(field.nullable)), T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(field.nullable)), T::List(child) => A::List(ListBuilder::new( @@ -193,8 +198,8 @@ fn build_builder(path: String, field: &Field) -> Result { (*n).try_into()?, field.nullable, )), - T::Binary => A::Binary(BinaryBuilder::new(field.nullable)), - T::LargeBinary => A::LargeBinary(BinaryBuilder::new(field.nullable)), + T::Binary => A::Binary(BinaryBuilder::new(path, field.nullable)), + T::LargeBinary => A::LargeBinary(BinaryBuilder::new(path, field.nullable)), T::FixedSizeBinary(n) => A::FixedSizeBinary(FixedSizeBinaryBuilder::new( (*n).try_into()?, field.nullable, @@ -209,12 +214,15 @@ fn build_builder(path: String, field: &Field) -> Result { )?), T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?), T::Dictionary(key, value, _) => { + let key_path = format!("{path}.key"); let key_field = Field { name: "key".to_string(), data_type: *key.clone(), nullable: field.nullable, metadata: HashMap::new(), }; + + let value_path = format!("{path}.value"); let value_field = Field { name: "value".to_string(), data_type: *value.clone(), @@ -223,8 +231,9 @@ fn build_builder(path: String, field: &Field) -> Result { }; A::DictionaryUtf8(DictionaryUtf8Builder::new( - build_builder(format!("{path}.key"), &key_field)?, - build_builder(format!("{path}.value"), &value_field)?, + path, + build_builder(key_path, &key_field)?, + build_builder(value_path, &value_field)?, )) } T::Union(union_fields, _) => { diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 4914ea1e..8d557f3d 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -289,6 +289,10 @@ impl<'a> SimpleSerializer for KeyLookupSerializer<'a> { "KeyLookupSerializer" } + fn annotate_error(&self, err: Error) -> Error { + err + } + fn serialize_str(&mut self, v: &str) -> Result<()> { self.result = self.index.get(v).copied(); Ok(()) diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 550cc17a..20a04e04 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -10,13 +10,15 @@ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct TimeBuilder { + path: String, pub unit: TimeUnit, pub array: PrimitiveArray, } impl TimeBuilder { - pub fn new(unit: TimeUnit, is_nullable: bool) -> Self { + pub fn new(path: String, unit: TimeUnit, is_nullable: bool) -> Self { Self { + path, unit, array: new_primitive_array(is_nullable), } @@ -24,6 +26,7 @@ impl TimeBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), unit: self.unit, array: self.array.take(), } @@ -64,6 +67,12 @@ where "Time64Builder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.array.push_scalar_default() } From 07316123a26ba0a7e0554885015284c4494b67e8 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 08:18:46 +0200 Subject: [PATCH 131/178] Rewrite FieldLookup helper to ensure ArrayBuilder is Send + Sync --- serde_arrow/src/internal/array_builder.rs | 8 ++++++++ .../internal/serialization/struct_builder.rs | 18 ++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/internal/array_builder.rs b/serde_arrow/src/internal/array_builder.rs index 3e65a6e4..5b86e198 100644 --- a/serde_arrow/src/internal/array_builder.rs +++ b/serde_arrow/src/internal/array_builder.rs @@ -95,3 +95,11 @@ impl std::convert::AsMut for ArrayBuilder { self } } + + +const _: () = { + const fn assert_send_sync() { + () + } + assert_send_sync::() +}; diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 68753f13..228f3b1f 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -205,12 +205,23 @@ impl SimpleSerializer for StructBuilder { } } +/// Optimize field lookups for static names #[derive(Debug, Clone)] pub struct FieldLookup { - pub cached_names: Vec>, + pub cached_names: Vec>, pub index: BTreeMap, } +/// A wrapper around a static field name that compares using ptr and length +#[derive(Debug, Clone)] +pub struct StaticFieldName(&'static str); + +impl std::cmp::PartialEq for StaticFieldName { + fn eq(&self, other: &Self) -> bool { + (self.0.as_ptr(), self.0.len()) == (other.0.as_ptr(), other.0.len()) + } +} + impl FieldLookup { pub fn new(field_names: Vec) -> Result { let mut index = BTreeMap::new(); @@ -234,13 +245,12 @@ impl FieldLookup { } pub fn lookup(&mut self, guess: usize, key: &'static str) -> Option { - let fast_key = (key.as_ptr(), key.len()); - if self.cached_names.get(guess) == Some(&Some(fast_key)) { + if self.cached_names.get(guess) == Some(&Some(StaticFieldName(key))) { Some(guess) } else { let &idx = self.index.get(key)?; if self.cached_names[idx].is_none() { - self.cached_names[idx] = Some(fast_key); + self.cached_names[idx] = Some(StaticFieldName(key)); } Some(idx) } From ac8bcafdd2f230cfdb84237c7bd60a9ae538102e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 08:42:33 +0200 Subject: [PATCH 132/178] Assert that all public types implement Send + Sync --- serde_arrow/src/internal/array_builder.rs | 7 ++----- serde_arrow/src/internal/deserializer.rs | 5 +++++ serde_arrow/src/internal/error.rs | 6 ++++++ serde_arrow/src/internal/schema/extensions/mod.rs | 7 +++++++ serde_arrow/src/internal/schema/mod.rs | 8 ++++++++ serde_arrow/src/internal/serializer.rs | 5 +++++ 6 files changed, 33 insertions(+), 5 deletions(-) diff --git a/serde_arrow/src/internal/array_builder.rs b/serde_arrow/src/internal/array_builder.rs index 5b86e198..6b07eec8 100644 --- a/serde_arrow/src/internal/array_builder.rs +++ b/serde_arrow/src/internal/array_builder.rs @@ -96,10 +96,7 @@ impl std::convert::AsMut for ArrayBuilder { } } - const _: () = { - const fn assert_send_sync() { - () - } - assert_send_sync::() + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for ArrayBuilder {} }; diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index 723615b3..3fdaaa2c 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -194,3 +194,8 @@ impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { false } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl<'de> AssertSendSync for Deserializer<'de> {} +}; diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 610cd1ab..66b22f58 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -211,3 +211,9 @@ impl From for PanicOnErrorError { panic!("{value}"); } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for Error {} + impl AssertSendSync for Result {} +}; diff --git a/serde_arrow/src/internal/schema/extensions/mod.rs b/serde_arrow/src/internal/schema/extensions/mod.rs index fa879d7c..9b11016e 100644 --- a/serde_arrow/src/internal/schema/extensions/mod.rs +++ b/serde_arrow/src/internal/schema/extensions/mod.rs @@ -6,3 +6,10 @@ mod variable_shape_tensor_field; pub use bool8_field::Bool8Field; pub use fixed_shape_tensor_field::FixedShapeTensorField; pub use variable_shape_tensor_field::VariableShapeTensorField; + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for Bool8Field {} + impl AssertSendSync for FixedShapeTensorField {} + impl AssertSendSync for VariableShapeTensorField {} +}; diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 8ff90182..7537b268 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -546,3 +546,11 @@ impl<'a> std::fmt::Display for DataTypeDisplay<'a> { } } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for SerdeArrowSchema {} + impl AssertSendSync for TracingOptions {} + impl AssertSendSync for Strategy {} + impl AssertSendSync for Overwrites {} +}; diff --git a/serde_arrow/src/internal/serializer.rs b/serde_arrow/src/internal/serializer.rs index add07604..70927cf2 100644 --- a/serde_arrow/src/internal/serializer.rs +++ b/serde_arrow/src/internal/serializer.rs @@ -281,3 +281,8 @@ impl> serde::ser::SerializeTupleVariant for CollectionSer Ok(Serializer(self.0)) } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for Serializer {} +}; From d99cbee8a52f144452627693a1ba42934e33a98c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 08:47:08 +0200 Subject: [PATCH 133/178] Execute test workflow on PRs for develop branches --- .github/workflows/test.yml | 3 ++- x.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1a9a1cfc..a69e321e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,8 @@ "workflow_dispatch": {}, "pull_request": { "branches": [ - "main" + "main", + "develop-*" ], "types": [ "opened", diff --git a/x.py b/x.py index cbcfc488..df6d900e 100644 --- a/x.py +++ b/x.py @@ -35,7 +35,7 @@ "on": { "workflow_dispatch": {}, "pull_request": { - "branches": ["main"], + "branches": ["main", "develop-*"], "types": [ "opened", "edited", From 1a22e6e0f505dbbfed956520b2d9c3f10d0ecd85 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 09:11:05 +0200 Subject: [PATCH 134/178] Simplify error impl --- serde_arrow/src/internal/error.rs | 47 ++++++------------- .../serialization/simple_serializer.rs | 4 +- 2 files changed, 16 insertions(+), 35 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index fdc80392..56cee212 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -22,7 +22,6 @@ pub type Result = std::result::Result; #[non_exhaustive] pub enum Error { Custom(CustomError), - Annotated(AnnotatedError), } impl Error { @@ -31,6 +30,7 @@ impl Error { message, backtrace: Backtrace::capture(), cause: None, + annotations: BTreeMap::new(), }) } @@ -42,6 +42,7 @@ impl Error { message, backtrace: Backtrace::capture(), cause: Some(Box::new(cause)), + annotations: BTreeMap::new(), }) } } @@ -52,20 +53,19 @@ impl Error { message: String::new(), backtrace: Backtrace::disabled(), cause: None, + annotations: BTreeMap::new(), }) } pub fn message(&self) -> &str { match self { Self::Custom(err) => &err.message, - Self::Annotated(err) => err.error.message(), } } pub fn backtrace(&self) -> &Backtrace { match self { Self::Custom(err) => &err.backtrace, - Self::Annotated(err) => &err.error.backtrace(), } } @@ -75,23 +75,16 @@ impl Error { self, func: F, ) -> Self { - match self { - Self::Annotated(err) => Self::Annotated(err), - non_annotated_err => { - let mut annotations = BTreeMap::new(); - func(&mut annotations); - Self::Annotated(AnnotatedError { - error: Box::new(non_annotated_err), - annotations, - }) - } + let Self::Custom(mut this) = self; + if this.annotations.is_empty() { + func(&mut this.annotations); } + Self::Custom(this) } pub(crate) fn annotations(&self) -> Option<&BTreeMap> { match self { - Self::Custom(_) => None, - Self::Annotated(err) => Some(&err.annotations), + Self::Custom(err) => Some(&err.annotations), } } } @@ -100,22 +93,12 @@ pub struct CustomError { message: String, backtrace: Backtrace, cause: Option>, -} - -impl std::cmp::PartialEq for CustomError { - fn eq(&self, other: &Self) -> bool { - self.message == other.message - } -} - -pub struct AnnotatedError { - pub(crate) error: Box, pub(crate) annotations: BTreeMap, } -impl std::cmp::PartialEq for AnnotatedError { +impl std::cmp::PartialEq for CustomError { fn eq(&self, other: &Self) -> bool { - self.error.eq(&other.error) && self.annotations == other.annotations + self.message == other.message && self.annotations == other.annotations } } @@ -173,11 +156,10 @@ impl<'a> std::fmt::Display for BacktraceDisplay<'a> { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::Custom(CustomError { - cause: Some(err), .. - }) => Some(err.as_ref()), - _ => None, + let Self::Custom(this) = self; + match this.cause.as_ref() { + Some(cause) => Some(cause.as_ref()), + None => None, } } } @@ -266,6 +248,7 @@ impl From for Error { impl From for Error { fn from(err: bytemuck::PodCastError) -> Self { + // Note: bytemuck::PodCastError does not implement std::error::Error Self::custom(format!("bytemuck::PodCastError: {err}")) } } diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 57d956f7..6cb489aa 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -499,10 +499,8 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { } pub fn merge_annotations(err: Error, annotations_err: Error) -> Error { - let Error::Annotated(annotations_err) = annotations_err else { - return err; - }; err.annotate_unannotated(|annotations| { + let Error::Custom(annotations_err) = annotations_err; *annotations = annotations_err.annotations; }) } From bd93301aaa724d4f52ed70c780ce0f44ca0e21a6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 09:11:53 +0200 Subject: [PATCH 135/178] Make the original arrow error available as source() --- serde_arrow/src/arrow_impl/type_support.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/arrow_impl/type_support.rs b/serde_arrow/src/arrow_impl/type_support.rs index 8330bb80..a6940811 100644 --- a/serde_arrow/src/arrow_impl/type_support.rs +++ b/serde_arrow/src/arrow_impl/type_support.rs @@ -11,7 +11,7 @@ use crate::internal::{ impl From for Error { fn from(err: ArrowError) -> Self { - Self::custom(err.to_string()) + Self::custom_from(err.to_string(), err) } } From f08c276c324a1948ae4090e4f533fa632598c25b Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 09:31:22 +0200 Subject: [PATCH 136/178] Implement annotations for all builders --- Development.md | 4 + .../fixed_size_binary_builder.rs | 17 +++- .../serialization/fixed_size_list_builder.rs | 19 ++++- .../internal/serialization/list_builder.rs | 18 +++- .../src/internal/serialization/map_builder.rs | 18 +++- .../internal/serialization/null_builder.rs | 16 +++- .../serialization/outer_sequence_builder.rs | 85 ++++++++++--------- .../serialization/simple_serializer.rs | 5 +- .../internal/serialization/union_builder.rs | 13 ++- .../serialization/unknown_variant_builder.rs | 20 ++++- .../internal/serialization/utf8_builder.rs | 39 ++++++--- 11 files changed, 183 insertions(+), 71 deletions(-) diff --git a/Development.md b/Development.md index 9551a2f3..60d67a61 100644 --- a/Development.md +++ b/Development.md @@ -32,3 +32,7 @@ modules can can be run without installing further packages. ## Error format - Include the path to the field where sensible + +Common annotations: + +- `field`: the path of the field affected by the error \ No newline at end of file diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 329d4fc2..06cb6a71 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; @@ -12,6 +12,7 @@ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] pub struct FixedSizeBinaryBuilder { + pub path: String, pub seq: CountArray, pub buffer: Vec, pub current_n: usize, @@ -19,8 +20,9 @@ pub struct FixedSizeBinaryBuilder { } impl FixedSizeBinaryBuilder { - pub fn new(n: usize, is_nullable: bool) -> Self { + pub fn new(path: String, n: usize, is_nullable: bool) -> Self { Self { + path, seq: CountArray::new(is_nullable), buffer: Vec::new(), n, @@ -30,6 +32,7 @@ impl FixedSizeBinaryBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), seq: self.seq.take(), buffer: std::mem::take(&mut self.buffer), current_n: std::mem::take(&mut self.current_n), @@ -83,6 +86,12 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { "FixedSizeBinaryBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.seq.push_seq_default()?; for _ in 0..self.n { @@ -157,6 +166,10 @@ impl SimpleSerializer for U8Serializer { "SerializeU8" } + fn annotate_error(&self, err: Error) -> Error { + err + } + fn serialize_u8(&mut self, v: u8) -> Result<()> { self.0 = v; Ok(()) diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 1c222304..1035796d 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, utils::Mut, }; @@ -12,6 +12,7 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct FixedSizeListBuilder { + pub path: String, pub seq: CountArray, pub meta: FieldMeta, pub n: usize, @@ -20,8 +21,15 @@ pub struct FixedSizeListBuilder { } impl FixedSizeListBuilder { - pub fn new(meta: FieldMeta, element: ArrayBuilder, n: usize, is_nullable: bool) -> Self { + pub fn new( + path: String, + meta: FieldMeta, + element: ArrayBuilder, + n: usize, + is_nullable: bool, + ) -> Self { Self { + path, seq: CountArray::new(is_nullable), meta, n, @@ -32,6 +40,7 @@ impl FixedSizeListBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), seq: self.seq.take(), meta: self.meta.clone(), n: self.n, @@ -85,6 +94,12 @@ impl SimpleSerializer for FixedSizeListBuilder { "FixedSizeListBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.seq.push_seq_default()?; for _ in 0..self.n { diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 7578433e..6abf5f24 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::Result, + error::{Error, Result}, utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, utils::{Mut, Offset}, }; @@ -12,14 +12,21 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct ListBuilder { + pub path: String, pub meta: FieldMeta, pub element: Box, pub offsets: OffsetsArray, } impl ListBuilder { - pub fn new(meta: FieldMeta, element: ArrayBuilder, is_nullable: bool) -> Result { + pub fn new( + path: String, + meta: FieldMeta, + element: ArrayBuilder, + is_nullable: bool, + ) -> Result { Ok(Self { + path, meta, element: Box::new(element), offsets: OffsetsArray::new(is_nullable), @@ -28,6 +35,7 @@ impl ListBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), meta: self.meta.clone(), offsets: self.offsets.take(), element: Box::new(self.element.take()), @@ -81,6 +89,12 @@ impl SimpleSerializer for ListBuilder { "ListBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.offsets.push_seq_default() } diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 737d1df6..9dd4fc37 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -2,7 +2,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, }; @@ -10,15 +10,22 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct MapBuilder { + pub path: String, pub meta: FieldMeta, pub entry: Box, pub offsets: OffsetsArray, } impl MapBuilder { - pub fn new(meta: FieldMeta, entry: ArrayBuilder, is_nullable: bool) -> Result { + pub fn new( + path: String, + meta: FieldMeta, + entry: ArrayBuilder, + is_nullable: bool, + ) -> Result { Self::validate_entry(&entry)?; Ok(Self { + path, meta, offsets: OffsetsArray::new(is_nullable), entry: Box::new(entry), @@ -37,6 +44,7 @@ impl MapBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), meta: self.meta.clone(), offsets: self.offsets.take(), entry: Box::new(self.entry.take()), @@ -62,6 +70,12 @@ impl SimpleSerializer for MapBuilder { "MapBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.offsets.push_seq_default() } diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index 60850d94..a06863e8 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -1,22 +1,24 @@ use crate::internal::{ arrow::{Array, NullArray}, - error::Result, + error::{Error, Result}, }; use super::simple_serializer::SimpleSerializer; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct NullBuilder { + pub path: String, pub count: usize, } impl NullBuilder { - pub fn new() -> Self { - Self::default() + pub fn new(path: String) -> Self { + Self { path, count: 0 } } pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), count: std::mem::take(&mut self.count), } } @@ -35,6 +37,12 @@ impl SimpleSerializer for NullBuilder { "NullBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { self.count += 1; Ok(()) diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 33d058f4..8db26fc0 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -66,6 +66,10 @@ impl SimpleSerializer for OuterSequenceBuilder { "OuterSequenceBuilder" } + fn annotate_error(&self, err: crate::Error) -> crate::Error { + self.0.annotate_error(err) + } + fn serialize_none(&mut self) -> Result<()> { self.0.serialize_none() } @@ -124,8 +128,8 @@ fn build_builder(path: String, field: &Field) -> Result { let builder = match &field.data_type { T::Null => match get_strategy_from_metadata(&field.metadata)? { - Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder), - _ => A::Null(NullBuilder::new()), + Some(Strategy::UnknownVariant) => A::UnknownVariant(UnknownVariantBuilder::new(path)), + _ => A::Null(NullBuilder::new(path)), }, T::Boolean => A::Bool(BoolBuilder::new(path, field.nullable)), T::Int8 => A::I8(IntBuilder::new(path, field.nullable)), @@ -171,47 +175,52 @@ fn build_builder(path: String, field: &Field) -> Result { *scale, field.nullable, )), - T::Utf8 => A::Utf8(Utf8Builder::new(field.nullable)), - T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(field.nullable)), - T::List(child) => A::List(ListBuilder::new( - meta_from_field(*child.clone())?, - build_builder( - format!("{path}.{child_name}", child_name = child.name), - child.as_ref(), - )?, - field.nullable, - )?), - T::LargeList(child) => A::LargeList(ListBuilder::new( - meta_from_field(*child.clone())?, - build_builder( - format!("{path}.{child_name}", child_name = child.name), - child.as_ref(), - )?, - field.nullable, - )?), - T::FixedSizeList(child, n) => A::FixedSizedList(FixedSizeListBuilder::new( - meta_from_field(*child.clone())?, - build_builder( - format!("{path}.{child_name}", child_name = child.name), - child.as_ref(), - )?, - (*n).try_into()?, - field.nullable, - )), + T::Utf8 => A::Utf8(Utf8Builder::new(path, field.nullable)), + T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(path, field.nullable)), + T::List(child) => { + let child_path = format!("{path}.{child_name}", child_name = child.name); + A::List(ListBuilder::new( + path, + meta_from_field(*child.clone())?, + build_builder(child_path, child.as_ref())?, + field.nullable, + )?) + } + T::LargeList(child) => { + let child_path = format!("{path}.{child_name}", child_name = child.name); + A::LargeList(ListBuilder::new( + path, + meta_from_field(*child.clone())?, + build_builder(child_path, child.as_ref())?, + field.nullable, + )?) + } + T::FixedSizeList(child, n) => { + let child_path = format!("{path}.{child_name}", child_name = child.name); + A::FixedSizedList(FixedSizeListBuilder::new( + path, + meta_from_field(*child.clone())?, + build_builder(child_path, child.as_ref())?, + (*n).try_into()?, + field.nullable, + )) + } T::Binary => A::Binary(BinaryBuilder::new(path, field.nullable)), T::LargeBinary => A::LargeBinary(BinaryBuilder::new(path, field.nullable)), T::FixedSizeBinary(n) => A::FixedSizeBinary(FixedSizeBinaryBuilder::new( + path, (*n).try_into()?, field.nullable, )), - T::Map(entry_field, _) => A::Map(MapBuilder::new( - meta_from_field(*entry_field.clone())?, - build_builder( - format!("{path}.{child_name}", child_name = entry_field.name), - entry_field.as_ref(), - )?, - field.nullable, - )?), + T::Map(entry_field, _) => { + let child_path = format!("{path}.{child_name}", child_name = entry_field.name); + A::Map(MapBuilder::new( + path, + meta_from_field(*entry_field.clone())?, + build_builder(child_path, entry_field.as_ref())?, + field.nullable, + )?) + } T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?), T::Dictionary(key, value, _) => { let key_path = format!("{path}.key"); @@ -249,7 +258,7 @@ fn build_builder(path: String, field: &Field) -> Result { )); } - A::Union(UnionBuilder::new(fields)) + A::Union(UnionBuilder::new(path, fields)) } }; Ok(builder) diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 6cb489aa..c0034d4e 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -25,10 +25,7 @@ use super::ArrayBuilder; pub trait SimpleSerializer: Sized { fn name(&self) -> &str; - // TODO: remove default - fn annotate_error(&self, err: Error) -> Error { - err - } + fn annotate_error(&self, err: Error) -> Error; fn serialize_default(&mut self) -> Result<()> { fail!("serialize_default is not supported for {}", self.name()); diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index ea3f66b6..9b108ae2 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::{Array, DenseUnionArray, FieldMeta}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::Mut, }; @@ -8,6 +8,7 @@ use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; #[derive(Debug, Clone)] pub struct UnionBuilder { + pub path: String, pub fields: Vec<(ArrayBuilder, FieldMeta)>, pub types: Vec, pub offsets: Vec, @@ -15,8 +16,9 @@ pub struct UnionBuilder { } impl UnionBuilder { - pub fn new(fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self { + pub fn new(path: String, fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self { Self { + path, current_offset: vec![0; fields.len()], types: Vec::new(), offsets: Vec::new(), @@ -26,6 +28,7 @@ impl UnionBuilder { pub fn take(&mut self) -> Self { Self { + path: self.path.clone(), fields: self .fields .iter_mut() @@ -75,6 +78,12 @@ impl SimpleSerializer for UnionBuilder { "UnionBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_unit_variant( &mut self, _: &'static str, diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index c1c4b4bb..59eb4387 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -2,17 +2,25 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, NullArray}, - error::{fail, Result}, + error::{fail, Error, Result}, }; use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; #[derive(Debug, Clone)] -pub struct UnknownVariantBuilder; +pub struct UnknownVariantBuilder { + path: String, +} impl UnknownVariantBuilder { + pub fn new(path: String) -> Self { + UnknownVariantBuilder { path } + } + pub fn take(&mut self) -> Self { - UnknownVariantBuilder + UnknownVariantBuilder { + path: self.path.clone(), + } } pub fn is_nullable(&self) -> bool { @@ -29,6 +37,12 @@ impl SimpleSerializer for UnknownVariantBuilder { "UnknownVariantBuilder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { fail!("Serialization failed: an unknown variant") } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 69f78efd..9e8dea4b 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::{Array, BytesArray}, - error::{fail, Result}, + error::{fail, Error, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, Offset, @@ -10,31 +10,40 @@ use crate::internal::{ use super::simple_serializer::SimpleSerializer; #[derive(Debug, Clone)] -pub struct Utf8Builder(BytesArray); +pub struct Utf8Builder { + path: String, + array: BytesArray, +} impl Utf8Builder { - pub fn new(is_nullable: bool) -> Self { - Self(new_bytes_array(is_nullable)) + pub fn new(path: String, is_nullable: bool) -> Self { + Self { + path, + array: new_bytes_array(is_nullable), + } } pub fn take(&mut self) -> Self { - Self(self.0.take()) + Self { + path: self.path.clone(), + array: self.array.take(), + } } pub fn is_nullable(&self) -> bool { - self.0.validity.is_some() + self.array.validity.is_some() } } impl Utf8Builder { pub fn into_array(self) -> Result { - Ok(Array::Utf8(self.0)) + Ok(Array::Utf8(self.array)) } } impl Utf8Builder { pub fn into_array(self) -> Result { - Ok(Array::LargeUtf8(self.0)) + Ok(Array::LargeUtf8(self.array)) } } @@ -43,16 +52,22 @@ impl SimpleSerializer for Utf8Builder { "Utf8Builder" } + fn annotate_error(&self, err: Error) -> Error { + err.annotate_unannotated(|annotations| { + annotations.insert(String::from("field"), self.path.clone()); + }) + } + fn serialize_default(&mut self) -> Result<()> { - self.0.push_scalar_default() + self.array.push_scalar_default() } fn serialize_none(&mut self) -> Result<()> { - self.0.push_scalar_none() + self.array.push_scalar_none() } fn serialize_str(&mut self, v: &str) -> Result<()> { - self.0.push_scalar_value(v.as_bytes()) + self.array.push_scalar_value(v.as_bytes()) } fn serialize_unit_variant( @@ -61,7 +76,7 @@ impl SimpleSerializer for Utf8Builder { _: u32, variant: &'static str, ) -> Result<()> { - self.0.push_scalar_value(variant.as_bytes()) + self.array.push_scalar_value(variant.as_bytes()) } fn serialize_tuple_variant_start<'this>( From 3e971399551e2dcfebfff8ed4c42c71b3cafb18b Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 11:50:52 +0200 Subject: [PATCH 137/178] Introduce error context to add additional info --- serde_arrow/src/internal/error.rs | 80 ++++++++----- .../internal/serialization/array_builder.rs | 10 +- .../internal/serialization/binary_builder.rs | 18 ++- .../internal/serialization/bool_builder.rs | 15 ++- .../internal/serialization/date32_builder.rs | 15 ++- .../internal/serialization/date64_builder.rs | 15 ++- .../internal/serialization/decimal_builder.rs | 17 ++- .../serialization/dictionary_utf8_builder.rs | 12 +- .../serialization/duration_builder.rs | 15 ++- .../fixed_size_binary_builder.rs | 22 +++- .../serialization/fixed_size_list_builder.rs | 16 ++- .../internal/serialization/float_builder.rs | 16 ++- .../src/internal/serialization/int_builder.rs | 15 ++- .../internal/serialization/list_builder.rs | 16 ++- .../src/internal/serialization/map_builder.rs | 15 ++- .../internal/serialization/null_builder.rs | 11 +- .../serialization/outer_sequence_builder.rs | 10 +- .../serialization/simple_serializer.rs | 111 +++++++++--------- .../internal/serialization/struct_builder.rs | 20 +++- .../internal/serialization/time_builder.rs | 15 ++- .../internal/serialization/union_builder.rs | 12 +- .../serialization/unknown_variant_builder.rs | 11 +- .../internal/serialization/utf8_builder.rs | 12 +- serde_arrow/src/internal/testing.rs | 15 --- serde_arrow/src/internal/utils/mod.rs | 15 +++ serde_arrow/src/test_with_arrow/impls/map.rs | 4 +- 26 files changed, 386 insertions(+), 147 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 56cee212..e3bc0866 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -4,6 +4,10 @@ use std::{ convert::Infallible, }; +pub trait Context { + fn annotations(&self) -> BTreeMap; +} + /// A Result type that defaults to `serde_arrow`'s [Error] type /// pub type Result = std::result::Result; @@ -24,79 +28,93 @@ pub enum Error { Custom(CustomError), } +/// Error creation impl Error { pub fn custom(message: String) -> Self { - Self::Custom(CustomError { + Self::Custom(CustomError(Box::new(CustomErrorImpl { message, backtrace: Backtrace::capture(), cause: None, annotations: BTreeMap::new(), - }) + }))) } pub fn custom_from( message: String, cause: E, ) -> Self { - Self::Custom(CustomError { + Self::Custom(CustomError(Box::new(CustomErrorImpl { message, backtrace: Backtrace::capture(), cause: Some(Box::new(cause)), annotations: BTreeMap::new(), - }) + }))) } -} -impl Error { pub(crate) fn empty() -> Self { - Self::Custom(CustomError { + Self::Custom(CustomError(Box::new(CustomErrorImpl { message: String::new(), backtrace: Backtrace::disabled(), cause: None, annotations: BTreeMap::new(), - }) + }))) } +} - pub fn message(&self) -> &str { - match self { - Self::Custom(err) => &err.message, +impl Error { + /// Call the function with a mutable reference to this errors annotations, if the error was not + /// annotated before + pub(crate) fn annotate_unannotated)>( + self, + func: F, + ) -> Self { + let Self::Custom(mut this) = self; + if this.0.annotations.is_empty() { + func(&mut this.0.annotations); } + Self::Custom(this) } - pub fn backtrace(&self) -> &Backtrace { + pub(crate) fn with_annotations(self, annotations: BTreeMap) -> Self { + let Self::Custom(mut this) = self; + this.0.annotations = annotations; + Self::Custom(this) + } +} + +/// Access information about the error +impl Error { + pub fn message(&self) -> &str { match self { - Self::Custom(err) => &err.backtrace, + Self::Custom(err) => &err.0.message, } } - /// Turn the error into an annotated error and call the provided function with a mutable - /// reference to the annotations - pub(crate) fn annotate_unannotated)>( - self, - func: F, - ) -> Self { - let Self::Custom(mut this) = self; - if this.annotations.is_empty() { - func(&mut this.annotations); + pub fn backtrace(&self) -> &Backtrace { + match self { + Self::Custom(err) => &err.0.backtrace, } - Self::Custom(this) } + /// Get a reference to the annotations of this error pub(crate) fn annotations(&self) -> Option<&BTreeMap> { match self { - Self::Custom(err) => Some(&err.annotations), + Self::Custom(err) => Some(&err.0.annotations), } } } -pub struct CustomError { +#[derive(PartialEq)] +pub struct CustomError(pub(crate) Box); + +pub struct CustomErrorImpl { message: String, backtrace: Backtrace, cause: Option>, pub(crate) annotations: BTreeMap, } -impl std::cmp::PartialEq for CustomError { +impl std::cmp::PartialEq for CustomErrorImpl { fn eq(&self, other: &Self) -> bool { self.message == other.message && self.annotations == other.annotations } @@ -157,7 +175,7 @@ impl<'a> std::fmt::Display for BacktraceDisplay<'a> { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { let Self::Custom(this) = self; - match this.cause.as_ref() { + match this.0.cause.as_ref() { Some(cause) => Some(cause.as_ref()), None => None, } @@ -191,6 +209,14 @@ macro_rules! error { pub(crate) use error; macro_rules! fail { + (in $context:expr, $($tt:tt)*) => { + { + #[allow(unused)] + use $crate::internal::error::Context; + let annotations = $context.annotations(); + return Err($crate::internal::error::error!($($tt)*).with_annotations(annotations)) + } + }; ($($tt:tt)*) => { return Err($crate::internal::error::error!($($tt)*)) }; diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 60df05fa..17a4a4a1 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -1,9 +1,11 @@ +use std::collections::BTreeMap; + use half::f16; use serde::Serialize; use crate::internal::{ arrow::Array, - error::{Error, Result}, + error::{Context, Error, Result}, }; use super::{ @@ -145,6 +147,12 @@ impl ArrayBuilder { } } +impl Context for ArrayBuilder { + fn annotations(&self) -> BTreeMap { + dispatch!(self, Self(builder) => builder.annotations()) + } +} + #[rustfmt::skip] impl SimpleSerializer for ArrayBuilder { fn name(&self) -> &str { diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index ed4b0624..560347e4 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -1,11 +1,13 @@ +use std::collections::BTreeMap; + use serde::Serialize; use crate::internal::{ arrow::{Array, BytesArray}, - error::{Error, Result}, + error::{Context, Error, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, - Mut, Offset, + btree_map, Mut, Offset, }, }; @@ -68,6 +70,12 @@ impl BinaryBuilder { } } +impl Context for BinaryBuilder { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for BinaryBuilder { fn name(&self) -> &str { "BinaryBuilder" @@ -130,6 +138,12 @@ impl SimpleSerializer for BinaryBuilder { struct U8Serializer(u8); +impl Context for U8Serializer { + fn annotations(&self) -> BTreeMap { + Default::default() + } +} + impl SimpleSerializer for U8Serializer { fn name(&self) -> &str { "SerializeU8" diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index b6d7c983..2e2341ba 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -1,7 +1,12 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, BooleanArray}, - error::{Error, Result}, - utils::array_ext::{set_bit_buffer, set_validity, set_validity_default}, + error::{Context, Error, Result}, + utils::{ + array_ext::{set_bit_buffer, set_validity, set_validity_default}, + btree_map, + }, }; use super::simple_serializer::SimpleSerializer; @@ -44,6 +49,12 @@ impl BoolBuilder { } } +impl Context for BoolBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for BoolBuilder { fn name(&self) -> &str { "BoolBuilder" diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 3c564d04..603ad829 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -1,9 +1,14 @@ +use std::collections::BTreeMap; + use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + error::{Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, + }, }; use super::simple_serializer::SimpleSerializer; @@ -38,6 +43,12 @@ impl Date32Builder { } } +impl Context for Date32Builder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for Date32Builder { fn name(&self) -> &str { "Date32Builder" diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index fb2df64b..0e884396 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -1,7 +1,12 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, - error::{fail, Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + error::{fail, Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, + }, }; use super::simple_serializer::SimpleSerializer; @@ -59,6 +64,12 @@ impl Date64Builder { } } +impl Context for Date64Builder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for Date64Builder { fn name(&self) -> &str { "Date64Builder" diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 2f35093d..810977f1 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -1,8 +1,13 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, - error::{Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::decimal::{self, DecimalParser}, + error::{Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, + decimal::{self, DecimalParser}, + }, }; use super::simple_serializer::SimpleSerializer; @@ -57,6 +62,12 @@ impl DecimalBuilder { } } +impl Context for DecimalBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for DecimalBuilder { fn name(&self) -> &str { "DecimalBuilder" diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index a6ce94e5..cc8e6ff9 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -1,11 +1,11 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use serde::Serialize; use crate::internal::{ arrow::{Array, DictionaryArray}, - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -49,6 +49,12 @@ impl DictionaryUtf8Builder { } } +impl Context for DictionaryUtf8Builder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for DictionaryUtf8Builder { fn name(&self) -> &str { "DictionaryUtf8" diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index c18053c7..009bded8 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -1,7 +1,12 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + error::{Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, + }, }; use super::simple_serializer::SimpleSerializer; @@ -43,6 +48,12 @@ impl DurationBuilder { } } +impl Context for DurationBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for DurationBuilder { fn name(&self) -> &str { "DurationBuilder" diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 06cb6a71..1065fc26 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -1,10 +1,14 @@ +use std::collections::BTreeMap; + use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, - error::{fail, Error, Result}, - utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{ + array_ext::{ArrayExt, CountArray, SeqArrayExt}, + btree_map, Mut, + }, }; use super::simple_serializer::SimpleSerializer; @@ -81,6 +85,12 @@ impl FixedSizeBinaryBuilder { } } +impl Context for FixedSizeBinaryBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for FixedSizeBinaryBuilder { fn name(&self) -> &str { "FixedSizeBinaryBuilder" @@ -161,6 +171,12 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { struct U8Serializer(u8); +impl Context for U8Serializer { + fn annotations(&self) -> BTreeMap { + btree_map!() + } +} + impl SimpleSerializer for U8Serializer { fn name(&self) -> &str { "SerializeU8" diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 1035796d..9ba95eb3 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -1,10 +1,14 @@ +use std::collections::BTreeMap; + use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, - error::{fail, Error, Result}, - utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{ + array_ext::{ArrayExt, CountArray, SeqArrayExt}, + btree_map, Mut, + }, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -89,6 +93,12 @@ impl FixedSizeListBuilder { } } +impl Context for FixedSizeListBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for FixedSizeListBuilder { fn name(&self) -> &str { "FixedSizeListBuilder" diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index d747c6ec..017f397a 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -1,10 +1,14 @@ +use std::collections::BTreeMap; + use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - utils::Mut, + error::{Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, Mut, + }, }; use super::simple_serializer::SimpleSerializer; @@ -49,6 +53,12 @@ impl_into_array!(f16, Float16); impl_into_array!(f32, Float32); impl_into_array!(f64, Float64); +impl Context for FloatBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for FloatBuilder { fn name(&self) -> &str { "FloatBuilder" diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 204a74ab..78632b56 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -1,7 +1,12 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + error::{Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, + }, }; use super::simple_serializer::SimpleSerializer; @@ -52,6 +57,12 @@ impl_into_array!(u16, UInt16); impl_into_array!(u32, UInt32); impl_into_array!(u64, UInt64); +impl Context for IntBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for IntBuilder where I: Default diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 6abf5f24..1f816715 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -1,10 +1,14 @@ +use std::collections::BTreeMap; + use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{Error, Result}, - utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - utils::{Mut, Offset}, + error::{Context, Error, Result}, + utils::{ + array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, + btree_map, Mut, Offset, + }, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -84,6 +88,12 @@ impl ListBuilder { } } +impl Context for ListBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for ListBuilder { fn name(&self) -> &str { "ListBuilder" diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 9dd4fc37..0f740d22 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -1,9 +1,14 @@ +use std::collections::BTreeMap; + use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{fail, Error, Result}, - utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, + error::{fail, Context, Error, Result}, + utils::{ + array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, + btree_map, + }, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -65,6 +70,12 @@ impl MapBuilder { } } +impl Context for MapBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for MapBuilder { fn name(&self) -> &str { "MapBuilder" diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index a06863e8..0acb8166 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -1,6 +1,9 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, NullArray}, - error::{Error, Result}, + error::{Context, Error, Result}, + utils::btree_map, }; use super::simple_serializer::SimpleSerializer; @@ -32,6 +35,12 @@ impl NullBuilder { } } +impl Context for NullBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for NullBuilder { fn name(&self) -> &str { "NullBuilder" diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 8db26fc0..c120faf3 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use serde::Serialize; use crate::internal::{ arrow::{DataType, Field, TimeUnit}, - error::{fail, Result}, + error::{fail, Context, Result}, schema::{get_strategy_from_metadata, SerdeArrowSchema, Strategy}, serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, @@ -61,6 +61,12 @@ impl OuterSequenceBuilder { } } +impl Context for OuterSequenceBuilder { + fn annotations(&self) -> BTreeMap { + self.0.annotations() + } +} + impl SimpleSerializer for OuterSequenceBuilder { fn name(&self) -> &str { "OuterSequenceBuilder" diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index c0034d4e..9a303460 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -7,7 +7,7 @@ use serde::{ }; use crate::internal::{ - error::{fail, Error, Result}, + error::{fail, Context, Error, Result}, utils::Mut, }; @@ -22,13 +22,13 @@ use super::ArrayBuilder; /// start call. /// #[allow(unused_variables)] -pub trait SimpleSerializer: Sized { +pub trait SimpleSerializer: Sized + Context { fn name(&self) -> &str; fn annotate_error(&self, err: Error) -> Error; fn serialize_default(&mut self) -> Result<()> { - fail!("serialize_default is not supported for {}", self.name()); + fail!(in self, "serialize_default is not supported"); } fn serialize_unit(&mut self) -> Result<()> { @@ -36,10 +36,7 @@ pub trait SimpleSerializer: Sized { } fn serialize_none(&mut self) -> Result<()> { - fail!( - "serialize_unit/serialize_none is not supported for {}", - self.name() - ); + fail!(in self, "serialize_unit/serialize_none is not supported"); } fn serialize_some(&mut self, value: &V) -> Result<()> { @@ -47,59 +44,59 @@ pub trait SimpleSerializer: Sized { } fn serialize_bool(&mut self, v: bool) -> Result<()> { - fail!("serialize_bool is not implemented for {}", self.name()) + fail!(in self, "serialize_bool is not supported") } fn serialize_char(&mut self, v: char) -> Result<()> { - fail!("serialize_char is not implemented for {}", self.name()) + fail!(in self, "serialize_char is not supported ") } fn serialize_u8(&mut self, v: u8) -> Result<()> { - fail!("serialize_u8 is not implemented for {}", self.name()) + fail!(in self, "serialize_u8 is not supported ") } fn serialize_u16(&mut self, v: u16) -> Result<()> { - fail!("serialize_u16 is not implemented for {}", self.name()) + fail!(in self, "serialize_u16 is not supported ") } fn serialize_u32(&mut self, v: u32) -> Result<()> { - fail!("serialize_u32 is not implemented for {}", self.name()) + fail!(in self, "serialize_u32 is not supported ") } fn serialize_u64(&mut self, v: u64) -> Result<()> { - fail!("serialize_u64 is not implemented for {}", self.name()) + fail!(in self, "serialize_u64 is not supported ") } fn serialize_i8(&mut self, v: i8) -> Result<()> { - fail!("serialize_i8 is not implemented for {}", self.name()) + fail!(in self, "serialize_i8 is not supported ") } fn serialize_i16(&mut self, v: i16) -> Result<()> { - fail!("serialize_i16 is not implemented for {}", self.name()) + fail!(in self, "serialize_i16 is not supported ") } fn serialize_i32(&mut self, v: i32) -> Result<()> { - fail!("serialize_i32 is not implemented for {}", self.name()) + fail!(in self, "serialize_i32 is not supported ") } fn serialize_i64(&mut self, v: i64) -> Result<()> { - fail!("serialize_i64 is not implemented for {}", self.name()) + fail!(in self, "serialize_i64 is not supported ") } fn serialize_f32(&mut self, v: f32) -> Result<()> { - fail!("serialize_f32 is not implemented for {}", self.name()) + fail!(in self, "serialize_f32 is not supported ") } fn serialize_f64(&mut self, v: f64) -> Result<()> { - fail!("serialize_f64 is not implemented for {}", self.name()) + fail!(in self, "serialize_f64 is not supported ") } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - fail!("serialize_bytes is not implemented for {}", self.name()) + fail!(in self, "serialize_bytes is not supported ") } fn serialize_str(&mut self, v: &str) -> Result<()> { - fail!("serialize_str is not implemented for {}", self.name()) + fail!(in self, "serialize_str is not supported ") } fn serialize_newtype_struct( @@ -118,15 +115,15 @@ pub trait SimpleSerializer: Sized { value: &V, ) -> Result<()> { fail!( - "serialize_newtype_variant is not implemented for {}", - self.name() + in self, + "serialize_newtype_variant is not supported", ) } fn serialize_unit_struct(&mut self, name: &'static str) -> Result<()> { fail!( - "serialize_unit_struct is not implemented for {}", - self.name() + in self, + "serialize_unit_struct is not supported", ) } @@ -137,46 +134,46 @@ pub trait SimpleSerializer: Sized { variant: &'static str, ) -> Result<()> { fail!( - "serialize_unit_variant is not implemented for {}", - self.name() + in self, + "serialize_unit_variant is not supported", ) } fn serialize_map_start(&mut self, len: Option) -> Result<()> { - fail!("serialize_map_start is not implemented for {}", self.name()) + fail!(in self, "serialize_map_start is not supported ") } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - fail!("serialize_map_key is not implemented for {}", self.name()); + fail!(in self, "serialize_map_key is not supported "); } fn serialize_map_value(&mut self, value: &V) -> Result<()> { - fail!("serialize_map_value is not implemented for {}", self.name()) + fail!(in self, "serialize_map_value is not supported ") } fn serialize_map_end(&mut self) -> Result<()> { - fail!("serialize_map_end is not implemented for {}", self.name()) + fail!(in self, "serialize_map_end is not supported ") } fn serialize_seq_start(&mut self, len: Option) -> Result<()> { - fail!("serialize_seq_start is not implemented for {}", self.name()) + fail!(in self, "serialize_seq_start is not supported ") } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { fail!( - "serialize_seq_element is not implemented for {}", - self.name() + in self, + "serialize_seq_element is not supported", ); } fn serialize_seq_end(&mut self) -> Result<()> { - fail!("serialize_seq_end is not implemented for {}", self.name()); + fail!(in self, "serialize_seq_end is not supported "); } fn serialize_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { fail!( - "serialize_start_start is not implemented for {}", - self.name() + in self, + "serialize_start_start is not supported", ) } @@ -186,54 +183,54 @@ pub trait SimpleSerializer: Sized { value: &V, ) -> Result<()> { fail!( - "serialize_struct_field is not implemented for {}", - self.name() + in self, + "serialize_struct_field is not supported", ); } fn serialize_struct_end(&mut self) -> Result<()> { fail!( - "serialize_struct_end is not implemented for {}", - self.name() + in self, + "serialize_struct_end is not supported", ); } fn serialize_tuple_start(&mut self, len: usize) -> Result<()> { fail!( - "serialize_tuple_start is not implemented for {}", - self.name() + in self, + "serialize_tuple_start is not supported", ) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { fail!( - "serialize_tuple_element is not implemented for {}", - self.name() + in self, + "serialize_tuple_element is not supported", ); } fn serialize_tuple_end(&mut self) -> Result<()> { - fail!("serialize_tuple_end is not implemented for {}", self.name()) + fail!(in self, "serialize_tuple_end is not supported ") } fn serialize_tuple_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { fail!( - "serialize_tuple_struct_start is not implemented for {}", - self.name() + in self, + "serialize_tuple_struct_start is not supported", ) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { fail!( - "serialize_tuple_struct_field is not implemented for {}", - self.name() + in self, + "serialize_tuple_struct_field is not supported", ); } fn serialize_tuple_struct_end(&mut self) -> Result<()> { fail!( - "serialize_tuple_struct_end is not implemented for {}", - self.name() + in self, + "serialize_tuple_struct_end is not supported", ); } @@ -245,8 +242,8 @@ pub trait SimpleSerializer: Sized { len: usize, ) -> Result<&'this mut ArrayBuilder> { fail!( - "serialize_struct_variant_start is not implemented for {}", - self.name() + in self, + "serialize_struct_variant_start is not supported", ) } @@ -258,8 +255,8 @@ pub trait SimpleSerializer: Sized { len: usize, ) -> Result<&'this mut ArrayBuilder> { fail!( - "serialize_tuple_variant_start is not implemented for {}", - self.name() + in self, + "serialize_tuple_variant_start is not supported", ) } } @@ -498,7 +495,7 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { pub fn merge_annotations(err: Error, annotations_err: Error) -> Error { err.annotate_unannotated(|annotations| { let Error::Custom(annotations_err) = annotations_err; - *annotations = annotations_err.annotations; + *annotations = annotations_err.0.annotations; }) } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index af1d0b69..994d27f3 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -4,9 +4,11 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, - error::{fail, Error, Result}, - utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{ + array_ext::{ArrayExt, CountArray, SeqArrayExt}, + btree_map, Mut, + }, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -116,6 +118,12 @@ impl StructBuilder { } } +impl Context for StructBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for StructBuilder { fn name(&self) -> &str { "StructBuilder" @@ -294,6 +302,12 @@ impl<'a> KeyLookupSerializer<'a> { } } +impl<'a> Context for KeyLookupSerializer<'a> { + fn annotations(&self) -> BTreeMap { + btree_map!() + } +} + impl<'a> SimpleSerializer for KeyLookupSerializer<'a> { fn name(&self) -> &str { "KeyLookupSerializer" diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 20a04e04..f9c6c703 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -1,9 +1,14 @@ +use std::collections::BTreeMap; + use chrono::Timelike; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{Error, Result}, - utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + error::{Context, Error, Result}, + utils::{ + array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, + btree_map, + }, }; use super::simple_serializer::SimpleSerializer; @@ -57,6 +62,12 @@ impl TimeBuilder { } } +impl Context for TimeBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for TimeBuilder where I: TryFrom + TryFrom + Default + 'static, diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 9b108ae2..9f8fa675 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -1,7 +1,9 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, DenseUnionArray, FieldMeta}, - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; @@ -73,6 +75,12 @@ impl UnionBuilder { } } +impl Context for UnionBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for UnionBuilder { fn name(&self) -> &str { "UnionBuilder" diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 59eb4387..b20476f6 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -1,8 +1,11 @@ +use std::collections::BTreeMap; + use serde::Serialize; use crate::internal::{ arrow::{Array, NullArray}, - error::{fail, Error, Result}, + error::{fail, Context, Error, Result}, + utils::btree_map, }; use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; @@ -32,6 +35,12 @@ impl UnknownVariantBuilder { } } +impl Context for UnknownVariantBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for UnknownVariantBuilder { fn name(&self) -> &str { "UnknownVariantBuilder" diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 9e8dea4b..0bf6321f 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -1,9 +1,11 @@ +use std::collections::BTreeMap; + use crate::internal::{ arrow::{Array, BytesArray}, - error::{fail, Error, Result}, + error::{fail, Context, Error, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - Offset, + btree_map, Offset, }, }; @@ -47,6 +49,12 @@ impl Utf8Builder { } } +impl Context for Utf8Builder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + impl SimpleSerializer for Utf8Builder { fn name(&self) -> &str { "Utf8Builder" diff --git a/serde_arrow/src/internal/testing.rs b/serde_arrow/src/internal/testing.rs index 06179ec7..7f922a7c 100644 --- a/serde_arrow/src/internal/testing.rs +++ b/serde_arrow/src/internal/testing.rs @@ -10,21 +10,6 @@ pub fn assert_error_contains(actual: &Result, exp } } -macro_rules! btree_map { - () => { - ::std::collections::BTreeMap::new() - }; - ($($key:expr => $value:expr),* $(,)?) => { - { - let mut m = ::std::collections::BTreeMap::new(); - $(m.insert($key.into(), $value.into());)* - m - } - }; -} - -pub(crate) use btree_map; - macro_rules! hash_map { () => { ::std::collections::HashMap::new() diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index cee187fe..36586c1b 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -187,3 +187,18 @@ pub fn meta_from_field(field: Field) -> Result { metadata: field.metadata, }) } + +macro_rules! btree_map { + () => { + ::std::collections::BTreeMap::new() + }; + ($($key:expr => $value:expr),* $(,)?) => { + { + let mut m = ::std::collections::BTreeMap::new(); + $(m.insert($key.into(), $value.into());)* + m + } + }; +} + +pub(crate) use btree_map; diff --git a/serde_arrow/src/test_with_arrow/impls/map.rs b/serde_arrow/src/test_with_arrow/impls/map.rs index 0a5905c3..349f2e19 100644 --- a/serde_arrow/src/test_with_arrow/impls/map.rs +++ b/serde_arrow/src/test_with_arrow/impls/map.rs @@ -4,8 +4,8 @@ use serde_json::json; use crate::internal::{ schema::TracingOptions, - testing::{btree_map, hash_map}, - utils::Item, + testing::hash_map, + utils::{btree_map, Item}, }; use super::utils::Test; From dd70221a854d2530035a43a376fa663fc7c945e6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 14:54:51 +0200 Subject: [PATCH 138/178] Use explicit context calls instead of relying on wrapper classes --- serde_arrow/src/internal/error.rs | 63 ++++-- .../internal/serialization/array_builder.rs | 99 ++++----- .../internal/serialization/binary_builder.rs | 44 ++-- .../internal/serialization/bool_builder.rs | 16 +- .../internal/serialization/date32_builder.rs | 37 ++-- .../internal/serialization/date64_builder.rs | 69 +++--- .../internal/serialization/decimal_builder.rs | 29 ++- .../serialization/dictionary_utf8_builder.rs | 24 +-- .../serialization/duration_builder.rs | 34 ++- .../fixed_size_binary_builder.rs | 47 ++-- .../serialization/fixed_size_list_builder.rs | 30 +-- .../internal/serialization/float_builder.rs | 88 +++----- .../src/internal/serialization/int_builder.rs | 47 ++-- .../internal/serialization/list_builder.rs | 34 ++- .../src/internal/serialization/map_builder.rs | 26 +-- .../internal/serialization/null_builder.rs | 12 +- .../serialization/outer_sequence_builder.rs | 8 - .../serialization/simple_serializer.rs | 200 +++++------------- .../internal/serialization/struct_builder.rs | 48 ++--- .../internal/serialization/time_builder.rs | 34 ++- .../internal/serialization/union_builder.rs | 26 +-- .../serialization/unknown_variant_builder.rs | 88 ++++---- .../internal/serialization/utf8_builder.rs | 26 +-- 23 files changed, 418 insertions(+), 711 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index e3bc0866..ac11d30b 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -4,10 +4,51 @@ use std::{ convert::Infallible, }; +/// An object that offers additional context to an error pub trait Context { fn annotations(&self) -> BTreeMap; } +pub struct StaticContext(BTreeMap); + +impl StaticContext { + pub fn from_context(context: &C) -> Self { + Self(context.annotations()) + } +} + +impl Context for StaticContext { + fn annotations(&self) -> BTreeMap { + self.0.clone() + } +} + +/// Helpers to attach the metadata associated with a context to an error +pub trait ContextSupport { + type Output; + + fn ctx(self, context: &C) -> Self::Output; +} + +impl> ContextSupport for Result { + type Output = Result; + + fn ctx(self, context: &C) -> Self::Output { + match self { + Ok(value) => Ok(value), + Err(err) => Err(err.ctx(context)), + } + } +} + +impl> ContextSupport for E { + type Output = Error; + + fn ctx(self, context: &C) -> Self::Output { + self.into().with_annotations(context.annotations()) + } +} + /// A Result type that defaults to `serde_arrow`'s [Error] type /// pub type Result = std::result::Result; @@ -50,31 +91,9 @@ impl Error { annotations: BTreeMap::new(), }))) } - - pub(crate) fn empty() -> Self { - Self::Custom(CustomError(Box::new(CustomErrorImpl { - message: String::new(), - backtrace: Backtrace::disabled(), - cause: None, - annotations: BTreeMap::new(), - }))) - } } impl Error { - /// Call the function with a mutable reference to this errors annotations, if the error was not - /// annotated before - pub(crate) fn annotate_unannotated)>( - self, - func: F, - ) -> Self { - let Self::Custom(mut this) = self; - if this.0.annotations.is_empty() { - func(&mut this.0.annotations); - } - Self::Custom(this) - } - pub(crate) fn with_annotations(self, annotations: BTreeMap) -> Self { let Self::Custom(mut this) = self; this.0.annotations = annotations; diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 17a4a4a1..0bdcc2b5 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -5,7 +5,7 @@ use serde::Serialize; use crate::internal::{ arrow::Array, - error::{Context, Error, Result}, + error::{Context, Result}, }; use super::{ @@ -15,10 +15,9 @@ use super::{ fixed_size_binary_builder::FixedSizeBinaryBuilder, fixed_size_list_builder::FixedSizeListBuilder, float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder, - null_builder::NullBuilder, simple_serializer::merge_annotations, - simple_serializer::SimpleSerializer, struct_builder::StructBuilder, time_builder::TimeBuilder, - union_builder::UnionBuilder, unknown_variant_builder::UnknownVariantBuilder, - utf8_builder::Utf8Builder, + null_builder::NullBuilder, simple_serializer::SimpleSerializer, struct_builder::StructBuilder, + time_builder::TimeBuilder, union_builder::UnionBuilder, + unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, }; #[derive(Debug, Clone)] @@ -155,173 +154,163 @@ impl Context for ArrayBuilder { #[rustfmt::skip] impl SimpleSerializer for ArrayBuilder { - fn name(&self) -> &str { - "ArrayBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - dispatch!(self, Self(builder) => builder.annotate_error(err)) - } - fn serialize_default(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_default().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_default()) } fn serialize_unit_struct(&mut self, name: &'static str) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_unit_struct(name).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_unit_struct(name)) } fn serialize_none(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_none().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_none()) } fn serialize_some(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_some(value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_some(value)) } fn serialize_unit(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_unit().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_unit()) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_bool(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_bool(v)) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i8(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_i8(v)) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i16(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_i16(v)) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i32(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_i32(v)) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_i64(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_i64(v)) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u8(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_u8(v)) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u16(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_u16(v)) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u32(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_u32(v)) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_u64(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_u64(v)) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_f32(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_f32(v)) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_f64(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_f64(v)) } fn serialize_char(&mut self, v: char) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_char(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_char(v)) } fn serialize_str(&mut self, v: &str) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_str(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_str(v)) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_bytes(v).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_bytes(v)) } fn serialize_seq_start(&mut self, len: Option) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_seq_start(len).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_seq_start(len)) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_seq_element(value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_seq_element(value)) } fn serialize_seq_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_seq_end().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_seq_end()) } fn serialize_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_struct_start(name, len).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_struct_start(name, len)) } fn serialize_struct_field(&mut self, key: &'static str, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_struct_field(key, value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_struct_field(key, value)) } fn serialize_struct_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_struct_end().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_struct_end()) } fn serialize_map_start(&mut self, len: Option) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_start(len).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_map_start(len)) } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_key(key).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_map_key(key)) } fn serialize_map_value(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_value(value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_map_value(value)) } fn serialize_map_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_map_end().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_map_end()) } fn serialize_tuple_start(&mut self, len: usize) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_start(len).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_start(len)) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_element(value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_element(value)) } fn serialize_tuple_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_end().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_end()) } fn serialize_tuple_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_struct_start(name, len).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_struct_start(name, len)) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_struct_field(value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_struct_field(value)) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_tuple_struct_end().map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_struct_end()) } fn serialize_newtype_struct(&mut self, name: &'static str, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_newtype_struct(name, value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_newtype_struct(name, value)) } fn serialize_newtype_variant(&mut self, name: &'static str, variant_index: u32, variant: &'static str, value: &V) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_newtype_variant(name, variant_index, variant, value).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_newtype_variant(name, variant_index, variant, value)) } fn serialize_unit_variant(&mut self, name: &'static str, variant_index: u32, variant: &'static str) -> Result<()> { - dispatch!(self, Self(builder) => builder.serialize_unit_variant(name, variant_index, variant).map_err(|err| builder.annotate_error(err))) + dispatch!(self, Self(builder) => builder.serialize_unit_variant(name, variant_index, variant)) } fn serialize_struct_variant_start<'this>(&'this mut self, name: &'static str, variant_index: u32, variant: &'static str, len: usize) -> Result<&'this mut ArrayBuilder> { - let annotations_err = dispatch!(self, Self(builder) => builder.annotate_error(Error::empty())); - dispatch!(self, Self(builder) => builder.serialize_struct_variant_start(name, variant_index, variant, len).map_err(|err| merge_annotations(err, annotations_err))) + dispatch!(self, Self(builder) => builder.serialize_struct_variant_start(name, variant_index, variant, len)) } fn serialize_tuple_variant_start<'this> (&'this mut self, name: &'static str, variant_index: u32, variant: &'static str, len: usize) -> Result<&'this mut ArrayBuilder> { - let annotations_err = dispatch!(self, Self(builder) => builder.annotate_error(Error::empty())); - dispatch!(self, Self(builder) => builder.serialize_tuple_variant_start(name, variant_index, variant, len).map_err(|err| merge_annotations(err, annotations_err))) + dispatch!(self, Self(builder) => builder.serialize_tuple_variant_start(name, variant_index, variant, len)) } } diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 560347e4..de98e655 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, BytesArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, btree_map, Mut, Offset, @@ -77,62 +77,52 @@ impl Context for BinaryBuilder { } impl SimpleSerializer for BinaryBuilder { - fn name(&self) -> &str { - "BinaryBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - self.element(value) + self.element(value).ctx(self) } fn serialize_seq_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - self.element(value) + self.element(value).ctx(self) } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - self.element(value) + self.element(value).ctx(self) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - self.array.push_scalar_value(v) + self.array.push_scalar_value(v).ctx(self) } } @@ -145,14 +135,6 @@ impl Context for U8Serializer { } impl SimpleSerializer for U8Serializer { - fn name(&self) -> &str { - "SerializeU8" - } - - fn annotate_error(&self, err: Error) -> Error { - err - } - fn serialize_u8(&mut self, v: u8) -> Result<()> { self.0 = v; Ok(()) diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 2e2341ba..2d9c0b40 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, BooleanArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{set_bit_buffer, set_validity, set_validity_default}, btree_map, @@ -56,16 +56,6 @@ impl Context for BoolBuilder { } impl SimpleSerializer for BoolBuilder { - fn name(&self) -> &str { - "BoolBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { set_validity_default(self.array.validity.as_mut(), self.array.len); set_bit_buffer(&mut self.array.values, self.array.len, false); @@ -74,14 +64,14 @@ impl SimpleSerializer for BoolBuilder { } fn serialize_none(&mut self) -> Result<()> { - set_validity(self.array.validity.as_mut(), self.array.len, false)?; + set_validity(self.array.validity.as_mut(), self.array.len, false).ctx(self)?; set_bit_buffer(&mut self.array.values, self.array.len, false); self.array.len += 1; Ok(()) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - set_validity(self.array.validity.as_mut(), self.array.len, true)?; + set_validity(self.array.validity.as_mut(), self.array.len, true).ctx(self)?; set_bit_buffer(&mut self.array.values, self.array.len, v); self.array.len += 1; Ok(()) diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 603ad829..d6d1b40c 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -4,7 +4,7 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, @@ -41,6 +41,16 @@ impl Date32Builder { pub fn into_array(self) -> Result { Ok(Array::Date32(self.array)) } + + fn parse_str_to_days_since_epoch(&self, s: &str) -> Result { + const UNIX_EPOCH: NaiveDate = NaiveDateTime::UNIX_EPOCH.date(); + + let date = s.parse::()?; + let duration_since_epoch = date.signed_duration_since(UNIX_EPOCH); + let days_since_epoch = duration_since_epoch.num_days().try_into()?; + + Ok(days_since_epoch) + } } impl Context for Date32Builder { @@ -50,35 +60,20 @@ impl Context for Date32Builder { } impl SimpleSerializer for Date32Builder { - fn name(&self) -> &str { - "Date32Builder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - const UNIX_EPOCH: NaiveDate = NaiveDateTime::UNIX_EPOCH.date(); - - let date = v.parse::()?; - let duration_since_epoch = date.signed_duration_since(UNIX_EPOCH); - let days_since_epoch = duration_since_epoch.num_days().try_into()?; - - self.array.push_scalar_value(days_since_epoch) + let days_since_epoch = self.parse_str_to_days_since_epoch(v).ctx(self)?; + self.array.push_scalar_value(days_since_epoch).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v) + self.array.push_scalar_value(v).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 0e884396..65c8bf9a 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, @@ -64,43 +64,19 @@ impl Date64Builder { } } -impl Context for Date64Builder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) - } -} - -impl SimpleSerializer for Date64Builder { - fn name(&self) -> &str { - "Date64Builder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() - } - - fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() - } +impl Date64Builder { + fn parse_str_to_timestamp(&self, s: &str) -> Result { + use chrono::{DateTime, NaiveDateTime, Utc}; - fn serialize_str(&mut self, v: &str) -> Result<()> { let date_time = if self.utc { - use chrono::{DateTime, Utc}; - v.parse::>()? + s.parse::>()? } else { - use chrono::NaiveDateTime; - v.parse::()?.and_utc() + s.parse::()?.and_utc() }; - let timestamp = match self.meta.as_ref() { + match self.meta.as_ref() { Some((TimeUnit::Nanosecond, _)) => match date_time.timestamp_nanos_opt() { - Some(timestamp) => timestamp, + Some(timestamp) => Ok(timestamp), _ => fail!( concat!( "Timestamp '{date_time}' cannot be converted to nanoseconds. ", @@ -110,15 +86,34 @@ impl SimpleSerializer for Date64Builder { date_time = date_time, ), }, - Some((TimeUnit::Microsecond, _)) => date_time.timestamp_micros(), - Some((TimeUnit::Millisecond, _)) | None => date_time.timestamp_millis(), - Some((TimeUnit::Second, _)) => date_time.timestamp(), - }; + Some((TimeUnit::Microsecond, _)) => Ok(date_time.timestamp_micros()), + Some((TimeUnit::Millisecond, _)) | None => Ok(date_time.timestamp_millis()), + Some((TimeUnit::Second, _)) => Ok(date_time.timestamp()), + } + } +} + +impl Context for Date64Builder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone()) + } +} + +impl SimpleSerializer for Date64Builder { + fn serialize_default(&mut self) -> Result<()> { + self.array.push_scalar_default().ctx(self) + } + fn serialize_none(&mut self) -> Result<()> { + self.array.push_scalar_none().ctx(self) + } + + fn serialize_str(&mut self, v: &str) -> Result<()> { + let timestamp = self.parse_str_to_timestamp(v).ctx(self)?; self.array.push_scalar_value(timestamp) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v) + self.array.push_scalar_value(v).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 810977f1..173451d8 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, @@ -69,38 +69,33 @@ impl Context for DecimalBuilder { } impl SimpleSerializer for DecimalBuilder { - fn name(&self) -> &str { - "DecimalBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value((v * self.f32_factor) as i128) + self.array + .push_scalar_value((v * self.f32_factor) as i128) + .ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value((v * self.f64_factor) as i128) + self.array + .push_scalar_value((v * self.f64_factor) as i128) + .ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { let mut parse_buffer = [0; decimal::BUFFER_SIZE_I128]; let val = self .parser - .parse_decimal128(&mut parse_buffer, v.as_bytes())?; + .parse_decimal128(&mut parse_buffer, v.as_bytes()) + .ctx(self)?; - self.array.push_scalar_value(val) + self.array.push_scalar_value(val).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index cc8e6ff9..48fb23e2 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, DictionaryArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{btree_map, Mut}, }; @@ -56,25 +56,16 @@ impl Context for DictionaryUtf8Builder { } impl SimpleSerializer for DictionaryUtf8Builder { - fn name(&self) -> &str { - "DictionaryUtf8" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.indices.serialize_none() + self.indices.serialize_none().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.indices.serialize_none() + self.indices.serialize_none().ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { + // the only faillible operations concern children: do not apply the context let idx = match self.index.get(v) { Some(idx) => *idx, None => { @@ -93,6 +84,7 @@ impl SimpleSerializer for DictionaryUtf8Builder { _: u32, variant: &'static str, ) -> Result<()> { + // NOTE: context logic is implemented in serialize_str self.serialize_str(variant) } @@ -103,7 +95,7 @@ impl SimpleSerializer for DictionaryUtf8Builder { _: &'static str, _: usize, ) -> Result<&'this mut super::ArrayBuilder> { - fail!("Cannot serialize enum with data as string"); + fail!(in self, "Cannot serialize enum with data as string"); } fn serialize_struct_variant_start<'this>( @@ -113,7 +105,7 @@ impl SimpleSerializer for DictionaryUtf8Builder { _: &'static str, _: usize, ) -> Result<&'this mut super::ArrayBuilder> { - fail!("Cannot serialize enum with data as string"); + fail!(in self, "Cannot serialize enum with data as string"); } fn serialize_newtype_variant( @@ -123,6 +115,6 @@ impl SimpleSerializer for DictionaryUtf8Builder { _: &'static str, _: &V, ) -> Result<()> { - fail!("Cannot serialize enum with data as string"); + fail!(in self, "Cannot serialize enum with data as string"); } } diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 009bded8..353c29e9 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, @@ -55,53 +55,45 @@ impl Context for DurationBuilder { } impl SimpleSerializer for DurationBuilder { - fn name(&self) -> &str { - "DurationBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(i64::from(v)) + self.array.push_scalar_value(i64::from(v)).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(i64::from(v)) + self.array.push_scalar_value(i64::from(v)).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(i64::from(v)) + self.array.push_scalar_value(i64::from(v)).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v) + self.array.push_scalar_value(v).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(i64::from(v)) + self.array.push_scalar_value(i64::from(v)).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(i64::from(v)) + self.array.push_scalar_value(i64::from(v)).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(i64::from(v)) + self.array.push_scalar_value(i64::from(v)).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array.push_scalar_value(i64::try_from(v)?) + self.array + .push_scalar_value(i64::try_from(v).ctx(self)?) + .ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 1065fc26..4de21459 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, btree_map, Mut, @@ -92,18 +92,8 @@ impl Context for FixedSizeBinaryBuilder { } impl SimpleSerializer for FixedSizeBinaryBuilder { - fn name(&self) -> &str { - "FixedSizeBinaryBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.seq.push_seq_default()?; + self.seq.push_seq_default().ctx(self)?; for _ in 0..self.n { self.buffer.push(0); } @@ -111,7 +101,7 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { } fn serialize_none(&mut self) -> Result<()> { - self.seq.push_seq_none()?; + self.seq.push_seq_none().ctx(self)?; for _ in 0..self.n { self.buffer.push(0); } @@ -119,53 +109,54 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - self.element(value) + self.element(value).ctx(self) } fn serialize_seq_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - self.element(value) + self.element(value).ctx(self) } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - self.element(value) + self.element(value).ctx(self) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { if v.len() != self.n { fail!( + in self, "Invalid number of elements for fixed size binary: got {actual}, expected {expected}", actual = v.len(), expected = self.n, ); } - self.seq.start_seq()?; + self.seq.start_seq().ctx(self)?; self.buffer.extend(v); - self.seq.end_seq() + self.seq.end_seq().ctx(self) } } @@ -178,14 +169,6 @@ impl Context for U8Serializer { } impl SimpleSerializer for U8Serializer { - fn name(&self) -> &str { - "SerializeU8" - } - - fn annotate_error(&self, err: Error) -> Error { - err - } - fn serialize_u8(&mut self, v: u8) -> Result<()> { self.0 = v; Ok(()) diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 9ba95eb3..e2307f9c 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, btree_map, Mut, @@ -76,7 +76,7 @@ impl FixedSizeListBuilder { fn element(&mut self, value: &V) -> Result<()> { self.current_count += 1; - self.seq.push_seq_elements(1)?; + self.seq.push_seq_elements(1).ctx(self)?; value.serialize(Mut(self.element.as_mut())) } @@ -100,18 +100,8 @@ impl Context for FixedSizeListBuilder { } impl SimpleSerializer for FixedSizeListBuilder { - fn name(&self) -> &str { - "FixedSizeListBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.seq.push_seq_default()?; + self.seq.push_seq_default().ctx(self)?; for _ in 0..self.n { self.element.serialize_default()?; } @@ -119,7 +109,7 @@ impl SimpleSerializer for FixedSizeListBuilder { } fn serialize_none(&mut self) -> Result<()> { - self.seq.push_seq_none()?; + self.seq.push_seq_none().ctx(self)?; for _ in 0..self.n { self.element.serialize_default()?; } @@ -127,7 +117,7 @@ impl SimpleSerializer for FixedSizeListBuilder { } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { @@ -135,11 +125,11 @@ impl SimpleSerializer for FixedSizeListBuilder { } fn serialize_seq_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { @@ -147,11 +137,11 @@ impl SimpleSerializer for FixedSizeListBuilder { } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { @@ -159,6 +149,6 @@ impl SimpleSerializer for FixedSizeListBuilder { } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 017f397a..5d9bc519 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -4,7 +4,7 @@ use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, Mut, @@ -60,22 +60,12 @@ impl Context for FloatBuilder { } impl SimpleSerializer for FloatBuilder { - fn name(&self) -> &str { - "FloatBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_some(&mut self, value: &V) -> Result<()> { @@ -83,130 +73,110 @@ impl SimpleSerializer for FloatBuilder { } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value(v) + self.array.push_scalar_value(v).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value(v as f32) + self.array.push_scalar_value(v as f32).ctx(self) } } impl SimpleSerializer for FloatBuilder { - fn name(&self) -> &str { - "FloatBuilder<64>" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value(v as f64) + self.array.push_scalar_value(v as f64).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value(v) + self.array.push_scalar_value(v).ctx(self) } } impl SimpleSerializer for FloatBuilder { - fn name(&self) -> &str { - "FloatBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value(f16::from_f32(v)) + self.array.push_scalar_value(f16::from_f32(v)).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value(f16::from_f64(v)) + self.array.push_scalar_value(f16::from_f64(v)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 78632b56..2f621e39 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, @@ -19,7 +19,6 @@ pub struct IntBuilder { impl IntBuilder { pub fn new(path: String, is_nullable: bool) -> Self { - println!("new IntBuilder ({path}"); Self { path, array: new_primitive_array(is_nullable), @@ -63,6 +62,16 @@ impl Context for IntBuilder { } } +impl IntBuilder { + fn push_value(&mut self, v: J) -> Result<()> + where + I: Default + TryFrom + 'static, + Error: From<>::Error>, + { + self.array.push_scalar_value(I::try_from(v)?) + } +} + impl SimpleSerializer for IntBuilder where I: Default @@ -84,62 +93,52 @@ where Error: From<>::Error>, Error: From<>::Error>, { - fn name(&self) -> &str { - "IntBuilder<()>" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_bool(&mut self, v: bool) -> Result<()> { let v: u8 = if v { 1 } else { 0 }; - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array.push_scalar_value(I::try_from(v)?) + self.push_value(v).ctx(self) } fn serialize_char(&mut self, v: char) -> Result<()> { - self.array.push_scalar_value(I::try_from(u32::from(v))?) + self.push_value(u32::from(v)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 1f816715..05bf2ddd 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, btree_map, Mut, Offset, @@ -79,7 +79,7 @@ impl ListBuilder { } fn element(&mut self, value: &V) -> Result<()> { - self.offsets.push_seq_elements(1)?; + self.offsets.push_seq_elements(1).ctx(self)?; value.serialize(Mut(self.element.as_mut())) } @@ -95,26 +95,16 @@ impl Context for ListBuilder { } impl SimpleSerializer for ListBuilder { - fn name(&self) -> &str { - "ListBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.offsets.push_seq_default() + self.offsets.push_seq_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_seq_none() + self.offsets.push_seq_none().ctx(self) } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { @@ -122,11 +112,11 @@ impl SimpleSerializer for ListBuilder { } fn serialize_seq_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { @@ -134,11 +124,11 @@ impl SimpleSerializer for ListBuilder { } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { @@ -146,14 +136,14 @@ impl SimpleSerializer for ListBuilder { } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - self.start()?; + self.start().ctx(self)?; for item in v { self.element(item)?; } - self.end() + self.end().ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 0f740d22..e803dc1b 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, btree_map, @@ -77,40 +77,30 @@ impl Context for MapBuilder { } impl SimpleSerializer for MapBuilder { - fn name(&self) -> &str { - "MapBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.offsets.push_seq_default() + self.offsets.push_seq_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_seq_none() + self.offsets.push_seq_none().ctx(self) } fn serialize_map_start(&mut self, _: Option) -> Result<()> { - self.offsets.start_seq() + self.offsets.start_seq().ctx(self) } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - self.offsets.push_seq_elements(1)?; - self.entry.serialize_tuple_start(2)?; + self.offsets.push_seq_elements(1).ctx(self)?; + self.entry.serialize_tuple_start(2).ctx(self)?; self.entry.serialize_tuple_element(key) } fn serialize_map_value(&mut self, value: &V) -> Result<()> { self.entry.serialize_tuple_element(value)?; - self.entry.serialize_tuple_end() + self.entry.serialize_tuple_end().ctx(self) } fn serialize_map_end(&mut self) -> Result<()> { - self.offsets.end_seq() + self.offsets.end_seq().ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index 0acb8166..df2280af 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, NullArray}, - error::{Context, Error, Result}, + error::{Context, Result}, utils::btree_map, }; @@ -42,16 +42,6 @@ impl Context for NullBuilder { } impl SimpleSerializer for NullBuilder { - fn name(&self) -> &str { - "NullBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { self.count += 1; Ok(()) diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index c120faf3..05db939a 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -68,14 +68,6 @@ impl Context for OuterSequenceBuilder { } impl SimpleSerializer for OuterSequenceBuilder { - fn name(&self) -> &str { - "OuterSequenceBuilder" - } - - fn annotate_error(&self, err: crate::Error) -> crate::Error { - self.0.annotate_error(err) - } - fn serialize_none(&mut self) -> Result<()> { self.0.serialize_none() } diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 9a303460..80183e0b 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -23,10 +23,6 @@ use super::ArrayBuilder; /// #[allow(unused_variables)] pub trait SimpleSerializer: Sized + Context { - fn name(&self) -> &str; - - fn annotate_error(&self, err: Error) -> Error; - fn serialize_default(&mut self) -> Result<()> { fail!(in self, "serialize_default is not supported"); } @@ -274,105 +270,71 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { type SerializeTupleVariant = Mut<'a, ArrayBuilder>; fn serialize_unit(self) -> Result<()> { - self.0 - .serialize_unit() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_unit() } fn serialize_none(self) -> Result<()> { - self.0 - .serialize_none() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_none() } fn serialize_some(self, value: &V) -> Result<()> { - self.0 - .serialize_some(value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_some(value) } fn serialize_bool(self, v: bool) -> Result<()> { - self.0 - .serialize_bool(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_bool(v) } fn serialize_char(self, v: char) -> Result<()> { - self.0 - .serialize_char(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_char(v) } fn serialize_u8(self, v: u8) -> Result<()> { - self.0 - .serialize_u8(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_u8(v) } fn serialize_u16(self, v: u16) -> Result<()> { - self.0 - .serialize_u16(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_u16(v) } fn serialize_u32(self, v: u32) -> Result<()> { - self.0 - .serialize_u32(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_u32(v) } fn serialize_u64(self, v: u64) -> Result<()> { - self.0 - .serialize_u64(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_u64(v) } fn serialize_i8(self, v: i8) -> Result<()> { - self.0 - .serialize_i8(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_i8(v) } fn serialize_i16(self, v: i16) -> Result<()> { - self.0 - .serialize_i16(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_i16(v) } fn serialize_i32(self, v: i32) -> Result<()> { - self.0 - .serialize_i32(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_i32(v) } fn serialize_i64(self, v: i64) -> Result<()> { - self.0 - .serialize_i64(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_i64(v) } fn serialize_f32(self, v: f32) -> Result<()> { - self.0 - .serialize_f32(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_f32(v) } fn serialize_f64(self, v: f64) -> Result<()> { - self.0 - .serialize_f64(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_f64(v) } fn serialize_bytes(self, v: &[u8]) -> Result<()> { - self.0 - .serialize_bytes(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_bytes(v) } fn serialize_str(self, v: &str) -> Result<()> { - self.0 - .serialize_str(v) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_str(v) } fn serialize_newtype_struct( @@ -380,9 +342,7 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { name: &'static str, value: &V, ) -> Result<()> { - self.0 - .serialize_newtype_struct(name, value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_newtype_struct(name, value) } fn serialize_newtype_variant( @@ -394,13 +354,10 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { ) -> Result<()> { self.0 .serialize_newtype_variant(name, variant_index, variant, value) - .map_err(|err| self.0.annotate_error(err)) } fn serialize_unit_struct(self, name: &'static str) -> Result<()> { - self.0 - .serialize_unit_struct(name) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_unit_struct(name) } fn serialize_unit_variant( @@ -409,36 +366,26 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant_index: u32, variant: &'static str, ) -> Result<()> { - self.0 - .serialize_unit_variant(name, variant_index, variant) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_unit_variant(name, variant_index, variant) } fn serialize_map(self, len: Option) -> Result { - self.0 - .serialize_map_start(len) - .map_err(|err| self.0.annotate_error(err))?; + self.0.serialize_map_start(len)?; Ok(Mut(&mut *self.0)) } fn serialize_seq(self, len: Option) -> Result { - self.0 - .serialize_seq_start(len) - .map_err(|err| self.0.annotate_error(err))?; + self.0.serialize_seq_start(len)?; Ok(Mut(&mut *self.0)) } fn serialize_struct(self, name: &'static str, len: usize) -> Result { - self.0 - .serialize_struct_start(name, len) - .map_err(|err| self.0.annotate_error(err))?; + self.0.serialize_struct_start(name, len)?; Ok(Mut(&mut *self.0)) } fn serialize_tuple(self, len: usize) -> Result { - self.0 - .serialize_tuple_start(len) - .map_err(|err| self.0.annotate_error(err))?; + self.0.serialize_tuple_start(len)?; Ok(Mut(&mut *self.0)) } @@ -447,9 +394,7 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { name: &'static str, len: usize, ) -> Result { - self.0 - .serialize_tuple_struct_start(name, len) - .map_err(|err| self.0.annotate_error(err))?; + self.0.serialize_tuple_struct_start(name, len)?; Ok(Mut(&mut *self.0)) } @@ -460,16 +405,10 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant: &'static str, len: usize, ) -> Result { - // cannot borrow self immutably, as the result will keep self.0 borrowed mutably - // TODO: figure out how to remove this hack - let annotations_error = self.0.annotate_error(Error::empty()); - match self - .0 - .serialize_struct_variant_start(name, variant_index, variant, len) - { - Ok(variant_builder) => Ok(Mut(variant_builder)), - Err(err) => Err(merge_annotations(err, annotations_error)), - } + let variant_builder = + self.0 + .serialize_struct_variant_start(name, variant_index, variant, len)?; + Ok(Mut(variant_builder)) } fn serialize_tuple_variant( @@ -479,46 +418,27 @@ impl<'a, T: SimpleSerializer> Serializer for Mut<'a, T> { variant: &'static str, len: usize, ) -> Result { - // cannot borrow self immutably, as the result will keep self.0 borrowed mutably - // TODO: figure out how to remove this hack - let annotations_error = self.0.annotate_error(Error::empty()); - match self - .0 - .serialize_tuple_variant_start(name, variant_index, variant, len) - { - Ok(variant_builder) => Ok(Mut(variant_builder)), - Err(err) => Err(merge_annotations(err, annotations_error)), - } + let variant_builder = + self.0 + .serialize_tuple_variant_start(name, variant_index, variant, len)?; + Ok(Mut(variant_builder)) } } -pub fn merge_annotations(err: Error, annotations_err: Error) -> Error { - err.annotate_unannotated(|annotations| { - let Error::Custom(annotations_err) = annotations_err; - *annotations = annotations_err.0.annotations; - }) -} - impl<'a, T: SimpleSerializer> SerializeMap for Mut<'a, T> { type Ok = (); type Error = Error; fn serialize_key(&mut self, key: &V) -> Result<()> { - self.0 - .serialize_map_key(key) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_map_key(key) } fn serialize_value(&mut self, value: &V) -> Result<()> { - self.0 - .serialize_map_value(value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_map_value(value) } fn end(self) -> Result<()> { - self.0 - .serialize_map_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_map_end() } } @@ -527,15 +447,11 @@ impl<'a, T: SimpleSerializer> SerializeSeq for Mut<'a, T> { type Error = Error; fn serialize_element(&mut self, value: &V) -> Result<()> { - self.0 - .serialize_seq_element(value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_seq_element(value) } fn end(self) -> Result<()> { - self.0 - .serialize_seq_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_seq_end() } } @@ -548,15 +464,11 @@ impl<'a, T: SimpleSerializer> SerializeStruct for Mut<'a, T> { key: &'static str, value: &V, ) -> Result<()> { - self.0 - .serialize_struct_field(key, value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_struct_field(key, value) } fn end(self) -> Result<()> { - self.0 - .serialize_struct_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_struct_end() } } @@ -565,15 +477,11 @@ impl<'a, T: SimpleSerializer> SerializeTuple for Mut<'a, T> { type Error = Error; fn serialize_element(&mut self, value: &V) -> Result<()> { - self.0 - .serialize_tuple_element(value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_tuple_element(value) } fn end(self) -> Result<()> { - self.0 - .serialize_tuple_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_tuple_end() } } @@ -582,15 +490,11 @@ impl<'a, T: SimpleSerializer> SerializeTupleStruct for Mut<'a, T> { type Error = Error; fn serialize_field(&mut self, value: &V) -> Result<()> { - self.0 - .serialize_tuple_struct_field(value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_tuple_struct_field(value) } fn end(self) -> Result<()> { - self.0 - .serialize_tuple_struct_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_tuple_struct_end() } } @@ -603,15 +507,11 @@ impl<'a, T: SimpleSerializer> SerializeStructVariant for Mut<'a, T> { key: &'static str, value: &V, ) -> Result<()> { - self.0 - .serialize_struct_field(key, value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_struct_field(key, value) } fn end(self) -> Result<()> { - self.0 - .serialize_struct_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_struct_end() } } @@ -620,14 +520,10 @@ impl<'a, T: SimpleSerializer> SerializeTupleVariant for Mut<'a, T> { type Error = Error; fn serialize_field(&mut self, value: &V) -> Result<()> { - self.0 - .serialize_tuple_struct_field(value) - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_tuple_struct_field(value) } fn end(self) -> Result<()> { - self.0 - .serialize_tuple_struct_end() - .map_err(|err| self.0.annotate_error(err)) + self.0.serialize_tuple_struct_end() } } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 994d27f3..d0a53ac2 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, btree_map, Mut, @@ -106,9 +106,9 @@ impl StructBuilder { } fn element(&mut self, idx: usize, value: &T) -> Result<()> { - self.seq.push_seq_elements(1)?; + self.seq.push_seq_elements(1).ctx(self)?; if self.seen[idx] { - fail!("Duplicate field {key}", key = self.fields[idx].1.name); + fail!(in self, "Duplicate field {key}", key = self.fields[idx].1.name); } value.serialize(Mut(&mut self.fields[idx].0))?; @@ -125,18 +125,8 @@ impl Context for StructBuilder { } impl SimpleSerializer for StructBuilder { - fn name(&self) -> &str { - "StructBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.seq.push_seq_default()?; + self.seq.push_seq_default().ctx(self)?; for (builder, _) in &mut self.fields { builder.serialize_default()?; } @@ -145,7 +135,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_none(&mut self) -> Result<()> { - self.seq.push_seq_none()?; + self.seq.push_seq_none().ctx(self)?; for (builder, _) in &mut self.fields { builder.serialize_default()?; } @@ -153,7 +143,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_struct_field( @@ -169,11 +159,11 @@ impl SimpleSerializer for StructBuilder { } fn serialize_struct_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { @@ -181,11 +171,11 @@ impl SimpleSerializer for StructBuilder { } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start() + self.start().ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { @@ -197,7 +187,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } fn serialize_map_start(&mut self, _: Option) -> Result<()> { @@ -208,7 +198,11 @@ impl SimpleSerializer for StructBuilder { } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - self.next = self.lookup.lookup_serialize(key)?.unwrap_or(UNKNOWN_KEY); + self.next = self + .lookup + .lookup_serialize(key) + .ctx(self)? + .unwrap_or(UNKNOWN_KEY); Ok(()) } @@ -222,7 +216,7 @@ impl SimpleSerializer for StructBuilder { } fn serialize_map_end(&mut self) -> Result<()> { - self.end() + self.end().ctx(self) } } @@ -309,14 +303,6 @@ impl<'a> Context for KeyLookupSerializer<'a> { } impl<'a> SimpleSerializer for KeyLookupSerializer<'a> { - fn name(&self) -> &str { - "KeyLookupSerializer" - } - - fn annotate_error(&self, err: Error) -> Error { - err - } - fn serialize_str(&mut self, v: &str) -> Result<()> { self.result = self.index.get(v).copied(); Ok(()) diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index f9c6c703..f7a3504a 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -4,7 +4,7 @@ use chrono::Timelike; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{Context, Error, Result}, + error::{Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, btree_map, @@ -74,22 +74,12 @@ where Error: From<>::Error>, Error: From<>::Error>, { - fn name(&self) -> &str { - "Time64Builder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { @@ -101,18 +91,24 @@ where }; use chrono::naive::NaiveTime; - let time = v.parse::()?; - let timestamp = time.num_seconds_from_midnight() as i64 * seconds_factor - + time.nanosecond() as i64 / nanoseconds_factor; + let time = v.parse::().ctx(self)?; + let timestamp = i64::from(time.num_seconds_from_midnight()) * seconds_factor + + i64::from(time.nanosecond()) / nanoseconds_factor; - self.array.push_scalar_value(timestamp.try_into()?) + self.array + .push_scalar_value(timestamp.try_into().ctx(self)?) + .ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v.try_into()?) + self.array + .push_scalar_value(v.try_into().ctx(self)?) + .ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v.try_into()?) + self.array + .push_scalar_value(v.try_into().ctx(self)?) + .ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 9f8fa675..944be947 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DenseUnionArray, FieldMeta}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result, StaticContext}, utils::{btree_map, Mut}, }; @@ -82,23 +82,16 @@ impl Context for UnionBuilder { } impl SimpleSerializer for UnionBuilder { - fn name(&self) -> &str { - "UnionBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_unit_variant( &mut self, _: &'static str, variant_index: u32, _: &'static str, ) -> Result<()> { - self.serialize_variant(variant_index)?.serialize_unit() + let ctx = StaticContext::from_context(self); + self.serialize_variant(variant_index) + .ctx(&ctx)? + .serialize_unit() } fn serialize_newtype_variant( @@ -108,7 +101,8 @@ impl SimpleSerializer for UnionBuilder { _: &'static str, value: &V, ) -> Result<()> { - let variant_builder = self.serialize_variant(variant_index)?; + let ctx = StaticContext::from_context(self); + let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; value.serialize(Mut(variant_builder)) } @@ -119,7 +113,8 @@ impl SimpleSerializer for UnionBuilder { variant: &'static str, len: usize, ) -> Result<&'this mut ArrayBuilder> { - let variant_builder = self.serialize_variant(variant_index)?; + let ctx = StaticContext::from_context(self); + let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; variant_builder.serialize_struct_start(variant, len)?; Ok(variant_builder) } @@ -131,7 +126,8 @@ impl SimpleSerializer for UnionBuilder { variant: &'static str, len: usize, ) -> Result<&'this mut ArrayBuilder> { - let variant_builder = self.serialize_variant(variant_index)?; + let ctx = StaticContext::from_context(self); + let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; variant_builder.serialize_tuple_struct_start(variant, len)?; Ok(variant_builder) } diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index b20476f6..105eb9ae 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, NullArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, Result}, utils::btree_map, }; @@ -42,82 +42,72 @@ impl Context for UnknownVariantBuilder { } impl SimpleSerializer for UnknownVariantBuilder { - fn name(&self) -> &str { - "UnknownVariantBuilder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_unit(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_none(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_bool(&mut self, _: bool) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_char(&mut self, _: char) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_u8(&mut self, _: u8) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_u16(&mut self, _: u16) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_u32(&mut self, _: u32) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_u64(&mut self, _: u64) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_i8(&mut self, _: i8) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_i16(&mut self, _: i16) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_i32(&mut self, _: i32) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_i64(&mut self, _: i64) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_f32(&mut self, _: f32) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_f64(&mut self, _: f64) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_bytes(&mut self, _: &[u8]) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_str(&mut self, _: &str) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_newtype_variant( @@ -127,47 +117,47 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: &V, ) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_unit_struct(&mut self, _: &'static str) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_unit_variant(&mut self, _: &'static str, _: u32, _: &'static str) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_map_start(&mut self, _: Option) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_map_key(&mut self, _: &V) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_map_value(&mut self, _: &V) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_map_end(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_seq_element(&mut self, _: &V) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_seq_end(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_struct_field( @@ -175,35 +165,35 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: &V, ) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_struct_end(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_tuple_element(&mut self, _: &V) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_tuple_end(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_tuple_struct_field(&mut self, _: &V) -> Result<()> { - fail!("Serialization failed: an unknown variant"); + fail!(in self, "Serialization failed: an unknown variant"); } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - fail!("Serialization failed: an unknown variant"); + fail!(in self, "Serialization failed: an unknown variant"); } fn serialize_struct_variant_start<'this>( @@ -213,7 +203,7 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: usize, ) -> Result<&'this mut ArrayBuilder> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } fn serialize_tuple_variant_start<'this>( @@ -223,6 +213,6 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: usize, ) -> Result<&'this mut ArrayBuilder> { - fail!("Serialization failed: an unknown variant") + fail!(in self, "Serialization failed: an unknown variant") } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 0bf6321f..b392fdc2 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, BytesArray}, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, btree_map, Offset, @@ -56,26 +56,16 @@ impl Context for Utf8Builder { } impl SimpleSerializer for Utf8Builder { - fn name(&self) -> &str { - "Utf8Builder" - } - - fn annotate_error(&self, err: Error) -> Error { - err.annotate_unannotated(|annotations| { - annotations.insert(String::from("field"), self.path.clone()); - }) - } - fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default() + self.array.push_scalar_default().ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none() + self.array.push_scalar_none().ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - self.array.push_scalar_value(v.as_bytes()) + self.array.push_scalar_value(v.as_bytes()).ctx(self) } fn serialize_unit_variant( @@ -84,7 +74,7 @@ impl SimpleSerializer for Utf8Builder { _: u32, variant: &'static str, ) -> Result<()> { - self.array.push_scalar_value(variant.as_bytes()) + self.array.push_scalar_value(variant.as_bytes()).ctx(self) } fn serialize_tuple_variant_start<'this>( @@ -94,7 +84,7 @@ impl SimpleSerializer for Utf8Builder { _: &'static str, _: usize, ) -> Result<&'this mut super::ArrayBuilder> { - fail!("Cannot serialize enum with data as string"); + fail!(in self, "Cannot serialize enum with data as string"); } fn serialize_struct_variant_start<'this>( @@ -104,7 +94,7 @@ impl SimpleSerializer for Utf8Builder { _: &'static str, _: usize, ) -> Result<&'this mut super::ArrayBuilder> { - fail!("Cannot serialize enum with data as string"); + fail!(in self, "Cannot serialize enum with data as string"); } fn serialize_newtype_variant( @@ -114,6 +104,6 @@ impl SimpleSerializer for Utf8Builder { _: &'static str, _: &V, ) -> Result<()> { - fail!("Cannot serialize enum with data as string"); + fail!(in self, "Cannot serialize enum with data as string"); } } From 1aef722c35514c797cab0cc87b9f109ab202ecaa Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 15:12:21 +0200 Subject: [PATCH 139/178] Use ArrayBuilder as return value of take --- .../internal/serialization/array_builder.rs | 37 +------------------ .../internal/serialization/binary_builder.rs | 12 +++++- .../internal/serialization/bool_builder.rs | 8 ++-- .../internal/serialization/date32_builder.rs | 8 ++-- .../internal/serialization/date64_builder.rs | 8 ++-- .../internal/serialization/decimal_builder.rs | 8 ++-- .../serialization/dictionary_utf8_builder.rs | 6 +-- .../serialization/duration_builder.rs | 8 ++-- .../fixed_size_binary_builder.rs | 8 ++-- .../serialization/fixed_size_list_builder.rs | 6 +-- .../internal/serialization/float_builder.rs | 18 +++++---- .../src/internal/serialization/int_builder.rs | 28 ++++++++------ .../internal/serialization/list_builder.rs | 10 ++++- .../src/internal/serialization/map_builder.rs | 6 +-- .../internal/serialization/null_builder.rs | 8 ++-- .../serialization/outer_sequence_builder.rs | 2 +- .../internal/serialization/struct_builder.rs | 6 ++- .../internal/serialization/time_builder.rs | 12 +++++- .../internal/serialization/union_builder.rs | 8 ++-- .../serialization/unknown_variant_builder.rs | 8 ++-- .../internal/serialization/utf8_builder.rs | 12 +++++- 21 files changed, 118 insertions(+), 109 deletions(-) diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 0bdcc2b5..3f9281aa 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -106,43 +106,8 @@ impl ArrayBuilder { } impl ArrayBuilder { - /// Take the contained array builder, while leaving structure intact - // TODO: use ArrayBuilder as return type for the impls and use dispatch here pub fn take(&mut self) -> ArrayBuilder { - match self { - Self::Null(builder) => Self::Null(builder.take()), - Self::Bool(builder) => Self::Bool(builder.take()), - Self::I8(builder) => Self::I8(builder.take()), - Self::I16(builder) => Self::I16(builder.take()), - Self::I32(builder) => Self::I32(builder.take()), - Self::I64(builder) => Self::I64(builder.take()), - Self::U8(builder) => Self::U8(builder.take()), - Self::U16(builder) => Self::U16(builder.take()), - Self::U32(builder) => Self::U32(builder.take()), - Self::U64(builder) => Self::U64(builder.take()), - Self::F16(builder) => Self::F16(builder.take()), - Self::F32(builder) => Self::F32(builder.take()), - Self::F64(builder) => Self::F64(builder.take()), - Self::Date32(builder) => Self::Date32(builder.take()), - Self::Date64(builder) => Self::Date64(builder.take()), - Self::Time32(builder) => Self::Time32(builder.take()), - Self::Time64(builder) => Self::Time64(builder.take()), - Self::Duration(builder) => Self::Duration(builder.take()), - Self::Decimal128(builder) => Self::Decimal128(builder.take()), - Self::Utf8(builder) => Self::Utf8(builder.take()), - Self::LargeUtf8(builder) => Self::LargeUtf8(builder.take()), - Self::List(builder) => Self::List(builder.take()), - Self::LargeList(builder) => Self::LargeList(builder.take()), - Self::FixedSizedList(builder) => Self::FixedSizedList(builder.take()), - Self::Binary(builder) => Self::Binary(builder.take()), - Self::LargeBinary(builder) => Self::LargeBinary(builder.take()), - Self::FixedSizeBinary(builder) => Self::FixedSizeBinary(builder.take()), - Self::Struct(builder) => Self::Struct(builder.take()), - Self::Map(builder) => Self::Map(builder.take()), - Self::DictionaryUtf8(builder) => Self::DictionaryUtf8(builder.take()), - Self::Union(builder) => Self::Union(builder.take()), - Self::UnknownVariant(builder) => Self::UnknownVariant(builder.take()), - } + dispatch!(self, Self(builder) => builder.take()) } } diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index de98e655..2502312a 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] @@ -28,7 +28,7 @@ impl BinaryBuilder { } } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), array: self.array.take(), @@ -41,12 +41,20 @@ impl BinaryBuilder { } impl BinaryBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Binary(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::Binary(self.array)) } } impl BinaryBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::LargeBinary(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::LargeBinary(self.array)) } diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 2d9c0b40..1c8b6ce0 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct BoolBuilder { @@ -29,15 +29,15 @@ impl BoolBuilder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Bool(Self { path: self.path.clone(), array: BooleanArray { len: std::mem::take(&mut self.array.len), validity: self.array.validity.as_mut().map(std::mem::take), values: std::mem::take(&mut self.array.values), }, - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index d6d1b40c..0f1c4859 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct Date32Builder { @@ -27,11 +27,11 @@ impl Date32Builder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Date32(Self { path: self.path.clone(), array: self.array.take(), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 65c8bf9a..bef9709d 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct Date64Builder { @@ -34,13 +34,13 @@ impl Date64Builder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Date64(Self { path: self.path.clone(), meta: self.meta.clone(), utc: self.utc, array: self.array.take(), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 173451d8..401957bd 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -10,7 +10,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct DecimalBuilder { @@ -36,8 +36,8 @@ impl DecimalBuilder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Decimal128(Self { path: self.path.clone(), precision: self.precision, scale: self.scale, @@ -45,7 +45,7 @@ impl DecimalBuilder { f64_factor: self.f64_factor, parser: self.parser, array: self.array.take(), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index 48fb23e2..22c46489 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -28,13 +28,13 @@ impl DictionaryUtf8Builder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::DictionaryUtf8(Self { path: self.path.clone(), indices: Box::new(self.indices.take()), values: Box::new(self.values.take()), index: std::mem::take(&mut self.index), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 353c29e9..01ad93cb 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct DurationBuilder { @@ -27,12 +27,12 @@ impl DurationBuilder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Duration(Self { path: self.path.clone(), unit: self.unit, array: self.array.take(), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 4de21459..29d799a6 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] @@ -34,14 +34,14 @@ impl FixedSizeBinaryBuilder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::FixedSizeBinary(Self { path: self.path.clone(), seq: self.seq.take(), buffer: std::mem::take(&mut self.buffer), current_n: std::mem::take(&mut self.current_n), n: self.n, - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index e2307f9c..a3df91a3 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -42,15 +42,15 @@ impl FixedSizeListBuilder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::FixedSizedList(Self { path: self.path.clone(), seq: self.seq.take(), meta: self.meta.clone(), n: self.n, current_count: std::mem::take(&mut self.current_count), element: Box::new(self.element.take()), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 5d9bc519..79b5fa28 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct FloatBuilder { @@ -27,7 +27,7 @@ impl FloatBuilder { } } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), array: self.array.take(), @@ -40,18 +40,22 @@ impl FloatBuilder { } macro_rules! impl_into_array { - ($ty:ty, $var:ident) => { + ($ty:ty, $builder_var:ident, $array_var:ident) => { impl FloatBuilder<$ty> { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::$builder_var(self.take_self()) + } + pub fn into_array(self) -> Result { - Ok(Array::$var(self.array)) + Ok(Array::$array_var(self.array)) } } }; } -impl_into_array!(f16, Float16); -impl_into_array!(f32, Float32); -impl_into_array!(f64, Float64); +impl_into_array!(f16, F16, Float16); +impl_into_array!(f32, F32, Float32); +impl_into_array!(f64, F64, Float64); impl Context for FloatBuilder { fn annotations(&self) -> BTreeMap { diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 2f621e39..e92c7ac1 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct IntBuilder { @@ -25,7 +25,7 @@ impl IntBuilder { } } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), array: self.array.take(), @@ -38,23 +38,27 @@ impl IntBuilder { } macro_rules! impl_into_array { - ($ty:ty, $var:ident) => { + ($ty:ty, $builder_var: ident, $array_var:ident) => { impl IntBuilder<$ty> { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::$builder_var(self.take_self()) + } + pub fn into_array(self) -> Result { - Ok(Array::$var(self.array)) + Ok(Array::$array_var(self.array)) } } }; } -impl_into_array!(i8, Int8); -impl_into_array!(i16, Int16); -impl_into_array!(i32, Int32); -impl_into_array!(i64, Int64); -impl_into_array!(u8, UInt8); -impl_into_array!(u16, UInt16); -impl_into_array!(u32, UInt32); -impl_into_array!(u64, UInt64); +impl_into_array!(i8, I8, Int8); +impl_into_array!(i16, I16, Int16); +impl_into_array!(i32, I32, Int32); +impl_into_array!(i64, I64, Int64); +impl_into_array!(u8, U8, UInt8); +impl_into_array!(u16, U16, UInt16); +impl_into_array!(u32, U32, UInt32); +impl_into_array!(u64, U64, UInt64); impl Context for IntBuilder { fn annotations(&self) -> BTreeMap { diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 05bf2ddd..21b77003 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -37,7 +37,7 @@ impl ListBuilder { }) } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), meta: self.meta.clone(), @@ -52,6 +52,10 @@ impl ListBuilder { } impl ListBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::List(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::List(ListArray { validity: self.offsets.validity, @@ -63,6 +67,10 @@ impl ListBuilder { } impl ListBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::LargeList(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::LargeList(ListArray { validity: self.offsets.validity, diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index e803dc1b..bcca4563 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -47,13 +47,13 @@ impl MapBuilder { Ok(()) } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Map(Self { path: self.path.clone(), meta: self.meta.clone(), offsets: self.offsets.take(), entry: Box::new(self.entry.take()), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index df2280af..6b6b9be4 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -6,7 +6,7 @@ use crate::internal::{ utils::btree_map, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct NullBuilder { @@ -19,11 +19,11 @@ impl NullBuilder { Self { path, count: 0 } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Null(Self { path: self.path.clone(), count: std::mem::take(&mut self.count), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 05db939a..aff3c8c5 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -38,7 +38,7 @@ impl OuterSequenceBuilder { /// Extract the contained struct fields pub fn take_records(&mut self) -> Result> { let mut result = Vec::new(); - for (builder, _) in self.0.take().fields { + for (builder, _) in self.0.take_self().fields { result.push(builder); } Ok(result) diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index d0a53ac2..31c28144 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -43,7 +43,7 @@ impl StructBuilder { }) } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), fields: self @@ -58,6 +58,10 @@ impl StructBuilder { } } + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Struct(self.take_self()) + } + pub fn is_nullable(&self) -> bool { self.seq.validity.is_some() } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index f7a3504a..b5a3fda7 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct TimeBuilder { @@ -29,7 +29,7 @@ impl TimeBuilder { } } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), unit: self.unit, @@ -43,6 +43,10 @@ impl TimeBuilder { } impl TimeBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Time32(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::Time32(TimeArray { unit: self.unit, @@ -53,6 +57,10 @@ impl TimeBuilder { } impl TimeBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Time64(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::Time64(TimeArray { unit: self.unit, diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 944be947..98590e60 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -6,7 +6,7 @@ use crate::internal::{ utils::{btree_map, Mut}, }; -use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct UnionBuilder { @@ -28,8 +28,8 @@ impl UnionBuilder { } } - pub fn take(&mut self) -> Self { - Self { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Union(Self { path: self.path.clone(), fields: self .fields @@ -39,7 +39,7 @@ impl UnionBuilder { types: std::mem::take(&mut self.types), offsets: std::mem::take(&mut self.offsets), current_offset: std::mem::replace(&mut self.current_offset, vec![0; self.fields.len()]), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 105eb9ae..b90472e5 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -8,7 +8,7 @@ use crate::internal::{ utils::btree_map, }; -use super::{simple_serializer::SimpleSerializer, ArrayBuilder}; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct UnknownVariantBuilder { @@ -20,10 +20,10 @@ impl UnknownVariantBuilder { UnknownVariantBuilder { path } } - pub fn take(&mut self) -> Self { - UnknownVariantBuilder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::UnknownVariant(UnknownVariantBuilder { path: self.path.clone(), - } + }) } pub fn is_nullable(&self) -> bool { diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index b392fdc2..0fe2dfe2 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -9,7 +9,7 @@ use crate::internal::{ }, }; -use super::simple_serializer::SimpleSerializer; +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct Utf8Builder { @@ -25,7 +25,7 @@ impl Utf8Builder { } } - pub fn take(&mut self) -> Self { + pub fn take_self(&mut self) -> Self { Self { path: self.path.clone(), array: self.array.take(), @@ -38,12 +38,20 @@ impl Utf8Builder { } impl Utf8Builder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::Utf8(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::Utf8(self.array)) } } impl Utf8Builder { + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::LargeUtf8(self.take_self()) + } + pub fn into_array(self) -> Result { Ok(Array::LargeUtf8(self.array)) } From f5434f6f669b637137516d245194c2141d270985 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 15:35:51 +0200 Subject: [PATCH 140/178] Add some style guides --- Development.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Development.md b/Development.md index 60d67a61..4682635c 100644 --- a/Development.md +++ b/Development.md @@ -31,7 +31,10 @@ modules can can be run without installing further packages. ## Error format -- Include the path to the field where sensible +Style: + +- Use uppercase letters to start the error message +- Do not include trailing punctuation (e.g., "Not supported", not "Not supported.") Common annotations: From 45b1b054a54969067c968869179add5c7954b66f Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 15:36:20 +0200 Subject: [PATCH 141/178] Improve filed path formatting, remove unncessary Result --- serde_arrow/src/arrow2_impl/array.rs | 10 +-- serde_arrow/src/arrow_impl/array.rs | 12 +-- .../serialization/outer_sequence_builder.rs | 38 +++++++--- .../serialization/unknown_variant_builder.rs | 76 +++++++++---------- serde_arrow/src/internal/utils/mod.rs | 6 +- .../src/test_with_arrow/impls/union.rs | 3 +- 6 files changed, 81 insertions(+), 64 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 7df8f857..6a32fd5e 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -298,7 +298,7 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { ); }; Ok(V::List(ListArrayView { - meta: meta_from_field(field.as_ref().try_into()?)?, + meta: meta_from_field(field.as_ref().try_into()?), validity: bits_with_offset_from_bitmap(array.validity()), offsets: array.offsets().as_slice(), element: Box::new(array.values().as_ref().try_into()?), @@ -311,7 +311,7 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { ); }; Ok(V::LargeList(ListArrayView { - meta: meta_from_field(field.as_ref().try_into()?)?, + meta: meta_from_field(field.as_ref().try_into()?), validity: bits_with_offset_from_bitmap(array.validity()), offsets: array.offsets().as_slice(), element: Box::new(array.values().as_ref().try_into()?), @@ -327,7 +327,7 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { for (child_field, child) in child_fields.iter().zip(array.values()) { fields.push(( child.as_ref().try_into()?, - meta_from_field(child_field.try_into()?)?, + meta_from_field(child_field.try_into()?), )); } Ok(V::Struct(StructArrayView { @@ -342,7 +342,7 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { array.data_type(), ); }; - let meta = meta_from_field(field.as_ref().try_into()?)?; + let meta = meta_from_field(field.as_ref().try_into()?); let element: ArrayView<'_> = array.field().as_ref().try_into()?; Ok(V::Map(ListArrayView { @@ -378,7 +378,7 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { fields.push(( (*type_id).try_into()?, child.as_ref().try_into()?, - meta_from_field(child_field.try_into()?)?, + meta_from_field(child_field.try_into()?), )); } diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index 8340ff57..9c03b3dd 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -433,7 +433,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { Ok(ArrayView::List(ListArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - meta: meta_from_field(field.as_ref().try_into()?)?, + meta: meta_from_field(field.as_ref().try_into()?), element: Box::new(array.values().as_ref().try_into()?), })) } else if let Some(array) = any.downcast_ref::>() { @@ -443,7 +443,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { Ok(ArrayView::LargeList(ListArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - meta: meta_from_field(field.as_ref().try_into()?)?, + meta: meta_from_field(field.as_ref().try_into()?), element: Box::new(array.values().as_ref().try_into()?), })) } else if let Some(array) = any.downcast_ref::() { @@ -454,7 +454,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { len: array.len(), n: *n, validity: get_bits_with_offset(array), - meta: meta_from_field(field.as_ref().try_into()?)?, + meta: meta_from_field(field.as_ref().try_into()?), element: Box::new(array.values().as_ref().try_into()?), })) } else if let Some(array) = any.downcast_ref::() { @@ -465,7 +465,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { let mut fields = Vec::new(); for (field, array) in std::iter::zip(column_fields, array.columns()) { let view = ArrayView::try_from(array.as_ref())?; - let meta = meta_from_field(Field::try_from(field.as_ref())?)?; + let meta = meta_from_field(Field::try_from(field.as_ref())?); fields.push((view, meta)); } @@ -483,7 +483,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { Ok(ArrayView::Map(ListArrayView { validity: get_bits_with_offset(array), offsets: array.value_offsets(), - meta: meta_from_field(Field::try_from(entries_field.as_ref())?)?, + meta: meta_from_field(Field::try_from(entries_field.as_ref())?), element: Box::new(entries_array.try_into()?), })) } else if let Some(array) = any.downcast_ref::>() { @@ -509,7 +509,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayView<'a> { let mut fields = Vec::new(); for (type_id, field) in union_fields.iter() { - let meta = meta_from_field(Field::try_from(field.as_ref())?)?; + let meta = meta_from_field(Field::try_from(field.as_ref())?); let view: ArrayView = array.child(type_id).as_ref().try_into()?; fields.push((type_id, view, meta)); } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index aff3c8c5..285dde41 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -115,7 +115,7 @@ fn build_struct(path: String, struct_fields: &[Field], nullable: bool) -> Result let field_path = format!("{path}.{field_name}", field_name = field.name); fields.push(( build_builder(field_path, field)?, - meta_from_field(field.clone())?, + meta_from_field(field.clone()), )); } StructBuilder::new(path, fields, nullable) @@ -176,28 +176,28 @@ fn build_builder(path: String, field: &Field) -> Result { T::Utf8 => A::Utf8(Utf8Builder::new(path, field.nullable)), T::LargeUtf8 => A::LargeUtf8(Utf8Builder::new(path, field.nullable)), T::List(child) => { - let child_path = format!("{path}.{child_name}", child_name = child.name); + let child_path = format!("{path}.{child_name}", child_name = ChildName(&child.name)); A::List(ListBuilder::new( path, - meta_from_field(*child.clone())?, + meta_from_field(*child.clone()), build_builder(child_path, child.as_ref())?, field.nullable, )?) } T::LargeList(child) => { - let child_path = format!("{path}.{child_name}", child_name = child.name); + let child_path = format!("{path}.{child_name}", child_name = ChildName(&child.name)); A::LargeList(ListBuilder::new( path, - meta_from_field(*child.clone())?, + meta_from_field(*child.clone()), build_builder(child_path, child.as_ref())?, field.nullable, )?) } T::FixedSizeList(child, n) => { - let child_path = format!("{path}.{child_name}", child_name = child.name); + let child_path = format!("{path}.{child_name}", child_name = ChildName(&child.name)); A::FixedSizedList(FixedSizeListBuilder::new( path, - meta_from_field(*child.clone())?, + meta_from_field(*child.clone()), build_builder(child_path, child.as_ref())?, (*n).try_into()?, field.nullable, @@ -211,10 +211,13 @@ fn build_builder(path: String, field: &Field) -> Result { field.nullable, )), T::Map(entry_field, _) => { - let child_path = format!("{path}.{child_name}", child_name = entry_field.name); + let child_path = format!( + "{path}.{child_name}", + child_name = ChildName(&entry_field.name) + ); A::Map(MapBuilder::new( path, - meta_from_field(*entry_field.clone())?, + meta_from_field(*entry_field.clone()), build_builder(child_path, entry_field.as_ref())?, field.nullable, )?) @@ -249,10 +252,11 @@ fn build_builder(path: String, field: &Field) -> Result { if usize::try_from(*type_id) != Ok(idx) { fail!("non consecutive type ids are not supported"); } - let field_path = format!("{path}.{field_name}", field_name = field.name); + let field_path = + format!("{path}.{field_name}", field_name = ChildName(&field.name)); fields.push(( build_builder(field_path, field)?, - meta_from_field(field.clone())?, + meta_from_field(field.clone()), )); } @@ -277,3 +281,15 @@ fn is_utc_strategy(strategy: Option<&Strategy>) -> Result { Some(st) => fail!("Cannot builder Date64 builder with strategy {st}"), } } + +struct ChildName<'a>(&'a str); + +impl<'a> std::fmt::Display for ChildName<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if !self.0.is_empty() { + write!(f, "{}", self.0) + } else { + write!(f, "") + } + } +} diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index b90472e5..cf993412 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -43,71 +43,71 @@ impl Context for UnknownVariantBuilder { impl SimpleSerializer for UnknownVariantBuilder { fn serialize_default(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_default") } fn serialize_unit(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_unit") } fn serialize_none(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_none") } fn serialize_bool(&mut self, _: bool) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_bool") } fn serialize_char(&mut self, _: char) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_char") } fn serialize_u8(&mut self, _: u8) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_u8") } fn serialize_u16(&mut self, _: u16) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_u16") } fn serialize_u32(&mut self, _: u32) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_u32") } fn serialize_u64(&mut self, _: u64) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_u64") } fn serialize_i8(&mut self, _: i8) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_i8") } fn serialize_i16(&mut self, _: i16) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_i16") } fn serialize_i32(&mut self, _: i32) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_i32") } fn serialize_i64(&mut self, _: i64) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_i64") } fn serialize_f32(&mut self, _: f32) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_f32") } fn serialize_f64(&mut self, _: f64) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_f64") } fn serialize_bytes(&mut self, _: &[u8]) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_bytes") } fn serialize_str(&mut self, _: &str) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_str") } fn serialize_newtype_variant( @@ -117,47 +117,47 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: &V, ) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_newtype_variant") } fn serialize_unit_struct(&mut self, _: &'static str) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_unit_struct") } fn serialize_unit_variant(&mut self, _: &'static str, _: u32, _: &'static str) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_unit_variant") } fn serialize_map_start(&mut self, _: Option) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_map_start") } fn serialize_map_key(&mut self, _: &V) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_map_jey") } fn serialize_map_value(&mut self, _: &V) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_map_value") } fn serialize_map_end(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_map_end") } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_seq_start") } fn serialize_seq_element(&mut self, _: &V) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_seq_element") } fn serialize_seq_end(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_seq_end") } fn serialize_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_struct_start") } fn serialize_struct_field( @@ -165,35 +165,35 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: &V, ) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_struct_field") } fn serialize_struct_end(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_struct_end") } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_tuple_start") } fn serialize_tuple_element(&mut self, _: &V) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_tuple_element") } fn serialize_tuple_end(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_tuple_end") } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_tuple_struct_start") } fn serialize_tuple_struct_field(&mut self, _: &V) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant"); + fail!(in self, "Unknown variant does not support serialize_tuple_struct_field"); } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - fail!(in self, "Serialization failed: an unknown variant"); + fail!(in self, "Unknown variant does not support serialize_tuple_struct_end"); } fn serialize_struct_variant_start<'this>( @@ -203,7 +203,7 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: usize, ) -> Result<&'this mut ArrayBuilder> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_struct_variant_start") } fn serialize_tuple_variant_start<'this>( @@ -213,6 +213,6 @@ impl SimpleSerializer for UnknownVariantBuilder { _: &'static str, _: usize, ) -> Result<&'this mut ArrayBuilder> { - fail!(in self, "Serialization failed: an unknown variant") + fail!(in self, "Unknown variant does not support serialize_tuple_variant_start") } } diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 36586c1b..abbcebf8 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -180,12 +180,12 @@ impl Offset for i64 { } } -pub fn meta_from_field(field: Field) -> Result { - Ok(FieldMeta { +pub fn meta_from_field(field: Field) -> FieldMeta { + FieldMeta { name: field.name, nullable: field.nullable, metadata: field.metadata, - }) + } } macro_rules! btree_map { diff --git a/serde_arrow/src/test_with_arrow/impls/union.rs b/serde_arrow/src/test_with_arrow/impls/union.rs index 80894af2..c3a3e632 100644 --- a/serde_arrow/src/test_with_arrow/impls/union.rs +++ b/serde_arrow/src/test_with_arrow/impls/union.rs @@ -377,7 +377,8 @@ fn missing_union_variants() { // NOTE: variant B was never encountered during tracing let res = crate::to_arrow(&fields, &Items(&[U::A, U::B, U::C])); - assert_error_contains(&res, "Serialization failed: an unknown variant"); + assert_error_contains(&res, "Unknown variant does not support serialize_unit"); + assert_error_contains(&res, "field: \"$.item.\"") } #[test] From 218586b5dfd3fea96611ef7f83fc29d74166d165 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 15:44:17 +0200 Subject: [PATCH 142/178] Add error contexts to OuterSequence builder contruction --- serde_arrow/src/internal/error.rs | 12 +---- .../internal/serialization/list_builder.rs | 11 ++--- .../serialization/outer_sequence_builder.rs | 44 ++++++++++--------- .../internal/serialization/union_builder.rs | 10 ++--- 4 files changed, 34 insertions(+), 43 deletions(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index ac11d30b..136bc62f 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -9,17 +9,9 @@ pub trait Context { fn annotations(&self) -> BTreeMap; } -pub struct StaticContext(BTreeMap); - -impl StaticContext { - pub fn from_context(context: &C) -> Self { - Self(context.annotations()) - } -} - -impl Context for StaticContext { +impl Context for BTreeMap { fn annotations(&self) -> BTreeMap { - self.0.clone() + self.clone() } } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 21b77003..811a33e4 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -23,18 +23,13 @@ pub struct ListBuilder { } impl ListBuilder { - pub fn new( - path: String, - meta: FieldMeta, - element: ArrayBuilder, - is_nullable: bool, - ) -> Result { - Ok(Self { + pub fn new(path: String, meta: FieldMeta, element: ArrayBuilder, is_nullable: bool) -> Self { + Self { path, meta, element: Box::new(element), offsets: OffsetsArray::new(is_nullable), - }) + } } pub fn take_self(&mut self) -> Self { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 285dde41..0e5fa071 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -4,14 +4,14 @@ use serde::Serialize; use crate::internal::{ arrow::{DataType, Field, TimeUnit}, - error::{fail, Context, Result}, + error::{fail, Context, ContextSupport, Result}, schema::{get_strategy_from_metadata, SerdeArrowSchema, Strategy}, serialization::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, fixed_size_list_builder::FixedSizeListBuilder, }, - utils::{meta_from_field, Mut}, + utils::{btree_map, meta_from_field, Mut}, }; use super::{ @@ -123,6 +123,7 @@ fn build_struct(path: String, struct_fields: &[Field], nullable: bool) -> Result fn build_builder(path: String, field: &Field) -> Result { use {ArrayBuilder as A, DataType as T}; + let ctx: BTreeMap = btree_map!("path" => path.clone()); let builder = match &field.data_type { T::Null => match get_strategy_from_metadata(&field.metadata)? { @@ -151,18 +152,18 @@ fn build_builder(path: String, field: &Field) -> Result { T::Timestamp(unit, tz) => A::Date64(Date64Builder::new( path, Some((*unit, tz.clone())), - is_utc_tz(tz.as_deref())?, + is_utc_tz(tz.as_deref()).ctx(&ctx)?, field.nullable, )), T::Time32(unit) => { if !matches!(unit, TimeUnit::Second | TimeUnit::Millisecond) { - fail!("Only timestamps with second or millisecond unit are supported"); + fail!(in ctx, "Time32 only supports second or millisecond resolutions"); } A::Time32(TimeBuilder::new(path, *unit, field.nullable)) } T::Time64(unit) => { if !matches!(unit, TimeUnit::Nanosecond | TimeUnit::Microsecond) { - fail!("Only timestamps with nanosecond or microsecond unit are supported"); + fail!(in ctx, "Time64 only supports nanosecond or microsecond resolutions"); } A::Time64(TimeBuilder::new(path, *unit, field.nullable)) } @@ -182,7 +183,7 @@ fn build_builder(path: String, field: &Field) -> Result { meta_from_field(*child.clone()), build_builder(child_path, child.as_ref())?, field.nullable, - )?) + )) } T::LargeList(child) => { let child_path = format!("{path}.{child_name}", child_name = ChildName(&child.name)); @@ -191,36 +192,39 @@ fn build_builder(path: String, field: &Field) -> Result { meta_from_field(*child.clone()), build_builder(child_path, child.as_ref())?, field.nullable, - )?) + )) } T::FixedSizeList(child, n) => { let child_path = format!("{path}.{child_name}", child_name = ChildName(&child.name)); + let n = usize::try_from(*n).ctx(&ctx)?; A::FixedSizedList(FixedSizeListBuilder::new( path, meta_from_field(*child.clone()), build_builder(child_path, child.as_ref())?, - (*n).try_into()?, + n, field.nullable, )) } T::Binary => A::Binary(BinaryBuilder::new(path, field.nullable)), T::LargeBinary => A::LargeBinary(BinaryBuilder::new(path, field.nullable)), - T::FixedSizeBinary(n) => A::FixedSizeBinary(FixedSizeBinaryBuilder::new( - path, - (*n).try_into()?, - field.nullable, - )), + T::FixedSizeBinary(n) => { + let n = usize::try_from(*n).ctx(&ctx)?; + A::FixedSizeBinary(FixedSizeBinaryBuilder::new(path, n, field.nullable)) + } T::Map(entry_field, _) => { let child_path = format!( "{path}.{child_name}", child_name = ChildName(&entry_field.name) ); - A::Map(MapBuilder::new( - path, - meta_from_field(*entry_field.clone()), - build_builder(child_path, entry_field.as_ref())?, - field.nullable, - )?) + A::Map( + MapBuilder::new( + path, + meta_from_field(*entry_field.clone()), + build_builder(child_path, entry_field.as_ref())?, + field.nullable, + ) + .ctx(&ctx)?, + ) } T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?), T::Dictionary(key, value, _) => { @@ -250,7 +254,7 @@ fn build_builder(path: String, field: &Field) -> Result { let mut fields = Vec::new(); for (idx, (type_id, field)) in union_fields.iter().enumerate() { if usize::try_from(*type_id) != Ok(idx) { - fail!("non consecutive type ids are not supported"); + fail!(in ctx, "Union with non consecutive type ids are not supported"); } let field_path = format!("{path}.{field_name}", field_name = ChildName(&field.name)); diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 98590e60..44ee388a 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DenseUnionArray, FieldMeta}, - error::{fail, Context, ContextSupport, Result, StaticContext}, + error::{fail, Context, ContextSupport, Result}, utils::{btree_map, Mut}, }; @@ -88,7 +88,7 @@ impl SimpleSerializer for UnionBuilder { variant_index: u32, _: &'static str, ) -> Result<()> { - let ctx = StaticContext::from_context(self); + let ctx = self.annotations(); self.serialize_variant(variant_index) .ctx(&ctx)? .serialize_unit() @@ -101,7 +101,7 @@ impl SimpleSerializer for UnionBuilder { _: &'static str, value: &V, ) -> Result<()> { - let ctx = StaticContext::from_context(self); + let ctx = self.annotations(); let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; value.serialize(Mut(variant_builder)) } @@ -113,7 +113,7 @@ impl SimpleSerializer for UnionBuilder { variant: &'static str, len: usize, ) -> Result<&'this mut ArrayBuilder> { - let ctx = StaticContext::from_context(self); + let ctx = self.annotations(); let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; variant_builder.serialize_struct_start(variant, len)?; Ok(variant_builder) @@ -126,7 +126,7 @@ impl SimpleSerializer for UnionBuilder { variant: &'static str, len: usize, ) -> Result<&'this mut ArrayBuilder> { - let ctx = StaticContext::from_context(self); + let ctx = self.annotations(); let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; variant_builder.serialize_tuple_struct_start(variant, len)?; Ok(variant_builder) From f7e2225150f5071d3d4e75a7834e54d98e8c18d7 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 15:48:00 +0200 Subject: [PATCH 143/178] Review error messages in serialization --- serde_arrow/src/internal/serialization/map_builder.rs | 4 ++-- serde_arrow/src/internal/serialization/struct_builder.rs | 2 +- serde_arrow/src/internal/serialization/union_builder.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index bcca4563..522684df 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -39,10 +39,10 @@ impl MapBuilder { fn validate_entry(entry: &ArrayBuilder) -> Result<()> { let ArrayBuilder::Struct(entry) = entry else { - fail!("entry field of a map must be a struct field"); + fail!("Entry field of a map must be a struct field"); }; if entry.fields.len() != 2 { - fail!("entry field of a map must be a struct field with 2 fields"); + fail!("Entry field of a map must be a struct field with 2 fields"); } Ok(()) } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 31c28144..0a130aaf 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -98,7 +98,7 @@ impl StructBuilder { if !*seen { if !self.fields[idx].1.nullable { fail!( - "missing non-nullable field {:?} in struct", + "Missing non-nullable field {:?} in struct", self.fields[idx].1.name ); } diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 44ee388a..dcf66c64 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -64,7 +64,7 @@ impl UnionBuilder { pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> { let variant_index = variant_index as usize; let Some((variant_builder, _)) = self.fields.get_mut(variant_index) else { - fail!("Unknown variant {variant_index}"); + fail!("Could not find variant {variant_index} in Union"); }; self.offsets.push(self.current_offset[variant_index]); From 8e62280b9adf90b29e13910b43a88662a4e31999 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 16:17:06 +0200 Subject: [PATCH 144/178] Include data types in context for serialization --- .../internal/serialization/binary_builder.rs | 13 ++++++++---- .../internal/serialization/bool_builder.rs | 2 +- .../internal/serialization/date32_builder.rs | 2 +- .../internal/serialization/date64_builder.rs | 7 ++++++- .../internal/serialization/decimal_builder.rs | 2 +- .../serialization/dictionary_utf8_builder.rs | 2 +- .../serialization/duration_builder.rs | 2 +- .../fixed_size_binary_builder.rs | 2 +- .../serialization/fixed_size_list_builder.rs | 2 +- .../internal/serialization/float_builder.rs | 16 +++++++++++++-- .../src/internal/serialization/int_builder.rs | 20 +++++++++++++++---- .../internal/serialization/list_builder.rs | 15 +++++++++----- .../src/internal/serialization/map_builder.rs | 2 +- .../internal/serialization/null_builder.rs | 2 +- .../internal/serialization/struct_builder.rs | 2 +- .../internal/serialization/time_builder.rs | 13 ++++++++---- .../internal/serialization/union_builder.rs | 2 +- .../serialization/unknown_variant_builder.rs | 2 +- .../internal/serialization/utf8_builder.rs | 14 +++++++++---- serde_arrow/src/internal/utils/mod.rs | 17 ++++++++++++++++ 20 files changed, 103 insertions(+), 36 deletions(-) diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 2502312a..2126dea6 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -7,7 +7,7 @@ use crate::internal::{ error::{Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, - btree_map, Mut, Offset, + btree_map, Mut, NamedType, Offset, }, }; @@ -78,13 +78,18 @@ impl BinaryBuilder { } } -impl Context for BinaryBuilder { +impl Context for BinaryBuilder { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone()) + let data_type = match O::NAME { + "i32" => "Binary", + "i64" => "LargeBinary", + _ => "", + }; + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } -impl SimpleSerializer for BinaryBuilder { +impl SimpleSerializer for BinaryBuilder { fn serialize_default(&mut self) -> Result<()> { self.array.push_scalar_default().ctx(self) } diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 1c8b6ce0..cbec6a6b 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -51,7 +51,7 @@ impl BoolBuilder { impl Context for BoolBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Boolean") } } diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 0f1c4859..e01a4d8a 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -55,7 +55,7 @@ impl Date32Builder { impl Context for Date32Builder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Date32") } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index bef9709d..7d85feb2 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -95,7 +95,12 @@ impl Date64Builder { impl Context for Date64Builder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + let data_type = if self.meta.is_some() { + "Timestamp(..)" + } else { + "Date64" + }; + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 401957bd..d1616ed3 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -64,7 +64,7 @@ impl DecimalBuilder { impl Context for DecimalBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Decimal128(..)") } } diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index 22c46489..61998668 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -51,7 +51,7 @@ impl DictionaryUtf8Builder { impl Context for DictionaryUtf8Builder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Dictionary(..)") } } diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 01ad93cb..a1e37b09 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -50,7 +50,7 @@ impl DurationBuilder { impl Context for DurationBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Duration(..)") } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 29d799a6..32ca9b85 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -87,7 +87,7 @@ impl FixedSizeBinaryBuilder { impl Context for FixedSizeBinaryBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeBinary(..)") } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index a3df91a3..0c986bd7 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -95,7 +95,7 @@ impl FixedSizeListBuilder { impl Context for FixedSizeListBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeList(..)") } } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 79b5fa28..7769e8ba 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -57,9 +57,21 @@ impl_into_array!(f16, F16, Float16); impl_into_array!(f32, F32, Float32); impl_into_array!(f64, F64, Float64); -impl Context for FloatBuilder { +impl Context for FloatBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Float16") + } +} + +impl Context for FloatBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone(), "data_type" => "Float32") + } +} + +impl Context for FloatBuilder { + fn annotations(&self) -> BTreeMap { + btree_map!("field" => self.path.clone(), "data_type" => "Float64") } } diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index e92c7ac1..7f65e4f7 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -5,7 +5,7 @@ use crate::internal::{ error::{Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, + btree_map, NamedType, }, }; @@ -60,9 +60,20 @@ impl_into_array!(u16, U16, UInt16); impl_into_array!(u32, U32, UInt32); impl_into_array!(u64, U64, UInt64); -impl Context for IntBuilder { +impl Context for IntBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + let data_type = match I::NAME { + "i8" => "Int8", + "i16" => "Int16", + "i32" => "Int32", + "i64" => "Int64", + "u8" => "UInt8", + "u16" => "UInt16", + "u32" => "UInt32", + "u64" => "UInt64", + _ => "", + }; + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } @@ -78,7 +89,8 @@ impl IntBuilder { impl SimpleSerializer for IntBuilder where - I: Default + I: NamedType + + Default + TryFrom + TryFrom + TryFrom diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index 811a33e4..d7efb6d3 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -7,7 +7,7 @@ use crate::internal::{ error::{Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - btree_map, Mut, Offset, + btree_map, Mut, NamedType, Offset, }, }; @@ -76,7 +76,7 @@ impl ListBuilder { } } -impl ListBuilder { +impl ListBuilder { fn start(&mut self) -> Result<()> { self.offsets.start_seq() } @@ -91,13 +91,18 @@ impl ListBuilder { } } -impl Context for ListBuilder { +impl Context for ListBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + let data_type = if O::NAME == "i32" { + "List" + } else { + "LargeList" + }; + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } -impl SimpleSerializer for ListBuilder { +impl SimpleSerializer for ListBuilder { fn serialize_default(&mut self) -> Result<()> { self.offsets.push_seq_default().ctx(self) } diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 522684df..09018096 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -72,7 +72,7 @@ impl MapBuilder { impl Context for MapBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Map(..)") } } diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index 6b6b9be4..d8e056ac 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -37,7 +37,7 @@ impl NullBuilder { impl Context for NullBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Null") } } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 0a130aaf..4b0786ea 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -124,7 +124,7 @@ impl StructBuilder { impl Context for StructBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Struct(..)") } } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index b5a3fda7..8728d068 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -7,7 +7,7 @@ use crate::internal::{ error::{Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, + btree_map, NamedType, }, }; @@ -70,15 +70,20 @@ impl TimeBuilder { } } -impl Context for TimeBuilder { +impl Context for TimeBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + let data_type = match I::NAME { + "i32" => "Time32", + "i64" => "Time64", + _ => "", + }; + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } impl SimpleSerializer for TimeBuilder where - I: TryFrom + TryFrom + Default + 'static, + I: NamedType + TryFrom + TryFrom + Default + 'static, Error: From<>::Error>, Error: From<>::Error>, { diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index dcf66c64..1a9b7483 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -77,7 +77,7 @@ impl UnionBuilder { impl Context for UnionBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "Union(..)") } } diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index cf993412..753cee8c 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -37,7 +37,7 @@ impl UnknownVariantBuilder { impl Context for UnknownVariantBuilder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + btree_map!("field" => self.path.clone(), "data_type" => "") } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 0fe2dfe2..17fed8ce 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -5,7 +5,7 @@ use crate::internal::{ error::{fail, Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - btree_map, Offset, + btree_map, NamedType, Offset, }, }; @@ -57,13 +57,19 @@ impl Utf8Builder { } } -impl Context for Utf8Builder { +impl Context for Utf8Builder { fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone()) + let data_type = if O::NAME == "i32" { + "Utf8" + } else { + "LargeUtf8" + }; + + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } -impl SimpleSerializer for Utf8Builder { +impl SimpleSerializer for Utf8Builder { fn serialize_default(&mut self) -> Result<()> { self.array.push_scalar_default().ctx(self) } diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index abbcebf8..8a39ba1e 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -7,6 +7,7 @@ pub mod value; #[cfg(test)] mod test_value; +use half::f16; use serde::{ser::SerializeSeq, Deserialize, Serialize}; use crate::internal::error::Result; @@ -154,6 +155,22 @@ impl<'a, T: Serialize> Serialize for Items<&'a [T]> { /// A wrapper type to allow implementing foreign traits pub struct Mut<'a, T>(pub &'a mut T); +pub trait NamedType { + const NAME: &'static str; +} + +macro_rules! impl_named_type { + ($($ty:ty),*) => { + $( + impl NamedType for $ty { + const NAME: &'static str = stringify!($ty); + } + )* + }; +} + +impl_named_type!(i8, i16, i32, i64, u8, u16, u32, u64, f16, f32, f64); + /// A trait to handle different offset types pub trait Offset: std::ops::Add + Clone + Copy + Default + 'static { fn try_form_usize(val: usize) -> Result; From d77191843a7e13ce186d8d5589584614cc16493d Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 16:17:47 +0200 Subject: [PATCH 145/178] Add "data_type" to documentation style guide --- Development.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Development.md b/Development.md index 4682635c..a7658fff 100644 --- a/Development.md +++ b/Development.md @@ -38,4 +38,5 @@ Style: Common annotations: -- `field`: the path of the field affected by the error \ No newline at end of file +- `field`: the path of the field affected by the error +- `data_type`: the Arrow data type of the field affected by the error From b94935c3d76d3c27942da0d7ea7349133e5de976 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 16:17:56 +0200 Subject: [PATCH 146/178] Tune error messages --- serde_arrow/src/internal/utils/array_ext.rs | 6 +++--- .../impls/issue_79_declared_but_missing_fields.rs | 15 ++++++++++----- .../src/test_with_arrow/impls/json_values.rs | 15 +++++++++------ serde_arrow/src/test_with_arrow/impls/utils.rs | 13 ------------- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/serde_arrow/src/internal/utils/array_ext.rs b/serde_arrow/src/internal/utils/array_ext.rs index b2df07f5..e00c25ba 100644 --- a/serde_arrow/src/internal/utils/array_ext.rs +++ b/serde_arrow/src/internal/utils/array_ext.rs @@ -262,7 +262,7 @@ impl SeqArrayExt for CountArray { pub fn duplicate_last(vec: &mut Vec) -> Result<()> { let Some(last) = vec.last() else { - fail!("invalid offset array") + fail!("Invalid offset array: expected at least a single element") }; vec.push(last.clone()); Ok(()) @@ -270,7 +270,7 @@ pub fn duplicate_last(vec: &mut Vec) -> Result<()> { pub fn increment_last(vec: &mut [O], inc: usize) -> Result<()> { let Some(last) = vec.last_mut() else { - fail!("invalid offset array") + fail!("Invalid offset array: expected at least a single element") }; *last = *last + O::try_form_usize(inc)?; Ok(()) @@ -283,7 +283,7 @@ pub fn set_validity(buffer: Option<&mut Vec>, idx: usize, value: bool) -> Re } else if value { Ok(()) } else { - fail!("cannot push null for non-nullable array"); + fail!("Cannot push null for non-nullable array"); } } diff --git a/serde_arrow/src/test_with_arrow/impls/issue_79_declared_but_missing_fields.rs b/serde_arrow/src/test_with_arrow/impls/issue_79_declared_but_missing_fields.rs index 7ef887c3..2473e433 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_79_declared_but_missing_fields.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_79_declared_but_missing_fields.rs @@ -1,7 +1,9 @@ use serde::Serialize; use serde_json::json; -use super::utils::{ResultAsserts, Test}; +use crate::internal::testing::assert_error_contains; + +use super::utils::Test; #[test] fn declared_but_missing_fields() { @@ -43,8 +45,11 @@ fn declared_but_missing_fields_non_nullable() { {"name": "b", "data_type": "U8"}, ])); - test.try_serialize_arrow(&items) - .assert_error("missing non-nullable field \"b\" in struct"); - test.try_serialize_arrow2(&items) - .assert_error("missing non-nullable field \"b\" in struct"); + let res = test.try_serialize_arrow(&items); + assert_error_contains(&res, "Missing non-nullable field \"b\" in struct"); + assert_error_contains(&res, "field: \"$\""); + + let res = test.try_serialize_arrow2(&items); + assert_error_contains(&res, "Missing non-nullable field \"b\" in struct"); + assert_error_contains(&res, "field: \"$\""); } diff --git a/serde_arrow/src/test_with_arrow/impls/json_values.rs b/serde_arrow/src/test_with_arrow/impls/json_values.rs index 93d0d0a9..ca280945 100644 --- a/serde_arrow/src/test_with_arrow/impls/json_values.rs +++ b/serde_arrow/src/test_with_arrow/impls/json_values.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; use serde_json::{json, Value}; -use crate::schema::TracingOptions; +use crate::{internal::testing::assert_error_contains, schema::TracingOptions}; -use super::utils::{ResultAsserts, Test}; +use super::utils::Test; struct ApproxEq<'a>(&'a Value); @@ -176,8 +176,11 @@ fn serde_json_nullable_strings_non_nullable_field() { {"name": "a", "data_type": "Utf8"}, ])); - test.try_serialize_arrow(&items) - .assert_error("cannot push null for non-nullable array"); - test.try_serialize_arrow2(&items) - .assert_error("cannot push null for non-nullable array"); + let res = test.try_serialize_arrow(&items); + assert_error_contains(&res, "Cannot push null for non-nullable array"); + assert_error_contains(&res, "field: \"$.a\""); + + let res = test.try_serialize_arrow2(&items); + assert_error_contains(&res, "Cannot push null for non-nullable array"); + assert_error_contains(&res, "field: \"$.a\""); } diff --git a/serde_arrow/src/test_with_arrow/impls/utils.rs b/serde_arrow/src/test_with_arrow/impls/utils.rs index a0349924..e2aaecbe 100644 --- a/serde_arrow/src/test_with_arrow/impls/utils.rs +++ b/serde_arrow/src/test_with_arrow/impls/utils.rs @@ -37,19 +37,6 @@ impl std::default::Default for Impls { } } -pub trait ResultAsserts { - fn assert_error(&self, message: &str); -} - -impl ResultAsserts for Result { - fn assert_error(&self, message: &str) { - let Err(err) = self else { - panic!("Expected error"); - }; - assert!(err.to_string().contains(message), "unexpected error: {err}"); - } -} - #[derive(Default)] pub struct Test { schema: Option, From f525905137f6dfc2d5bb608d52d835f7225944a1 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 16:30:07 +0200 Subject: [PATCH 147/178] Further improve error messages --- .../deserialization/string_deserializer.rs | 9 +-- .../src/internal/deserialization/utils.rs | 7 ++- serde_arrow/src/internal/deserializer.rs | 50 ++++++++-------- serde_arrow/src/internal/error.rs | 14 +---- serde_arrow/src/internal/testing.rs | 2 +- serde_arrow/src/internal/utils/decimal.rs | 6 +- serde_arrow/src/internal/utils/dsl.rs | 16 ++--- serde_arrow/src/internal/utils/value.rs | 60 +++++++++---------- serde_arrow/src/test/api_chrono.rs | 8 +-- 9 files changed, 84 insertions(+), 88 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index ab03e2d8..76f31a77 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::BytesArrayView, - error::{error, fail, Result}, + error::{fail, Result}, utils::{Mut, Offset}, }; @@ -39,9 +39,10 @@ impl<'a, O: Offset> StringDeserializer<'a, O> { } pub fn next_required(&mut self) -> Result<&'a str> { - self.next()?.ok_or_else(|| { - error!("Tried to deserialize a value from StringDeserializer, but value is missing") - }) + let Some(next) = self.next()? else { + fail!("Tried to deserialize a value from StringDeserializer, but value is missing") + }; + Ok(next) } pub fn peek_next(&self) -> Result { diff --git a/serde_arrow/src/internal/deserialization/utils.rs b/serde_arrow/src/internal/deserialization/utils.rs index 6f4e21b4..eb55086c 100644 --- a/serde_arrow/src/internal/deserialization/utils.rs +++ b/serde_arrow/src/internal/deserialization/utils.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::BitsWithOffset, - error::{error, fail, Result}, + error::{fail, Result}, utils::Offset, }; @@ -44,7 +44,10 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { } pub fn next_required(&mut self) -> Result { - self.next()?.ok_or_else(|| error!("missing value")) + let Some(next) = self.next()? else { + fail!("missing value"); + }; + Ok(next) } pub fn peek_next(&self) -> Result { diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index 3fdaaa2c..e4db6f9d 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -81,19 +81,19 @@ impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { } fn deserialize_bool>(self, _: V) -> Result { - fail!("cannot deserialize single bools") + fail!("Cannot deserialize single bools") } fn deserialize_byte_buf>(self, _: V) -> Result { - fail!("cannot deserialize byte buffers") + fail!("Cannot deserialize byte buffers") } fn deserialize_bytes>(self, _: V) -> Result { - fail!("cannot deserialize byte arrays") + fail!("Cannot deserialize byte arrays") } fn deserialize_char>(self, _: V) -> Result { - fail!("cannot deserialize single chars") + fail!("Cannot deserialize single chars") } fn deserialize_enum>( @@ -102,55 +102,55 @@ impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { _: &'static [&'static str], _: V, ) -> Result { - fail!("cannot deserialize single enums") + fail!("Cannot deserialize single enums") } fn deserialize_f32>(self, _: V) -> Result { - fail!("cannot deserialize single floats") + fail!("Cannot deserialize single floats") } fn deserialize_f64>(self, _: V) -> Result { - fail!("cannot deserialize single floats") + fail!("Cannot deserialize single floats") } fn deserialize_i128>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_i16>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_i32>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_i64>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_i8>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_identifier>(self, _: V) -> Result { - fail!("cannot deserialize single identifiers") + fail!("Cannot deserialize single identifiers") } fn deserialize_map>(self, _: V) -> Result { - fail!("cannot deserialize single maps") + fail!("Cannot deserialize single maps") } fn deserialize_option>(self, _: V) -> Result { - fail!("cannot deserialize single options") + fail!("Cannot deserialize single options") } fn deserialize_str>(self, _: V) -> Result { - fail!("cannot deserialize single strings") + fail!("Cannot deserialize single strings") } fn deserialize_string>(self, _: V) -> Result { - fail!("cannot deserialize single strings") + fail!("Cannot deserialize single strings") } fn deserialize_struct>( @@ -159,35 +159,35 @@ impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { _: &'static [&'static str], _: V, ) -> Result { - fail!("cannot deserialize single structs") + fail!("Cannot deserialize single structs") } fn deserialize_u128>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_u16>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_u32>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_u64>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_u8>(self, _: V) -> Result { - fail!("cannot deserialize single integers") + fail!("Cannot deserialize single integers") } fn deserialize_unit>(self, _: V) -> Result { - fail!("cannot deserialize single units") + fail!("Cannot deserialize single units") } fn deserialize_unit_struct>(self, _: &'static str, _: V) -> Result { - fail!("cannot deserialize single units") + fail!("Cannot deserialize single units") } fn is_human_readable(&self) -> bool { diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 136bc62f..4cc29514 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -211,25 +211,17 @@ impl serde::de::Error for Error { } } -macro_rules! error { - ($($tt:tt)*) => { - $crate::internal::error::Error::custom(format!($($tt)*)) - }; -} - -pub(crate) use error; - macro_rules! fail { (in $context:expr, $($tt:tt)*) => { { #[allow(unused)] use $crate::internal::error::Context; let annotations = $context.annotations(); - return Err($crate::internal::error::error!($($tt)*).with_annotations(annotations)) + return Err($crate::internal::error::Error::custom(format!($($tt)*)).with_annotations(annotations)) } }; ($($tt:tt)*) => { - return Err($crate::internal::error::error!($($tt)*)) + return Err($crate::internal::error::Error::custom(format!($($tt)*))) }; } @@ -305,7 +297,7 @@ impl From for PanicOnErrorError { #[test] fn error_can_be_converted_to_anyhow() { fn func() -> anyhow::Result<()> { - Err(error!("dummy"))?; + Err(Error::custom("dummy".to_string()))?; Ok(()) } assert!(func().is_err()); diff --git a/serde_arrow/src/internal/testing.rs b/serde_arrow/src/internal/testing.rs index 7f922a7c..16d968e2 100644 --- a/serde_arrow/src/internal/testing.rs +++ b/serde_arrow/src/internal/testing.rs @@ -1,7 +1,7 @@ //! Support for tests pub fn assert_error_contains(actual: &Result, expected: &str) { let Err(actual) = actual else { - panic!("expected an error, but no error was raised"); + panic!("Expected an error, but no error was raised"); }; let actual = actual.to_string(); diff --git a/serde_arrow/src/internal/utils/decimal.rs b/serde_arrow/src/internal/utils/decimal.rs index 0d1b5bff..771e58f3 100644 --- a/serde_arrow/src/internal/utils/decimal.rs +++ b/serde_arrow/src/internal/utils/decimal.rs @@ -198,9 +198,9 @@ fn find_period(s: &[u8]) -> (usize, usize) { fn check_all_ascii_zero(s: &[u8], leading: bool) -> Result<()> { if s.iter().any(|c| *c != b'0') { if leading { - fail!("invalid decimal: not enough precision"); + fail!("Invalid decimal: not enough precision"); } else { - fail!("invalid decimal: not enough scale, the given number would be truncated"); + fail!("Invalid decimal: not enough scale, the given number would be truncated"); } } Ok(()) @@ -208,7 +208,7 @@ fn check_all_ascii_zero(s: &[u8], leading: bool) -> Result<()> { fn check_all_ascii_digit(s: &[u8]) -> Result<()> { if s.iter().any(|c| *c < b'0' || *c > b'9') { - fail!("invalid decimal"); + fail!("Invalid decimal: only ascii digits are supported"); } Ok(()) } diff --git a/serde_arrow/src/internal/utils/dsl.rs b/serde_arrow/src/internal/utils/dsl.rs index 5493adfc..252b334e 100644 --- a/serde_arrow/src/internal/utils/dsl.rs +++ b/serde_arrow/src/internal/utils/dsl.rs @@ -18,16 +18,16 @@ impl Term { pub fn as_ident(&self) -> Result<&str> { match self.as_parts() { (name, false, []) => Ok(name), - (_, true, _) => fail!("expected identifier, found quoted string"), - (_, _, [_, ..]) => fail!("expected identifier, found call"), + (_, true, _) => fail!("Expected identifier, found quoted string"), + (_, _, [_, ..]) => fail!("Expected identifier, found call"), } } pub fn as_string(&self) -> Result<&str> { match self.as_parts() { (name, true, []) => Ok(name), - (_, false, _) => fail!("expected string, found identifier"), - (_, _, [_, ..]) => fail!("expected identifier, found call"), + (_, false, _) => fail!("Expected string, found identifier"), + (_, _, [_, ..]) => fail!("Expected identifier, found call"), } } @@ -35,14 +35,14 @@ impl Term { match self.as_parts() { ("None", false, []) => Ok(None), ("Some", false, [arg]) => Ok(Some(arg)), - _ => fail!("expected Some(arg) or None found quoted string"), + _ => fail!("Expected Some(arg) or None found quoted string"), } } pub fn as_call(&self) -> Result<(&str, &[Term])> { match self.as_parts() { (name, false, args) => Ok((name, args)), - (_, true, _) => fail!("expected call, found quoted string"), + (_, true, _) => fail!("Expected call, found quoted string"), } } } @@ -166,7 +166,7 @@ fn parse_ident_term_name(s: &str) -> Result<(String, &str)> { let rest = &s[pos..]; if ident.is_empty() { - fail!("no identifier found"); + fail!("No identifier found"); } Ok((ident, rest)) @@ -196,7 +196,7 @@ fn parse_arguments(s: &str) -> Result<(Vec, &str)> { let s = s.trim_start(); let Some(s) = s.strip_prefix(')') else { - fail!("mising ')'"); + fail!("Missing ')'"); }; Ok((arguments, s)) diff --git a/serde_arrow/src/internal/utils/value.rs b/serde_arrow/src/internal/utils/value.rs index 9f7eea00..2dc7c1b5 100644 --- a/serde_arrow/src/internal/utils/value.rs +++ b/serde_arrow/src/internal/utils/value.rs @@ -94,7 +94,7 @@ impl<'a> TryFrom<&'a Value> for &'a str { match value { Value::StaticStr(s) => Ok(s), Value::String(s) => Ok(s), - _ => fail!("cannot extract string from non-string value"), + _ => fail!("Cannot extract string from non-string value"), } } } @@ -114,7 +114,7 @@ macro_rules! impl_try_from_value_for_int { &Value::I16(v) => Ok(v.try_into()?), &Value::I32(v) => Ok(v.try_into()?), &Value::I64(v) => Ok(v.try_into()?), - _ => fail!("cannot extract integer from non-integer value"), + _ => fail!("Cannot extract integer from non-integer value"), } } } @@ -577,21 +577,21 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { fn deserialize_byte_buf>(self, visitor: V) -> Result { match self.0 { Value::Bytes(v) => visitor.visit_byte_buf(v.to_owned()), - v => fail!("cannot deserialize bytes from non-bytes value {v:?}"), + v => fail!("Cannot deserialize bytes from non-bytes value {v:?}"), } } fn deserialize_bytes>(self, visitor: V) -> Result { match self.0 { Value::Bytes(v) => visitor.visit_bytes(v), - v => fail!("cannot deserialize bytes from non-bytes value {v:?}"), + v => fail!("Cannot deserialize bytes from non-bytes value {v:?}"), } } fn deserialize_char>(self, visitor: V) -> Result { match self.0 { &Value::Char(v) => visitor.visit_char(v), - v => fail!("cannot deserializer char from non-char value {v:?}"), + v => fail!("Cannot deserializer char from non-char value {v:?}"), } } @@ -612,13 +612,13 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { Value::StructVariant(variant, fields) => { visitor.visit_enum(StructVariantDeserializer(*variant, fields)) } - v => fail!("cannot deserialize enum from non-enum value {v:?}"), + v => fail!("Cannot deserialize enum from non-enum value {v:?}"), } } fn deserialize_bool>(self, visitor: V) -> Result { let &Value::Bool(v) = self.0 else { - fail!("cannot deserialize bool from non-bool {:?}", self.0); + fail!("Cannot deserialize bool from non-bool {:?}", self.0); }; visitor.visit_bool(v) } @@ -659,7 +659,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { match self.0 { &Value::F32(v) => visitor.visit_f32(v.0), &Value::F64(v) => visitor.visit_f32(v.0 as f32), - v => fail!("cannot deserialize f32 from non-float value {v:?}"), + v => fail!("Cannot deserialize f32 from non-float value {v:?}"), } } @@ -667,7 +667,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { match self.0 { &Value::F32(v) => visitor.visit_f64(v.0 as f64), &Value::F64(v) => visitor.visit_f64(v.0), - v => fail!("cannot deserialize f64 from non-float value {v:?}"), + v => fail!("Cannot deserialize f64 from non-float value {v:?}"), } } @@ -679,7 +679,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { match self.0 { Value::Map(entries) => visitor.visit_map(MapDeserializer::new(entries)), Value::Struct(_, fields) => visitor.visit_map(StructDeserializer::new(fields)), - v => fail!("cannot deserialize a map from a non-map value {:?}", v), + v => fail!("Cannot deserialize a map from a non-map value {:?}", v), } } @@ -709,7 +709,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { match self.0 { Value::Seq(values) => visitor.visit_seq(SeqDeserializer::new(values)), Value::Tuple(values) => visitor.visit_seq(SeqDeserializer::new(values)), - v => fail!("cannot deserialize sequence from non-sequence value {v:?}"), + v => fail!("Cannot deserialize sequence from non-sequence value {v:?}"), } } @@ -730,7 +730,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { match self.0 { Value::Struct(_name, fields) => visitor.visit_map(StructDeserializer::new(fields)), Value::Map(entries) => visitor.visit_map(MapDeserializer::new(entries)), - v => fail!("cannot deserialize struct from non-struct value {v:?}"), + v => fail!("Cannot deserialize struct from non-struct value {v:?}"), } } @@ -742,7 +742,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { match self.0 { Value::Seq(values) => visitor.visit_seq(SeqDeserializer::new(values)), Value::Tuple(values) => visitor.visit_seq(SeqDeserializer::new(values)), - v => fail!("cannot deserialize tuple from non-sequence value {v:?}"), + v => fail!("Cannot deserialize tuple from non-sequence value {v:?}"), } } @@ -754,14 +754,14 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { ) -> Result { match self.0 { Value::TupleStruct(_, values) => visitor.visit_seq(SeqDeserializer::new(values)), - v => fail!("cannot deserialize tuple struct from non-tuple-struct value {v:?}"), + v => fail!("Cannot deserialize tuple struct from non-tuple-struct value {v:?}"), } } fn deserialize_unit>(self, visitor: V) -> Result { match self.0 { Value::Unit => visitor.visit_unit(), - v => fail!("cannot deserialize unit from non-unit value {v:?}"), + v => fail!("Cannot deserialize unit from non-unit value {v:?}"), } } @@ -772,7 +772,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for ValueDeserializer<'a> { ) -> Result { match self.0 { Value::UnitStruct(_) => visitor.visit_unit(), - v => fail!("cannot deserialize unit from non-unit value {v:?}"), + v => fail!("Cannot deserialize unit from non-unit value {v:?}"), } } @@ -809,7 +809,7 @@ impl<'de, 'a> serde::de::MapAccess<'de> for StructDeserializer<'a> { fn next_value_seed>(&mut self, seed: V) -> Result { let Some(value) = self.1.take() else { - fail!("invalid usage"); + fail!("Invalid usage"); }; seed.deserialize(ValueDeserializer::new(value)) } @@ -865,7 +865,7 @@ impl<'de, 'a> serde::de::MapAccess<'de> for MapDeserializer<'a> { fn next_value_seed>(&mut self, seed: V) -> Result { let Some(value) = self.1.take() else { - fail!("invalid usage"); + fail!("Invalid usage"); }; seed.deserialize(ValueDeserializer::new(value)) } @@ -895,7 +895,7 @@ impl<'de> serde::de::VariantAccess<'de> for UnitVariantVariant { self, _seed: T, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected unit variant found newtype variant") } fn struct_variant>( @@ -903,7 +903,7 @@ impl<'de> serde::de::VariantAccess<'de> for UnitVariantVariant { _fields: &'static [&'static str], _visitor: V, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected unit variant found struct variant") } fn tuple_variant>( @@ -911,7 +911,7 @@ impl<'de> serde::de::VariantAccess<'de> for UnitVariantVariant { _len: usize, _visitor: V, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected unit variant, found tuple variant") } fn unit_variant(self) -> Result<()> { @@ -943,11 +943,11 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for TupleVariantVariant<'a> { self, _seed: T, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected tuple variant, found newtype variant") } fn unit_variant(self) -> std::prelude::v1::Result<(), Self::Error> { - fail!("invalid variant") + fail!("Invalid variant: expected tuple variant, found unit variant") } fn struct_variant>( @@ -955,7 +955,7 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for TupleVariantVariant<'a> { _fields: &'static [&'static str], _visitor: V, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected tuple variant, found struct variant") } fn tuple_variant>( @@ -992,7 +992,7 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for NewTypeVariantVariant<'a> { } fn unit_variant(self) -> std::prelude::v1::Result<(), Self::Error> { - fail!("invalid variant") + fail!("Invalid variant: expected newtype variant, found unit variant") } fn struct_variant>( @@ -1000,7 +1000,7 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for NewTypeVariantVariant<'a> { _fields: &'static [&'static str], _visitor: V, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected newtype variant, found struct variant") } fn tuple_variant>( @@ -1008,7 +1008,7 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for NewTypeVariantVariant<'a> { _len: usize, _visitor: V, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected newtype variant, found tuple variant") } } @@ -1036,7 +1036,7 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for StructVariantVariant<'a> { self, _seed: T, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected struct variant, found newtype variant") } fn struct_variant>( @@ -1052,11 +1052,11 @@ impl<'de, 'a> serde::de::VariantAccess<'de> for StructVariantVariant<'a> { _len: usize, _visitor: V, ) -> Result { - fail!("invalid variant") + fail!("Invalid variant: expected struct variant, found tuple variant") } fn unit_variant(self) -> Result<()> { - fail!("invalid variant") + fail!("Invalid variant: expected struct variant, found unit variant") } } diff --git a/serde_arrow/src/test/api_chrono.rs b/serde_arrow/src/test/api_chrono.rs index cf96d3a5..ef9a5cae 100644 --- a/serde_arrow/src/test/api_chrono.rs +++ b/serde_arrow/src/test/api_chrono.rs @@ -3,7 +3,7 @@ use chrono::{DateTime, FixedOffset, NaiveDate, TimeZone, Utc}; -use crate::internal::error::{error, Result}; +use crate::internal::error::{Error, Result}; #[test] fn test_parse_utc() -> Result<()> { @@ -18,7 +18,7 @@ fn test_parse_utc() -> Result<()> { #[test] fn test_chrono_api_naive_datetime() -> Result<()> { let dt = NaiveDate::from_ymd(2021, 8, 3).and_hms(12, 0, 0); - let dt_str = serde_json::to_string(&dt).map_err(|err| error!("{err}"))?; + let dt_str = serde_json::to_string(&dt).map_err(|err| Error::custom(err.to_string()))?; assert_eq!(dt_str, "\"2021-08-03T12:00:00\""); Ok(()) } @@ -26,7 +26,7 @@ fn test_chrono_api_naive_datetime() -> Result<()> { #[test] fn test_chrono_api_datetime() -> Result<()> { let dt = Utc.ymd(730, 12, 1).and_hms(2, 3, 50); - let dt_str = serde_json::to_string(&dt).map_err(|err| error!("{err}"))?; + let dt_str = serde_json::to_string(&dt).map_err(|err| Error::custom(err.to_string()))?; assert_eq!(dt_str, "\"0730-12-01T02:03:50Z\""); Ok(()) @@ -46,7 +46,7 @@ fn test_chrono_fixed_offset() -> Result<()> { let dt = FixedOffset::east(5 * 3600) .ymd(2020, 12, 24) .and_hms(13, 30, 00); - let dt_str = serde_json::to_string(&dt).map_err(|err| error!("{err}"))?; + let dt_str = serde_json::to_string(&dt).map_err(|err| Error::custom(err.to_string()))?; assert_eq!(dt_str, "\"2020-12-24T13:30:00+05:00\""); Ok(()) From d93a05d0e0d782700732bb47b2e541a52f96ea7a Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 16:30:47 +0200 Subject: [PATCH 148/178] Fix clippy --- serde_arrow/src/internal/serialization/bool_builder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index cbec6a6b..3976f155 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -20,7 +20,7 @@ pub struct BoolBuilder { impl BoolBuilder { pub fn new(path: String, is_nullable: bool) -> Self { Self { - path: path, + path, array: BooleanArray { len: 0, validity: is_nullable.then(Vec::new), From 2180e6339c641432cc3eab968874d141bb51c144 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 19:30:01 +0200 Subject: [PATCH 149/178] Make paths available throughout the deserializers --- .../deserialization/array_deserializer.rs | 187 +++++++++++------- .../deserialization/binary_deserializer.rs | 26 ++- .../deserialization/bool_deserializer.rs | 19 +- .../deserialization/date32_deserializer.rs | 36 +++- .../deserialization/date64_deserializer.rs | 41 ++-- .../deserialization/decimal_deserializer.rs | 14 +- .../dictionary_deserializer.rs | 18 +- .../deserialization/enum_deserializer.rs | 18 +- .../fixed_size_binary_deserializer.rs | 14 +- .../fixed_size_list_deserializer.rs | 13 +- .../deserialization/float_deserializer.rs | 44 +++-- .../deserialization/integer_deserializer.rs | 63 ++++-- .../deserialization/list_deserializer.rs | 18 +- .../deserialization/map_deserializer.rs | 13 +- .../deserialization/null_deserializer.rs | 21 +- .../outer_sequence_deserializer.rs | 12 +- .../deserialization/string_deserializer.rs | 24 ++- .../deserialization/struct_deserializer.rs | 13 +- .../src/internal/deserialization/test.rs | 22 ++- .../deserialization/time_deserializer.rs | 49 +++-- serde_arrow/src/internal/deserializer.rs | 2 +- .../serialization/outer_sequence_builder.rs | 14 +- serde_arrow/src/internal/utils/mod.rs | 12 ++ 23 files changed, 498 insertions(+), 195 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 9ba76024..4b749a9e 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -5,7 +5,7 @@ use crate::internal::{ arrow::{ArrayView, FieldMeta, PrimitiveArrayView, TimeUnit}, error::{fail, Error, Result}, schema::{Strategy, STRATEGY_KEY}, - utils::Mut, + utils::{ChildName, Mut}, }; use super::{ @@ -70,38 +70,41 @@ pub enum ArrayDeserializer<'a> { } impl<'a> ArrayDeserializer<'a> { - pub fn new(strategy: Option<&Strategy>, array: ArrayView<'a>) -> Result { + pub fn new(path: String, strategy: Option<&Strategy>, array: ArrayView<'a>) -> Result { use {ArrayDeserializer as D, ArrayView as V}; match array { - ArrayView::Null(_) => Ok(Self::Null(NullDeserializer {})), - V::Boolean(view) => Ok(D::Bool(BoolDeserializer::new(view))), - V::Int8(view) => Ok(D::I8(IntegerDeserializer::new(view))), - V::Int16(view) => Ok(D::I16(IntegerDeserializer::new(view))), - V::Int32(view) => Ok(D::I32(IntegerDeserializer::new(view))), - V::Int64(view) => Ok(D::I64(IntegerDeserializer::new(view))), - V::UInt8(view) => Ok(D::U8(IntegerDeserializer::new(view))), - V::UInt16(view) => Ok(D::U16(IntegerDeserializer::new(view))), - V::UInt32(view) => Ok(D::U32(IntegerDeserializer::new(view))), - V::UInt64(view) => Ok(D::U64(IntegerDeserializer::new(view))), - V::Float16(view) => Ok(D::F16(FloatDeserializer::new(view))), - V::Float32(view) => Ok(D::F32(FloatDeserializer::new(view))), - V::Float64(view) => Ok(D::F64(FloatDeserializer::new(view))), - V::Decimal128(view) => Ok(D::Decimal128(DecimalDeserializer::new(view))), + ArrayView::Null(_) => Ok(Self::Null(NullDeserializer::new(path))), + V::Boolean(view) => Ok(D::Bool(BoolDeserializer::new(path, view))), + V::Int8(view) => Ok(D::I8(IntegerDeserializer::new(path, view))), + V::Int16(view) => Ok(D::I16(IntegerDeserializer::new(path, view))), + V::Int32(view) => Ok(D::I32(IntegerDeserializer::new(path, view))), + V::Int64(view) => Ok(D::I64(IntegerDeserializer::new(path, view))), + V::UInt8(view) => Ok(D::U8(IntegerDeserializer::new(path, view))), + V::UInt16(view) => Ok(D::U16(IntegerDeserializer::new(path, view))), + V::UInt32(view) => Ok(D::U32(IntegerDeserializer::new(path, view))), + V::UInt64(view) => Ok(D::U64(IntegerDeserializer::new(path, view))), + V::Float16(view) => Ok(D::F16(FloatDeserializer::new(path, view))), + V::Float32(view) => Ok(D::F32(FloatDeserializer::new(path, view))), + V::Float64(view) => Ok(D::F64(FloatDeserializer::new(path, view))), + V::Decimal128(view) => Ok(D::Decimal128(DecimalDeserializer::new(path, view))), ArrayView::Date32(view) => Ok(Self::Date32(Date32Deserializer::new( + path, view.values, view.validity, ))), ArrayView::Date64(view) => Ok(Self::Date64(Date64Deserializer::new( + path, view.values, view.validity, TimeUnit::Millisecond, is_utc_date64(strategy)?, ))), - V::Time32(view) => Ok(D::Time32(TimeDeserializer::new(view))), - V::Time64(view) => Ok(D::Time64(TimeDeserializer::new(view))), + V::Time32(view) => Ok(D::Time32(TimeDeserializer::new(path, view))), + V::Time64(view) => Ok(D::Time64(TimeDeserializer::new(path, view))), ArrayView::Timestamp(view) => match strategy { Some(Strategy::NaiveStrAsDate64 | Strategy::UtcStrAsDate64) => { Ok(Self::Date64(Date64Deserializer::new( + path, view.values, view.validity, view.unit, @@ -110,50 +113,83 @@ impl<'a> ArrayDeserializer<'a> { } Some(strategy) => fail!("invalid strategy {strategy} for timestamp field"), None => Ok(Self::Date64(Date64Deserializer::new( + path, view.values, view.validity, view.unit, is_utc_timestamp(view.timezone.as_deref())?, ))), }, - V::Duration(view) => Ok(D::I64(IntegerDeserializer::new(PrimitiveArrayView { - values: view.values, - validity: view.validity, - }))), - V::Utf8(view) => Ok(D::Utf8(StringDeserializer::new(view))), - V::LargeUtf8(view) => Ok(D::LargeUtf8(StringDeserializer::new(view))), - V::Binary(view) => Ok(D::Binary(BinaryDeserializer::new(view))), - V::LargeBinary(view) => Ok(D::LargeBinary(BinaryDeserializer::new(view))), - V::FixedSizeBinary(view) => { - Ok(D::FixedSizeBinary(FixedSizeBinaryDeserializer::new(view)?)) - } - V::List(view) => Ok(D::List(ListDeserializer::new( - ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, - view.offsets, - view.validity, - )?)), - V::LargeList(view) => Ok(D::LargeList(ListDeserializer::new( - ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, - view.offsets, - view.validity, - )?)), - V::FixedSizeList(view) => Ok(D::FixedSizeList(FixedSizeListDeserializer::new( - ArrayDeserializer::new(get_strategy(&view.meta)?.as_ref(), *view.element)?, - view.validity, - view.n.try_into()?, - view.len, + V::Duration(view) => Ok(D::I64(IntegerDeserializer::new( + path, + PrimitiveArrayView { + values: view.values, + validity: view.validity, + }, ))), + V::Utf8(view) => Ok(D::Utf8(StringDeserializer::new(path, view))), + V::LargeUtf8(view) => Ok(D::LargeUtf8(StringDeserializer::new(path, view))), + V::Binary(view) => Ok(D::Binary(BinaryDeserializer::new(path, view))), + V::LargeBinary(view) => Ok(D::LargeBinary(BinaryDeserializer::new(path, view))), + V::FixedSizeBinary(view) => Ok(D::FixedSizeBinary(FixedSizeBinaryDeserializer::new( + path, view, + )?)), + V::List(view) => { + let child_path = format!("{path}.{child}", child = ChildName(&view.meta.name)); + Ok(D::List(ListDeserializer::new( + path, + ArrayDeserializer::new( + child_path, + get_strategy(&view.meta)?.as_ref(), + *view.element, + )?, + view.offsets, + view.validity, + )?)) + } + V::LargeList(view) => { + let child_path = format!("{path}.{child}", child = ChildName(&view.meta.name)); + Ok(D::LargeList(ListDeserializer::new( + path, + ArrayDeserializer::new( + child_path, + get_strategy(&view.meta)?.as_ref(), + *view.element, + )?, + view.offsets, + view.validity, + )?)) + } + V::FixedSizeList(view) => { + let child_path = format!("{path}.{child}", child = ChildName(&view.meta.name)); + Ok(D::FixedSizeList(FixedSizeListDeserializer::new( + path, + ArrayDeserializer::new( + child_path, + get_strategy(&view.meta)?.as_ref(), + *view.element, + )?, + view.validity, + view.n.try_into()?, + view.len, + ))) + } V::Struct(view) => { let mut fields = Vec::new(); for (field_view, field_meta) in view.fields { - let field_deserializer = - ArrayDeserializer::new(get_strategy(&field_meta)?.as_ref(), field_view)?; + let child_path = format!("{path}.{child}", child = ChildName(&field_meta.name)); + let field_deserializer = ArrayDeserializer::new( + child_path, + get_strategy(&field_meta)?.as_ref(), + field_view, + )?; let field_name = field_meta.name; fields.push((field_name, field_deserializer)); } Ok(D::Struct(StructDeserializer::new( + path, fields, view.validity, view.len, @@ -167,11 +203,22 @@ impl<'a> ArrayDeserializer<'a> { fail!("invalid entries field in map array") }; let [(keys_view, keys_meta), (values_view, values_meta)] = entries_fields; - let keys = ArrayDeserializer::new(get_strategy(&keys_meta)?.as_ref(), keys_view)?; - let values = - ArrayDeserializer::new(get_strategy(&values_meta)?.as_ref(), values_view)?; + let keys_path = format!("{path}.{child}", child = ChildName(&keys_meta.name)); + let keys = ArrayDeserializer::new( + keys_path, + get_strategy(&keys_meta)?.as_ref(), + keys_view, + )?; + + let values_path = format!("{path}.{child}", child = ChildName(&values_meta.name)); + let values = ArrayDeserializer::new( + values_path, + get_strategy(&values_meta)?.as_ref(), + values_view, + )?; Ok(D::Map(MapDeserializer::new( + path, keys, values, view.offsets, @@ -180,52 +227,52 @@ impl<'a> ArrayDeserializer<'a> { } V::Dictionary(view) => match (*view.indices, *view.values) { (V::Int8(keys), V::Utf8(values)) => Ok(D::DictionaryI8I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int16(keys), V::Utf8(values)) => Ok(D::DictionaryI16I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int32(keys), V::Utf8(values)) => Ok(D::DictionaryI32I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int64(keys), V::Utf8(values)) => Ok(D::DictionaryI64I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt8(keys), V::Utf8(values)) => Ok(Self::DictionaryU8I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt16(keys), V::Utf8(values)) => Ok(D::DictionaryU16I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt32(keys), V::Utf8(values)) => Ok(D::DictionaryU32I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt64(keys), V::Utf8(values)) => Ok(D::DictionaryU64I32( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int8(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI8I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int16(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI16I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int32(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI32I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::Int64(keys), V::LargeUtf8(values)) => Ok(D::DictionaryI64I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt8(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU8I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt16(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU16I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt32(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU32I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), (V::UInt64(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU64I64( - DictionaryDeserializer::new(keys, values)?, + DictionaryDeserializer::new(path, keys, values)?, )), _ => fail!("unsupported dictionary array"), }, @@ -236,12 +283,16 @@ impl<'a> ArrayDeserializer<'a> { if usize::try_from(type_id) != Ok(idx) { fail!("Only unions with consecutive type ids are currently supported in arrow2"); } - let field_deserializer = - ArrayDeserializer::new(get_strategy(&field_meta)?.as_ref(), field_view)?; + let child_path = format!("{path}.{child}", child = ChildName(&field_meta.name)); + let field_deserializer = ArrayDeserializer::new( + child_path, + get_strategy(&field_meta)?.as_ref(), + field_view, + )?; fields.push((field_meta.name, field_deserializer)) } - Ok(Self::Enum(EnumDeserializer::new(view.types, fields))) + Ok(Self::Enum(EnumDeserializer::new(path, view.types, fields))) } } } diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index 7c9ff1df..05c9c645 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -2,20 +2,25 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BytesArrayView, - error::{fail, Error, Result}, - utils::{Mut, Offset}, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut, NamedType, Offset}, }; use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; pub struct BinaryDeserializer<'a, O: Offset> { + pub path: String, pub view: BytesArrayView<'a, O>, pub next: (usize, usize), } impl<'a, O: Offset> BinaryDeserializer<'a, O> { - pub fn new(view: BytesArrayView<'a, O>) -> Self { - Self { view, next: (0, 0) } + pub fn new(path: String, view: BytesArrayView<'a, O>) -> Self { + Self { + path, + view, + next: (0, 0), + } } pub fn peek_next(&self) -> Result { @@ -51,7 +56,18 @@ impl<'a, O: Offset> BinaryDeserializer<'a, O> { } } -impl<'a, O: Offset> SimpleDeserializer<'a> for BinaryDeserializer<'a, O> { +impl<'a, O: Offset + NamedType> Context for BinaryDeserializer<'a, O> { + fn annotations(&self) -> std::collections::BTreeMap { + let data_type = match O::NAME { + "i32" => "Binary", + "i64" => "LargeBinary", + _ => "", + }; + btree_map!("path" => self.path.clone(), "data_type" => data_type) + } +} + +impl<'a, O: Offset + NamedType> SimpleDeserializer<'a> for BinaryDeserializer<'a, O> { fn name() -> &'static str { "BinaryDeserializer" } diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index 4e7fd87d..ad8529bd 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -2,20 +2,25 @@ use serde::de::Visitor; use crate::internal::{ arrow::BooleanArrayView, - error::{fail, Result}, - utils::Mut, + error::{fail, Context, Result}, + utils::{btree_map, Mut}, }; use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; pub struct BoolDeserializer<'a> { + pub path: String, pub view: BooleanArrayView<'a>, pub next: usize, } impl<'a> BoolDeserializer<'a> { - pub fn new(view: BooleanArrayView<'a>) -> Self { - Self { view, next: 0 } + pub fn new(path: String, view: BooleanArrayView<'a>) -> Self { + Self { + path, + view, + next: 0, + } } fn next(&mut self) -> Result> { @@ -57,6 +62,12 @@ impl<'a> BoolDeserializer<'a> { } } +impl<'de> Context for BoolDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Boolean") + } +} + impl<'de> SimpleDeserializer<'de> for BoolDeserializer<'de> { fn name() -> &'static str { "BoolDeserializer" diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index fca2d52f..9057605c 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -1,15 +1,25 @@ use chrono::{Duration, NaiveDate, NaiveDateTime}; use serde::de::Visitor; -use crate::internal::{arrow::BitsWithOffset, error::Result, utils::Mut}; +use crate::internal::{ + arrow::BitsWithOffset, + error::{Context, Result}, + utils::{btree_map, Mut}, +}; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; -pub struct Date32Deserializer<'a>(ArrayBufferIterator<'a, i32>); +pub struct Date32Deserializer<'a> { + path: String, + array: ArrayBufferIterator<'a, i32>, +} impl<'a> Date32Deserializer<'a> { - pub fn new(buffer: &'a [i32], validity: Option>) -> Self { - Self(ArrayBufferIterator::new(buffer, validity)) + pub fn new(path: String, buffer: &'a [i32], validity: Option>) -> Self { + Self { + path, + array: ArrayBufferIterator::new(buffer, validity), + } } pub fn get_string_repr(&self, ts: i32) -> Result { @@ -21,31 +31,37 @@ impl<'a> Date32Deserializer<'a> { } } +impl<'de> Context for Date32Deserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Date32") + } +} + impl<'de> SimpleDeserializer<'de> for Date32Deserializer<'de> { fn name() -> &'static str { "Date32Deserializer" } fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { self.deserialize_i32(visitor) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_i32>(&mut self, visitor: V) -> Result { - visitor.visit_i32(self.0.next_required()?) + visitor.visit_i32(self.array.next_required()?) } fn deserialize_str>(&mut self, visitor: V) -> Result { @@ -53,7 +69,7 @@ impl<'de> SimpleDeserializer<'de> for Date32Deserializer<'de> { } fn deserialize_string>(&mut self, visitor: V) -> Result { - let ts = self.0.next_required()?; + let ts = self.array.next_required()?; visitor.visit_string(self.get_string_repr(ts)?) } } diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index d6dd2be2..dda67c5d 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -3,26 +3,37 @@ use serde::de::Visitor; use crate::internal::{ arrow::{BitsWithOffset, TimeUnit}, - error::{fail, Result}, - utils::Mut, + error::{fail, Context, Result}, + utils::{btree_map, Mut}, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; -pub struct Date64Deserializer<'a>(ArrayBufferIterator<'a, i64>, TimeUnit, bool); +pub struct Date64Deserializer<'a> { + path: String, + array: ArrayBufferIterator<'a, i64>, + unit: TimeUnit, + is_utc: bool, +} impl<'a> Date64Deserializer<'a> { pub fn new( + path: String, buffer: &'a [i64], validity: Option>, unit: TimeUnit, is_utc: bool, ) -> Self { - Self(ArrayBufferIterator::new(buffer, validity), unit, is_utc) + Self { + path, + array: ArrayBufferIterator::new(buffer, validity), + unit, + is_utc, + } } pub fn get_string_repr(&self, ts: i64) -> Result { - let Some(date_time) = (match self.1 { + let Some(date_time) = (match self.unit { TimeUnit::Second => DateTime::from_timestamp(ts, 0), TimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), TimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), @@ -31,7 +42,7 @@ impl<'a> Date64Deserializer<'a> { fail!("Unsupported timestamp value: {ts}"); }; - if self.2 { + if self.is_utc { // NOTE: chrono documents that Debug, not Display, can be parsed Ok(format!("{:?}", date_time)) } else { @@ -41,31 +52,37 @@ impl<'a> Date64Deserializer<'a> { } } +impl<'de> Context for Date64Deserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Date64") + } +} + impl<'de> SimpleDeserializer<'de> for Date64Deserializer<'de> { fn name() -> &'static str { "Date64Deserializer" } fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { self.deserialize_i64(visitor) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_i64>(&mut self, visitor: V) -> Result { - visitor.visit_i64(self.0.next_required()?) + visitor.visit_i64(self.array.next_required()?) } fn deserialize_str>(&mut self, visitor: V) -> Result { @@ -73,7 +90,7 @@ impl<'de> SimpleDeserializer<'de> for Date64Deserializer<'de> { } fn deserialize_string>(&mut self, visitor: V) -> Result { - let ts = self.0.next_required()?; + let ts = self.array.next_required()?; visitor.visit_string(self.get_string_repr(ts)?) } } diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index 11b62a45..b84e72c9 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -2,26 +2,34 @@ use serde::de::Visitor; use crate::internal::{ arrow::DecimalArrayView, - error::Result, - utils::{decimal, Mut}, + error::{Context, Result}, + utils::{btree_map, decimal, Mut}, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; pub struct DecimalDeserializer<'a> { + path: String, inner: ArrayBufferIterator<'a, i128>, scale: i8, } impl<'a> DecimalDeserializer<'a> { - pub fn new(view: DecimalArrayView<'a, i128>) -> Self { + pub fn new(path: String, view: DecimalArrayView<'a, i128>) -> Self { Self { + path, inner: ArrayBufferIterator::new(view.values, view.validity), scale: view.scale, } } } +impl<'de> Context for DecimalDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Decimal128(..)") + } +} + impl<'de> SimpleDeserializer<'de> for DecimalDeserializer<'de> { fn name() -> &'static str { "DecimalDeserializer" diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index ef9bb397..edb4d1c2 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::{BytesArrayView, PrimitiveArrayView}, - error::{fail, Result}, - utils::{Mut, Offset}, + error::{fail, Context, Result}, + utils::{btree_map, Mut, Offset}, }; use super::{ @@ -12,18 +12,24 @@ use super::{ }; pub struct DictionaryDeserializer<'a, K: Integer, V: Offset> { + path: String, keys: ArrayBufferIterator<'a, K>, offsets: &'a [V], data: &'a [u8], } impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { - pub fn new(keys: PrimitiveArrayView<'a, K>, values: BytesArrayView<'a, V>) -> Result { + pub fn new( + path: String, + keys: PrimitiveArrayView<'a, K>, + values: BytesArrayView<'a, V>, + ) -> Result { if values.validity.is_some() { // TODO: check whether all values are defined? fail!("dictionaries with nullable values are not supported"); } Ok(Self { + path, keys: ArrayBufferIterator::new(keys.values, keys.validity), offsets: values.offsets, data: values.data, @@ -47,6 +53,12 @@ impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { } } +impl<'de, K: Integer, V: Offset> Context for DictionaryDeserializer<'de, K, V> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Dictionary(..)") + } +} + impl<'de, K: Integer, V: Offset> SimpleDeserializer<'de> for DictionaryDeserializer<'de, K, V> { fn name() -> &'static str { "DictionaryDeserializer" diff --git a/serde_arrow/src/internal/deserialization/enum_deserializer.rs b/serde_arrow/src/internal/deserialization/enum_deserializer.rs index 8a322d89..6eda8064 100644 --- a/serde_arrow/src/internal/deserialization/enum_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/enum_deserializer.rs @@ -1,21 +1,27 @@ use serde::de::{DeserializeSeed, Deserializer, EnumAccess, Visitor}; use crate::internal::{ - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer}; pub struct EnumDeserializer<'a> { + pub path: String, pub type_ids: &'a [i8], pub variants: Vec<(String, ArrayDeserializer<'a>)>, pub next: usize, } impl<'a> EnumDeserializer<'a> { - pub fn new(type_ids: &'a [i8], variants: Vec<(String, ArrayDeserializer<'a>)>) -> Self { + pub fn new( + path: String, + type_ids: &'a [i8], + variants: Vec<(String, ArrayDeserializer<'a>)>, + ) -> Self { Self { + path, type_ids, variants, next: 0, @@ -23,6 +29,12 @@ impl<'a> EnumDeserializer<'a> { } } +impl<'de> Context for EnumDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Union(..)") + } +} + impl<'de> SimpleDeserializer<'de> for EnumDeserializer<'de> { fn name() -> &'static str { "EnumDeserializer" diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index 757a6989..fd65987b 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -2,20 +2,21 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::FixedSizeBinaryArrayView, - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; pub struct FixedSizeBinaryDeserializer<'a> { + pub path: String, pub view: FixedSizeBinaryArrayView<'a>, pub next: (usize, usize), pub shape: (usize, usize), } impl<'a> FixedSizeBinaryDeserializer<'a> { - pub fn new(view: FixedSizeBinaryArrayView<'a>) -> Result { + pub fn new(path: String, view: FixedSizeBinaryArrayView<'a>) -> Result { let n = usize::try_from(view.n)?; if view.data.len() % n != 0 { fail!( @@ -30,6 +31,7 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { let shape = (view.data.len() / n, n); Ok(Self { + path, view, shape, next: (0, 0), @@ -62,6 +64,12 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { } } +impl<'a> Context for FixedSizeBinaryDeserializer<'a> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "FixedSizeBinary(..)") + } +} + impl<'a> SimpleDeserializer<'a> for FixedSizeBinaryDeserializer<'a> { fn name() -> &'static str { "FixedSizeBinaryDeserializer" diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index e146b309..a35de9c7 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{IgnoredAny, SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{ @@ -12,6 +12,7 @@ use super::{ }; pub struct FixedSizeListDeserializer<'a> { + pub path: String, pub item: Box>, pub validity: Option>, pub shape: (usize, usize), @@ -20,12 +21,14 @@ pub struct FixedSizeListDeserializer<'a> { impl<'a> FixedSizeListDeserializer<'a> { pub fn new( + path: String, item: ArrayDeserializer<'a>, validity: Option>, n: usize, len: usize, ) -> Self { Self { + path, item: Box::new(item), validity, shape: (len, n), @@ -54,6 +57,12 @@ impl<'a> FixedSizeListDeserializer<'a> { } } +impl<'a> Context for FixedSizeListDeserializer<'a> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "FixedSizeList(..)") + } +} + impl<'a> SimpleDeserializer<'a> for FixedSizeListDeserializer<'a> { fn name() -> &'static str { "ListDeserializer" diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 36dba2a1..8d2fd873 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -1,6 +1,10 @@ use serde::de::Visitor; -use crate::internal::{arrow::PrimitiveArrayView, error::Result, utils::Mut}; +use crate::internal::{ + arrow::PrimitiveArrayView, + error::{Context, Result}, + utils::{btree_map, Mut, NamedType}, +}; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -14,42 +18,60 @@ pub trait Float: Copy { fn into_f64(self) -> Result; } -pub struct FloatDeserializer<'a, F: Float>(ArrayBufferIterator<'a, F>); +pub struct FloatDeserializer<'a, F: Float> { + path: String, + array: ArrayBufferIterator<'a, F>, +} impl<'a, F: Float> FloatDeserializer<'a, F> { - pub fn new(view: PrimitiveArrayView<'a, F>) -> Self { - Self(ArrayBufferIterator::new(view.values, view.validity)) + pub fn new(path: String, view: PrimitiveArrayView<'a, F>) -> Self { + Self { + path, + array: ArrayBufferIterator::new(view.values, view.validity), + } + } +} + +impl<'de, F: NamedType + Float> Context for FloatDeserializer<'de, F> { + fn annotations(&self) -> std::collections::BTreeMap { + let data_type = match F::NAME { + "f16" => "Float16", + "f32" => "Float32", + "f64" => "Float64", + _ => "", + }; + btree_map!("path" => self.path.clone(), "data_type" => data_type) } } -impl<'de, F: Float> SimpleDeserializer<'de> for FloatDeserializer<'de, F> { +impl<'de, F: NamedType + Float> SimpleDeserializer<'de> for FloatDeserializer<'de, F> { fn name() -> &'static str { "FloatDeserializer" } fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { F::deserialize_any(self, visitor) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_f32>(&mut self, visitor: V) -> Result { - visitor.visit_f32(self.0.next_required()?.into_f32()?) + visitor.visit_f32(self.array.next_required()?.into_f32()?) } fn deserialize_f64>(&mut self, visitor: V) -> Result { - visitor.visit_f64(self.0.next_required()?.into_f64()?) + visitor.visit_f64(self.array.next_required()?.into_f64()?) } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index 4c5afe1a..abb2153f 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -1,6 +1,10 @@ use serde::de::Visitor; -use crate::internal::{arrow::PrimitiveArrayView, error::Result, utils::Mut}; +use crate::internal::{ + arrow::PrimitiveArrayView, + error::{Context, Result}, + utils::{btree_map, Mut, NamedType}, +}; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -23,11 +27,34 @@ pub trait Integer: Sized + Copy { fn into_u64(self) -> Result; } -pub struct IntegerDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>); +pub struct IntegerDeserializer<'a, T: Integer> { + path: String, + array: ArrayBufferIterator<'a, T>, +} impl<'a, T: Integer> IntegerDeserializer<'a, T> { - pub fn new(view: PrimitiveArrayView<'a, T>) -> Self { - Self(ArrayBufferIterator::new(view.values, view.validity)) + pub fn new(path: String, view: PrimitiveArrayView<'a, T>) -> Self { + Self { + path, + array: ArrayBufferIterator::new(view.values, view.validity), + } + } +} + +impl<'de, T: NamedType + Integer> Context for IntegerDeserializer<'de, T> { + fn annotations(&self) -> std::collections::BTreeMap { + let data_type = match T::NAME { + "i8" => "Int8", + "i16" => "Int16", + "i32" => "Int32", + "i64" => "Int64", + "u8" => "UInt8", + "u16" => "UInt16", + "u32" => "UInt32", + "u64" => "UInt64", + _ => "", + }; + btree_map!("path" => self.path.clone(), "data_type" => data_type) } } @@ -37,60 +64,60 @@ impl<'de, T: Integer> SimpleDeserializer<'de> for IntegerDeserializer<'de, T> { } fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { T::deserialize_any(self, visitor) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_bool>(&mut self, visitor: V) -> Result { - visitor.visit_bool(self.0.next_required()?.into_bool()?) + visitor.visit_bool(self.array.next_required()?.into_bool()?) } fn deserialize_char>(&mut self, visitor: V) -> Result { - visitor.visit_char(self.0.next_required()?.into_u32()?.try_into()?) + visitor.visit_char(self.array.next_required()?.into_u32()?.try_into()?) } fn deserialize_u8>(&mut self, visitor: V) -> Result { - visitor.visit_u8(self.0.next_required()?.into_u8()?) + visitor.visit_u8(self.array.next_required()?.into_u8()?) } fn deserialize_u16>(&mut self, visitor: V) -> Result { - visitor.visit_u16(self.0.next_required()?.into_u16()?) + visitor.visit_u16(self.array.next_required()?.into_u16()?) } fn deserialize_u32>(&mut self, visitor: V) -> Result { - visitor.visit_u32(self.0.next_required()?.into_u32()?) + visitor.visit_u32(self.array.next_required()?.into_u32()?) } fn deserialize_u64>(&mut self, visitor: V) -> Result { - visitor.visit_u64(self.0.next_required()?.into_u64()?) + visitor.visit_u64(self.array.next_required()?.into_u64()?) } fn deserialize_i8>(&mut self, visitor: V) -> Result { - visitor.visit_i8(self.0.next_required()?.into_i8()?) + visitor.visit_i8(self.array.next_required()?.into_i8()?) } fn deserialize_i16>(&mut self, visitor: V) -> Result { - visitor.visit_i16(self.0.next_required()?.into_i16()?) + visitor.visit_i16(self.array.next_required()?.into_i16()?) } fn deserialize_i32>(&mut self, visitor: V) -> Result { - visitor.visit_i32(self.0.next_required()?.into_i32()?) + visitor.visit_i32(self.array.next_required()?.into_i32()?) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - visitor.visit_i64(self.0.next_required()?.into_i64()?) + visitor.visit_i64(self.array.next_required()?.into_i64()?) } } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index ebf45562..10e60aa8 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Error, Result}, - utils::{Mut, Offset}, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut, NamedType, Offset}, }; use super::{ @@ -13,6 +13,7 @@ use super::{ }; pub struct ListDeserializer<'a, O: Offset> { + pub path: String, pub item: Box>, pub offsets: &'a [O], pub validity: Option>, @@ -21,6 +22,7 @@ pub struct ListDeserializer<'a, O: Offset> { impl<'a, O: Offset> ListDeserializer<'a, O> { pub fn new( + path: String, item: ArrayDeserializer<'a>, offsets: &'a [O], validity: Option>, @@ -28,6 +30,7 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { check_supported_list_layout(validity, offsets)?; Ok(Self { + path, item: Box::new(item), offsets, validity, @@ -51,6 +54,17 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { } } +impl<'a, O: NamedType + Offset> Context for ListDeserializer<'a, O> { + fn annotations(&self) -> std::collections::BTreeMap { + let data_type = match O::NAME { + "i32" => "List(..)", + "i64" => "LargeList(..)", + _ => "", + }; + btree_map!("path" => self.path.clone(), "data_type" => data_type) + } +} + impl<'a, O: Offset> SimpleDeserializer<'a> for ListDeserializer<'a, O> { fn name() -> &'static str { "ListDeserializer" diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index fde9e0a0..73be7185 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{DeserializeSeed, MapAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{ @@ -13,6 +13,7 @@ use super::{ }; pub struct MapDeserializer<'a> { + path: String, key: Box>, value: Box>, offsets: &'a [i32], @@ -22,6 +23,7 @@ pub struct MapDeserializer<'a> { impl<'a> MapDeserializer<'a> { pub fn new( + path: String, key: ArrayDeserializer<'a>, value: ArrayDeserializer<'a>, offsets: &'a [i32], @@ -30,6 +32,7 @@ impl<'a> MapDeserializer<'a> { check_supported_list_layout(validity, offsets)?; Ok(Self { + path, key: Box::new(key), value: Box::new(value), offsets, @@ -54,6 +57,12 @@ impl<'a> MapDeserializer<'a> { } } +impl<'de> Context for MapDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Map(..)") + } +} + impl<'de> SimpleDeserializer<'de> for MapDeserializer<'de> { fn name() -> &'static str { "MapDeserializer" diff --git a/serde_arrow/src/internal/deserialization/null_deserializer.rs b/serde_arrow/src/internal/deserialization/null_deserializer.rs index 2909c491..adafe1c1 100644 --- a/serde_arrow/src/internal/deserialization/null_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/null_deserializer.rs @@ -1,10 +1,27 @@ use serde::de::Visitor; -use crate::internal::error::Result; +use crate::internal::{ + error::{Context, Result}, + utils::btree_map, +}; use super::simple_deserializer::SimpleDeserializer; -pub struct NullDeserializer; +pub struct NullDeserializer { + path: String, +} + +impl NullDeserializer { + pub fn new(path: String) -> Self { + Self { path } + } +} + +impl Context for NullDeserializer { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Null") + } +} impl<'de> SimpleDeserializer<'de> for NullDeserializer { fn name() -> &'static str { diff --git a/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs b/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs index f4497ddc..ee2914d6 100644 --- a/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs @@ -1,8 +1,8 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ - error::{Error, Result}, - utils::Mut, + error::{Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{ @@ -19,13 +19,19 @@ pub struct OuterSequenceDeserializer<'a> { impl<'a> OuterSequenceDeserializer<'a> { pub fn new(fields: Vec<(String, ArrayDeserializer<'a>)>, len: usize) -> Self { Self { - item: StructDeserializer::new(fields, None, len), + item: StructDeserializer::new(String::from("$"), fields, None, len), next: 0, len, } } } +impl<'de> Context for OuterSequenceDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!() + } +} + impl<'de> SimpleDeserializer<'de> for OuterSequenceDeserializer<'de> { fn name() -> &'static str { "OuterSequenceDeserializer" diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index 76f31a77..6dc5df25 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -1,7 +1,7 @@ use crate::internal::{ arrow::BytesArrayView, - error::{fail, Result}, - utils::{Mut, Offset}, + error::{fail, Context, Result}, + utils::{btree_map, Mut, NamedType, Offset}, }; use super::{ @@ -9,13 +9,18 @@ use super::{ }; pub struct StringDeserializer<'a, O: Offset> { + pub path: String, pub view: BytesArrayView<'a, O>, pub next: usize, } impl<'a, O: Offset> StringDeserializer<'a, O> { - pub fn new(view: BytesArrayView<'a, O>) -> Self { - Self { view, next: 0 } + pub fn new(path: String, view: BytesArrayView<'a, O>) -> Self { + Self { + path, + view, + next: 0, + } } pub fn next(&mut self) -> Result> { @@ -63,6 +68,17 @@ impl<'a, O: Offset> StringDeserializer<'a, O> { } } +impl<'a, O: NamedType + Offset> Context for StringDeserializer<'a, O> { + fn annotations(&self) -> std::collections::BTreeMap { + let data_type = match O::NAME { + "i32" => "Utf8", + "i64" => "LargeUtf8", + _ => "", + }; + btree_map!("path" => self.path.clone(), "data_type" => data_type) + } +} + impl<'a, O: Offset> SimpleDeserializer<'a> for StringDeserializer<'a, O> { fn name() -> &'static str { "StringDeserializer" diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index df7b1879..83c32110 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -4,8 +4,8 @@ use serde::de::{ use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Error, Result}, - utils::Mut, + error::{fail, Context, Error, Result}, + utils::{btree_map, Mut}, }; use super::{ @@ -14,6 +14,7 @@ use super::{ }; pub struct StructDeserializer<'a> { + pub path: String, pub fields: Vec<(String, ArrayDeserializer<'a>)>, pub validity: Option>, pub next: (usize, usize), @@ -22,11 +23,13 @@ pub struct StructDeserializer<'a> { impl<'a> StructDeserializer<'a> { pub fn new( + path: String, fields: Vec<(String, ArrayDeserializer<'a>)>, validity: Option>, len: usize, ) -> Self { Self { + path, fields, validity, len, @@ -50,6 +53,12 @@ impl<'a> StructDeserializer<'a> { } } +impl<'de> Context for StructDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!("path" => self.path.clone(), "data_type" => "Struct(..)") + } +} + impl<'de> SimpleDeserializer<'de> for StructDeserializer<'de> { fn name() -> &'static str { "StructDeserializer" diff --git a/serde_arrow/src/internal/deserialization/test.rs b/serde_arrow/src/internal/deserialization/test.rs index 362f4fee..5474050a 100644 --- a/serde_arrow/src/internal/deserialization/test.rs +++ b/serde_arrow/src/internal/deserialization/test.rs @@ -16,17 +16,23 @@ fn example() { vec![ ( String::from("a"), - ArrayDeserializer::I32(IntegerDeserializer::new(PrimitiveArrayView { - values: &[1, 2, 3], - validity: None, - })), + ArrayDeserializer::I32(IntegerDeserializer::new( + String::from("$"), + PrimitiveArrayView { + values: &[1, 2, 3], + validity: None, + }, + )), ), ( String::from("b"), - ArrayDeserializer::I32(IntegerDeserializer::new(PrimitiveArrayView { - values: &[4, 5, 6], - validity: None, - })), + ArrayDeserializer::I32(IntegerDeserializer::new( + String::from("$"), + PrimitiveArrayView { + values: &[4, 5, 6], + validity: None, + }, + )), ), ], 3, diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index ce5ae075..1d41ba2f 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -3,8 +3,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::{TimeArrayView, TimeUnit}, - error::{fail, Result}, - utils::Mut, + error::{fail, Context, Result}, + utils::{btree_map, Mut, NamedType}, }; use super::{ @@ -12,10 +12,15 @@ use super::{ utils::ArrayBufferIterator, }; -pub struct TimeDeserializer<'a, T: Integer>(ArrayBufferIterator<'a, T>, i64, i64); +pub struct TimeDeserializer<'a, T: Integer> { + path: String, + array: ArrayBufferIterator<'a, T>, + seconds_factor: i64, + nanoseconds_factor: i64, +} impl<'a, T: Integer> TimeDeserializer<'a, T> { - pub fn new(view: TimeArrayView<'a, T>) -> Self { + pub fn new(path: String, view: TimeArrayView<'a, T>) -> Self { let (seconds_factor, nanoseconds_factor) = match view.unit { TimeUnit::Nanosecond => (1_000_000_000, 1), TimeUnit::Microsecond => (1_000_000, 1_000), @@ -23,16 +28,17 @@ impl<'a, T: Integer> TimeDeserializer<'a, T> { TimeUnit::Second => (1, 1_000_000_000), }; - Self( - ArrayBufferIterator::new(view.values, view.validity), + Self { + path, + array: ArrayBufferIterator::new(view.values, view.validity), seconds_factor, nanoseconds_factor, - ) + } } pub fn get_string_repr(&self, ts: i64) -> Result { - let seconds = (ts / self.1) as u32; - let nanoseconds = ((ts % self.1) / self.2) as u32; + let seconds = (ts / self.seconds_factor) as u32; + let nanoseconds = ((ts % self.seconds_factor) / self.nanoseconds_factor) as u32; let Some(res) = NaiveTime::from_num_seconds_from_midnight_opt(seconds, nanoseconds) else { fail!("Invalid timestamp"); @@ -41,35 +47,46 @@ impl<'a, T: Integer> TimeDeserializer<'a, T> { } } +impl<'de, T: NamedType + Integer> Context for TimeDeserializer<'de, T> { + fn annotations(&self) -> std::collections::BTreeMap { + let data_type = match T::NAME { + "i32" => "Time32", + "i64" => "Time64", + _ => "", + }; + btree_map!("path" => self.path.clone(), "data_type" => data_type) + } +} + impl<'de, T: Integer> SimpleDeserializer<'de> for TimeDeserializer<'de, T> { fn name() -> &'static str { "Time64Deserializer" } fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { T::deserialize_any(self, visitor) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.0.peek_next()? { + if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { - self.0.consume_next(); + self.array.consume_next(); visitor.visit_none() } } fn deserialize_i32>(&mut self, visitor: V) -> Result { - visitor.visit_i32(self.0.next_required()?.into_i32()?) + visitor.visit_i32(self.array.next_required()?.into_i32()?) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - visitor.visit_i64(self.0.next_required()?.into_i64()?) + visitor.visit_i64(self.array.next_required()?.into_i64()?) } fn deserialize_str>(&mut self, visitor: V) -> Result { @@ -77,7 +94,7 @@ impl<'de, T: Integer> SimpleDeserializer<'de> for TimeDeserializer<'de, T> { } fn deserialize_string>(&mut self, visitor: V) -> Result { - let ts = self.0.next_required()?.into_i64()?; + let ts = self.array.next_required()?.into_i64()?; visitor.visit_string(self.get_string_repr(ts)?) } } diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index e4db6f9d..bcfd5f47 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -33,7 +33,7 @@ impl<'de> Deserializer<'de> { fail!("Cannot deserialize from arrays with different lengths"); } let strategy = get_strategy_from_metadata(&field.metadata)?; - let deserializer = ArrayDeserializer::new(strategy.as_ref(), view)?; + let deserializer = ArrayDeserializer::new(String::from("$"), strategy.as_ref(), view)?; deserializers.push((field.name.clone(), deserializer)); } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 0e5fa071..3cafa254 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -11,7 +11,7 @@ use crate::internal::{ fixed_size_binary_builder::FixedSizeBinaryBuilder, fixed_size_list_builder::FixedSizeListBuilder, }, - utils::{btree_map, meta_from_field, Mut}, + utils::{btree_map, meta_from_field, ChildName, Mut}, }; use super::{ @@ -285,15 +285,3 @@ fn is_utc_strategy(strategy: Option<&Strategy>) -> Result { Some(st) => fail!("Cannot builder Date64 builder with strategy {st}"), } } - -struct ChildName<'a>(&'a str); - -impl<'a> std::fmt::Display for ChildName<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if !self.0.is_empty() { - write!(f, "{}", self.0) - } else { - write!(f, "") - } - } -} diff --git a/serde_arrow/src/internal/utils/mod.rs b/serde_arrow/src/internal/utils/mod.rs index 8a39ba1e..7fa35e5d 100644 --- a/serde_arrow/src/internal/utils/mod.rs +++ b/serde_arrow/src/internal/utils/mod.rs @@ -219,3 +219,15 @@ macro_rules! btree_map { } pub(crate) use btree_map; + +pub struct ChildName<'a>(pub &'a str); + +impl<'a> std::fmt::Display for ChildName<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if !self.0.is_empty() { + write!(f, "{}", self.0) + } else { + write!(f, "") + } + } +} From 8dc74247451766c5ae76f451b16b4828c3772ca9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 19:38:52 +0200 Subject: [PATCH 150/178] Add context info for default impls in SimpleDeserializer --- .../deserialization/array_deserializer.rs | 10 +-- .../deserialization/binary_deserializer.rs | 12 ++-- .../deserialization/bool_deserializer.rs | 4 -- .../deserialization/date32_deserializer.rs | 4 -- .../deserialization/date64_deserializer.rs | 4 -- .../deserialization/decimal_deserializer.rs | 4 -- .../dictionary_deserializer.rs | 4 -- .../deserialization/enum_deserializer.rs | 4 -- .../fixed_size_binary_deserializer.rs | 12 ++-- .../fixed_size_list_deserializer.rs | 4 -- .../deserialization/float_deserializer.rs | 4 -- .../deserialization/integer_deserializer.rs | 6 +- .../deserialization/list_deserializer.rs | 6 +- .../deserialization/map_deserializer.rs | 4 -- .../deserialization/null_deserializer.rs | 4 -- .../outer_sequence_deserializer.rs | 4 -- .../deserialization/simple_deserializer.rs | 66 +++++++++---------- .../deserialization/string_deserializer.rs | 6 +- .../deserialization/struct_deserializer.rs | 4 -- .../deserialization/time_deserializer.rs | 6 +- 20 files changed, 51 insertions(+), 121 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index 4b749a9e..c343a475 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -3,7 +3,7 @@ use serde::de::{Deserialize, DeserializeSeed, VariantAccess, Visitor}; use crate::internal::{ arrow::{ArrayView, FieldMeta, PrimitiveArrayView, TimeUnit}, - error::{fail, Error, Result}, + error::{fail, Context, Error, Result}, schema::{Strategy, STRATEGY_KEY}, utils::{ChildName, Mut}, }; @@ -373,11 +373,13 @@ macro_rules! dispatch { }; } -impl<'de> SimpleDeserializer<'de> for ArrayDeserializer<'de> { - fn name() -> &'static str { - "ArrayDeserializer" +impl<'de> Context for ArrayDeserializer<'de> { + fn annotations(&self) -> std::collections::BTreeMap { + dispatch!(self, ArrayDeserializer(deser) => deser.annotations()) } +} +impl<'de> SimpleDeserializer<'de> for ArrayDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { dispatch!(self, ArrayDeserializer(deser) => deser.deserialize_any(visitor)) } diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index 05c9c645..f637dbd1 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -68,10 +68,6 @@ impl<'a, O: Offset + NamedType> Context for BinaryDeserializer<'a, O> { } impl<'a, O: Offset + NamedType> SimpleDeserializer<'a> for BinaryDeserializer<'a, O> { - fn name() -> &'static str { - "BinaryDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_bytes(visitor) @@ -127,11 +123,13 @@ impl<'de, O: Offset> SeqAccess<'de> for BinaryDeserializer<'de, O> { struct U8Deserializer(u8); -impl<'de> SimpleDeserializer<'de> for U8Deserializer { - fn name() -> &'static str { - "U8Deserializer" +impl Context for U8Deserializer { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!() } +} +impl<'de> SimpleDeserializer<'de> for U8Deserializer { fn deserialize_u8>(&mut self, visitor: V) -> Result { visitor.visit_u8(self.0) } diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index ad8529bd..00ec4ef5 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -69,10 +69,6 @@ impl<'de> Context for BoolDeserializer<'de> { } impl<'de> SimpleDeserializer<'de> for BoolDeserializer<'de> { - fn name() -> &'static str { - "BoolDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_bool(visitor) diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index 9057605c..9aba5634 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -38,10 +38,6 @@ impl<'de> Context for Date32Deserializer<'de> { } impl<'de> SimpleDeserializer<'de> for Date32Deserializer<'de> { - fn name() -> &'static str { - "Date32Deserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { self.deserialize_i32(visitor) diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index dda67c5d..f184134e 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -59,10 +59,6 @@ impl<'de> Context for Date64Deserializer<'de> { } impl<'de> SimpleDeserializer<'de> for Date64Deserializer<'de> { - fn name() -> &'static str { - "Date64Deserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { self.deserialize_i64(visitor) diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index b84e72c9..ad052128 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -31,10 +31,6 @@ impl<'de> Context for DecimalDeserializer<'de> { } impl<'de> SimpleDeserializer<'de> for DecimalDeserializer<'de> { - fn name() -> &'static str { - "DecimalDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.inner.peek_next()? { self.deserialize_str(visitor) diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index edb4d1c2..20e99acc 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -60,10 +60,6 @@ impl<'de, K: Integer, V: Offset> Context for DictionaryDeserializer<'de, K, V> { } impl<'de, K: Integer, V: Offset> SimpleDeserializer<'de> for DictionaryDeserializer<'de, K, V> { - fn name() -> &'static str { - "DictionaryDeserializer" - } - fn deserialize_any>(&mut self, visitor: VV) -> Result { if self.keys.peek_next()? { self.deserialize_str(visitor) diff --git a/serde_arrow/src/internal/deserialization/enum_deserializer.rs b/serde_arrow/src/internal/deserialization/enum_deserializer.rs index 6eda8064..8b819096 100644 --- a/serde_arrow/src/internal/deserialization/enum_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/enum_deserializer.rs @@ -36,10 +36,6 @@ impl<'de> Context for EnumDeserializer<'de> { } impl<'de> SimpleDeserializer<'de> for EnumDeserializer<'de> { - fn name() -> &'static str { - "EnumDeserializer" - } - fn deserialize_enum>( &mut self, _: &'static str, diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index fd65987b..fa9f6fc8 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -71,10 +71,6 @@ impl<'a> Context for FixedSizeBinaryDeserializer<'a> { } impl<'a> SimpleDeserializer<'a> for FixedSizeBinaryDeserializer<'a> { - fn name() -> &'static str { - "FixedSizeBinaryDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_bytes(visitor) @@ -130,11 +126,13 @@ impl<'de> SeqAccess<'de> for FixedSizeBinaryDeserializer<'de> { struct U8Deserializer(u8); -impl<'de> SimpleDeserializer<'de> for U8Deserializer { - fn name() -> &'static str { - "U8Deserializer" +impl Context for U8Deserializer { + fn annotations(&self) -> std::collections::BTreeMap { + btree_map!() } +} +impl<'de> SimpleDeserializer<'de> for U8Deserializer { fn deserialize_u8>(&mut self, visitor: V) -> Result { visitor.visit_u8(self.0) } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index a35de9c7..3f10c2e4 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -64,10 +64,6 @@ impl<'a> Context for FixedSizeListDeserializer<'a> { } impl<'a> SimpleDeserializer<'a> for FixedSizeListDeserializer<'a> { - fn name() -> &'static str { - "ListDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_seq(visitor) diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 8d2fd873..5edb30bf 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -45,10 +45,6 @@ impl<'de, F: NamedType + Float> Context for FloatDeserializer<'de, F> { } impl<'de, F: NamedType + Float> SimpleDeserializer<'de> for FloatDeserializer<'de, F> { - fn name() -> &'static str { - "FloatDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { F::deserialize_any(self, visitor) diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index abb2153f..40124baf 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -58,11 +58,7 @@ impl<'de, T: NamedType + Integer> Context for IntegerDeserializer<'de, T> { } } -impl<'de, T: Integer> SimpleDeserializer<'de> for IntegerDeserializer<'de, T> { - fn name() -> &'static str { - "IntegerDeserializer" - } - +impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for IntegerDeserializer<'de, T> { fn deserialize_any>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { T::deserialize_any(self, visitor) diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 10e60aa8..562eb988 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -65,11 +65,7 @@ impl<'a, O: NamedType + Offset> Context for ListDeserializer<'a, O> { } } -impl<'a, O: Offset> SimpleDeserializer<'a> for ListDeserializer<'a, O> { - fn name() -> &'static str { - "ListDeserializer" - } - +impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for ListDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_seq(visitor) diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 73be7185..69fdd230 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -64,10 +64,6 @@ impl<'de> Context for MapDeserializer<'de> { } impl<'de> SimpleDeserializer<'de> for MapDeserializer<'de> { - fn name() -> &'static str { - "MapDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_map(visitor) diff --git a/serde_arrow/src/internal/deserialization/null_deserializer.rs b/serde_arrow/src/internal/deserialization/null_deserializer.rs index adafe1c1..3f0a300d 100644 --- a/serde_arrow/src/internal/deserialization/null_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/null_deserializer.rs @@ -24,10 +24,6 @@ impl Context for NullDeserializer { } impl<'de> SimpleDeserializer<'de> for NullDeserializer { - fn name() -> &'static str { - "NullDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { visitor.visit_unit() } diff --git a/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs b/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs index ee2914d6..c4504880 100644 --- a/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs @@ -33,10 +33,6 @@ impl<'de> Context for OuterSequenceDeserializer<'de> { } impl<'de> SimpleDeserializer<'de> for OuterSequenceDeserializer<'de> { - fn name() -> &'static str { - "OuterSequenceDeserializer" - } - fn deserialize_newtype_struct>( &mut self, _: &'static str, diff --git a/serde_arrow/src/internal/deserialization/simple_deserializer.rs b/serde_arrow/src/internal/deserialization/simple_deserializer.rs index 7d2e50bc..4753bc86 100644 --- a/serde_arrow/src/internal/deserialization/simple_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/simple_deserializer.rs @@ -1,16 +1,14 @@ use serde::{de::Visitor, Deserializer}; use crate::internal::{ - error::{fail, Error, Result}, + error::{fail, Context, Error, Result}, utils::Mut, }; #[allow(unused)] -pub trait SimpleDeserializer<'de>: Sized { - fn name() -> &'static str; - +pub trait SimpleDeserializer<'de>: Context + Sized { fn deserialize_any>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_any", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_any"); } fn deserialize_ignored_any>(&mut self, visitor: V) -> Result { @@ -18,63 +16,63 @@ pub trait SimpleDeserializer<'de>: Sized { } fn deserialize_bool>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_bool", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_bool"); } fn deserialize_i8>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_i8", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_i8"); } fn deserialize_i16>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_i16", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_i16"); } fn deserialize_i32>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_i32", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_i32"); } fn deserialize_i64>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_i64", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_i64"); } fn deserialize_u8>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_u8", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_u8"); } fn deserialize_u16>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_u16", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_u16"); } fn deserialize_u32>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_u32", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_u32"); } fn deserialize_u64>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_u64", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_u64"); } fn deserialize_f32>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_f32", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_f32"); } fn deserialize_f64>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_f64", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_f64"); } fn deserialize_char>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_char", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_char"); } fn deserialize_str>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_str", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_str"); } fn deserialize_string>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_string", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_string"); } fn deserialize_map>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_map", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_map"); } fn deserialize_struct>( @@ -83,15 +81,15 @@ pub trait SimpleDeserializer<'de>: Sized { fields: &'static [&'static str], visitor: V, ) -> Result { - fail!("{} does not implement deserialize_struct", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_struct"); } fn deserialize_byte_buf>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_byte_buf", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_byte_buf"); } fn deserialize_bytes>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_bytes", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_bytes"); } fn deserialize_enum>( @@ -100,15 +98,15 @@ pub trait SimpleDeserializer<'de>: Sized { variants: &'static [&'static str], visitor: V, ) -> Result { - fail!("{} does not implement deserialize_enum", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_enum"); } fn deserialize_identifier>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_identifier", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_identifier"); } fn deserialize_option>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_option", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_option"); } fn deserialize_newtype_struct>( @@ -120,11 +118,11 @@ pub trait SimpleDeserializer<'de>: Sized { } fn deserialize_tuple>(&mut self, len: usize, visitor: V) -> Result { - fail!("{} does not implement deserialize_tuple", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_tuple"); } fn deserialize_seq>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_seq", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_seq"); } fn deserialize_tuple_struct>( @@ -133,14 +131,13 @@ pub trait SimpleDeserializer<'de>: Sized { len: usize, visitor: V, ) -> Result { - fail!( - "{} does not implement deserialize_tuple_struct", - Self::name() + fail!(in self, + "Deserializer does not implement deserialize_tuple_struct", ); } fn deserialize_unit>(&mut self, visitor: V) -> Result { - fail!("{} does not implement deserialize_unit", Self::name()); + fail!(in self, "Deserializer does not implement deserialize_unit"); } fn deserialize_unit_struct>( @@ -148,9 +145,8 @@ pub trait SimpleDeserializer<'de>: Sized { name: &'static str, visitor: V, ) -> Result { - fail!( - "{} does not implement deserialize_unit_struct", - Self::name() + fail!(in self, + "Deserializer does not implement deserialize_unit_struct", ); } } diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index 6dc5df25..a5c46336 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -79,11 +79,7 @@ impl<'a, O: NamedType + Offset> Context for StringDeserializer<'a, O> { } } -impl<'a, O: Offset> SimpleDeserializer<'a> for StringDeserializer<'a, O> { - fn name() -> &'static str { - "StringDeserializer" - } - +impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for StringDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_str(visitor) diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index 83c32110..fe2db8ed 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -60,10 +60,6 @@ impl<'de> Context for StructDeserializer<'de> { } impl<'de> SimpleDeserializer<'de> for StructDeserializer<'de> { - fn name() -> &'static str { - "StructDeserializer" - } - fn deserialize_any>(&mut self, visitor: V) -> Result { if self.peek_next()? { visitor.visit_map(self) diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index 1d41ba2f..edb48ece 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -58,11 +58,7 @@ impl<'de, T: NamedType + Integer> Context for TimeDeserializer<'de, T> { } } -impl<'de, T: Integer> SimpleDeserializer<'de> for TimeDeserializer<'de, T> { - fn name() -> &'static str { - "Time64Deserializer" - } - +impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for TimeDeserializer<'de, T> { fn deserialize_any>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { T::deserialize_any(self, visitor) From 93580db45dce5dc3311ca6a31c5075f1a97f92ee Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 19:55:51 +0200 Subject: [PATCH 151/178] Implement Float, Integer, List, Map, Null, String, Time Deserializers --- .../deserialization/float_deserializer.rs | 26 ++++++- .../deserialization/integer_deserializer.rs | 74 ++++++++++++++++--- .../deserialization/list_deserializer.rs | 16 ++-- .../deserialization/map_deserializer.rs | 16 ++-- .../deserialization/null_deserializer.rs | 10 +-- .../deserialization/string_deserializer.rs | 44 +++++++++-- .../deserialization/time_deserializer.rs | 38 ++++++++-- 7 files changed, 176 insertions(+), 48 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 5edb30bf..765058be 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::PrimitiveArrayView, - error::{Context, Result}, + error::{Context, ContextSupport, Result}, utils::{btree_map, Mut, NamedType}, }; @@ -46,6 +46,24 @@ impl<'de, F: NamedType + Float> Context for FloatDeserializer<'de, F> { impl<'de, F: NamedType + Float> SimpleDeserializer<'de> for FloatDeserializer<'de, F> { fn deserialize_any>(&mut self, visitor: V) -> Result { + self.deserialize_any_impl(visitor).ctx(self) + } + + fn deserialize_option>(&mut self, visitor: V) -> Result { + self.deserialize_option_impl(visitor).ctx(self) + } + + fn deserialize_f32>(&mut self, visitor: V) -> Result { + self.deserialize_f32_impl(visitor).ctx(self) + } + + fn deserialize_f64>(&mut self, visitor: V) -> Result { + self.deserialize_f64_impl(visitor).ctx(self) + } +} + +impl<'de, F: NamedType + Float> FloatDeserializer<'de, F> { + fn deserialize_any_impl>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { F::deserialize_any(self, visitor) } else { @@ -54,7 +72,7 @@ impl<'de, F: NamedType + Float> SimpleDeserializer<'de> for FloatDeserializer<'d } } - fn deserialize_option>(&mut self, visitor: V) -> Result { + fn deserialize_option_impl>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { @@ -63,11 +81,11 @@ impl<'de, F: NamedType + Float> SimpleDeserializer<'de> for FloatDeserializer<'d } } - fn deserialize_f32>(&mut self, visitor: V) -> Result { + fn deserialize_f32_impl>(&mut self, visitor: V) -> Result { visitor.visit_f32(self.array.next_required()?.into_f32()?) } - fn deserialize_f64>(&mut self, visitor: V) -> Result { + fn deserialize_f64_impl>(&mut self, visitor: V) -> Result { visitor.visit_f64(self.array.next_required()?.into_f64()?) } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index 40124baf..159c50f3 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::PrimitiveArrayView, - error::{Context, Result}, + error::{Context, ContextSupport, Result}, utils::{btree_map, Mut, NamedType}, }; @@ -60,6 +60,56 @@ impl<'de, T: NamedType + Integer> Context for IntegerDeserializer<'de, T> { impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for IntegerDeserializer<'de, T> { fn deserialize_any>(&mut self, visitor: V) -> Result { + self.deserialize_any_impl(visitor).ctx(self) + } + + fn deserialize_option>(&mut self, visitor: V) -> Result { + self.deserialize_option_impl(visitor).ctx(self) + } + + fn deserialize_bool>(&mut self, visitor: V) -> Result { + self.deserialize_bool_impl(visitor).ctx(self) + } + + fn deserialize_char>(&mut self, visitor: V) -> Result { + self.deserialize_char_impl(visitor).ctx(self) + } + + fn deserialize_u8>(&mut self, visitor: V) -> Result { + self.deserialize_u8_impl(visitor).ctx(self) + } + + fn deserialize_u16>(&mut self, visitor: V) -> Result { + self.deserialize_u16_impl(visitor).ctx(self) + } + + fn deserialize_u32>(&mut self, visitor: V) -> Result { + self.deserialize_u32_impl(visitor).ctx(self) + } + + fn deserialize_u64>(&mut self, visitor: V) -> Result { + self.deserialize_u64_impl(visitor).ctx(self) + } + + fn deserialize_i8>(&mut self, visitor: V) -> Result { + self.deserialize_i8_impl(visitor).ctx(self) + } + + fn deserialize_i16>(&mut self, visitor: V) -> Result { + self.deserialize_i16_impl(visitor).ctx(self) + } + + fn deserialize_i32>(&mut self, visitor: V) -> Result { + self.deserialize_i32_impl(visitor).ctx(self) + } + + fn deserialize_i64>(&mut self, visitor: V) -> Result { + self.deserialize_i64_impl(visitor).ctx(self) + } +} + +impl<'de, T: NamedType + Integer> IntegerDeserializer<'de, T> { + fn deserialize_any_impl>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { T::deserialize_any(self, visitor) } else { @@ -68,7 +118,7 @@ impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for IntegerDeserialize } } - fn deserialize_option>(&mut self, visitor: V) -> Result { + fn deserialize_option_impl>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { @@ -77,43 +127,43 @@ impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for IntegerDeserialize } } - fn deserialize_bool>(&mut self, visitor: V) -> Result { + fn deserialize_bool_impl>(&mut self, visitor: V) -> Result { visitor.visit_bool(self.array.next_required()?.into_bool()?) } - fn deserialize_char>(&mut self, visitor: V) -> Result { + fn deserialize_char_impl>(&mut self, visitor: V) -> Result { visitor.visit_char(self.array.next_required()?.into_u32()?.try_into()?) } - fn deserialize_u8>(&mut self, visitor: V) -> Result { + fn deserialize_u8_impl>(&mut self, visitor: V) -> Result { visitor.visit_u8(self.array.next_required()?.into_u8()?) } - fn deserialize_u16>(&mut self, visitor: V) -> Result { + fn deserialize_u16_impl>(&mut self, visitor: V) -> Result { visitor.visit_u16(self.array.next_required()?.into_u16()?) } - fn deserialize_u32>(&mut self, visitor: V) -> Result { + fn deserialize_u32_impl>(&mut self, visitor: V) -> Result { visitor.visit_u32(self.array.next_required()?.into_u32()?) } - fn deserialize_u64>(&mut self, visitor: V) -> Result { + fn deserialize_u64_impl>(&mut self, visitor: V) -> Result { visitor.visit_u64(self.array.next_required()?.into_u64()?) } - fn deserialize_i8>(&mut self, visitor: V) -> Result { + fn deserialize_i8_impl>(&mut self, visitor: V) -> Result { visitor.visit_i8(self.array.next_required()?.into_i8()?) } - fn deserialize_i16>(&mut self, visitor: V) -> Result { + fn deserialize_i16_impl>(&mut self, visitor: V) -> Result { visitor.visit_i16(self.array.next_required()?.into_i16()?) } - fn deserialize_i32>(&mut self, visitor: V) -> Result { + fn deserialize_i32_impl>(&mut self, visitor: V) -> Result { visitor.visit_i32(self.array.next_required()?.into_i32()?) } - fn deserialize_i64>(&mut self, visitor: V) -> Result { + fn deserialize_i64_impl>(&mut self, visitor: V) -> Result { visitor.visit_i64(self.array.next_required()?.into_i64()?) } } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 562eb988..8bc7c292 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Error, Result}, utils::{btree_map, Mut, NamedType, Offset}, }; @@ -67,20 +67,20 @@ impl<'a, O: NamedType + Offset> Context for ListDeserializer<'a, O> { impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for ListDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { + if self.peek_next().ctx(self)? { self.deserialize_seq(visitor) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { + if self.peek_next().ctx(self)? { visitor.visit_some(Mut(self)) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } @@ -97,7 +97,7 @@ impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for ListDeserializer<'a, } } -impl<'de, O: Offset> SeqAccess<'de> for ListDeserializer<'de, O> { +impl<'de, O: NamedType + Offset> SeqAccess<'de> for ListDeserializer<'de, O> { type Error = Error; fn next_element_seed>( @@ -108,8 +108,8 @@ impl<'de, O: Offset> SeqAccess<'de> for ListDeserializer<'de, O> { if item + 1 >= self.offsets.len() { return Ok(None); } - let end = self.offsets[item + 1].try_into_usize()?; - let start = self.offsets[item].try_into_usize()?; + let end = self.offsets[item + 1].try_into_usize().ctx(self)?; + let start = self.offsets[item].try_into_usize().ctx(self)?; if offset >= end - start { self.next = (item + 1, 0); diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 69fdd230..3c877169 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{DeserializeSeed, MapAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Error, Result}, utils::{btree_map, Mut}, }; @@ -65,20 +65,20 @@ impl<'de> Context for MapDeserializer<'de> { impl<'de> SimpleDeserializer<'de> for MapDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { + if self.peek_next().ctx(self)? { self.deserialize_map(visitor) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { + if self.peek_next().ctx(self)? { visitor.visit_some(Mut(self)) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } @@ -96,10 +96,10 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { ) -> Result, Self::Error> { let (item, entry) = self.next; if item + 1 >= self.offsets.len() { - fail!("Exhausted MapDeserializer"); + fail!(in self, "Exhausted MapDeserializer"); } - let start: usize = self.offsets[item].try_into()?; - let end: usize = self.offsets[item + 1].try_into()?; + let start: usize = self.offsets[item].try_into().ctx(self)?; + let end: usize = self.offsets[item + 1].try_into().ctx(self)?; if entry >= (end - start) { self.next = (item + 1, 0); diff --git a/serde_arrow/src/internal/deserialization/null_deserializer.rs b/serde_arrow/src/internal/deserialization/null_deserializer.rs index 3f0a300d..c07054d0 100644 --- a/serde_arrow/src/internal/deserialization/null_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/null_deserializer.rs @@ -1,7 +1,7 @@ use serde::de::Visitor; use crate::internal::{ - error::{Context, Result}, + error::{Context, ContextSupport, Error, Result}, utils::btree_map, }; @@ -25,15 +25,15 @@ impl Context for NullDeserializer { impl<'de> SimpleDeserializer<'de> for NullDeserializer { fn deserialize_any>(&mut self, visitor: V) -> Result { - visitor.visit_unit() + visitor.visit_unit::().ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - visitor.visit_none() + visitor.visit_none::().ctx(self) } fn deserialize_unit>(&mut self, visitor: V) -> Result { - visitor.visit_unit() + visitor.visit_unit::().ctx(self) } fn deserialize_unit_struct>( @@ -41,6 +41,6 @@ impl<'de> SimpleDeserializer<'de> for NullDeserializer { _: &'static str, visitor: V, ) -> Result { - visitor.visit_unit() + visitor.visit_unit::().ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index a5c46336..1f77be51 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -1,6 +1,6 @@ use crate::internal::{ arrow::BytesArrayView, - error::{fail, Context, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{btree_map, Mut, NamedType, Offset}, }; @@ -81,6 +81,34 @@ impl<'a, O: NamedType + Offset> Context for StringDeserializer<'a, O> { impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for StringDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { + self.deserialize_any_impl(visitor).ctx(self) + } + + fn deserialize_option>(&mut self, visitor: V) -> Result { + self.deserialize_option_impl(visitor).ctx(self) + } + + fn deserialize_str>(&mut self, visitor: V) -> Result { + self.deserialize_str_impl(visitor).ctx(self) + } + + fn deserialize_string>(&mut self, visitor: V) -> Result { + self.deserialize_string_impl(visitor).ctx(self) + } + + fn deserialize_enum>( + &mut self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result { + self.deserialize_enum_impl(name, variants, visitor) + .ctx(self) + } +} + +impl<'a, O: NamedType + Offset> StringDeserializer<'a, O> { + fn deserialize_any_impl>(&mut self, visitor: V) -> Result { if self.peek_next()? { self.deserialize_str(visitor) } else { @@ -89,7 +117,10 @@ impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for StringDeserializer<'a } } - fn deserialize_option>(&mut self, visitor: V) -> Result { + fn deserialize_option_impl>( + &mut self, + visitor: V, + ) -> Result { if self.peek_next()? { visitor.visit_some(Mut(self)) } else { @@ -98,15 +129,18 @@ impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for StringDeserializer<'a } } - fn deserialize_str>(&mut self, visitor: V) -> Result { + fn deserialize_str_impl>(&mut self, visitor: V) -> Result { visitor.visit_borrowed_str(self.next_required()?) } - fn deserialize_string>(&mut self, visitor: V) -> Result { + fn deserialize_string_impl>( + &mut self, + visitor: V, + ) -> Result { visitor.visit_string(self.next_required()?.to_owned()) } - fn deserialize_enum>( + fn deserialize_enum_impl>( &mut self, _: &'static str, _: &'static [&'static str], diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index edb48ece..06c53a6a 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -3,7 +3,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::{TimeArrayView, TimeUnit}, - error::{fail, Context, Result}, + error::{fail, Context, ContextSupport, Result}, utils::{btree_map, Mut, NamedType}, }; @@ -60,6 +60,32 @@ impl<'de, T: NamedType + Integer> Context for TimeDeserializer<'de, T> { impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for TimeDeserializer<'de, T> { fn deserialize_any>(&mut self, visitor: V) -> Result { + self.deserialize_any_impl(visitor).ctx(self) + } + + fn deserialize_option>(&mut self, visitor: V) -> Result { + self.deserialize_option_impl(visitor).ctx(self) + } + + fn deserialize_i32>(&mut self, visitor: V) -> Result { + self.deserialize_i32_impl(visitor).ctx(self) + } + + fn deserialize_i64>(&mut self, visitor: V) -> Result { + self.deserialize_i64_impl(visitor).ctx(self) + } + + fn deserialize_str>(&mut self, visitor: V) -> Result { + self.deserialize_str_impl(visitor).ctx(self) + } + + fn deserialize_string>(&mut self, visitor: V) -> Result { + self.deserialize_string_impl(visitor).ctx(self) + } +} + +impl<'de, T: NamedType + Integer> TimeDeserializer<'de, T> { + fn deserialize_any_impl>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { T::deserialize_any(self, visitor) } else { @@ -68,7 +94,7 @@ impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for TimeDeserializer<' } } - fn deserialize_option>(&mut self, visitor: V) -> Result { + fn deserialize_option_impl>(&mut self, visitor: V) -> Result { if self.array.peek_next()? { visitor.visit_some(Mut(self)) } else { @@ -77,19 +103,19 @@ impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for TimeDeserializer<' } } - fn deserialize_i32>(&mut self, visitor: V) -> Result { + fn deserialize_i32_impl>(&mut self, visitor: V) -> Result { visitor.visit_i32(self.array.next_required()?.into_i32()?) } - fn deserialize_i64>(&mut self, visitor: V) -> Result { + fn deserialize_i64_impl>(&mut self, visitor: V) -> Result { visitor.visit_i64(self.array.next_required()?.into_i64()?) } - fn deserialize_str>(&mut self, visitor: V) -> Result { + fn deserialize_str_impl>(&mut self, visitor: V) -> Result { self.deserialize_string(visitor) } - fn deserialize_string>(&mut self, visitor: V) -> Result { + fn deserialize_string_impl>(&mut self, visitor: V) -> Result { let ts = self.array.next_required()?.into_i64()?; visitor.visit_string(self.get_string_repr(ts)?) } From acb71c2b70122aa21146dc14cb997d4186781c6e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 9 Sep 2024 18:53:10 +0200 Subject: [PATCH 152/178] Rename path => field, impl binary, bool deserializer --- .../deserialization/binary_deserializer.rs | 26 +++--- .../deserialization/bool_deserializer.rs | 88 +++++++++++++++---- .../deserialization/date32_deserializer.rs | 2 +- .../deserialization/date64_deserializer.rs | 2 +- .../deserialization/decimal_deserializer.rs | 2 +- .../dictionary_deserializer.rs | 2 +- .../deserialization/enum_deserializer.rs | 2 +- .../fixed_size_binary_deserializer.rs | 2 +- .../fixed_size_list_deserializer.rs | 2 +- .../deserialization/float_deserializer.rs | 2 +- .../deserialization/integer_deserializer.rs | 2 +- .../deserialization/list_deserializer.rs | 2 +- .../deserialization/map_deserializer.rs | 2 +- .../deserialization/null_deserializer.rs | 2 +- .../deserialization/string_deserializer.rs | 2 +- .../deserialization/struct_deserializer.rs | 2 +- .../deserialization/time_deserializer.rs | 2 +- serde_arrow/src/internal/deserializer.rs | 8 +- .../serialization/outer_sequence_builder.rs | 2 +- .../src/test/error_messages/deserializers.rs | 64 ++++++++++++++ serde_arrow/src/test/error_messages/mod.rs | 1 + 21 files changed, 172 insertions(+), 47 deletions(-) create mode 100644 serde_arrow/src/test/error_messages/deserializers.rs diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index f637dbd1..9ce0ec64 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BytesArrayView, - error::{fail, Context, Error, Result}, + error::{fail, Context, ContextSupport, Error, Result}, utils::{btree_map, Mut, NamedType, Offset}, }; @@ -63,39 +63,43 @@ impl<'a, O: Offset + NamedType> Context for BinaryDeserializer<'a, O> { "i64" => "LargeBinary", _ => "", }; - btree_map!("path" => self.path.clone(), "data_type" => data_type) + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } impl<'a, O: Offset + NamedType> SimpleDeserializer<'a> for BinaryDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - self.deserialize_bytes(visitor) + if self.peek_next().ctx(self)? { + self.deserialize_bytes(visitor).ctx(self) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - visitor.visit_some(Mut(self)) + if self.peek_next().ctx(self)? { + visitor.visit_some(Mut(self)).ctx(self) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } fn deserialize_seq>(&mut self, visitor: V) -> Result { - visitor.visit_seq(self) + visitor.visit_seq(&mut *self).ctx(self) } fn deserialize_bytes>(&mut self, visitor: V) -> Result { - visitor.visit_borrowed_bytes(self.next_slice()?) + visitor + .visit_borrowed_bytes::(self.next_slice().ctx(self)?) + .ctx(self) } fn deserialize_byte_buf>(&mut self, visitor: V) -> Result { - visitor.visit_borrowed_bytes(self.next_slice()?) + visitor + .visit_borrowed_bytes::(self.next_slice().ctx(self)?) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index 00ec4ef5..cd2070f5 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::BooleanArrayView, - error::{fail, Context, Result}, + error::{fail, Context, ContextSupport, Error, Result}, utils::{btree_map, Mut}, }; @@ -25,7 +25,7 @@ impl<'a> BoolDeserializer<'a> { fn next(&mut self) -> Result> { if self.next >= self.view.len { - fail!("Exhausted BoolDeserializer"); + fail!("Exhausted Deserializer"); } if let Some(validty) = &self.view.validity { if !bitset_is_set(validty, self.next)? { @@ -49,7 +49,7 @@ impl<'a> BoolDeserializer<'a> { fn peek_next(&self) -> Result { if self.next >= self.view.len { - fail!("Exhausted BoolDeserializer"); + fail!("Exhausted Deserializer"); } else if let Some(validity) = &self.view.validity { bitset_is_set(validity, self.next) } else { @@ -64,62 +64,112 @@ impl<'a> BoolDeserializer<'a> { impl<'de> Context for BoolDeserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Boolean") + btree_map!("field" => self.path.clone(), "data_type" => "Boolean") } } impl<'de> SimpleDeserializer<'de> for BoolDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - self.deserialize_bool(visitor) + if self.peek_next().ctx(self)? { + self.deserialize_bool(visitor).ctx(self) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - visitor.visit_some(Mut(self)) + if self.peek_next().ctx(self)? { + visitor.visit_some(Mut(self)).ctx(self) } else { self.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } fn deserialize_bool>(&mut self, visitor: V) -> Result { - visitor.visit_bool(self.next_required()?) + visitor + .visit_bool::(self.next_required().ctx(self)?) + .ctx(self) } fn deserialize_u8>(&mut self, visitor: V) -> Result { - visitor.visit_u8(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_u8::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_u16>(&mut self, visitor: V) -> Result { - visitor.visit_u16(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_u16::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_u32>(&mut self, visitor: V) -> Result { - visitor.visit_u32(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_u32::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_u64>(&mut self, visitor: V) -> Result { - visitor.visit_u64(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_u64::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_i8>(&mut self, visitor: V) -> Result { - visitor.visit_i8(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_i8::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_i16>(&mut self, visitor: V) -> Result { - visitor.visit_i16(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_i16::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_i32>(&mut self, visitor: V) -> Result { - visitor.visit_i32(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_i32::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - visitor.visit_i64(if self.next_required()? { 1 } else { 0 }) + visitor + .visit_i64::(if self.next_required().ctx(self)? { + 1 + } else { + 0 + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index 9aba5634..15dcb399 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -33,7 +33,7 @@ impl<'a> Date32Deserializer<'a> { impl<'de> Context for Date32Deserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Date32") + btree_map!("field" => self.path.clone(), "data_type" => "Date32") } } diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index f184134e..50193df9 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -54,7 +54,7 @@ impl<'a> Date64Deserializer<'a> { impl<'de> Context for Date64Deserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Date64") + btree_map!("field" => self.path.clone(), "data_type" => "Date64") } } diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index ad052128..4ced6d94 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -26,7 +26,7 @@ impl<'a> DecimalDeserializer<'a> { impl<'de> Context for DecimalDeserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Decimal128(..)") + btree_map!("field" => self.path.clone(), "data_type" => "Decimal128(..)") } } diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index 20e99acc..f61114b0 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -55,7 +55,7 @@ impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { impl<'de, K: Integer, V: Offset> Context for DictionaryDeserializer<'de, K, V> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Dictionary(..)") + btree_map!("field" => self.path.clone(), "data_type" => "Dictionary(..)") } } diff --git a/serde_arrow/src/internal/deserialization/enum_deserializer.rs b/serde_arrow/src/internal/deserialization/enum_deserializer.rs index 8b819096..fef1f3dc 100644 --- a/serde_arrow/src/internal/deserialization/enum_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/enum_deserializer.rs @@ -31,7 +31,7 @@ impl<'a> EnumDeserializer<'a> { impl<'de> Context for EnumDeserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Union(..)") + btree_map!("field" => self.path.clone(), "data_type" => "Union(..)") } } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index fa9f6fc8..f4504858 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -66,7 +66,7 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { impl<'a> Context for FixedSizeBinaryDeserializer<'a> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "FixedSizeBinary(..)") + btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeBinary(..)") } } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index 3f10c2e4..f0795cc9 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -59,7 +59,7 @@ impl<'a> FixedSizeListDeserializer<'a> { impl<'a> Context for FixedSizeListDeserializer<'a> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "FixedSizeList(..)") + btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeList(..)") } } diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 765058be..64c6af42 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -40,7 +40,7 @@ impl<'de, F: NamedType + Float> Context for FloatDeserializer<'de, F> { "f64" => "Float64", _ => "", }; - btree_map!("path" => self.path.clone(), "data_type" => data_type) + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index 159c50f3..b820cac2 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -54,7 +54,7 @@ impl<'de, T: NamedType + Integer> Context for IntegerDeserializer<'de, T> { "u64" => "UInt64", _ => "", }; - btree_map!("path" => self.path.clone(), "data_type" => data_type) + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 8bc7c292..59b3a23f 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -61,7 +61,7 @@ impl<'a, O: NamedType + Offset> Context for ListDeserializer<'a, O> { "i64" => "LargeList(..)", _ => "", }; - btree_map!("path" => self.path.clone(), "data_type" => data_type) + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 3c877169..0d3ca4e9 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -59,7 +59,7 @@ impl<'a> MapDeserializer<'a> { impl<'de> Context for MapDeserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Map(..)") + btree_map!("field" => self.path.clone(), "data_type" => "Map(..)") } } diff --git a/serde_arrow/src/internal/deserialization/null_deserializer.rs b/serde_arrow/src/internal/deserialization/null_deserializer.rs index c07054d0..58b7fdab 100644 --- a/serde_arrow/src/internal/deserialization/null_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/null_deserializer.rs @@ -19,7 +19,7 @@ impl NullDeserializer { impl Context for NullDeserializer { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Null") + btree_map!("field" => self.path.clone(), "data_type" => "Null") } } diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index 1f77be51..92468d37 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -75,7 +75,7 @@ impl<'a, O: NamedType + Offset> Context for StringDeserializer<'a, O> { "i64" => "LargeUtf8", _ => "", }; - btree_map!("path" => self.path.clone(), "data_type" => data_type) + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index fe2db8ed..080ae4fb 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -55,7 +55,7 @@ impl<'a> StructDeserializer<'a> { impl<'de> Context for StructDeserializer<'de> { fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("path" => self.path.clone(), "data_type" => "Struct(..)") + btree_map!("field" => self.path.clone(), "data_type" => "Struct(..)") } } diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index 06c53a6a..29516648 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -54,7 +54,7 @@ impl<'de, T: NamedType + Integer> Context for TimeDeserializer<'de, T> { "i64" => "Time64", _ => "", }; - btree_map!("path" => self.path.clone(), "data_type" => data_type) + btree_map!("field" => self.path.clone(), "data_type" => data_type) } } diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index bcfd5f47..a5642df4 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -11,6 +11,8 @@ use crate::internal::{ utils::array_view_ext::ArrayViewExt, }; +use super::utils::ChildName; + /// A structure to deserialize Arrow arrays into Rust objects /// #[cfg_attr(any(has_arrow, has_arrow2), doc = r"It can be constructed via")] @@ -33,7 +35,11 @@ impl<'de> Deserializer<'de> { fail!("Cannot deserialize from arrays with different lengths"); } let strategy = get_strategy_from_metadata(&field.metadata)?; - let deserializer = ArrayDeserializer::new(String::from("$"), strategy.as_ref(), view)?; + let deserializer = ArrayDeserializer::new( + format!("$.{child}", child = ChildName(&field.name)), + strategy.as_ref(), + view, + )?; deserializers.push((field.name.clone(), deserializer)); } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 3cafa254..30ec4e73 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -123,7 +123,7 @@ fn build_struct(path: String, struct_fields: &[Field], nullable: bool) -> Result fn build_builder(path: String, field: &Field) -> Result { use {ArrayBuilder as A, DataType as T}; - let ctx: BTreeMap = btree_map!("path" => path.clone()); + let ctx: BTreeMap = btree_map!("field" => path.clone()); let builder = match &field.data_type { T::Null => match get_strategy_from_metadata(&field.metadata)? { diff --git a/serde_arrow/src/test/error_messages/deserializers.rs b/serde_arrow/src/test/error_messages/deserializers.rs new file mode 100644 index 00000000..6fbd3058 --- /dev/null +++ b/serde_arrow/src/test/error_messages/deserializers.rs @@ -0,0 +1,64 @@ +use std::collections::HashMap; + +use serde::Deserialize; +use serde_json::json; + +use crate::{ + internal::{ + arrow::{ArrayView, BitsWithOffset, BooleanArrayView, FieldMeta, StructArrayView}, + testing::assert_error_contains, + }, + schema::{SchemaLike, SerdeArrowSchema}, + Deserializer, +}; + +#[test] +fn example_exhausted() { + let views = vec![ArrayView::Struct(StructArrayView { + len: 5, + validity: None, + fields: vec![( + ArrayView::Boolean(BooleanArrayView { + len: 2, + validity: None, + values: BitsWithOffset { + data: &[0b_0001_0011], + offset: 0, + }, + }), + FieldMeta { + name: String::from("nested"), + nullable: false, + metadata: HashMap::new(), + }, + )], + })]; + + let schema = SerdeArrowSchema::from_value(&json!([{ + "name": "item", + "data_type": "Struct", + "children": [ + {"name": "nested", "data_type": "Bool"}, + ], + }])) + .unwrap(); + + let deserializer = Deserializer::new(&schema.fields, views).unwrap(); + + #[derive(Deserialize)] + struct S { + #[allow(dead_code)] + item: Nested, + } + + #[derive(Deserialize)] + struct Nested { + #[allow(dead_code)] + nested: bool, + } + + let res = Vec::::deserialize(deserializer); + assert_error_contains(&res, "Exhausted Deserializer"); + assert_error_contains(&res, "field: \"$.item.nested\""); + assert_error_contains(&res, "data_type: \"Boolean\""); +} diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index 7730e5aa..d1c9cca5 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1 +1,2 @@ +mod deserializers; mod push_validity; From c60958c1daec95b35f98f9f258824541ba8adc18 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 9 Sep 2024 19:06:28 +0200 Subject: [PATCH 153/178] Add try_ function to simplify error handling code --- .../deserialization/binary_deserializer.rs | 42 ++++---- .../deserialization/bool_deserializer.rs | 100 +++++------------- .../deserialization/date32_deserializer.rs | 4 +- serde_arrow/src/internal/error.rs | 13 ++- 4 files changed, 64 insertions(+), 95 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index 9ce0ec64..acab4d0d 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BytesArrayView, - error::{fail, Context, ContextSupport, Error, Result}, + error::{fail, try_, Context, ContextSupport, Error, Result}, utils::{btree_map, Mut, NamedType, Offset}, }; @@ -69,37 +69,39 @@ impl<'a, O: Offset + NamedType> Context for BinaryDeserializer<'a, O> { impl<'a, O: Offset + NamedType> SimpleDeserializer<'a> for BinaryDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - self.deserialize_bytes(visitor).ctx(self) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next().ctx(self)? { + self.deserialize_bytes(visitor).ctx(self) + } else { + self.consume_next(); + visitor.visit_none::().ctx(self) + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - visitor.visit_some(Mut(self)).ctx(self) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next().ctx(self)? { + visitor.visit_some(Mut(self)).ctx(self) + } else { + self.consume_next(); + visitor.visit_none::().ctx(self) + } + }) + .ctx(self) } fn deserialize_seq>(&mut self, visitor: V) -> Result { - visitor.visit_seq(&mut *self).ctx(self) + try_(|| visitor.visit_seq(&mut *self)).ctx(self) } fn deserialize_bytes>(&mut self, visitor: V) -> Result { - visitor - .visit_borrowed_bytes::(self.next_slice().ctx(self)?) - .ctx(self) + try_(|| visitor.visit_borrowed_bytes::(self.next_slice()?)).ctx(self) } fn deserialize_byte_buf>(&mut self, visitor: V) -> Result { - visitor - .visit_borrowed_bytes::(self.next_slice().ctx(self)?) - .ctx(self) + try_(|| visitor.visit_borrowed_bytes::(self.next_slice()?)).ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index cd2070f5..0ecde3a0 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::BooleanArrayView, - error::{fail, Context, ContextSupport, Error, Result}, + error::{fail, try_, Context, ContextSupport, Error, Result}, utils::{btree_map, Mut}, }; @@ -70,106 +70,62 @@ impl<'de> Context for BoolDeserializer<'de> { impl<'de> SimpleDeserializer<'de> for BoolDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - self.deserialize_bool(visitor).ctx(self) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next()? { + self.deserialize_bool(visitor) + } else { + self.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - visitor.visit_some(Mut(self)).ctx(self) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next()? { + visitor.visit_some(Mut(self)) + } else { + self.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_bool>(&mut self, visitor: V) -> Result { - visitor - .visit_bool::(self.next_required().ctx(self)?) - .ctx(self) + try_(|| visitor.visit_bool::(self.next_required()?)).ctx(self) } fn deserialize_u8>(&mut self, visitor: V) -> Result { - visitor - .visit_u8::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_u8::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_u16>(&mut self, visitor: V) -> Result { - visitor - .visit_u16::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_u16::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_u32>(&mut self, visitor: V) -> Result { - visitor - .visit_u32::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_u32::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_u64>(&mut self, visitor: V) -> Result { - visitor - .visit_u64::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_u64::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_i8>(&mut self, visitor: V) -> Result { - visitor - .visit_i8::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_i8::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_i16>(&mut self, visitor: V) -> Result { - visitor - .visit_i16::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_i16::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_i32>(&mut self, visitor: V) -> Result { - visitor - .visit_i32::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_i32::(if self.next_required()? { 1 } else { 0 })).ctx(self) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - visitor - .visit_i64::(if self.next_required().ctx(self)? { - 1 - } else { - 0 - }) - .ctx(self) + try_(|| visitor.visit_i64::(if self.next_required()? { 1 } else { 0 })).ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index 15dcb399..edad92d0 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -3,7 +3,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::BitsWithOffset, - error::{Context, Result}, + error::{Context, ContextSupport, Error, Result}, utils::{btree_map, Mut}, }; @@ -52,7 +52,7 @@ impl<'de> SimpleDeserializer<'de> for Date32Deserializer<'de> { visitor.visit_some(Mut(self)) } else { self.array.consume_next(); - visitor.visit_none() + visitor.visit_none::().ctx(self) } } diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 4cc29514..5020550b 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -4,6 +4,13 @@ use std::{ convert::Infallible, }; +/// Execute a faillible function and return the result +/// +/// This function is mostly useful to add annotations to a complex block of operations +pub fn try_(func: impl FnOnce() -> Result) -> Result { + func() +} + /// An object that offers additional context to an error pub trait Context { fn annotations(&self) -> BTreeMap; @@ -88,7 +95,11 @@ impl Error { impl Error { pub(crate) fn with_annotations(self, annotations: BTreeMap) -> Self { let Self::Custom(mut this) = self; - this.0.annotations = annotations; + for (k, v) in annotations { + if !this.0.annotations.contains_key(&k) { + this.0.annotations.insert(k, v); + } + } Self::Custom(this) } } From 19c3fdf0255488e52e2b73559ccf8b8398417440 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 9 Sep 2024 20:18:11 +0200 Subject: [PATCH 154/178] Mutate map annotations directly --- .../deserialization/array_deserializer.rs | 4 +- .../deserialization/binary_deserializer.rs | 26 ++++++------ .../deserialization/bool_deserializer.rs | 9 ++-- .../deserialization/date32_deserializer.rs | 9 ++-- .../deserialization/date64_deserializer.rs | 9 ++-- .../deserialization/decimal_deserializer.rs | 9 ++-- .../dictionary_deserializer.rs | 9 ++-- .../deserialization/enum_deserializer.rs | 9 ++-- .../fixed_size_binary_deserializer.rs | 13 +++--- .../fixed_size_list_deserializer.rs | 9 ++-- .../deserialization/float_deserializer.rs | 24 ++++++----- .../deserialization/integer_deserializer.rs | 34 ++++++++------- .../deserialization/list_deserializer.rs | 22 ++++++---- .../deserialization/map_deserializer.rs | 9 ++-- .../deserialization/null_deserializer.rs | 10 ++--- .../outer_sequence_deserializer.rs | 6 +-- .../deserialization/string_deserializer.rs | 22 ++++++---- .../deserialization/struct_deserializer.rs | 9 ++-- .../deserialization/time_deserializer.rs | 22 ++++++---- serde_arrow/src/internal/error.rs | 41 +++++++++++-------- .../internal/serialization/array_builder.rs | 4 +- .../internal/serialization/binary_builder.rs | 26 ++++++------ .../internal/serialization/bool_builder.rs | 12 +++--- .../internal/serialization/date32_builder.rs | 12 +++--- .../internal/serialization/date64_builder.rs | 25 +++++------ .../internal/serialization/decimal_builder.rs | 8 ++-- .../serialization/dictionary_utf8_builder.rs | 9 ++-- .../serialization/duration_builder.rs | 12 +++--- .../fixed_size_binary_builder.rs | 13 +++--- .../serialization/fixed_size_list_builder.rs | 9 ++-- .../internal/serialization/float_builder.rs | 19 +++++---- .../src/internal/serialization/int_builder.rs | 34 ++++++++------- .../internal/serialization/list_builder.rs | 22 ++++++---- .../src/internal/serialization/map_builder.rs | 12 +++--- .../internal/serialization/null_builder.rs | 8 ++-- .../serialization/outer_sequence_builder.rs | 4 +- .../internal/serialization/struct_builder.rs | 13 +++--- .../internal/serialization/time_builder.rs | 22 ++++++---- .../internal/serialization/union_builder.rs | 25 +++++++---- .../serialization/unknown_variant_builder.rs | 8 ++-- .../internal/serialization/utf8_builder.rs | 23 ++++++----- 41 files changed, 339 insertions(+), 286 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index c343a475..ce5b94e7 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -374,8 +374,8 @@ macro_rules! dispatch { } impl<'de> Context for ArrayDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - dispatch!(self, ArrayDeserializer(deser) => deser.annotations()) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + dispatch!(self, ArrayDeserializer(deser) => deser.annotate(annotations)) } } diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index acab4d0d..9d463ef2 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BytesArrayView, - error::{fail, try_, Context, ContextSupport, Error, Result}, - utils::{btree_map, Mut, NamedType, Offset}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, + utils::{Mut, NamedType, Offset}, }; use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; @@ -57,13 +57,17 @@ impl<'a, O: Offset> BinaryDeserializer<'a, O> { } impl<'a, O: Offset + NamedType> Context for BinaryDeserializer<'a, O> { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match O::NAME { - "i32" => "Binary", - "i64" => "LargeBinary", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match O::NAME { + "i32" => "Binary", + "i64" => "LargeBinary", + _ => "", + }, + ); } } @@ -130,9 +134,7 @@ impl<'de, O: Offset> SeqAccess<'de> for BinaryDeserializer<'de, O> { struct U8Deserializer(u8); impl Context for U8Deserializer { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!() - } + fn annotate(&self, _: &mut std::collections::BTreeMap) {} } impl<'de> SimpleDeserializer<'de> for U8Deserializer { diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index 0ecde3a0..010c53b4 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::BooleanArrayView, - error::{fail, try_, Context, ContextSupport, Error, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, + utils::Mut, }; use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; @@ -63,8 +63,9 @@ impl<'a> BoolDeserializer<'a> { } impl<'de> Context for BoolDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Boolean") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Boolean"); } } diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index edad92d0..de71602d 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -3,8 +3,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::BitsWithOffset, - error::{Context, ContextSupport, Error, Result}, - utils::{btree_map, Mut}, + error::{set_default, Context, ContextSupport, Error, Result}, + utils::Mut, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -32,8 +32,9 @@ impl<'a> Date32Deserializer<'a> { } impl<'de> Context for Date32Deserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Date32") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Date32"); } } diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index 50193df9..dd1761ac 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -3,8 +3,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::{BitsWithOffset, TimeUnit}, - error::{fail, Context, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, Result}, + utils::Mut, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -53,8 +53,9 @@ impl<'a> Date64Deserializer<'a> { } impl<'de> Context for Date64Deserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Date64") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Date64"); } } diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index 4ced6d94..d3722e41 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::DecimalArrayView, - error::{Context, Result}, - utils::{btree_map, decimal, Mut}, + error::{set_default, Context, Result}, + utils::{decimal, Mut}, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -25,8 +25,9 @@ impl<'a> DecimalDeserializer<'a> { } impl<'de> Context for DecimalDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Decimal128(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Decimal128(..)"); } } diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index f61114b0..027579ac 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::{BytesArrayView, PrimitiveArrayView}, - error::{fail, Context, Result}, - utils::{btree_map, Mut, Offset}, + error::{fail, set_default, Context, Result}, + utils::{Mut, Offset}, }; use super::{ @@ -54,8 +54,9 @@ impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { } impl<'de, K: Integer, V: Offset> Context for DictionaryDeserializer<'de, K, V> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Dictionary(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Dictionary(..)"); } } diff --git a/serde_arrow/src/internal/deserialization/enum_deserializer.rs b/serde_arrow/src/internal/deserialization/enum_deserializer.rs index fef1f3dc..6473696d 100644 --- a/serde_arrow/src/internal/deserialization/enum_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/enum_deserializer.rs @@ -1,8 +1,8 @@ use serde::de::{DeserializeSeed, Deserializer, EnumAccess, Visitor}; use crate::internal::{ - error::{fail, Context, Error, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, Error, Result}, + utils::Mut, }; use super::{array_deserializer::ArrayDeserializer, simple_deserializer::SimpleDeserializer}; @@ -30,8 +30,9 @@ impl<'a> EnumDeserializer<'a> { } impl<'de> Context for EnumDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Union(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Union(..)"); } } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index f4504858..82f12778 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::FixedSizeBinaryArrayView, - error::{fail, Context, Error, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, Error, Result}, + utils::Mut, }; use super::{simple_deserializer::SimpleDeserializer, utils::bitset_is_set}; @@ -65,8 +65,9 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { } impl<'a> Context for FixedSizeBinaryDeserializer<'a> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeBinary(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "FixedSizeBinary(..)"); } } @@ -127,9 +128,7 @@ impl<'de> SeqAccess<'de> for FixedSizeBinaryDeserializer<'de> { struct U8Deserializer(u8); impl Context for U8Deserializer { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!() - } + fn annotate(&self, _: &mut std::collections::BTreeMap) {} } impl<'de> SimpleDeserializer<'de> for U8Deserializer { diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index f0795cc9..d6829810 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{IgnoredAny, SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Context, Error, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, Error, Result}, + utils::Mut, }; use super::{ @@ -58,8 +58,9 @@ impl<'a> FixedSizeListDeserializer<'a> { } impl<'a> Context for FixedSizeListDeserializer<'a> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeList(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "FixedSizeList(..)"); } } diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 64c6af42..9d7731a1 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::PrimitiveArrayView, - error::{Context, ContextSupport, Result}, - utils::{btree_map, Mut, NamedType}, + error::{set_default, Context, ContextSupport, Result}, + utils::{Mut, NamedType}, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -33,14 +33,18 @@ impl<'a, F: Float> FloatDeserializer<'a, F> { } impl<'de, F: NamedType + Float> Context for FloatDeserializer<'de, F> { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match F::NAME { - "f16" => "Float16", - "f32" => "Float32", - "f64" => "Float64", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match F::NAME { + "f16" => "Float16", + "f32" => "Float32", + "f64" => "Float64", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index b820cac2..5dda5431 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::PrimitiveArrayView, - error::{Context, ContextSupport, Result}, - utils::{btree_map, Mut, NamedType}, + error::{set_default, Context, ContextSupport, Result}, + utils::{Mut, NamedType}, }; use super::{simple_deserializer::SimpleDeserializer, utils::ArrayBufferIterator}; @@ -42,19 +42,23 @@ impl<'a, T: Integer> IntegerDeserializer<'a, T> { } impl<'de, T: NamedType + Integer> Context for IntegerDeserializer<'de, T> { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match T::NAME { - "i8" => "Int8", - "i16" => "Int16", - "i32" => "Int32", - "i64" => "Int64", - "u8" => "UInt8", - "u16" => "UInt16", - "u32" => "UInt32", - "u64" => "UInt64", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match T::NAME { + "i8" => "Int8", + "i16" => "Int16", + "i32" => "Int32", + "i64" => "Int64", + "u8" => "UInt8", + "u16" => "UInt16", + "u32" => "UInt32", + "u64" => "UInt64", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 59b3a23f..00f06e5e 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Context, ContextSupport, Error, Result}, - utils::{btree_map, Mut, NamedType, Offset}, + error::{fail, set_default, Context, ContextSupport, Error, Result}, + utils::{Mut, NamedType, Offset}, }; use super::{ @@ -55,13 +55,17 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { } impl<'a, O: NamedType + Offset> Context for ListDeserializer<'a, O> { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match O::NAME { - "i32" => "List(..)", - "i64" => "LargeList(..)", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "filed", &self.path); + set_default( + annotations, + "data_type", + match O::NAME { + "i32" => "List(..)", + "i64" => "LargeList(..)", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 0d3ca4e9..3941b027 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -2,8 +2,8 @@ use serde::de::{DeserializeSeed, MapAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Context, ContextSupport, Error, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, ContextSupport, Error, Result}, + utils::Mut, }; use super::{ @@ -58,8 +58,9 @@ impl<'a> MapDeserializer<'a> { } impl<'de> Context for MapDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Map(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Map(..)"); } } diff --git a/serde_arrow/src/internal/deserialization/null_deserializer.rs b/serde_arrow/src/internal/deserialization/null_deserializer.rs index 58b7fdab..5d76ab77 100644 --- a/serde_arrow/src/internal/deserialization/null_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/null_deserializer.rs @@ -1,9 +1,6 @@ use serde::de::Visitor; -use crate::internal::{ - error::{Context, ContextSupport, Error, Result}, - utils::btree_map, -}; +use crate::internal::error::{set_default, Context, ContextSupport, Error, Result}; use super::simple_deserializer::SimpleDeserializer; @@ -18,8 +15,9 @@ impl NullDeserializer { } impl Context for NullDeserializer { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Null") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Null"); } } diff --git a/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs b/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs index c4504880..24f0407f 100644 --- a/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/outer_sequence_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ error::{Context, Error, Result}, - utils::{btree_map, Mut}, + utils::Mut, }; use super::{ @@ -27,9 +27,7 @@ impl<'a> OuterSequenceDeserializer<'a> { } impl<'de> Context for OuterSequenceDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!() - } + fn annotate(&self, _: &mut std::collections::BTreeMap) {} } impl<'de> SimpleDeserializer<'de> for OuterSequenceDeserializer<'de> { diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index 92468d37..134173a8 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -1,7 +1,7 @@ use crate::internal::{ arrow::BytesArrayView, - error::{fail, Context, ContextSupport, Result}, - utils::{btree_map, Mut, NamedType, Offset}, + error::{fail, set_default, Context, ContextSupport, Result}, + utils::{Mut, NamedType, Offset}, }; use super::{ @@ -69,13 +69,17 @@ impl<'a, O: Offset> StringDeserializer<'a, O> { } impl<'a, O: NamedType + Offset> Context for StringDeserializer<'a, O> { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match O::NAME { - "i32" => "Utf8", - "i64" => "LargeUtf8", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match O::NAME { + "i32" => "Utf8", + "i64" => "LargeUtf8", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index 080ae4fb..32eff2e6 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -4,8 +4,8 @@ use serde::de::{ use crate::internal::{ arrow::BitsWithOffset, - error::{fail, Context, Error, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, Error, Result}, + utils::Mut, }; use super::{ @@ -54,8 +54,9 @@ impl<'a> StructDeserializer<'a> { } impl<'de> Context for StructDeserializer<'de> { - fn annotations(&self) -> std::collections::BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Struct(..)") + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Struct(..)"); } } diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index 29516648..aa4304ac 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -3,8 +3,8 @@ use serde::de::Visitor; use crate::internal::{ arrow::{TimeArrayView, TimeUnit}, - error::{fail, Context, ContextSupport, Result}, - utils::{btree_map, Mut, NamedType}, + error::{fail, set_default, Context, ContextSupport, Result}, + utils::{Mut, NamedType}, }; use super::{ @@ -48,13 +48,17 @@ impl<'a, T: Integer> TimeDeserializer<'a, T> { } impl<'de, T: NamedType + Integer> Context for TimeDeserializer<'de, T> { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match T::NAME { - "i32" => "Time32", - "i64" => "Time64", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut std::collections::BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match T::NAME { + "i32" => "Time32", + "i64" => "Time64", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 5020550b..ec497123 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -4,6 +4,16 @@ use std::{ convert::Infallible, }; +pub fn set_default>( + annotations: &mut BTreeMap, + key: &str, + value: V, +) { + if !annotations.contains_key(key) { + annotations.insert(String::from(key), value.into()); + } +} + /// Execute a faillible function and return the result /// /// This function is mostly useful to add annotations to a complex block of operations @@ -13,12 +23,16 @@ pub fn try_(func: impl FnOnce() -> Result) -> Result { /// An object that offers additional context to an error pub trait Context { - fn annotations(&self) -> BTreeMap; + fn annotate(&self, annotations: &mut BTreeMap); } impl Context for BTreeMap { - fn annotations(&self) -> BTreeMap { - self.clone() + fn annotate(&self, annotations: &mut BTreeMap) { + for (k, v) in self { + if !annotations.contains_key(k) { + annotations.insert(k.to_owned(), v.to_owned()); + } + } } } @@ -44,7 +58,9 @@ impl> ContextSupport for E { type Output = Error; fn ctx(self, context: &C) -> Self::Output { - self.into().with_annotations(context.annotations()) + let Error::Custom(mut error) = self.into(); + context.annotate(&mut error.0.annotations); + Error::Custom(error) } } @@ -92,18 +108,6 @@ impl Error { } } -impl Error { - pub(crate) fn with_annotations(self, annotations: BTreeMap) -> Self { - let Self::Custom(mut this) = self; - for (k, v) in annotations { - if !this.0.annotations.contains_key(&k) { - this.0.annotations.insert(k, v); - } - } - Self::Custom(this) - } -} - /// Access information about the error impl Error { pub fn message(&self) -> &str { @@ -227,8 +231,9 @@ macro_rules! fail { { #[allow(unused)] use $crate::internal::error::Context; - let annotations = $context.annotations(); - return Err($crate::internal::error::Error::custom(format!($($tt)*)).with_annotations(annotations)) + let $crate::internal::error::Error::Custom(mut err) = $crate::internal::error::Error::custom(format!($($tt)*)); + $context.annotate(&mut err.0.annotations); + return Err($crate::internal::error::Error::Custom(err)); } }; ($($tt:tt)*) => { diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 3f9281aa..75eb6ef5 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -112,8 +112,8 @@ impl ArrayBuilder { } impl Context for ArrayBuilder { - fn annotations(&self) -> BTreeMap { - dispatch!(self, Self(builder) => builder.annotations()) + fn annotate(&self, annotations: &mut BTreeMap) { + dispatch!(self, Self(builder) => builder.annotate(annotations)) } } diff --git a/serde_arrow/src/internal/serialization/binary_builder.rs b/serde_arrow/src/internal/serialization/binary_builder.rs index 2126dea6..b5cb2ff4 100644 --- a/serde_arrow/src/internal/serialization/binary_builder.rs +++ b/serde_arrow/src/internal/serialization/binary_builder.rs @@ -4,10 +4,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, BytesArray}, - error::{Context, ContextSupport, Result}, + error::{set_default, Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt, SeqArrayExt}, - btree_map, Mut, NamedType, Offset, + Mut, NamedType, Offset, }, }; @@ -79,13 +79,17 @@ impl BinaryBuilder { } impl Context for BinaryBuilder { - fn annotations(&self) -> std::collections::BTreeMap { - let data_type = match O::NAME { - "i32" => "Binary", - "i64" => "LargeBinary", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match O::NAME { + "i32" => "Binary", + "i64" => "LargeBinary", + _ => "", + }, + ); } } @@ -142,9 +146,7 @@ impl SimpleSerializer for BinaryBuilder { struct U8Serializer(u8); impl Context for U8Serializer { - fn annotations(&self) -> BTreeMap { - Default::default() - } + fn annotate(&self, _: &mut BTreeMap) {} } impl SimpleSerializer for U8Serializer { diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 3976f155..74bcd1cf 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -2,11 +2,8 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, BooleanArray}, - error::{Context, ContextSupport, Result}, - utils::{ - array_ext::{set_bit_buffer, set_validity, set_validity_default}, - btree_map, - }, + error::{set_default, Context, ContextSupport, Result}, + utils::array_ext::{set_bit_buffer, set_validity, set_validity_default}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -50,8 +47,9 @@ impl BoolBuilder { } impl Context for BoolBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Boolean") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Boolean"); } } diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index e01a4d8a..6987ce0d 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -4,11 +4,8 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Context, ContextSupport, Result}, - utils::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, - }, + error::{set_default, Context, ContextSupport, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -54,8 +51,9 @@ impl Date32Builder { } impl Context for Date32Builder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Date32") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Date32"); } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index 7d85feb2..f06e3a68 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -2,11 +2,8 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, - error::{fail, Context, ContextSupport, Result}, - utils::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, - }, + error::{fail, set_default, Context, ContextSupport, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -94,13 +91,17 @@ impl Date64Builder { } impl Context for Date64Builder { - fn annotations(&self) -> BTreeMap { - let data_type = if self.meta.is_some() { - "Timestamp(..)" - } else { - "Date64" - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + if self.meta.is_some() { + "Timestamp(..)" + } else { + "Date64" + }, + ); } } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index d1616ed3..2e75bd3d 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -2,10 +2,9 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, - error::{Context, ContextSupport, Result}, + error::{set_default, Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, decimal::{self, DecimalParser}, }, }; @@ -63,8 +62,9 @@ impl DecimalBuilder { } impl Context for DecimalBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Decimal128(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "filed", &self.path); + set_default(annotations, "data_type", "Decimal128(..)"); } } diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index 61998668..a7af2c63 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -4,8 +4,8 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, DictionaryArray}, - error::{fail, Context, ContextSupport, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, ContextSupport, Result}, + utils::Mut, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -50,8 +50,9 @@ impl DictionaryUtf8Builder { } impl Context for DictionaryUtf8Builder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Dictionary(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Dictionary(..)"); } } diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index a1e37b09..3163d96b 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -2,11 +2,8 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{Context, ContextSupport, Result}, - utils::{ - array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, - }, + error::{set_default, Context, ContextSupport, Result}, + utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -49,8 +46,9 @@ impl DurationBuilder { } impl Context for DurationBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Duration(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Duration(..)"); } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index 32ca9b85..b4cce14c 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -4,10 +4,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, - error::{fail, Context, ContextSupport, Result}, + error::{fail, set_default, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, - btree_map, Mut, + Mut, }, }; @@ -86,8 +86,9 @@ impl FixedSizeBinaryBuilder { } impl Context for FixedSizeBinaryBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeBinary(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "FixedSizeBinary(..)"); } } @@ -163,9 +164,7 @@ impl SimpleSerializer for FixedSizeBinaryBuilder { struct U8Serializer(u8); impl Context for U8Serializer { - fn annotations(&self) -> BTreeMap { - btree_map!() - } + fn annotate(&self, _: &mut BTreeMap) {} } impl SimpleSerializer for U8Serializer { diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 0c986bd7..41adaac3 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -4,10 +4,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, - error::{fail, Context, ContextSupport, Result}, + error::{fail, set_default, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, - btree_map, Mut, + Mut, }, }; @@ -94,8 +94,9 @@ impl FixedSizeListBuilder { } impl Context for FixedSizeListBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "FixedSizeList(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "FixedSizeList(..)"); } } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 7769e8ba..8bbd0ac6 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -4,10 +4,10 @@ use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Context, ContextSupport, Result}, + error::{set_default, Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, Mut, + Mut, }, }; @@ -58,20 +58,23 @@ impl_into_array!(f32, F32, Float32); impl_into_array!(f64, F64, Float64); impl Context for FloatBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Float16") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Float16"); } } impl Context for FloatBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Float32") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Float32"); } } impl Context for FloatBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Float64") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Float64"); } } diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 7f65e4f7..3d9815c5 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -2,10 +2,10 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{Context, ContextSupport, Error, Result}, + error::{set_default, Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, NamedType, + NamedType, }, }; @@ -61,19 +61,23 @@ impl_into_array!(u32, U32, UInt32); impl_into_array!(u64, U64, UInt64); impl Context for IntBuilder { - fn annotations(&self) -> BTreeMap { - let data_type = match I::NAME { - "i8" => "Int8", - "i16" => "Int16", - "i32" => "Int32", - "i64" => "Int64", - "u8" => "UInt8", - "u16" => "UInt16", - "u32" => "UInt32", - "u64" => "UInt64", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match I::NAME { + "i8" => "Int8", + "i16" => "Int16", + "i32" => "Int32", + "i64" => "Int64", + "u8" => "UInt8", + "u16" => "UInt16", + "u32" => "UInt32", + "u64" => "UInt64", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index d7efb6d3..cff30c22 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -4,10 +4,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{Context, ContextSupport, Result}, + error::{set_default, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - btree_map, Mut, NamedType, Offset, + Mut, NamedType, Offset, }, }; @@ -92,13 +92,17 @@ impl ListBuilder { } impl Context for ListBuilder { - fn annotations(&self) -> BTreeMap { - let data_type = if O::NAME == "i32" { - "List" - } else { - "LargeList" - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + if O::NAME == "i32" { + "List" + } else { + "LargeList" + }, + ); } } diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index 09018096..b17457dd 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -4,11 +4,8 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{fail, Context, ContextSupport, Result}, - utils::{ - array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, - btree_map, - }, + error::{fail, set_default, Context, ContextSupport, Result}, + utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -71,8 +68,9 @@ impl MapBuilder { } impl Context for MapBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Map(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Map(..)"); } } diff --git a/serde_arrow/src/internal/serialization/null_builder.rs b/serde_arrow/src/internal/serialization/null_builder.rs index d8e056ac..a08c6c87 100644 --- a/serde_arrow/src/internal/serialization/null_builder.rs +++ b/serde_arrow/src/internal/serialization/null_builder.rs @@ -2,8 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, NullArray}, - error::{Context, Result}, - utils::btree_map, + error::{set_default, Context, Result}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -36,8 +35,9 @@ impl NullBuilder { } impl Context for NullBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Null") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Null"); } } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 30ec4e73..f727ef85 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -62,8 +62,8 @@ impl OuterSequenceBuilder { } impl Context for OuterSequenceBuilder { - fn annotations(&self) -> BTreeMap { - self.0.annotations() + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) } } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 4b0786ea..70cb35eb 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -4,10 +4,10 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, - error::{fail, Context, ContextSupport, Result}, + error::{fail, set_default, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, - btree_map, Mut, + Mut, }, }; @@ -123,8 +123,9 @@ impl StructBuilder { } impl Context for StructBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Struct(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Struct(..)"); } } @@ -301,9 +302,7 @@ impl<'a> KeyLookupSerializer<'a> { } impl<'a> Context for KeyLookupSerializer<'a> { - fn annotations(&self) -> BTreeMap { - btree_map!() - } + fn annotate(&self, _: &mut BTreeMap) {} } impl<'a> SimpleSerializer for KeyLookupSerializer<'a> { diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index 8728d068..e2fb2622 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -4,10 +4,10 @@ use chrono::Timelike; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{Context, ContextSupport, Error, Result}, + error::{set_default, Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, - btree_map, NamedType, + NamedType, }, }; @@ -71,13 +71,17 @@ impl TimeBuilder { } impl Context for TimeBuilder { - fn annotations(&self) -> BTreeMap { - let data_type = match I::NAME { - "i32" => "Time32", - "i64" => "Time64", - _ => "", - }; - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + match I::NAME { + "i32" => "Time32", + "i64" => "Time64", + _ => "", + }, + ); } } diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 1a9b7483..13a3c85e 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -2,8 +2,8 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DenseUnionArray, FieldMeta}, - error::{fail, Context, ContextSupport, Result}, - utils::{btree_map, Mut}, + error::{fail, set_default, Context, ContextSupport, Result}, + utils::Mut, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -76,8 +76,9 @@ impl UnionBuilder { } impl Context for UnionBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "Union(..)") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Union(..)"); } } @@ -88,7 +89,9 @@ impl SimpleSerializer for UnionBuilder { variant_index: u32, _: &'static str, ) -> Result<()> { - let ctx = self.annotations(); + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + self.serialize_variant(variant_index) .ctx(&ctx)? .serialize_unit() @@ -101,7 +104,9 @@ impl SimpleSerializer for UnionBuilder { _: &'static str, value: &V, ) -> Result<()> { - let ctx = self.annotations(); + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; value.serialize(Mut(variant_builder)) } @@ -113,7 +118,9 @@ impl SimpleSerializer for UnionBuilder { variant: &'static str, len: usize, ) -> Result<&'this mut ArrayBuilder> { - let ctx = self.annotations(); + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; variant_builder.serialize_struct_start(variant, len)?; Ok(variant_builder) @@ -126,7 +133,9 @@ impl SimpleSerializer for UnionBuilder { variant: &'static str, len: usize, ) -> Result<&'this mut ArrayBuilder> { - let ctx = self.annotations(); + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; variant_builder.serialize_tuple_struct_start(variant, len)?; Ok(variant_builder) diff --git a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs index 753cee8c..7cb164e4 100644 --- a/serde_arrow/src/internal/serialization/unknown_variant_builder.rs +++ b/serde_arrow/src/internal/serialization/unknown_variant_builder.rs @@ -4,8 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, NullArray}, - error::{fail, Context, Result}, - utils::btree_map, + error::{fail, set_default, Context, Result}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -36,8 +35,9 @@ impl UnknownVariantBuilder { } impl Context for UnknownVariantBuilder { - fn annotations(&self) -> BTreeMap { - btree_map!("field" => self.path.clone(), "data_type" => "") + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", ""); } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 17fed8ce..3f839b24 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -2,10 +2,10 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, BytesArray}, - error::{fail, Context, ContextSupport, Result}, + error::{fail, set_default, Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, - btree_map, NamedType, Offset, + NamedType, Offset, }, }; @@ -58,14 +58,17 @@ impl Utf8Builder { } impl Context for Utf8Builder { - fn annotations(&self) -> BTreeMap { - let data_type = if O::NAME == "i32" { - "Utf8" - } else { - "LargeUtf8" - }; - - btree_map!("field" => self.path.clone(), "data_type" => data_type) + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default( + annotations, + "data_type", + if O::NAME == "i32" { + "Utf8" + } else { + "LargeUtf8" + }, + ); } } From 2c6af574e1e60fbf20e3456f17832764193473f8 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 10 Sep 2024 21:17:19 +0200 Subject: [PATCH 155/178] Use try_().ctx() in most deserializers --- .../deserialization/date32_deserializer.rs | 43 +++++--- .../deserialization/date64_deserializer.rs | 43 +++++--- .../deserialization/decimal_deserializer.rs | 43 +++++--- .../dictionary_deserializer.rs | 43 +++++--- .../deserialization/enum_deserializer.rs | 9 +- .../fixed_size_binary_deserializer.rs | 38 ++++--- .../fixed_size_list_deserializer.rs | 34 +++--- .../deserialization/float_deserializer.rs | 54 ++++------ .../deserialization/integer_deserializer.rs | 102 +++++------------- .../deserialization/list_deserializer.rs | 42 ++++---- .../deserialization/map_deserializer.rs | 38 ++++--- .../deserialization/struct_deserializer.rs | 66 +++++++----- .../deserialization/time_deserializer.rs | 71 +++++------- 13 files changed, 315 insertions(+), 311 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/date32_deserializer.rs b/serde_arrow/src/internal/deserialization/date32_deserializer.rs index de71602d..f37f0b00 100644 --- a/serde_arrow/src/internal/deserialization/date32_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date32_deserializer.rs @@ -3,7 +3,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::BitsWithOffset, - error::{set_default, Context, ContextSupport, Error, Result}, + error::{set_default, try_, Context, ContextSupport, Error, Result}, utils::Mut, }; @@ -40,33 +40,42 @@ impl<'de> Context for Date32Deserializer<'de> { impl<'de> SimpleDeserializer<'de> for Date32Deserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - self.deserialize_i32(visitor) - } else { - self.array.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.array.peek_next()? { + self.deserialize_i32(visitor) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.array.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.array.peek_next()? { + visitor.visit_some(Mut(self)) + } else { + self.array.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_i32>(&mut self, visitor: V) -> Result { - visitor.visit_i32(self.array.next_required()?) + try_(|| visitor.visit_i32(self.array.next_required()?)).ctx(self) } fn deserialize_str>(&mut self, visitor: V) -> Result { - self.deserialize_string(visitor) + try_(|| self.deserialize_string(visitor)).ctx(self) } fn deserialize_string>(&mut self, visitor: V) -> Result { - let ts = self.array.next_required()?; - visitor.visit_string(self.get_string_repr(ts)?) + try_(|| { + let ts = self.array.next_required()?; + visitor.visit_string(self.get_string_repr(ts)?) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/date64_deserializer.rs b/serde_arrow/src/internal/deserialization/date64_deserializer.rs index dd1761ac..b6bc2601 100644 --- a/serde_arrow/src/internal/deserialization/date64_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/date64_deserializer.rs @@ -3,7 +3,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::{BitsWithOffset, TimeUnit}, - error::{fail, set_default, Context, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::Mut, }; @@ -61,33 +61,42 @@ impl<'de> Context for Date64Deserializer<'de> { impl<'de> SimpleDeserializer<'de> for Date64Deserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - self.deserialize_i64(visitor) - } else { - self.array.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.array.peek_next()? { + self.deserialize_i64(visitor) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.array.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.array.peek_next()? { + visitor.visit_some(Mut(self)) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - visitor.visit_i64(self.array.next_required()?) + try_(|| visitor.visit_i64(self.array.next_required()?)).ctx(self) } fn deserialize_str>(&mut self, visitor: V) -> Result { - self.deserialize_string(visitor) + try_(|| self.deserialize_string(visitor)).ctx(self) } fn deserialize_string>(&mut self, visitor: V) -> Result { - let ts = self.array.next_required()?; - visitor.visit_string(self.get_string_repr(ts)?) + try_(|| { + let ts = self.array.next_required()?; + visitor.visit_string(self.get_string_repr(ts)?) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs index d3722e41..b58f5213 100644 --- a/serde_arrow/src/internal/deserialization/decimal_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/decimal_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::DecimalArrayView, - error::{set_default, Context, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{decimal, Mut}, }; @@ -33,28 +33,37 @@ impl<'de> Context for DecimalDeserializer<'de> { impl<'de> SimpleDeserializer<'de> for DecimalDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.inner.peek_next()? { - self.deserialize_str(visitor) - } else { - self.inner.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.inner.peek_next()? { + self.deserialize_str(visitor) + } else { + self.inner.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.inner.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.inner.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.inner.peek_next()? { + visitor.visit_some(Mut(self)) + } else { + self.inner.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_str>(&mut self, visitor: V) -> Result { - let val = self.inner.next_required()?; - let mut buffer = [0; decimal::BUFFER_SIZE_I128]; - let formatted = decimal::format_decimal(&mut buffer, val, self.scale); + try_(|| { + let val = self.inner.next_required()?; + let mut buffer = [0; decimal::BUFFER_SIZE_I128]; + let formatted = decimal::format_decimal(&mut buffer, val, self.scale); - visitor.visit_str(formatted) + visitor.visit_str(formatted) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index 027579ac..6c4f6d8f 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::{BytesArrayView, PrimitiveArrayView}, - error::{fail, set_default, Context, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::{Mut, Offset}, }; @@ -62,29 +62,35 @@ impl<'de, K: Integer, V: Offset> Context for DictionaryDeserializer<'de, K, V> { impl<'de, K: Integer, V: Offset> SimpleDeserializer<'de> for DictionaryDeserializer<'de, K, V> { fn deserialize_any>(&mut self, visitor: VV) -> Result { - if self.keys.peek_next()? { - self.deserialize_str(visitor) - } else { - self.keys.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.keys.peek_next()? { + self.deserialize_str(visitor) + } else { + self.keys.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: VV) -> Result { - if self.keys.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.keys.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.keys.peek_next()? { + visitor.visit_some(Mut(self)) + } else { + self.keys.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_str>(&mut self, visitor: VV) -> Result { - visitor.visit_str(self.next_str()?) + try_(|| visitor.visit_str(self.next_str()?)).ctx(self) } fn deserialize_string>(&mut self, visitor: VV) -> Result { - visitor.visit_string(self.next_str()?.to_owned()) + try_(|| visitor.visit_string(self.next_str()?.to_owned())).ctx(self) } fn deserialize_enum>( @@ -93,7 +99,10 @@ impl<'de, K: Integer, V: Offset> SimpleDeserializer<'de> for DictionaryDeseriali _: &'static [&'static str], visitor: VV, ) -> Result { - let variant = self.next_str()?; - visitor.visit_enum(EnumAccess(variant)) + try_(|| { + let variant = self.next_str()?; + visitor.visit_enum(EnumAccess(variant)) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/enum_deserializer.rs b/serde_arrow/src/internal/deserialization/enum_deserializer.rs index 6473696d..502ff45d 100644 --- a/serde_arrow/src/internal/deserialization/enum_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/enum_deserializer.rs @@ -1,7 +1,9 @@ +use std::collections::BTreeMap; + use serde::de::{DeserializeSeed, Deserializer, EnumAccess, Visitor}; use crate::internal::{ - error::{fail, set_default, Context, Error, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, utils::Mut, }; @@ -43,7 +45,10 @@ impl<'de> SimpleDeserializer<'de> for EnumDeserializer<'de> { _: &'static [&'static str], visitor: V, ) -> Result { - visitor.visit_enum(self) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(|| visitor.visit_enum(self)).ctx(&ctx) } } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index 82f12778..dbeb9094 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::FixedSizeBinaryArrayView, - error::{fail, set_default, Context, Error, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, utils::Mut, }; @@ -73,33 +73,39 @@ impl<'a> Context for FixedSizeBinaryDeserializer<'a> { impl<'a> SimpleDeserializer<'a> for FixedSizeBinaryDeserializer<'a> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - self.deserialize_bytes(visitor) - } else { - self.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.peek_next()? { + self.deserialize_bytes(visitor) + } else { + self.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.consume_next(); - visitor.visit_none() - } + try_(|| { + if self.peek_next()? { + visitor.visit_some(Mut(self)) + } else { + self.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_seq>(&mut self, visitor: V) -> Result { - visitor.visit_seq(self) + try_(|| visitor.visit_seq(&mut *self)).ctx(self) } fn deserialize_bytes>(&mut self, visitor: V) -> Result { - visitor.visit_borrowed_bytes(self.next_slice()?) + try_(|| visitor.visit_borrowed_bytes(self.next_slice()?)).ctx(self) } fn deserialize_byte_buf>(&mut self, visitor: V) -> Result { - visitor.visit_borrowed_bytes(self.next_slice()?) + try_(|| visitor.visit_borrowed_bytes(self.next_slice()?)).ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index d6829810..fc4bf01a 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{IgnoredAny, SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, set_default, Context, Error, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, utils::Mut, }; @@ -66,25 +66,31 @@ impl<'a> Context for FixedSizeListDeserializer<'a> { impl<'a> SimpleDeserializer<'a> for FixedSizeListDeserializer<'a> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - self.deserialize_seq(visitor) - } else { - self.consume_next()?; - visitor.visit_none() - } + try_(|| { + if self.peek_next()? { + self.deserialize_seq(visitor) + } else { + self.consume_next()?; + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.consume_next()?; - visitor.visit_none() - } + try_(|| { + if self.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.consume_next()?; + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_seq>(&mut self, visitor: V) -> Result { - visitor.visit_seq(self) + try_(|| visitor.visit_seq(&mut *self)).ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/float_deserializer.rs b/serde_arrow/src/internal/deserialization/float_deserializer.rs index 9d7731a1..46db143e 100644 --- a/serde_arrow/src/internal/deserialization/float_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/float_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::PrimitiveArrayView, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{Mut, NamedType}, }; @@ -50,46 +50,34 @@ impl<'de, F: NamedType + Float> Context for FloatDeserializer<'de, F> { impl<'de, F: NamedType + Float> SimpleDeserializer<'de> for FloatDeserializer<'de, F> { fn deserialize_any>(&mut self, visitor: V) -> Result { - self.deserialize_any_impl(visitor).ctx(self) + try_(|| { + if self.array.peek_next()? { + F::deserialize_any(&mut *self, visitor) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - self.deserialize_option_impl(visitor).ctx(self) + try_(|| { + if self.array.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_f32>(&mut self, visitor: V) -> Result { - self.deserialize_f32_impl(visitor).ctx(self) + try_(|| visitor.visit_f32(self.array.next_required()?.into_f32()?)).ctx(self) } fn deserialize_f64>(&mut self, visitor: V) -> Result { - self.deserialize_f64_impl(visitor).ctx(self) - } -} - -impl<'de, F: NamedType + Float> FloatDeserializer<'de, F> { - fn deserialize_any_impl>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - F::deserialize_any(self, visitor) - } else { - self.array.consume_next(); - visitor.visit_none() - } - } - - fn deserialize_option_impl>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.array.consume_next(); - visitor.visit_none() - } - } - - fn deserialize_f32_impl>(&mut self, visitor: V) -> Result { - visitor.visit_f32(self.array.next_required()?.into_f32()?) - } - - fn deserialize_f64_impl>(&mut self, visitor: V) -> Result { - visitor.visit_f64(self.array.next_required()?.into_f64()?) + try_(|| visitor.visit_f64(self.array.next_required()?.into_f64()?)).ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/integer_deserializer.rs b/serde_arrow/src/internal/deserialization/integer_deserializer.rs index 5dda5431..5ef6c3ba 100644 --- a/serde_arrow/src/internal/deserialization/integer_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/integer_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::PrimitiveArrayView, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{Mut, NamedType}, }; @@ -64,110 +64,66 @@ impl<'de, T: NamedType + Integer> Context for IntegerDeserializer<'de, T> { impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for IntegerDeserializer<'de, T> { fn deserialize_any>(&mut self, visitor: V) -> Result { - self.deserialize_any_impl(visitor).ctx(self) + try_(|| { + if self.array.peek_next()? { + T::deserialize_any(&mut *self, visitor) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - self.deserialize_option_impl(visitor).ctx(self) + try_(|| { + if self.array.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_bool>(&mut self, visitor: V) -> Result { - self.deserialize_bool_impl(visitor).ctx(self) + try_(|| visitor.visit_bool(self.array.next_required()?.into_bool()?)).ctx(self) } fn deserialize_char>(&mut self, visitor: V) -> Result { - self.deserialize_char_impl(visitor).ctx(self) + try_(|| visitor.visit_char(self.array.next_required()?.into_u32()?.try_into()?)).ctx(self) } fn deserialize_u8>(&mut self, visitor: V) -> Result { - self.deserialize_u8_impl(visitor).ctx(self) + try_(|| visitor.visit_u8(self.array.next_required()?.into_u8()?)).ctx(self) } fn deserialize_u16>(&mut self, visitor: V) -> Result { - self.deserialize_u16_impl(visitor).ctx(self) + try_(|| visitor.visit_u16(self.array.next_required()?.into_u16()?)).ctx(self) } fn deserialize_u32>(&mut self, visitor: V) -> Result { - self.deserialize_u32_impl(visitor).ctx(self) + try_(|| visitor.visit_u32(self.array.next_required()?.into_u32()?)).ctx(self) } fn deserialize_u64>(&mut self, visitor: V) -> Result { - self.deserialize_u64_impl(visitor).ctx(self) + try_(|| visitor.visit_u64(self.array.next_required()?.into_u64()?)).ctx(self) } fn deserialize_i8>(&mut self, visitor: V) -> Result { - self.deserialize_i8_impl(visitor).ctx(self) + try_(|| visitor.visit_i8(self.array.next_required()?.into_i8()?)).ctx(self) } fn deserialize_i16>(&mut self, visitor: V) -> Result { - self.deserialize_i16_impl(visitor).ctx(self) + try_(|| visitor.visit_i16(self.array.next_required()?.into_i16()?)).ctx(self) } fn deserialize_i32>(&mut self, visitor: V) -> Result { - self.deserialize_i32_impl(visitor).ctx(self) + try_(|| visitor.visit_i32(self.array.next_required()?.into_i32()?)).ctx(self) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - self.deserialize_i64_impl(visitor).ctx(self) - } -} - -impl<'de, T: NamedType + Integer> IntegerDeserializer<'de, T> { - fn deserialize_any_impl>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - T::deserialize_any(self, visitor) - } else { - self.array.consume_next(); - visitor.visit_none() - } - } - - fn deserialize_option_impl>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.array.consume_next(); - visitor.visit_none() - } - } - - fn deserialize_bool_impl>(&mut self, visitor: V) -> Result { - visitor.visit_bool(self.array.next_required()?.into_bool()?) - } - - fn deserialize_char_impl>(&mut self, visitor: V) -> Result { - visitor.visit_char(self.array.next_required()?.into_u32()?.try_into()?) - } - - fn deserialize_u8_impl>(&mut self, visitor: V) -> Result { - visitor.visit_u8(self.array.next_required()?.into_u8()?) - } - - fn deserialize_u16_impl>(&mut self, visitor: V) -> Result { - visitor.visit_u16(self.array.next_required()?.into_u16()?) - } - - fn deserialize_u32_impl>(&mut self, visitor: V) -> Result { - visitor.visit_u32(self.array.next_required()?.into_u32()?) - } - - fn deserialize_u64_impl>(&mut self, visitor: V) -> Result { - visitor.visit_u64(self.array.next_required()?.into_u64()?) - } - - fn deserialize_i8_impl>(&mut self, visitor: V) -> Result { - visitor.visit_i8(self.array.next_required()?.into_i8()?) - } - - fn deserialize_i16_impl>(&mut self, visitor: V) -> Result { - visitor.visit_i16(self.array.next_required()?.into_i16()?) - } - - fn deserialize_i32_impl>(&mut self, visitor: V) -> Result { - visitor.visit_i32(self.array.next_required()?.into_i32()?) - } - - fn deserialize_i64_impl>(&mut self, visitor: V) -> Result { - visitor.visit_i64(self.array.next_required()?.into_i64()?) + try_(|| visitor.visit_i64(self.array.next_required()?.into_i64()?)).ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 00f06e5e..8fac54f1 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{SeqAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, set_default, Context, ContextSupport, Error, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, utils::{Mut, NamedType, Offset}, }; @@ -71,33 +71,39 @@ impl<'a, O: NamedType + Offset> Context for ListDeserializer<'a, O> { impl<'a, O: NamedType + Offset> SimpleDeserializer<'a> for ListDeserializer<'a, O> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - self.deserialize_seq(visitor) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next()? { + self.deserialize_seq(visitor) + } else { + self.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - visitor.visit_some(Mut(self)) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_seq>(&mut self, visitor: V) -> Result { - visitor.visit_seq(self) + try_(|| visitor.visit_seq(&mut *self)).ctx(self) } fn deserialize_bytes>(&mut self, visitor: V) -> Result { - visitor.visit_seq(self) + try_(|| visitor.visit_seq(&mut *self)).ctx(self) } fn deserialize_byte_buf>(&mut self, visitor: V) -> Result { - visitor.visit_seq(self) + try_(|| visitor.visit_seq(&mut *self)).ctx(self) } } @@ -112,8 +118,8 @@ impl<'de, O: NamedType + Offset> SeqAccess<'de> for ListDeserializer<'de, O> { if item + 1 >= self.offsets.len() { return Ok(None); } - let end = self.offsets[item + 1].try_into_usize().ctx(self)?; - let start = self.offsets[item].try_into_usize().ctx(self)?; + let end = self.offsets[item + 1].try_into_usize()?; + let start = self.offsets[item].try_into_usize()?; if offset >= end - start { self.next = (item + 1, 0); diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index 3941b027..f8a1cf85 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -2,7 +2,7 @@ use serde::de::{DeserializeSeed, MapAccess, Visitor}; use crate::internal::{ arrow::BitsWithOffset, - error::{fail, set_default, Context, ContextSupport, Error, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, utils::Mut, }; @@ -66,25 +66,31 @@ impl<'de> Context for MapDeserializer<'de> { impl<'de> SimpleDeserializer<'de> for MapDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - self.deserialize_map(visitor) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next()? { + self.deserialize_map(visitor) + } else { + self.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next().ctx(self)? { - visitor.visit_some(Mut(self)) - } else { - self.consume_next(); - visitor.visit_none::().ctx(self) - } + try_(|| { + if self.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.consume_next(); + visitor.visit_none::() + } + }) + .ctx(self) } fn deserialize_map>(&mut self, visitor: V) -> Result { - visitor.visit_map(self) + try_(|| visitor.visit_map(&mut *self)).ctx(self) } } @@ -99,8 +105,8 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { if item + 1 >= self.offsets.len() { fail!(in self, "Exhausted MapDeserializer"); } - let start: usize = self.offsets[item].try_into().ctx(self)?; - let end: usize = self.offsets[item + 1].try_into().ctx(self)?; + let start: usize = self.offsets[item].try_into()?; + let end: usize = self.offsets[item + 1].try_into()?; if entry >= (end - start) { self.next = (item + 1, 0); diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index 32eff2e6..98244a15 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -4,7 +4,7 @@ use serde::de::{ use crate::internal::{ arrow::BitsWithOffset, - error::{fail, set_default, Context, Error, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Error, Result}, utils::Mut, }; @@ -62,31 +62,37 @@ impl<'de> Context for StructDeserializer<'de> { impl<'de> SimpleDeserializer<'de> for StructDeserializer<'de> { fn deserialize_any>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - visitor.visit_map(self) - } else { - self.consume_next(); - for (_, field) in &mut self.fields { - field.deserialize_ignored_any(IgnoredAny)?; + try_(|| { + if self.peek_next()? { + visitor.visit_map(&mut *self) + } else { + self.consume_next(); + for (_, field) in &mut self.fields { + field.deserialize_ignored_any(IgnoredAny)?; + } + visitor.visit_none() } - visitor.visit_none() - } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - if self.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.consume_next(); - for (_, field) in &mut self.fields { - field.deserialize_ignored_any(IgnoredAny)?; + try_(|| { + if self.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.consume_next(); + for (_, field) in &mut self.fields { + field.deserialize_ignored_any(IgnoredAny)?; + } + visitor.visit_none() } - visitor.visit_none() - } + }) + .ctx(self) } fn deserialize_map>(&mut self, visitor: V) -> Result { - visitor.visit_map(self) + try_(|| visitor.visit_map(&mut *self)).ctx(self) } fn deserialize_struct>( @@ -95,15 +101,18 @@ impl<'de> SimpleDeserializer<'de> for StructDeserializer<'de> { _: &'static [&'static str], visitor: V, ) -> Result { - visitor.visit_map(self) + try_(|| visitor.visit_map(&mut *self)).ctx(self) } fn deserialize_tuple>(&mut self, _: usize, visitor: V) -> Result { - let res = visitor.visit_seq(&mut *self)?; + try_(|| { + let res = visitor.visit_seq(&mut *self)?; - // tuples do not consume the sequence until none is raised - self.consume_next(); - Ok(res) + // tuples do not consume the sequence until none is raised + self.consume_next(); + Ok(res) + }) + .ctx(self) } fn deserialize_tuple_struct>( @@ -112,11 +121,14 @@ impl<'de> SimpleDeserializer<'de> for StructDeserializer<'de> { _: usize, visitor: V, ) -> Result { - let res = visitor.visit_seq(&mut *self)?; + try_(|| { + let res = visitor.visit_seq(&mut *self)?; - // tuples do not consume the sequence until none is raised - self.consume_next(); - Ok(res) + // tuples do not consume the sequence until none is raised + self.consume_next(); + Ok(res) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/deserialization/time_deserializer.rs b/serde_arrow/src/internal/deserialization/time_deserializer.rs index aa4304ac..92d15206 100644 --- a/serde_arrow/src/internal/deserialization/time_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/time_deserializer.rs @@ -3,7 +3,7 @@ use serde::de::Visitor; use crate::internal::{ arrow::{TimeArrayView, TimeUnit}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::{Mut, NamedType}, }; @@ -64,63 +64,46 @@ impl<'de, T: NamedType + Integer> Context for TimeDeserializer<'de, T> { impl<'de, T: NamedType + Integer> SimpleDeserializer<'de> for TimeDeserializer<'de, T> { fn deserialize_any>(&mut self, visitor: V) -> Result { - self.deserialize_any_impl(visitor).ctx(self) + try_(|| { + if self.array.peek_next()? { + T::deserialize_any(&mut *self, visitor) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_option>(&mut self, visitor: V) -> Result { - self.deserialize_option_impl(visitor).ctx(self) + try_(|| { + if self.array.peek_next()? { + visitor.visit_some(Mut(&mut *self)) + } else { + self.array.consume_next(); + visitor.visit_none() + } + }) + .ctx(self) } fn deserialize_i32>(&mut self, visitor: V) -> Result { - self.deserialize_i32_impl(visitor).ctx(self) + try_(|| visitor.visit_i32(self.array.next_required()?.into_i32()?)).ctx(self) } fn deserialize_i64>(&mut self, visitor: V) -> Result { - self.deserialize_i64_impl(visitor).ctx(self) + try_(|| visitor.visit_i64(self.array.next_required()?.into_i64()?)).ctx(self) } fn deserialize_str>(&mut self, visitor: V) -> Result { - self.deserialize_str_impl(visitor).ctx(self) + try_(|| self.deserialize_string(visitor)).ctx(self) } fn deserialize_string>(&mut self, visitor: V) -> Result { - self.deserialize_string_impl(visitor).ctx(self) - } -} - -impl<'de, T: NamedType + Integer> TimeDeserializer<'de, T> { - fn deserialize_any_impl>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - T::deserialize_any(self, visitor) - } else { - self.array.consume_next(); - visitor.visit_none() - } - } - - fn deserialize_option_impl>(&mut self, visitor: V) -> Result { - if self.array.peek_next()? { - visitor.visit_some(Mut(self)) - } else { - self.array.consume_next(); - visitor.visit_none() - } - } - - fn deserialize_i32_impl>(&mut self, visitor: V) -> Result { - visitor.visit_i32(self.array.next_required()?.into_i32()?) - } - - fn deserialize_i64_impl>(&mut self, visitor: V) -> Result { - visitor.visit_i64(self.array.next_required()?.into_i64()?) - } - - fn deserialize_str_impl>(&mut self, visitor: V) -> Result { - self.deserialize_string(visitor) - } - - fn deserialize_string_impl>(&mut self, visitor: V) -> Result { - let ts = self.array.next_required()?.into_i64()?; - visitor.visit_string(self.get_string_repr(ts)?) + try_(|| { + let ts = self.array.next_required()?.into_i64()?; + visitor.visit_string(self.get_string_repr(ts)?) + }) + .ctx(self) } } From c9231f9eeea9d4cdfe8f3d4512ce11a0c1f81eca Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 10 Sep 2024 21:24:37 +0200 Subject: [PATCH 156/178] Add context only for non-annotated errors --- serde_arrow/src/internal/error.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index ec497123..35b11f30 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -59,7 +59,9 @@ impl> ContextSupport for E { fn ctx(self, context: &C) -> Self::Output { let Error::Custom(mut error) = self.into(); - context.annotate(&mut error.0.annotations); + if error.0.annotations.is_empty() { + context.annotate(&mut error.0.annotations); + } Error::Custom(error) } } From 2e72489b515d51821f081306fc241617e00b8e30 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 10 Sep 2024 21:48:48 +0200 Subject: [PATCH 157/178] use try_().ctx() in serializer --- .../internal/serialization/bool_builder.rs | 35 +++--- .../internal/serialization/date32_builder.rs | 15 ++- .../internal/serialization/date64_builder.rs | 15 ++- .../internal/serialization/decimal_builder.rs | 28 +++-- .../serialization/dictionary_utf8_builder.rs | 33 +++--- .../serialization/duration_builder.rs | 24 ++-- .../fixed_size_binary_builder.rs | 72 +++++++----- .../serialization/fixed_size_list_builder.rs | 48 ++++---- .../internal/serialization/float_builder.rs | 60 +++++----- .../src/internal/serialization/int_builder.rs | 41 +++---- .../internal/serialization/list_builder.rs | 39 ++++--- .../src/internal/serialization/map_builder.rs | 26 +++-- .../internal/serialization/struct_builder.rs | 109 ++++++++++-------- .../internal/serialization/time_builder.rs | 45 ++++---- .../internal/serialization/union_builder.rs | 31 +++-- .../internal/serialization/utf8_builder.rs | 10 +- 16 files changed, 340 insertions(+), 291 deletions(-) diff --git a/serde_arrow/src/internal/serialization/bool_builder.rs b/serde_arrow/src/internal/serialization/bool_builder.rs index 74bcd1cf..08b64877 100644 --- a/serde_arrow/src/internal/serialization/bool_builder.rs +++ b/serde_arrow/src/internal/serialization/bool_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, BooleanArray}, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::array_ext::{set_bit_buffer, set_validity, set_validity_default}, }; @@ -55,23 +55,32 @@ impl Context for BoolBuilder { impl SimpleSerializer for BoolBuilder { fn serialize_default(&mut self) -> Result<()> { - set_validity_default(self.array.validity.as_mut(), self.array.len); - set_bit_buffer(&mut self.array.values, self.array.len, false); - self.array.len += 1; - Ok(()) + try_(|| { + set_validity_default(self.array.validity.as_mut(), self.array.len); + set_bit_buffer(&mut self.array.values, self.array.len, false); + self.array.len += 1; + Ok(()) + }) + .ctx(self) } fn serialize_none(&mut self) -> Result<()> { - set_validity(self.array.validity.as_mut(), self.array.len, false).ctx(self)?; - set_bit_buffer(&mut self.array.values, self.array.len, false); - self.array.len += 1; - Ok(()) + try_(|| { + set_validity(self.array.validity.as_mut(), self.array.len, false)?; + set_bit_buffer(&mut self.array.values, self.array.len, false); + self.array.len += 1; + Ok(()) + }) + .ctx(self) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - set_validity(self.array.validity.as_mut(), self.array.len, true).ctx(self)?; - set_bit_buffer(&mut self.array.values, self.array.len, v); - self.array.len += 1; - Ok(()) + try_(|| { + set_validity(self.array.validity.as_mut(), self.array.len, true)?; + set_bit_buffer(&mut self.array.values, self.array.len, v); + self.array.len += 1; + Ok(()) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/date32_builder.rs b/serde_arrow/src/internal/serialization/date32_builder.rs index 6987ce0d..6fefed26 100644 --- a/serde_arrow/src/internal/serialization/date32_builder.rs +++ b/serde_arrow/src/internal/serialization/date32_builder.rs @@ -4,7 +4,7 @@ use chrono::{NaiveDate, NaiveDateTime}; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; @@ -59,19 +59,22 @@ impl Context for Date32Builder { impl SimpleSerializer for Date32Builder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - let days_since_epoch = self.parse_str_to_days_since_epoch(v).ctx(self)?; - self.array.push_scalar_value(days_since_epoch).ctx(self) + try_(|| { + let days_since_epoch = self.parse_str_to_days_since_epoch(v)?; + self.array.push_scalar_value(days_since_epoch) + }) + .ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v).ctx(self) + try_(|| self.array.push_scalar_value(v)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/date64_builder.rs b/serde_arrow/src/internal/serialization/date64_builder.rs index f06e3a68..405d9b28 100644 --- a/serde_arrow/src/internal/serialization/date64_builder.rs +++ b/serde_arrow/src/internal/serialization/date64_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeUnit, TimestampArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; @@ -107,19 +107,22 @@ impl Context for Date64Builder { impl SimpleSerializer for Date64Builder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - let timestamp = self.parse_str_to_timestamp(v).ctx(self)?; - self.array.push_scalar_value(timestamp) + try_(|| { + let timestamp = self.parse_str_to_timestamp(v)?; + self.array.push_scalar_value(timestamp) + }) + .ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v).ctx(self) + try_(|| self.array.push_scalar_value(v)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/decimal_builder.rs b/serde_arrow/src/internal/serialization/decimal_builder.rs index 2e75bd3d..0d4a22fb 100644 --- a/serde_arrow/src/internal/serialization/decimal_builder.rs +++ b/serde_arrow/src/internal/serialization/decimal_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DecimalArray, PrimitiveArray}, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, decimal::{self, DecimalParser}, @@ -70,32 +70,30 @@ impl Context for DecimalBuilder { impl SimpleSerializer for DecimalBuilder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array - .push_scalar_value((v * self.f32_factor) as i128) - .ctx(self) + try_(|| self.array.push_scalar_value((v * self.f32_factor) as i128)).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array - .push_scalar_value((v * self.f64_factor) as i128) - .ctx(self) + try_(|| self.array.push_scalar_value((v * self.f64_factor) as i128)).ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - let mut parse_buffer = [0; decimal::BUFFER_SIZE_I128]; - let val = self - .parser - .parse_decimal128(&mut parse_buffer, v.as_bytes()) - .ctx(self)?; + try_(|| { + let mut parse_buffer = [0; decimal::BUFFER_SIZE_I128]; + let val = self + .parser + .parse_decimal128(&mut parse_buffer, v.as_bytes())?; - self.array.push_scalar_value(val).ctx(self) + self.array.push_scalar_value(val) + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs index a7af2c63..66a50f81 100644 --- a/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/dictionary_utf8_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, DictionaryArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::Mut, }; @@ -58,25 +58,27 @@ impl Context for DictionaryUtf8Builder { impl SimpleSerializer for DictionaryUtf8Builder { fn serialize_default(&mut self) -> Result<()> { - self.indices.serialize_none().ctx(self) + try_(|| self.indices.serialize_none()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.indices.serialize_none().ctx(self) + try_(|| self.indices.serialize_none().ctx(self)).ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - // the only faillible operations concern children: do not apply the context - let idx = match self.index.get(v) { - Some(idx) => *idx, - None => { - let idx = self.index.len(); - self.values.serialize_str(v)?; - self.index.insert(v.to_string(), idx); - idx - } - }; - idx.serialize(Mut(self.indices.as_mut())) + try_(|| { + let idx = match self.index.get(v) { + Some(idx) => *idx, + None => { + let idx = self.index.len(); + self.values.serialize_str(v)?; + self.index.insert(v.to_string(), idx); + idx + } + }; + idx.serialize(Mut(self.indices.as_mut())) + }) + .ctx(self) } fn serialize_unit_variant( @@ -85,8 +87,7 @@ impl SimpleSerializer for DictionaryUtf8Builder { _: u32, variant: &'static str, ) -> Result<()> { - // NOTE: context logic is implemented in serialize_str - self.serialize_str(variant) + try_(|| self.serialize_str(variant)).ctx(self) } fn serialize_tuple_variant_start<'this>( diff --git a/serde_arrow/src/internal/serialization/duration_builder.rs b/serde_arrow/src/internal/serialization/duration_builder.rs index 3163d96b..2c73e250 100644 --- a/serde_arrow/src/internal/serialization/duration_builder.rs +++ b/serde_arrow/src/internal/serialization/duration_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, }; @@ -54,44 +54,42 @@ impl Context for DurationBuilder { impl SimpleSerializer for DurationBuilder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(i64::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(i64::from(v))).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(i64::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(i64::from(v))).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(i64::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(i64::from(v))).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v).ctx(self) + try_(|| self.array.push_scalar_value(v)).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(i64::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(i64::from(v))).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(i64::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(i64::from(v))).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(i64::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(i64::from(v))).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array - .push_scalar_value(i64::try_from(v).ctx(self)?) - .ctx(self) + try_(|| self.array.push_scalar_value(i64::try_from(v)?)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs index b4cce14c..1b5cfa48 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_binary_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FixedSizeBinaryArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, Mut, @@ -94,70 +94,78 @@ impl Context for FixedSizeBinaryBuilder { impl SimpleSerializer for FixedSizeBinaryBuilder { fn serialize_default(&mut self) -> Result<()> { - self.seq.push_seq_default().ctx(self)?; - for _ in 0..self.n { - self.buffer.push(0); - } - Ok(()) + try_(|| { + self.seq.push_seq_default()?; + for _ in 0..self.n { + self.buffer.push(0); + } + Ok(()) + }) + .ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.seq.push_seq_none().ctx(self)?; - for _ in 0..self.n { - self.buffer.push(0); - } - Ok(()) + try_(|| { + self.seq.push_seq_none()?; + for _ in 0..self.n { + self.buffer.push(0); + } + Ok(()) + }) + .ctx(self) } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - self.element(value).ctx(self) + try_(|| self.element(value)).ctx(self) } fn serialize_seq_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - self.element(value).ctx(self) + try_(|| self.element(value)).ctx(self) } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - self.element(value).ctx(self) + try_(|| self.element(value)).ctx(self) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - if v.len() != self.n { - fail!( - in self, - "Invalid number of elements for fixed size binary: got {actual}, expected {expected}", - actual = v.len(), - expected = self.n, - ); - } - - self.seq.start_seq().ctx(self)?; - self.buffer.extend(v); - self.seq.end_seq().ctx(self) + try_(|| { + if v.len() != self.n { + fail!( + in self, + "Invalid number of elements for fixed size binary: got {actual}, expected {expected}", + actual = v.len(), + expected = self.n, + ); + } + + self.seq.start_seq()?; + self.buffer.extend(v); + self.seq.end_seq() + }).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs index 41adaac3..03595b23 100644 --- a/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs +++ b/serde_arrow/src/internal/serialization/fixed_size_list_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, FixedSizeListArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, Mut, @@ -76,7 +76,7 @@ impl FixedSizeListBuilder { fn element(&mut self, value: &V) -> Result<()> { self.current_count += 1; - self.seq.push_seq_elements(1).ctx(self)?; + self.seq.push_seq_elements(1)?; value.serialize(Mut(self.element.as_mut())) } @@ -102,54 +102,60 @@ impl Context for FixedSizeListBuilder { impl SimpleSerializer for FixedSizeListBuilder { fn serialize_default(&mut self) -> Result<()> { - self.seq.push_seq_default().ctx(self)?; - for _ in 0..self.n { - self.element.serialize_default()?; - } - Ok(()) + try_(|| { + self.seq.push_seq_default()?; + for _ in 0..self.n { + self.element.serialize_default()?; + } + Ok(()) + }) + .ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.seq.push_seq_none().ctx(self)?; - for _ in 0..self.n { - self.element.serialize_default()?; - } - Ok(()) + try_(|| { + self.seq.push_seq_none()?; + for _ in 0..self.n { + self.element.serialize_default()?; + } + Ok(()) + }) + .ctx(self) } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - self.element(value) + try_(|| self.element(value)).ctx(self) } fn serialize_seq_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - self.element(value) + try_(|| self.element(value)).ctx(self) } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - self.element(value) + try_(|| self.element(value)).ctx(self) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/float_builder.rs b/serde_arrow/src/internal/serialization/float_builder.rs index 8bbd0ac6..5eb6ad58 100644 --- a/serde_arrow/src/internal/serialization/float_builder.rs +++ b/serde_arrow/src/internal/serialization/float_builder.rs @@ -4,7 +4,7 @@ use half::f16; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, Mut, @@ -80,122 +80,122 @@ impl Context for FloatBuilder { impl SimpleSerializer for FloatBuilder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_some(&mut self, value: &V) -> Result<()> { - value.serialize(Mut(self)) + try_(|| value.serialize(Mut(&mut *self))).ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value(v).ctx(self) + try_(|| self.array.push_scalar_value(v)).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value(v as f32).ctx(self) + try_(|| self.array.push_scalar_value(v as f32)).ctx(self) } } impl SimpleSerializer for FloatBuilder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value(v as f64).ctx(self) + try_(|| self.array.push_scalar_value(v as f64)).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value(v).ctx(self) + try_(|| self.array.push_scalar_value(v)).ctx(self) } } impl SimpleSerializer for FloatBuilder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_f32(&mut self, v: f32) -> Result<()> { - self.array.push_scalar_value(f16::from_f32(v)).ctx(self) + try_(|| self.array.push_scalar_value(f16::from_f32(v))).ctx(self) } fn serialize_f64(&mut self, v: f64) -> Result<()> { - self.array.push_scalar_value(f16::from_f64(v)).ctx(self) + try_(|| self.array.push_scalar_value(f16::from_f64(v))).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/int_builder.rs b/serde_arrow/src/internal/serialization/int_builder.rs index 3d9815c5..34fa7e3a 100644 --- a/serde_arrow/src/internal/serialization/int_builder.rs +++ b/serde_arrow/src/internal/serialization/int_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, PrimitiveArray}, - error::{set_default, Context, ContextSupport, Error, Result}, + error::{set_default, try_, Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, NamedType, @@ -81,16 +81,6 @@ impl Context for IntBuilder { } } -impl IntBuilder { - fn push_value(&mut self, v: J) -> Result<()> - where - I: Default + TryFrom + 'static, - Error: From<>::Error>, - { - self.array.push_scalar_value(I::try_from(v)?) - } -} - impl SimpleSerializer for IntBuilder where I: NamedType @@ -114,51 +104,54 @@ where Error: From<>::Error>, { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_bool(&mut self, v: bool) -> Result<()> { - let v: u8 = if v { 1 } else { 0 }; - self.push_value(v).ctx(self) + try_(|| { + let v: u8 = if v { 1 } else { 0 }; + self.array.push_scalar_value(I::try_from(v)?) + }) + .ctx(self) } fn serialize_i8(&mut self, v: i8) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_i16(&mut self, v: i16) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_u8(&mut self, v: u8) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_u16(&mut self, v: u16) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_u32(&mut self, v: u32) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_u64(&mut self, v: u64) -> Result<()> { - self.push_value(v).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(v)?)).ctx(self) } fn serialize_char(&mut self, v: char) -> Result<()> { - self.push_value(u32::from(v)).ctx(self) + try_(|| self.array.push_scalar_value(I::try_from(u32::from(v))?)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/list_builder.rs b/serde_arrow/src/internal/serialization/list_builder.rs index cff30c22..321eb95b 100644 --- a/serde_arrow/src/internal/serialization/list_builder.rs +++ b/serde_arrow/src/internal/serialization/list_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{set_default, Context, ContextSupport, Result}, + error::{set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, Mut, NamedType, Offset, @@ -82,7 +82,7 @@ impl ListBuilder { } fn element(&mut self, value: &V) -> Result<()> { - self.offsets.push_seq_elements(1).ctx(self)?; + self.offsets.push_seq_elements(1)?; value.serialize(Mut(self.element.as_mut())) } @@ -108,54 +108,57 @@ impl Context for ListBuilder { impl SimpleSerializer for ListBuilder { fn serialize_default(&mut self) -> Result<()> { - self.offsets.push_seq_default().ctx(self) + try_(|| self.offsets.push_seq_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_seq_none().ctx(self) + try_(|| self.offsets.push_seq_none()).ctx(self) } fn serialize_seq_start(&mut self, _: Option) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_seq_element(&mut self, value: &V) -> Result<()> { - self.element(value) + try_(|| self.element(value)).ctx(self) } fn serialize_seq_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - self.element(value) + try_(|| self.element(value)).ctx(self) } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - self.element(value) + try_(|| self.element(value)).ctx(self) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_bytes(&mut self, v: &[u8]) -> Result<()> { - self.start().ctx(self)?; - for item in v { - self.element(item)?; - } - self.end().ctx(self) + try_(|| { + self.start()?; + for item in v { + self.element(item)?; + } + self.end() + }) + .ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/map_builder.rs b/serde_arrow/src/internal/serialization/map_builder.rs index b17457dd..b7529cce 100644 --- a/serde_arrow/src/internal/serialization/map_builder.rs +++ b/serde_arrow/src/internal/serialization/map_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, ListArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::array_ext::{ArrayExt, OffsetsArray, SeqArrayExt}, }; @@ -76,29 +76,35 @@ impl Context for MapBuilder { impl SimpleSerializer for MapBuilder { fn serialize_default(&mut self) -> Result<()> { - self.offsets.push_seq_default().ctx(self) + try_(|| self.offsets.push_seq_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.offsets.push_seq_none().ctx(self) + try_(|| self.offsets.push_seq_none()).ctx(self) } fn serialize_map_start(&mut self, _: Option) -> Result<()> { - self.offsets.start_seq().ctx(self) + try_(|| self.offsets.start_seq()).ctx(self) } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - self.offsets.push_seq_elements(1).ctx(self)?; - self.entry.serialize_tuple_start(2).ctx(self)?; - self.entry.serialize_tuple_element(key) + try_(|| { + self.offsets.push_seq_elements(1)?; + self.entry.serialize_tuple_start(2)?; + self.entry.serialize_tuple_element(key) + }) + .ctx(self) } fn serialize_map_value(&mut self, value: &V) -> Result<()> { - self.entry.serialize_tuple_element(value)?; - self.entry.serialize_tuple_end().ctx(self) + try_(|| { + self.entry.serialize_tuple_element(value)?; + self.entry.serialize_tuple_end() + }) + .ctx(self) } fn serialize_map_end(&mut self) -> Result<()> { - self.offsets.end_seq().ctx(self) + try_(|| self.offsets.end_seq()).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 70cb35eb..d1a2e9e3 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -4,7 +4,7 @@ use serde::Serialize; use crate::internal::{ arrow::{Array, FieldMeta, StructArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{ArrayExt, CountArray, SeqArrayExt}, Mut, @@ -110,7 +110,7 @@ impl StructBuilder { } fn element(&mut self, idx: usize, value: &T) -> Result<()> { - self.seq.push_seq_elements(1).ctx(self)?; + self.seq.push_seq_elements(1)?; if self.seen[idx] { fail!(in self, "Duplicate field {key}", key = self.fields[idx].1.name); } @@ -131,24 +131,30 @@ impl Context for StructBuilder { impl SimpleSerializer for StructBuilder { fn serialize_default(&mut self) -> Result<()> { - self.seq.push_seq_default().ctx(self)?; - for (builder, _) in &mut self.fields { - builder.serialize_default()?; - } + try_(|| { + self.seq.push_seq_default()?; + for (builder, _) in &mut self.fields { + builder.serialize_default()?; + } - Ok(()) + Ok(()) + }) + .ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.seq.push_seq_none().ctx(self)?; - for (builder, _) in &mut self.fields { - builder.serialize_default()?; - } - Ok(()) + try_(|| { + self.seq.push_seq_none()?; + for (builder, _) in &mut self.fields { + builder.serialize_default()?; + } + Ok(()) + }) + .ctx(self) } fn serialize_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_struct_field( @@ -156,72 +162,83 @@ impl SimpleSerializer for StructBuilder { key: &'static str, value: &T, ) -> Result<()> { - let Some(idx) = self.lookup.lookup(self.next, key) else { - // ignore unknown fields - return Ok(()); - }; - self.element(idx, value) + try_(|| { + let Some(idx) = self.lookup.lookup(self.next, key) else { + // ignore unknown fields + return Ok(()); + }; + self.element(idx, value) + }) + .ctx(self) } fn serialize_struct_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_start(&mut self, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_element(&mut self, value: &V) -> Result<()> { - self.element(self.next, value) + try_(|| self.element(self.next, value)).ctx(self) } fn serialize_tuple_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_tuple_struct_start(&mut self, _: &'static str, _: usize) -> Result<()> { - self.start().ctx(self) + try_(|| self.start()).ctx(self) } fn serialize_tuple_struct_field(&mut self, value: &V) -> Result<()> { - // ignore extra tuple fields - if self.next < self.fields.len() { - self.element(self.next, value)?; - } - Ok(()) + try_(|| { + // ignore extra tuple fields + if self.next < self.fields.len() { + self.element(self.next, value)?; + } + Ok(()) + }) + .ctx(self) } fn serialize_tuple_struct_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } fn serialize_map_start(&mut self, _: Option) -> Result<()> { - self.start()?; - // always re-set to an invalid field to force that `_key()` is called before `_value()`. - self.next = UNKNOWN_KEY; - Ok(()) + try_(|| { + self.start()?; + // always re-set to an invalid field to force that `_key()` is called before `_value()`. + self.next = UNKNOWN_KEY; + Ok(()) + }) + .ctx(self) } fn serialize_map_key(&mut self, key: &V) -> Result<()> { - self.next = self - .lookup - .lookup_serialize(key) - .ctx(self)? - .unwrap_or(UNKNOWN_KEY); - Ok(()) + try_(|| { + self.next = self.lookup.lookup_serialize(key)?.unwrap_or(UNKNOWN_KEY); + Ok(()) + }) + .ctx(self) } fn serialize_map_value(&mut self, value: &V) -> Result<()> { - if self.next != UNKNOWN_KEY { - self.element(self.next, value)?; - } - // see serialize_map_start - self.next = UNKNOWN_KEY; - Ok(()) + try_(|| { + if self.next != UNKNOWN_KEY { + self.element(self.next, value)?; + } + // see serialize_map_start + self.next = UNKNOWN_KEY; + Ok(()) + }) + .ctx(self) } fn serialize_map_end(&mut self) -> Result<()> { - self.end().ctx(self) + try_(|| self.end()).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/time_builder.rs b/serde_arrow/src/internal/serialization/time_builder.rs index e2fb2622..d52543ec 100644 --- a/serde_arrow/src/internal/serialization/time_builder.rs +++ b/serde_arrow/src/internal/serialization/time_builder.rs @@ -4,7 +4,7 @@ use chrono::Timelike; use crate::internal::{ arrow::{Array, PrimitiveArray, TimeArray, TimeUnit}, - error::{set_default, Context, ContextSupport, Error, Result}, + error::{set_default, try_, Context, ContextSupport, Error, Result}, utils::{ array_ext::{new_primitive_array, ArrayExt, ScalarArrayExt}, NamedType, @@ -92,40 +92,37 @@ where Error: From<>::Error>, { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - let (seconds_factor, nanoseconds_factor) = match self.unit { - TimeUnit::Nanosecond => (1_000_000_000, 1), - TimeUnit::Microsecond => (1_000_000, 1_000), - TimeUnit::Millisecond => (1_000, 1_000_000), - TimeUnit::Second => (1, 1_000_000_000), - }; - - use chrono::naive::NaiveTime; - let time = v.parse::().ctx(self)?; - let timestamp = i64::from(time.num_seconds_from_midnight()) * seconds_factor - + i64::from(time.nanosecond()) / nanoseconds_factor; - - self.array - .push_scalar_value(timestamp.try_into().ctx(self)?) - .ctx(self) + try_(|| { + let (seconds_factor, nanoseconds_factor) = match self.unit { + TimeUnit::Nanosecond => (1_000_000_000, 1), + TimeUnit::Microsecond => (1_000_000, 1_000), + TimeUnit::Millisecond => (1_000, 1_000_000), + TimeUnit::Second => (1, 1_000_000_000), + }; + + use chrono::naive::NaiveTime; + let time = v.parse::()?; + let timestamp = i64::from(time.num_seconds_from_midnight()) * seconds_factor + + i64::from(time.nanosecond()) / nanoseconds_factor; + + self.array.push_scalar_value(timestamp.try_into()?) + }) + .ctx(self) } fn serialize_i32(&mut self, v: i32) -> Result<()> { - self.array - .push_scalar_value(v.try_into().ctx(self)?) - .ctx(self) + try_(|| self.array.push_scalar_value(v.try_into()?)).ctx(self) } fn serialize_i64(&mut self, v: i64) -> Result<()> { - self.array - .push_scalar_value(v.try_into().ctx(self)?) - .ctx(self) + try_(|| self.array.push_scalar_value(v.try_into()?)).ctx(self) } } diff --git a/serde_arrow/src/internal/serialization/union_builder.rs b/serde_arrow/src/internal/serialization/union_builder.rs index 13a3c85e..c675d480 100644 --- a/serde_arrow/src/internal/serialization/union_builder.rs +++ b/serde_arrow/src/internal/serialization/union_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, DenseUnionArray, FieldMeta}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::Mut, }; @@ -92,9 +92,7 @@ impl SimpleSerializer for UnionBuilder { let mut ctx = BTreeMap::new(); self.annotate(&mut ctx); - self.serialize_variant(variant_index) - .ctx(&ctx)? - .serialize_unit() + try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx) } fn serialize_newtype_variant( @@ -107,8 +105,11 @@ impl SimpleSerializer for UnionBuilder { let mut ctx = BTreeMap::new(); self.annotate(&mut ctx); - let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; - value.serialize(Mut(variant_builder)) + try_(|| { + let variant_builder = self.serialize_variant(variant_index)?; + value.serialize(Mut(variant_builder)) + }) + .ctx(&ctx) } fn serialize_struct_variant_start<'this>( @@ -121,9 +122,12 @@ impl SimpleSerializer for UnionBuilder { let mut ctx = BTreeMap::new(); self.annotate(&mut ctx); - let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; - variant_builder.serialize_struct_start(variant, len)?; - Ok(variant_builder) + try_(|| { + let variant_builder = self.serialize_variant(variant_index)?; + variant_builder.serialize_struct_start(variant, len)?; + Ok(variant_builder) + }) + .ctx(&ctx) } fn serialize_tuple_variant_start<'this>( @@ -136,8 +140,11 @@ impl SimpleSerializer for UnionBuilder { let mut ctx = BTreeMap::new(); self.annotate(&mut ctx); - let variant_builder = self.serialize_variant(variant_index).ctx(&ctx)?; - variant_builder.serialize_tuple_struct_start(variant, len)?; - Ok(variant_builder) + try_(|| { + let variant_builder = self.serialize_variant(variant_index)?; + variant_builder.serialize_tuple_struct_start(variant, len)?; + Ok(variant_builder) + }) + .ctx(&ctx) } } diff --git a/serde_arrow/src/internal/serialization/utf8_builder.rs b/serde_arrow/src/internal/serialization/utf8_builder.rs index 3f839b24..baf9a93f 100644 --- a/serde_arrow/src/internal/serialization/utf8_builder.rs +++ b/serde_arrow/src/internal/serialization/utf8_builder.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::internal::{ arrow::{Array, BytesArray}, - error::{fail, set_default, Context, ContextSupport, Result}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, utils::{ array_ext::{new_bytes_array, ArrayExt, ScalarArrayExt}, NamedType, Offset, @@ -74,15 +74,15 @@ impl Context for Utf8Builder { impl SimpleSerializer for Utf8Builder { fn serialize_default(&mut self) -> Result<()> { - self.array.push_scalar_default().ctx(self) + try_(|| self.array.push_scalar_default()).ctx(self) } fn serialize_none(&mut self) -> Result<()> { - self.array.push_scalar_none().ctx(self) + try_(|| self.array.push_scalar_none()).ctx(self) } fn serialize_str(&mut self, v: &str) -> Result<()> { - self.array.push_scalar_value(v.as_bytes()).ctx(self) + try_(|| self.array.push_scalar_value(v.as_bytes())).ctx(self) } fn serialize_unit_variant( @@ -91,7 +91,7 @@ impl SimpleSerializer for Utf8Builder { _: u32, variant: &'static str, ) -> Result<()> { - self.array.push_scalar_value(variant.as_bytes()).ctx(self) + try_(|| self.array.push_scalar_value(variant.as_bytes())).ctx(self) } fn serialize_tuple_variant_start<'this>( From 8f8e717f31073f0284d93cd320fce92410518dbc Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Tue, 10 Sep 2024 22:07:52 +0200 Subject: [PATCH 158/178] Unify error message format --- .../deserialization/array_deserializer.rs | 18 +++++++++++------- .../deserialization/binary_deserializer.rs | 4 ++-- .../deserialization/bool_deserializer.rs | 4 ++-- .../deserialization/dictionary_deserializer.rs | 6 +++--- .../deserialization/enum_deserializer.rs | 4 ++-- .../deserialization/enums_as_string_impl.rs | 8 ++++---- .../fixed_size_binary_deserializer.rs | 4 ++-- .../fixed_size_list_deserializer.rs | 2 +- .../deserialization/list_deserializer.rs | 2 +- .../deserialization/map_deserializer.rs | 4 ++-- .../deserialization/string_deserializer.rs | 6 +++--- .../deserialization/struct_deserializer.rs | 6 +++--- .../src/internal/deserialization/utils.rs | 12 ++++++------ .../src/test/error_messages/deserializers.rs | 2 +- 14 files changed, 43 insertions(+), 39 deletions(-) diff --git a/serde_arrow/src/internal/deserialization/array_deserializer.rs b/serde_arrow/src/internal/deserialization/array_deserializer.rs index ce5b94e7..c423d105 100644 --- a/serde_arrow/src/internal/deserialization/array_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/array_deserializer.rs @@ -111,7 +111,9 @@ impl<'a> ArrayDeserializer<'a> { is_utc_timestamp(view.timezone.as_deref())?, ))) } - Some(strategy) => fail!("invalid strategy {strategy} for timestamp field"), + Some(strategy) => { + fail!("Invalid strategy: {strategy} is not supported for timestamp field") + } None => Ok(Self::Date64(Date64Deserializer::new( path, view.values, @@ -197,10 +199,10 @@ impl<'a> ArrayDeserializer<'a> { } V::Map(view) => { let ArrayView::Struct(entries_view) = *view.element else { - fail!("invalid entries field in map array"); + fail!("Invalid entries field in map array"); }; let Ok(entries_fields) = <[_; 2]>::try_from(entries_view.fields) else { - fail!("invalid entries field in map array") + fail!("Invalid entries field in map array") }; let [(keys_view, keys_meta), (values_view, values_meta)] = entries_fields; let keys_path = format!("{path}.{child}", child = ChildName(&keys_meta.name)); @@ -274,14 +276,14 @@ impl<'a> ArrayDeserializer<'a> { (V::UInt64(keys), V::LargeUtf8(values)) => Ok(D::DictionaryU64I64( DictionaryDeserializer::new(path, keys, values)?, )), - _ => fail!("unsupported dictionary array"), + _ => fail!("Unsupported dictionary array type"), }, ArrayView::DenseUnion(view) => { let mut fields = Vec::new(); for (idx, (type_id, field_view, field_meta)) in view.fields.into_iter().enumerate() { if usize::try_from(type_id) != Ok(idx) { - fail!("Only unions with consecutive type ids are currently supported in arrow2"); + fail!("Only unions with consecutive type ids are currently supported"); } let child_path = format!("{path}.{child}", child = ChildName(&field_meta.name)); let field_deserializer = ArrayDeserializer::new( @@ -301,7 +303,7 @@ impl<'a> ArrayDeserializer<'a> { fn is_utc_timestamp(timezone: Option<&str>) -> Result { match timezone { Some(tz) if tz.to_lowercase() == "utc" => Ok(true), - Some(tz) => fail!("unsupported timezone {}", tz), + Some(tz) => fail!("Unsupported timezone: {} is not supported", tz), None => Ok(false), } } @@ -310,7 +312,9 @@ fn is_utc_date64(strategy: Option<&Strategy>) -> Result { match strategy { None | Some(Strategy::UtcStrAsDate64) => Ok(true), Some(Strategy::NaiveStrAsDate64) => Ok(false), - Some(strategy) => fail!("invalid strategy for date64 deserializer: {strategy}"), + Some(strategy) => { + fail!("Invalid strategy: {strategy} is not supported for date64 deserializer") + } } } diff --git a/serde_arrow/src/internal/deserialization/binary_deserializer.rs b/serde_arrow/src/internal/deserialization/binary_deserializer.rs index 9d463ef2..29e77a7d 100644 --- a/serde_arrow/src/internal/deserialization/binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/binary_deserializer.rs @@ -25,7 +25,7 @@ impl<'a, O: Offset> BinaryDeserializer<'a, O> { pub fn peek_next(&self) -> Result { if self.next.0 + 1 >= self.view.offsets.len() { - fail!("Exhausted ListDeserializer") + fail!("Exhausted deserializer") } if let Some(validity) = &self.view.validity { bitset_is_set(validity, self.next.0) @@ -41,7 +41,7 @@ impl<'a, O: Offset> BinaryDeserializer<'a, O> { pub fn peek_next_slice_range(&self) -> Result<(usize, usize)> { let (item, _) = self.next; if item + 1 >= self.view.offsets.len() { - fail!("called next_slices on exhausted BinaryDeserializer"); + fail!("Exhausted deserializer"); } let end = self.view.offsets[item + 1].try_into_usize()?; let start = self.view.offsets[item].try_into_usize()?; diff --git a/serde_arrow/src/internal/deserialization/bool_deserializer.rs b/serde_arrow/src/internal/deserialization/bool_deserializer.rs index 010c53b4..b91d3b8e 100644 --- a/serde_arrow/src/internal/deserialization/bool_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/bool_deserializer.rs @@ -25,7 +25,7 @@ impl<'a> BoolDeserializer<'a> { fn next(&mut self) -> Result> { if self.next >= self.view.len { - fail!("Exhausted Deserializer"); + fail!("Exhausted deserializer"); } if let Some(validty) = &self.view.validity { if !bitset_is_set(validty, self.next)? { @@ -49,7 +49,7 @@ impl<'a> BoolDeserializer<'a> { fn peek_next(&self) -> Result { if self.next >= self.view.len { - fail!("Exhausted Deserializer"); + fail!("Exhausted deserializer"); } else if let Some(validity) = &self.view.validity { bitset_is_set(validity, self.next) } else { diff --git a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs index 6c4f6d8f..84d86338 100644 --- a/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/dictionary_deserializer.rs @@ -26,7 +26,7 @@ impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { ) -> Result { if values.validity.is_some() { // TODO: check whether all values are defined? - fail!("dictionaries with nullable values are not supported"); + fail!("Null for non-nullable type: dictionaries do not support nullable values"); } Ok(Self { path, @@ -39,12 +39,12 @@ impl<'a, K: Integer, V: Offset> DictionaryDeserializer<'a, K, V> { pub fn next_str(&mut self) -> Result<&str> { let k: usize = self.keys.next_required()?.into_u64()?.try_into()?; let Some(start) = self.offsets.get(k) else { - fail!("invalid index"); + fail!("Invalid index"); }; let start = start.try_into_usize()?; let Some(end) = self.offsets.get(k + 1) else { - fail!("invalid index"); + fail!("Invalid index"); }; let end = end.try_into_usize()?; diff --git a/serde_arrow/src/internal/deserialization/enum_deserializer.rs b/serde_arrow/src/internal/deserialization/enum_deserializer.rs index 502ff45d..6ad2fbfa 100644 --- a/serde_arrow/src/internal/deserialization/enum_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/enum_deserializer.rs @@ -58,7 +58,7 @@ impl<'a, 'de> EnumAccess<'de> for &'a mut EnumDeserializer<'de> { fn variant_seed>(self, seed: V) -> Result<(V::Value, Self::Variant)> { if self.next >= self.type_ids.len() { - fail!("Exhausted EnumDeserializer"); + fail!("Exhausted deserializer"); } let type_id = self.type_ids[self.next]; self.next += 1; @@ -79,7 +79,7 @@ struct VariantIdDeserializer<'a> { macro_rules! unimplemented { ($lifetime:lifetime, $name:ident $($tt:tt)*) => { fn $name>(self $($tt)*, _: V) -> Result { - fail!("{} is not implemented", stringify!($name)) + fail!("Unsupported: EnumDeserializer does not implement {}", stringify!($name)) } }; } diff --git a/serde_arrow/src/internal/deserialization/enums_as_string_impl.rs b/serde_arrow/src/internal/deserialization/enums_as_string_impl.rs index fd4cd6fd..fcfc91e5 100644 --- a/serde_arrow/src/internal/deserialization/enums_as_string_impl.rs +++ b/serde_arrow/src/internal/deserialization/enums_as_string_impl.rs @@ -17,7 +17,7 @@ impl<'a, 'de> serde::de::EnumAccess<'de> for EnumAccess<'a> { macro_rules! unimplemented { ($lifetime:lifetime, $name:ident $($tt:tt)*) => { fn $name>(self $($tt)*, _: V) -> Result { - fail!("{} is not implemented", stringify!($name)) + fail!("Unsupported: EnumDeserializer does not implement {}", stringify!($name)) } }; } @@ -78,15 +78,15 @@ impl<'de> serde::de::VariantAccess<'de> for UnitVariant { type Error = Error; fn newtype_variant_seed>(self, _: T) -> Result { - fail!("cannot deserialize enums with data from strings") + fail!("Unsupported: cannot deserialize enums with data from strings") } fn struct_variant>(self, _: &'static [&'static str], _: V) -> Result { - fail!("cannot deserialize enums with data from strings") + fail!("Unsupported: cannot deserialize enums with data from strings") } fn tuple_variant>(self, _: usize, _: V) -> Result { - fail!("cannot deserialize enums with data from strings") + fail!("Unsupported: cannot deserialize enums with data from strings") } fn unit_variant(self) -> Result<(), Self::Error> { diff --git a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs index dbeb9094..fe91c0d3 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_binary_deserializer.rs @@ -40,7 +40,7 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { pub fn peek_next(&self) -> Result { if self.next.0 >= self.shape.0 { - fail!("Exhausted ListDeserializer") + fail!("Exhausted deserializer") } if let Some(validity) = &self.view.validity { Ok(bitset_is_set(validity, self.next.0)?) @@ -56,7 +56,7 @@ impl<'a> FixedSizeBinaryDeserializer<'a> { pub fn next_slice(&mut self) -> Result<&'a [u8]> { let (item, _) = self.next; if item >= self.shape.0 { - fail!("called next_slices on exhausted BinaryDeserializer"); + fail!("Exhausted deserializer"); } self.next = (item + 1, 0); diff --git a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs index fc4bf01a..8be733fb 100644 --- a/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/fixed_size_list_deserializer.rs @@ -38,7 +38,7 @@ impl<'a> FixedSizeListDeserializer<'a> { pub fn peek_next(&self) -> Result { if self.next.0 >= self.shape.0 { - fail!("Exhausted ListDeserializer") + fail!("Exhausted deserializer") } if let Some(validity) = &self.validity { Ok(bitset_is_set(validity, self.next.0)?) diff --git a/serde_arrow/src/internal/deserialization/list_deserializer.rs b/serde_arrow/src/internal/deserialization/list_deserializer.rs index 8fac54f1..c93a81c1 100644 --- a/serde_arrow/src/internal/deserialization/list_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/list_deserializer.rs @@ -40,7 +40,7 @@ impl<'a, O: Offset> ListDeserializer<'a, O> { pub fn peek_next(&self) -> Result { if self.next.0 + 1 >= self.offsets.len() { - fail!("Exhausted ListDeserializer") + fail!("Exhausted deserializer") } if let Some(validity) = &self.validity { Ok(bitset_is_set(validity, self.next.0)?) diff --git a/serde_arrow/src/internal/deserialization/map_deserializer.rs b/serde_arrow/src/internal/deserialization/map_deserializer.rs index f8a1cf85..e08fca98 100644 --- a/serde_arrow/src/internal/deserialization/map_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/map_deserializer.rs @@ -43,7 +43,7 @@ impl<'a> MapDeserializer<'a> { pub fn peek_next(&self) -> Result { if self.next.0 + 1 >= self.offsets.len() { - fail!("Exhausted ListDeserializer") + fail!("Exhausted deserializer") } if let Some(validity) = &self.validity { Ok(bitset_is_set(validity, self.next.0)?) @@ -103,7 +103,7 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { ) -> Result, Self::Error> { let (item, entry) = self.next; if item + 1 >= self.offsets.len() { - fail!(in self, "Exhausted MapDeserializer"); + fail!(in self, "Exhausted deserializer"); } let start: usize = self.offsets[item].try_into()?; let end: usize = self.offsets[item + 1].try_into()?; diff --git a/serde_arrow/src/internal/deserialization/string_deserializer.rs b/serde_arrow/src/internal/deserialization/string_deserializer.rs index 134173a8..067e79a1 100644 --- a/serde_arrow/src/internal/deserialization/string_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/string_deserializer.rs @@ -25,7 +25,7 @@ impl<'a, O: Offset> StringDeserializer<'a, O> { pub fn next(&mut self) -> Result> { if self.next + 1 > self.view.offsets.len() { - fail!("Tried to deserialize a value from an exhausted StringDeserializer"); + fail!("Exhausted deserializer: tried to deserialize a value from an exhausted StringDeserializer"); } if let Some(validity) = &self.view.validity { @@ -45,14 +45,14 @@ impl<'a, O: Offset> StringDeserializer<'a, O> { pub fn next_required(&mut self) -> Result<&'a str> { let Some(next) = self.next()? else { - fail!("Tried to deserialize a value from StringDeserializer, but value is missing") + fail!("Exhausted deserializer: tried to deserialize a value from StringDeserializer, but value is missing") }; Ok(next) } pub fn peek_next(&self) -> Result { if self.next + 1 > self.view.offsets.len() { - fail!("Tried to deserialize a value from an exhausted StringDeserializer"); + fail!("Exhausted deserializer: tried to deserialize a value from an exhausted StringDeserializer"); } if let Some(validity) = &self.view.validity { diff --git a/serde_arrow/src/internal/deserialization/struct_deserializer.rs b/serde_arrow/src/internal/deserialization/struct_deserializer.rs index 98244a15..bb889e1f 100644 --- a/serde_arrow/src/internal/deserialization/struct_deserializer.rs +++ b/serde_arrow/src/internal/deserialization/struct_deserializer.rs @@ -39,7 +39,7 @@ impl<'a> StructDeserializer<'a> { pub fn peek_next(&self) -> Result { if self.next.0 >= self.len { - fail!("Exhausted StructDeserializer"); + fail!("Exhausted deserializer"); } if let Some(validity) = &self.validity { Ok(bitset_is_set(validity, self.next.0)?) @@ -138,7 +138,7 @@ impl<'de> MapAccess<'de> for StructDeserializer<'de> { fn next_key_seed>(&mut self, seed: K) -> Result> { let (item, field) = self.next; if item >= self.len { - fail!("Exhausted StructDeserializer"); + fail!("Exhausted deserializer"); } if field >= self.fields.len() { self.next = (item + 1, 0); @@ -166,7 +166,7 @@ impl<'de> SeqAccess<'de> for StructDeserializer<'de> { ) -> Result, Self::Error> { let (item, field) = self.next; if item >= self.len { - fail!("Exhausted StructDeserializer"); + fail!("Exhausted deserializer"); } if field >= self.fields.len() { self.next = (item + 1, 0); diff --git a/serde_arrow/src/internal/deserialization/utils.rs b/serde_arrow/src/internal/deserialization/utils.rs index eb55086c..3f85e58e 100644 --- a/serde_arrow/src/internal/deserialization/utils.rs +++ b/serde_arrow/src/internal/deserialization/utils.rs @@ -7,7 +7,7 @@ use crate::internal::{ pub fn bitset_is_set(set: &BitsWithOffset<'_>, idx: usize) -> Result { let flag = 1 << ((idx + set.offset) % 8); let Some(byte) = set.data.get((idx + set.offset) / 8) else { - fail!("invalid access in bitset"); + fail!("Invalid access in bitset"); }; Ok(byte & flag == flag) } @@ -29,7 +29,7 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { pub fn next(&mut self) -> Result> { if self.next > self.buffer.len() { - fail!("Tried to deserialize a value from an exhausted FloatDeserializer"); + fail!("Exhausted deserializer"); } if let Some(validity) = &self.validity { @@ -45,14 +45,14 @@ impl<'a, T: Copy> ArrayBufferIterator<'a, T> { pub fn next_required(&mut self) -> Result { let Some(next) = self.next()? else { - fail!("missing value"); + fail!("Exhausted deserializer"); }; Ok(next) } pub fn peek_next(&self) -> Result { if self.next > self.buffer.len() { - fail!("Tried to deserialize a value from an exhausted StringDeserializer"); + fail!("Exhausted deserializer"); } if let Some(validity) = &self.validity { @@ -86,14 +86,14 @@ pub fn check_supported_list_layout<'a, O: Offset>( }; if offsets.is_empty() { - fail!("list offsets must be non empty"); + fail!("Unsupported: list offsets must be non empty"); } for i in 0..offsets.len().saturating_sub(1) { let curr = offsets[i].try_into_usize()?; let next = offsets[i + 1].try_into_usize()?; if !bitset_is_set(&validity, i)? && (next - curr) != 0 { - fail!("lists with data in null values are currently not supported in deserialization"); + fail!("Unsupported: lists with data in null values are currently not supported in deserialization"); } } diff --git a/serde_arrow/src/test/error_messages/deserializers.rs b/serde_arrow/src/test/error_messages/deserializers.rs index 6fbd3058..fd9f8977 100644 --- a/serde_arrow/src/test/error_messages/deserializers.rs +++ b/serde_arrow/src/test/error_messages/deserializers.rs @@ -58,7 +58,7 @@ fn example_exhausted() { } let res = Vec::::deserialize(deserializer); - assert_error_contains(&res, "Exhausted Deserializer"); + assert_error_contains(&res, "Exhausted deserializer"); assert_error_contains(&res, "field: \"$.item.nested\""); assert_error_contains(&res, "data_type: \"Boolean\""); } From 9e513a75ea1a5e74662ca59deca9d152c9a78189 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 21:48:10 +0200 Subject: [PATCH 159/178] Add arrow=53 --- Cargo.lock | 205 ++++++++++++-------- example/Cargo.toml | 6 +- integration_tests/Cargo.toml | 6 +- serde_arrow/Cargo.toml | 15 +- serde_arrow/benches/groups/impls.rs | 4 +- serde_arrow/benches/groups/json_to_arrow.rs | 6 +- serde_arrow/build.rs | 2 + serde_arrow/src/arrow_impl/array.rs | 16 +- serde_arrow/src/lib.rs | 2 + 9 files changed, 162 insertions(+), 100 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb6f06a6..70601619 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,35 +76,35 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" +checksum = "45aef0d9cf9a039bf6cd1acc451b137aca819977b0928dece52bd92811b640ba" dependencies = [ "arrow-arith", - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", "arrow-cast", "arrow-csv", - "arrow-data 52.0.0", + "arrow-data 53.0.0", "arrow-ipc", "arrow-json", "arrow-ord", "arrow-row", - "arrow-schema 52.0.0", + "arrow-schema 53.0.0", "arrow-select", "arrow-string", ] [[package]] name = "arrow-arith" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" +checksum = "03675e42d1560790f3524800e41403b40d0da1c793fe9528929fde06d8c7649a" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "chrono", "half 2.2.1", "num", @@ -202,7 +202,7 @@ dependencies = [ "arrow-schema 42.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -218,7 +218,7 @@ dependencies = [ "arrow-schema 43.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -234,7 +234,7 @@ dependencies = [ "arrow-schema 44.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", "num-complex", ] @@ -251,7 +251,7 @@ dependencies = [ "arrow-schema 45.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -267,7 +267,7 @@ dependencies = [ "arrow-schema 46.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -283,7 +283,7 @@ dependencies = [ "arrow-schema 47.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -299,7 +299,7 @@ dependencies = [ "arrow-schema 48.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -315,7 +315,7 @@ dependencies = [ "arrow-schema 49.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -331,7 +331,7 @@ dependencies = [ "arrow-schema 50.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -347,7 +347,7 @@ dependencies = [ "arrow-schema 51.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", "num", ] @@ -363,7 +363,23 @@ dependencies = [ "arrow-schema 52.0.0", "chrono", "half 2.2.1", - "hashbrown 0.14.0", + "hashbrown 0.14.5", + "num", +] + +[[package]] +name = "arrow-array" +version = "53.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd2bf348cf9f02a5975c5962c7fa6dee107a2009a7b41ac5fb1a027e12dc033f" +dependencies = [ + "ahash 0.8.3", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", + "chrono", + "half 2.2.1", + "hashbrown 0.14.5", "num", ] @@ -534,16 +550,27 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-buffer" +version = "53.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3092e37715f168976012ce52273c3989b5793b0db5f06cbaa246be25e5f0924d" +dependencies = [ + "bytes", + "half 2.2.1", + "num", +] + [[package]] name = "arrow-cast" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cf8385a9d5b5fcde771661dd07652b79b9139fea66193eda6a88664400ccab" +checksum = "7ce1018bb710d502f9db06af026ed3561552e493e989a79d0d0f5d9cf267a785" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "arrow-select", "atoi", "base64", @@ -556,15 +583,15 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea5068bef430a86690059665e40034625ec323ffa4dd21972048eebb0127adc" +checksum = "fd178575f45624d045e4ebee714e246a05d9652e41363ee3f57ec18cca97f740" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", "arrow-cast", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "chrono", "csv", "csv-core", @@ -765,31 +792,43 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-data" +version = "53.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4ac0c4ee79150afe067dc4857154b3ee9c1cd52b5f40d59a77306d0ed18d65" +dependencies = [ + "arrow-buffer 53.0.0", + "arrow-schema 53.0.0", + "half 2.2.1", + "num", +] + [[package]] name = "arrow-ipc" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc68f6523970aa6f7ce1dc9a33a7d9284cfb9af77d4ad3e617dbe5d79cc6ec8" +checksum = "bb307482348a1267f91b0912e962cd53440e5de0f7fb24c5f7b10da70b38c94a" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", "arrow-cast", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "flatbuffers", ] [[package]] name = "arrow-json" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2041380f94bd6437ab648e6c2085a045e45a0c44f91a1b9a4fe3fed3d379bfb1" +checksum = "d24805ba326758effdd6f2cbdd482fcfab749544f21b134701add25b33f474e6" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", "arrow-cast", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "chrono", "half 2.2.1", "indexmap", @@ -801,14 +840,14 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" +checksum = "644046c479d80ae8ed02a7f1e1399072ea344ca6a7b0e293ab2d5d9ed924aa3b" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "arrow-select", "half 2.2.1", "num", @@ -816,17 +855,16 @@ dependencies = [ [[package]] name = "arrow-row" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" +checksum = "a29791f8eb13b340ce35525b723f5f0df17ecb955599e11f65c2a94ab34e2efb" dependencies = [ "ahash 0.8.3", - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "half 2.2.1", - "hashbrown 0.14.0", ] [[package]] @@ -973,30 +1011,39 @@ dependencies = [ "serde", ] +[[package]] +name = "arrow-schema" +version = "53.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c85320a3a2facf2b2822b57aa9d6d9d55edb8aee0b6b5d3b8df158e503d10858" +dependencies = [ + "serde", +] + [[package]] name = "arrow-select" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de36abaef8767b4220d7b4a8c2fe5ffc78b47db81b03d77e2136091c3ba39102" +checksum = "9cc7e6b582e23855fd1625ce46e51647aa440c20ea2e71b1d748e0839dd73cba" dependencies = [ "ahash 0.8.3", - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "num", ] [[package]] name = "arrow-string" -version = "52.0.0" +version = "53.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" +checksum = "0775b6567c66e56ded19b87a954b6b1beffbdd784ef95a3a2b03f59570c1d230" dependencies = [ - "arrow-array 52.0.0", - "arrow-buffer 52.0.0", - "arrow-data 52.0.0", - "arrow-schema 52.0.0", + "arrow-array 53.0.0", + "arrow-buffer 53.0.0", + "arrow-data 53.0.0", + "arrow-schema 53.0.0", "arrow-select", "memchr", "num", @@ -1605,9 +1652,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" [[package]] name = "hermit-abi" @@ -1646,7 +1693,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.5", ] [[package]] @@ -1797,9 +1844,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "num" @@ -2248,6 +2295,7 @@ dependencies = [ "arrow-array 50.0.0", "arrow-array 51.0.0", "arrow-array 52.0.0", + "arrow-array 53.0.0", "arrow-buffer 37.0.0", "arrow-buffer 38.0.0", "arrow-buffer 39.0.0", @@ -2264,6 +2312,7 @@ dependencies = [ "arrow-buffer 50.0.0", "arrow-buffer 51.0.0", "arrow-buffer 52.0.0", + "arrow-buffer 53.0.0", "arrow-data 37.0.0", "arrow-data 38.0.0", "arrow-data 39.0.0", @@ -2280,6 +2329,7 @@ dependencies = [ "arrow-data 50.0.0", "arrow-data 51.0.0", "arrow-data 52.0.0", + "arrow-data 53.0.0", "arrow-json", "arrow-schema 37.0.0", "arrow-schema 38.0.0", @@ -2297,6 +2347,7 @@ dependencies = [ "arrow-schema 50.0.0", "arrow-schema 51.0.0", "arrow-schema 52.0.0", + "arrow-schema 53.0.0", "arrow2 0.16.0", "arrow2 0.17.0", "arrow2_convert", diff --git a/example/Cargo.toml b/example/Cargo.toml index 95731707..1b28ad77 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -7,11 +7,11 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -# arrow-version:replace: arrow = {{ version = "52.0", features = [{version}] }} -arrow = {version = "52.0", features = ["ipc"] } +# arrow-version:replace: arrow = {{ version = "{version}", features = ["ipc"] }} +arrow = {version = "53.0", features = ["ipc"] } chrono = { version = "0.4", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } # arrow-version:replace: serde_arrow = {{ path = "../serde_arrow", features = ["arrow-{version}"] }} -serde_arrow = { path = "../serde_arrow", features = ["arrow-52"] } +serde_arrow = { path = "../serde_arrow", features = ["arrow-53"] } diff --git a/integration_tests/Cargo.toml b/integration_tests/Cargo.toml index fd398ca3..a45707e9 100644 --- a/integration_tests/Cargo.toml +++ b/integration_tests/Cargo.toml @@ -4,12 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] -# arrow-version:replace: arrow = {{ version = "52.0", features = [{version}] }} -arrow = {version = "52.0", features = ["ipc"] } +# arrow-version:replace: arrow = {{ version = "{version}", features = ["ipc"] }} +arrow = {version = "53.0", features = ["ipc"] } chrono = { version = "0.4", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1" # arrow-version:replace: serde_arrow = {{ path = "../serde_arrow", features = ["arrow-{version}"] }} -serde_arrow = { path = "../serde_arrow", features = ["arrow-52"] } +serde_arrow = { path = "../serde_arrow", features = ["arrow-53"] } diff --git a/serde_arrow/Cargo.toml b/serde_arrow/Cargo.toml index a185071f..8ef43e3b 100644 --- a/serde_arrow/Cargo.toml +++ b/serde_arrow/Cargo.toml @@ -14,17 +14,18 @@ bench = false [[bench]] name = "serde_arrow_bench" # arrow-version:replace: required-features = ["arrow2-0-17", "arrow-{version}"] -required-features = ["arrow2-0-17", "arrow-52"] +required-features = ["arrow2-0-17", "arrow-53"] harness = false [package.metadata.docs.rs] # arrow-version:replace: features = ["arrow2-0-17", "arrow-{version}"] -features = ["arrow2-0-17", "arrow-52"] +features = ["arrow2-0-17", "arrow-53"] [features] default = [] # arrow-version:insert: arrow-{version} = ["dep:arrow-array-{version}", "dep:arrow-schema-{version}", "dep:arrow-data-{version}", "dep:arrow-buffer-{version}"] +arrow-53 = ["dep:arrow-array-53", "dep:arrow-schema-53", "dep:arrow-data-53", "dep:arrow-buffer-53"] arrow-52 = ["dep:arrow-array-52", "dep:arrow-schema-52", "dep:arrow-data-52", "dep:arrow-buffer-52"] arrow-51 = ["dep:arrow-array-51", "dep:arrow-schema-51", "dep:arrow-data-51", "dep:arrow-buffer-51"] arrow-50 = ["dep:arrow-array-50", "dep:arrow-schema-50", "dep:arrow-data-50", "dep:arrow-buffer-50"] @@ -50,6 +51,7 @@ half = { version = "2", features = ["bytemuck"], default-features = false } serde = { version = "1.0", features = ["derive", "std"], default-features = false } # arrow-version:insert: arrow-array-{version} = {{ package = "arrow-array", version = "{version}", optional = true, default-features = false }} +arrow-array-53 = { package = "arrow-array", version = "53", optional = true, default-features = false } arrow-array-52 = { package = "arrow-array", version = "52", optional = true, default-features = false } arrow-array-51 = { package = "arrow-array", version = "51", optional = true, default-features = false } arrow-array-50 = { package = "arrow-array", version = "50", optional = true, default-features = false } @@ -68,6 +70,7 @@ arrow-array-38 = { package = "arrow-array", version = "38", optional = true, def arrow-array-37 = { package = "arrow-array", version = "37", optional = true, default-features = false } # arrow-version:insert: arrow-buffer-{version} = {{ package = "arrow-buffer", version = "{version}", optional = true, default-features = false }} +arrow-buffer-53 = { package = "arrow-buffer", version = "53", optional = true, default-features = false } arrow-buffer-52 = { package = "arrow-buffer", version = "52", optional = true, default-features = false } arrow-buffer-51 = { package = "arrow-buffer", version = "51", optional = true, default-features = false } arrow-buffer-50 = { package = "arrow-buffer", version = "50", optional = true, default-features = false } @@ -86,6 +89,7 @@ arrow-buffer-38 = { package = "arrow-buffer", version = "38", optional = true, d arrow-buffer-37 = { package = "arrow-buffer", version = "37", optional = true, default-features = false } # arrow-version:insert: arrow-data-{version} = {{ package = "arrow-data", version="{version}", optional = true, default-features = false }} +arrow-data-53 = { package = "arrow-data", version="53", optional = true, default-features = false } arrow-data-52 = { package = "arrow-data", version="52", optional = true, default-features = false } arrow-data-51 = { package = "arrow-data", version="51", optional = true, default-features = false } arrow-data-50 = { package = "arrow-data", version="50", optional = true, default-features = false } @@ -104,6 +108,7 @@ arrow-data-38 = { package = "arrow-data", version="38", optional = true, default arrow-data-37 = { package = "arrow-data", version="37", optional = true, default-features = false } # arrow-version:insert: arrow-schema-{version} = {{ package = "arrow-schema", version = "{version}", optional = true, default-features = false }} +arrow-schema-53 = { package = "arrow-schema", version = "53", optional = true, default-features = false } arrow-schema-52 = { package = "arrow-schema", version = "52", optional = true, default-features = false } arrow-schema-51 = { package = "arrow-schema", version = "51", optional = true, default-features = false } arrow-schema-50 = { package = "arrow-schema", version = "50", optional = true, default-features = false } @@ -136,13 +141,14 @@ uuid = { version = "1.10.0", features = ["serde", "v4"] } # for benchmarks # arrow-version:replace: arrow-json-{version} = {{ package = "arrow-json", version = "{version}" }} -arrow-json-52 = { package = "arrow-json", version = "52" } +arrow-json-53 = { package = "arrow-json", version = "53" } criterion = "0.5" arrow2_convert = "0.5.0" serde-transcode = "1" simd-json = "0.13.8" # arrow-version:insert: arrow-schema-{version} = {{ package = "arrow-schema", version = "{version}", default-features = false, features = ["serde"] }} +arrow-schema-53 = { package = "arrow-schema", version = "53", default-features = false, features = ["serde"] } arrow-schema-52 = { package = "arrow-schema", version = "52", default-features = false, features = ["serde"] } arrow-schema-51 = { package = "arrow-schema", version = "51", default-features = false, features = ["serde"] } arrow-schema-50 = { package = "arrow-schema", version = "50", default-features = false, features = ["serde"] } @@ -178,6 +184,7 @@ check-cfg = [ 'cfg(has_arrow)', 'cfg(has_arrow_fixed_binary_support)', # arrow-version:insert: 'cfg(has_arrow_{version})', + 'cfg(has_arrow_53)', 'cfg(has_arrow_52)', 'cfg(has_arrow_51)', 'cfg(has_arrow_50)', @@ -194,4 +201,4 @@ check-cfg = [ 'cfg(has_arrow_39)', 'cfg(has_arrow_38)', 'cfg(has_arrow_37)', -] +] \ No newline at end of file diff --git a/serde_arrow/benches/groups/impls.rs b/serde_arrow/benches/groups/impls.rs index c8210479..6faac10b 100644 --- a/serde_arrow/benches/groups/impls.rs +++ b/serde_arrow/benches/groups/impls.rs @@ -125,9 +125,9 @@ pub mod arrow { use std::sync::Arc; // arrow-version:replace: use arrow_json_{version}::ReaderBuilder; - use arrow_json_52::ReaderBuilder; + use arrow_json_53::ReaderBuilder; // arrow-version:replace: use arrow_schema_{version}::Schema; - use arrow_schema_52::Schema; + use arrow_schema_53::Schema; use serde::Serialize; diff --git a/serde_arrow/benches/groups/json_to_arrow.rs b/serde_arrow/benches/groups/json_to_arrow.rs index f0fef921..063a92d2 100644 --- a/serde_arrow/benches/groups/json_to_arrow.rs +++ b/serde_arrow/benches/groups/json_to_arrow.rs @@ -6,13 +6,13 @@ use { }; // arrow-version:replace: use arrow_json_{version}::ReaderBuilder; -use arrow_json_52::ReaderBuilder; +use arrow_json_53::ReaderBuilder; // arrow-version:replace: use arrow_schema_{version}::{{FieldRef, Schema as ArrowSchema}}; -use arrow_schema_52::{FieldRef, Schema as ArrowSchema}; +use arrow_schema_53::{FieldRef, Schema as ArrowSchema}; // arrow-version:replace: use arrow_array_{version}::RecordBatch; -use arrow_array_52::RecordBatch; +use arrow_array_53::RecordBatch; use serde_json::Value; fn benchmark_json_to_arrow(c: &mut criterion::Criterion) { diff --git a/serde_arrow/build.rs b/serde_arrow/build.rs index 7c5b9639..2f2b7f7c 100644 --- a/serde_arrow/build.rs +++ b/serde_arrow/build.rs @@ -15,6 +15,8 @@ fn main() { let max_arrow_version: Option = [ // arrow-version:insert: #[cfg(feature = "arrow-{version}")]{\n}{version}, + #[cfg(feature = "arrow-53")] + 53, #[cfg(feature = "arrow-52")] 52, #[cfg(feature = "arrow-51")] diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index 9c03b3dd..a256ca3e 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -51,7 +51,7 @@ impl TryFrom for ArrayData { T::Boolean, // NOTE: use the explicit len arr.len, - arr.validity.map(Buffer::from), + arr.validity.map(Buffer::from_vec), 0, vec![ScalarBuffer::from(arr.values).into_inner()], vec![], @@ -114,7 +114,7 @@ impl TryFrom for ArrayData { Ok(ArrayData::builder(data_type) .len(arr.len) - .null_bit_buffer(arr.validity.map(Buffer::from)) + .null_bit_buffer(arr.validity.map(Buffer::from_vec)) .child_data(data) .build()?) } @@ -153,7 +153,7 @@ impl TryFrom for ArrayData { Ok(ArrayData::try_new( T::FixedSizeList(Arc::new(field), arr.n), child.len() / usize::try_from(arr.n)?, - arr.validity.map(Buffer::from), + arr.validity.map(Buffer::from_vec), 0, vec![], vec![child], @@ -170,7 +170,7 @@ impl TryFrom for ArrayData { Ok(ArrayData::try_new( T::FixedSizeBinary(arr.n), arr.data.len() / usize::try_from(arr.n)?, - arr.validity.map(Buffer::from), + arr.validity.map(Buffer::from_vec), 0, vec![ScalarBuffer::from(arr.data).into_inner()], vec![], @@ -196,7 +196,7 @@ impl TryFrom for ArrayData { Ok(ArrayData::try_new( T::Map(Arc::new(field), false), arr.offsets.len().saturating_sub(1), - arr.validity.map(Buffer::from), + arr.validity.map(Buffer::from_vec), 0, vec![ScalarBuffer::from(arr.offsets).into_inner()], vec![child], @@ -543,7 +543,7 @@ fn primitive_into_data( Ok(ArrayData::try_new( data_type, values.len(), - validity.map(Buffer::from), + validity.map(Buffer::from_vec), 0, vec![ScalarBuffer::from(values).into_inner()], vec![], @@ -559,7 +559,7 @@ fn bytes_into_data( Ok(ArrayData::try_new( data_type, offsets.len().saturating_sub(1), - validity.map(Buffer::from), + validity.map(Buffer::from_vec), 0, vec![ ScalarBuffer::from(offsets).into_inner(), @@ -579,7 +579,7 @@ fn list_into_data( Ok(ArrayData::try_new( data_type, len, - validity.map(Buffer::from), + validity.map(Buffer::from_vec), 0, vec![ScalarBuffer::from(offsets).into_inner()], vec![child_data], diff --git a/serde_arrow/src/lib.rs b/serde_arrow/src/lib.rs index 1ce4d392..16a3cbcb 100644 --- a/serde_arrow/src/lib.rs +++ b/serde_arrow/src/lib.rs @@ -145,6 +145,7 @@ //! | Arrow Feature | Arrow Version | //! |---------------|---------------| // arrow-version:insert: //! | `arrow-{version}` | `arrow={version}` | +//! | `arrow-53` | `arrow=53` | //! | `arrow-52` | `arrow=52` | //! | `arrow-51` | `arrow=51` | //! | `arrow-50` | `arrow=50` | @@ -268,6 +269,7 @@ pub mod _impl { } // arrow-version:insert: #[cfg(has_arrow_{version})] build_arrow_crate!(arrow_array_{version}, arrow_buffer_{version}, arrow_data_{version}, arrow_schema_{version}); +#[cfg(has_arrow_53)] build_arrow_crate!(arrow_array_53, arrow_buffer_53, arrow_data_53, arrow_schema_53); #[cfg(has_arrow_52)] build_arrow_crate!(arrow_array_52, arrow_buffer_52, arrow_data_52, arrow_schema_52); #[cfg(has_arrow_51)] build_arrow_crate!(arrow_array_51, arrow_buffer_51, arrow_data_51, arrow_schema_51); #[cfg(has_arrow_50)] build_arrow_crate!(arrow_array_50, arrow_buffer_50, arrow_data_50, arrow_schema_50); From b9ec6bfc8bee9c657a951d4cafeafe44e44e6083 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 21:48:19 +0200 Subject: [PATCH 160/178] Fix x.py add-arrow-version --- x.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/x.py b/x.py index df6d900e..144ddc5e 100644 --- a/x.py +++ b/x.py @@ -7,6 +7,7 @@ all_arrow_features = [ # arrow-version:insert: "arrow-{version}", + "arrow-53", "arrow-52", "arrow-51", "arrow-50", @@ -538,8 +539,8 @@ def add_arrow_version(version): for p in [ self_path / "x.py", - *self_path.glob("serde_arrow/**/*.rs"), - *self_path.glob("serde_arrow/**/*.toml"), + *self_path.glob("*/**/*.rs"), + *self_path.glob("*/**/*.toml"), ]: content = p.read_text() if "arrow-version" not in content: From 5688bdb0f4a7fcc1aeba0a5aeb496cefca5c1e6d Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 21:48:24 +0200 Subject: [PATCH 161/178] Fix clippy --- serde_arrow/src/arrow_impl/schema.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 338e51dc..064e291d 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -141,13 +141,13 @@ impl TryFrom<&ArrowDataType> for DataType { AT::Date32 => Ok(T::Date32), AT::Date64 => Ok(T::Date64), AT::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), - AT::Time32(unit) => Ok(T::Time32(unit.clone().into())), - AT::Time64(unit) => Ok(T::Time64(unit.clone().into())), + AT::Time32(unit) => Ok(T::Time32((*unit).into())), + AT::Time64(unit) => Ok(T::Time64((*unit).into())), AT::Timestamp(unit, tz) => Ok(T::Timestamp( - unit.clone().into(), + (*unit).into(), tz.as_ref().map(|s| s.to_string()), )), - AT::Duration(unit) => Ok(T::Duration(unit.clone().into())), + AT::Duration(unit) => Ok(T::Duration((*unit).into())), AT::Binary => Ok(T::Binary), AT::LargeBinary => Ok(T::LargeBinary), AT::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), From e5ee1e16fce6b592d2138c91cd80db7a8cca3c29 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 21:58:40 +0200 Subject: [PATCH 162/178] Revert Copy change to support previous Arrow versions --- serde_arrow/src/arrow_impl/schema.rs | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 064e291d..2e3da229 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -141,13 +141,27 @@ impl TryFrom<&ArrowDataType> for DataType { AT::Date32 => Ok(T::Date32), AT::Date64 => Ok(T::Date64), AT::Decimal128(precision, scale) => Ok(T::Decimal128(*precision, *scale)), - AT::Time32(unit) => Ok(T::Time32((*unit).into())), - AT::Time64(unit) => Ok(T::Time64((*unit).into())), + AT::Time32(unit) => Ok(T::Time32( + // only some arrow version implement Copy for unit + #[allow(clippy::clone_on_copy)] + unit.clone().into(), + )), + AT::Time64(unit) => Ok(T::Time64( + // only some arrow version implement Copy for unit + #[allow(clippy::clone_on_copy)] + unit.clone().into(), + )), AT::Timestamp(unit, tz) => Ok(T::Timestamp( - (*unit).into(), + // only some arrow version implement Copy for unit + #[allow(clippy::clone_on_copy)] + unit.clone().into(), tz.as_ref().map(|s| s.to_string()), )), - AT::Duration(unit) => Ok(T::Duration((*unit).into())), + AT::Duration(unit) => Ok(T::Duration( + // only some arrow version implement Copy for unit + #[allow(clippy::clone_on_copy)] + unit.clone().into(), + )), AT::Binary => Ok(T::Binary), AT::LargeBinary => Ok(T::LargeBinary), AT::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)), From 2f57a6fd8ec8074be301de10be77f332f62c4b31 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 22:04:03 +0200 Subject: [PATCH 163/178] Update changelog --- Changes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/Changes.md b/Changes.md index 03c3664c..aca50d64 100644 --- a/Changes.md +++ b/Changes.md @@ -6,6 +6,7 @@ Refactor the underlying implementation to prepare for further development New features +- Add `arrow=53` support - Add `Binary` / `LargeBinary` support for `arrow2` - Add support to serialize / deserialize `bool` from integer arrays - Add a helper to construct `Bool8` arrays From 095b5cd77cb4ce055c62ed0b868bac6ee767bf6c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 22:25:13 +0200 Subject: [PATCH 164/178] Address clippy --- serde_arrow/src/arrow2_impl/schema.rs | 8 +++---- serde_arrow/src/arrow_impl/schema.rs | 14 ++++++------ serde_arrow/src/internal/schema/mod.rs | 4 ++-- serde_arrow/src/internal/schema/tracer.rs | 26 ----------------------- 4 files changed, 12 insertions(+), 40 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 08aa8930..10805f3a 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -6,7 +6,7 @@ use crate::{ internal::{ arrow::{DataType, Field, TimeUnit, UnionMode}, error::{fail, Error, Result}, - schema::{validate_field, DataTypeDisplay, SchemaLike, Sealed, SerdeArrowSchema}, + schema::{validate_field, DataTypeDisplay, SchemaLike, Sealed, SerdeArrowSchema, TracingOptions}, }, }; @@ -45,15 +45,13 @@ impl SchemaLike for Vec { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de> + ?Sized>( - options: crate::schema::TracingOptions, - ) -> Result { + fn from_type<'de, T: serde::Deserialize<'de>>(options: TracingOptions) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } fn from_samples( samples: &T, - options: crate::schema::TracingOptions, + options: TracingOptions, ) -> Result { SerdeArrowSchema::from_samples(samples, options)?.try_into() } diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 338e51dc..7f48ac9f 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -8,7 +8,7 @@ use crate::{ internal::{ arrow::{DataType, Field, TimeUnit, UnionMode}, error::{fail, Error, Result}, - schema::{validate_field, SchemaLike, Sealed, SerdeArrowSchema}, + schema::{validate_field, SchemaLike, Sealed, SerdeArrowSchema, TracingOptions}, }, }; @@ -80,15 +80,15 @@ impl SchemaLike for Vec { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de> + ?Sized>( - options: crate::schema::TracingOptions, + fn from_type<'de, T: serde::Deserialize<'de>>( + options: TracingOptions, ) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } fn from_samples( samples: &T, - options: crate::schema::TracingOptions, + options: TracingOptions, ) -> Result { SerdeArrowSchema::from_samples(samples, options)?.try_into() } @@ -103,15 +103,15 @@ impl SchemaLike for Vec { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de> + ?Sized>( - options: crate::schema::TracingOptions, + fn from_type<'de, T: serde::Deserialize<'de>>( + options: TracingOptions, ) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } fn from_samples( samples: &T, - options: crate::schema::TracingOptions, + options: TracingOptions, ) -> Result { SerdeArrowSchema::from_samples(samples, options)?.try_into() } diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 7537b268..09157a88 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -195,7 +195,7 @@ pub trait SchemaLike: Sized + Sealed { /// # #[cfg(not(has_arrow))] /// # fn main() { } /// ``` - fn from_type<'de, T: Deserialize<'de> + ?Sized>(options: TracingOptions) -> Result; + fn from_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> Result; /// Determine the schema from samples. See [`TracingOptions`] for customization options. /// @@ -290,7 +290,7 @@ impl SchemaLike for SerdeArrowSchema { value::transmute(value) } - fn from_type<'de, T: Deserialize<'de> + ?Sized>(options: TracingOptions) -> Result { + fn from_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> Result { Tracer::from_type::(options)?.to_schema() } diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 37189b91..ed9f620a 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -138,10 +138,6 @@ impl Tracer { } impl Tracer { - pub fn is_unknown(&self) -> bool { - matches!(self, Tracer::Unknown(_)) - } - pub fn is_complete(&self) -> bool { dispatch_tracer!(self, tracer => tracer.is_complete()) } @@ -995,28 +991,6 @@ pub struct PrimitiveTracer { } impl PrimitiveTracer { - pub fn new( - name: String, - path: String, - options: Arc, - item_type: DataType, - nullable: bool, - ) -> Self { - Self { - name, - path, - options, - item_type, - nullable, - strategy: None, - } - } - - pub fn with_strategy(mut self, strategy: Option) -> Self { - self.strategy = strategy; - self - } - pub fn finish(&mut self) -> Result<()> { Ok(()) } From 9976af604a5653660d641a47c27c501b4cc12b2e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Wed, 11 Sep 2024 22:28:04 +0200 Subject: [PATCH 165/178] Reformat code --- serde_arrow/src/arrow2_impl/schema.rs | 4 +++- serde_arrow/src/arrow_impl/schema.rs | 8 ++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index 10805f3a..a2acc88e 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -6,7 +6,9 @@ use crate::{ internal::{ arrow::{DataType, Field, TimeUnit, UnionMode}, error::{fail, Error, Result}, - schema::{validate_field, DataTypeDisplay, SchemaLike, Sealed, SerdeArrowSchema, TracingOptions}, + schema::{ + validate_field, DataTypeDisplay, SchemaLike, Sealed, SerdeArrowSchema, TracingOptions, + }, }, }; diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 7f48ac9f..b60e989d 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -80,9 +80,7 @@ impl SchemaLike for Vec { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de>>( - options: TracingOptions, - ) -> Result { + fn from_type<'de, T: serde::Deserialize<'de>>(options: TracingOptions) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } @@ -103,9 +101,7 @@ impl SchemaLike for Vec { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de>>( - options: TracingOptions, - ) -> Result { + fn from_type<'de, T: serde::Deserialize<'de>>(options: TracingOptions) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } From d01fb59c004d68efe819daa2755630983f10816c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 15:48:41 +0200 Subject: [PATCH 166/178] Implement error context for SchemaLike::from_type --- .../src/internal/schema/from_type/mod.rs | 293 +++++++++++------- serde_arrow/src/internal/schema/tracer.rs | 39 ++- serde_arrow/src/test/error_messages/mod.rs | 5 +- .../test/error_messages/trace_from_type.rs | 18 ++ 4 files changed, 246 insertions(+), 109 deletions(-) create mode 100644 serde_arrow/src/test/error_messages/trace_from_type.rs diff --git a/serde_arrow/src/internal/schema/from_type/mod.rs b/serde_arrow/src/internal/schema/from_type/mod.rs index d127ca08..24b51d46 100644 --- a/serde_arrow/src/internal/schema/from_type/mod.rs +++ b/serde_arrow/src/internal/schema/from_type/mod.rs @@ -2,7 +2,7 @@ #[cfg(test)] mod test_error_messages; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use serde::{ de::{DeserializeSeed, Visitor}, @@ -11,7 +11,7 @@ use serde::{ use crate::internal::{ arrow::DataType, - error::{fail, Error, Result}, + error::{fail, try_, Context, ContextSupport, Error, Result}, schema::{TracingMode, TracingOptions}, }; @@ -46,11 +46,19 @@ impl Tracer { struct TraceAny<'a>(&'a mut Tracer); +impl<'a> Context for TraceAny<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) + } +} + impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { type Error = Error; fn deserialize_any>(self, _visitor: V) -> Result { - fail!(concat!( + fail!( + in self, + concat!( "Non self describing types cannot be traced with `from_type`. ", "Consider using `from_samples`. ", "One example is `serde_json::Value`. ", @@ -59,93 +67,147 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { } fn deserialize_bool>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Boolean)?; - visitor.visit_bool(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Boolean)?; + visitor.visit_bool(Default::default()) + }) + .ctx(&self) } fn deserialize_i8>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Int8)?; - visitor.visit_i8(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Int8)?; + visitor.visit_i8(Default::default()) + }) + .ctx(&self) } fn deserialize_i16>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Int16)?; - visitor.visit_i16(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Int16)?; + visitor.visit_i16(Default::default()) + }) + .ctx(&self) } fn deserialize_i32>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Int32)?; - visitor.visit_i32(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Int32)?; + visitor.visit_i32(Default::default()) + }) + .ctx(&self) } fn deserialize_i64>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Int64)?; - visitor.visit_i64(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Int64)?; + visitor.visit_i64(Default::default()) + }) + .ctx(&self) } fn deserialize_u8>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::UInt8)?; - visitor.visit_u8(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::UInt8)?; + visitor.visit_u8(Default::default()) + }) + .ctx(&self) } fn deserialize_u16>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::UInt16)?; - visitor.visit_u16(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::UInt16)?; + visitor.visit_u16(Default::default()) + }) + .ctx(&self) } fn deserialize_u32>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::UInt32)?; - visitor.visit_u32(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::UInt32)?; + visitor.visit_u32(Default::default()) + }) + .ctx(&self) } fn deserialize_u64>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::UInt64)?; - visitor.visit_u64(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::UInt64)?; + visitor.visit_u64(Default::default()) + }) + .ctx(&self) } fn deserialize_f32>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Float32)?; - visitor.visit_f32(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Float32)?; + visitor.visit_f32(Default::default()) + }) + .ctx(&self) } fn deserialize_f64>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Float64)?; - visitor.visit_f64(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::Float64)?; + visitor.visit_f64(Default::default()) + }) + .ctx(&self) } fn deserialize_char>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::UInt32)?; - visitor.visit_char(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::UInt32)?; + visitor.visit_char(Default::default()) + }) + .ctx(&self) } fn deserialize_str>(self, visitor: V) -> Result { - self.0.ensure_utf8(DataType::LargeUtf8, None)?; - visitor.visit_borrowed_str("") + try_(|| { + self.0.ensure_utf8(DataType::LargeUtf8, None)?; + visitor.visit_borrowed_str("") + }) + .ctx(&self) } fn deserialize_string>(self, visitor: V) -> Result { - self.0.ensure_utf8(DataType::LargeUtf8, None)?; - visitor.visit_string(Default::default()) + try_(|| { + self.0.ensure_utf8(DataType::LargeUtf8, None)?; + visitor.visit_string(Default::default()) + }) + .ctx(&self) } fn deserialize_bytes>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::LargeBinary)?; - visitor.visit_borrowed_bytes(&[]) + try_(|| { + self.0.ensure_primitive(DataType::LargeBinary)?; + visitor.visit_borrowed_bytes(&[]) + }) + .ctx(&self) } fn deserialize_byte_buf>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::LargeBinary)?; - visitor.visit_byte_buf(Default::default()) + try_(|| { + self.0.ensure_primitive(DataType::LargeBinary)?; + visitor.visit_byte_buf(Default::default()) + }) + .ctx(&self) } fn deserialize_option>(self, visitor: V) -> Result { - self.0.mark_nullable(); - visitor.visit_some(self) + try_(|| { + self.0.mark_nullable(); + visitor.visit_some(TraceAny(&mut *self.0)) + }) + .ctx(&self) } fn deserialize_unit>(self, visitor: V) -> Result { - self.0.ensure_primitive(DataType::Null)?; - visitor.visit_unit() + try_(|| { + self.0.ensure_primitive(DataType::Null)?; + visitor.visit_unit() + }) + .ctx(&self) } fn deserialize_unit_struct>( @@ -153,8 +215,11 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { _name: &'static str, visitor: V, ) -> Result { - self.0.ensure_primitive(DataType::Null)?; - visitor.visit_unit() + try_(|| { + self.0.ensure_primitive(DataType::Null)?; + visitor.visit_unit() + }) + .ctx(&self) } fn deserialize_newtype_struct>( @@ -162,28 +227,34 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { _name: &'static str, visitor: V, ) -> Result { - visitor.visit_newtype_struct(self) + try_(|| visitor.visit_newtype_struct(TraceAny(&mut *self.0))).ctx(&self) } fn deserialize_seq>(self, visitor: V) -> Result { - self.0.ensure_list()?; - let Tracer::List(tracer) = self.0 else { - unreachable!() - }; + try_(|| { + self.0.ensure_list()?; + let Tracer::List(tracer) = self.0 else { + unreachable!() + }; - visitor.visit_seq(TraceSeq(&mut tracer.item_tracer, true)) + visitor.visit_seq(TraceSeq(&mut tracer.item_tracer, true)) + }) + .ctx(&self) } fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - self.0.ensure_tuple(len)?; - let Tracer::Tuple(tracer) = self.0 else { - unreachable!(); - }; - - visitor.visit_seq(TraceTupleStruct { - tracers: &mut tracer.field_tracers, - pos: 0, + try_(|| { + self.0.ensure_tuple(len)?; + let Tracer::Tuple(tracer) = self.0 else { + unreachable!(); + }; + + visitor.visit_seq(TraceTupleStruct { + tracers: &mut tracer.field_tracers, + pos: 0, + }) }) + .ctx(&self) } fn deserialize_tuple_struct>( @@ -192,27 +263,30 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { len: usize, visitor: V, ) -> Result { - self.deserialize_tuple(len, visitor) + try_(|| TraceAny(&mut *self.0).deserialize_tuple(len, visitor)).ctx(&self) } fn deserialize_map>(self, visitor: V) -> Result { - if self.0.get_options().map_as_struct { - fail!(concat!( - "Cannot trace maps as structs with `from_type`. ", - "The struct fields cannot be known from the type alone.", - "Consider using `from_samples`. ", - )); - } + try_(|| { + if self.0.get_options().map_as_struct { + fail!(concat!( + "Cannot trace maps as structs with `from_type`. ", + "The struct fields cannot be known from the type alone.", + "Consider using `from_samples`. ", + )); + } - self.0.ensure_map()?; - let Tracer::Map(tracer) = self.0 else { - unreachable!() - }; - visitor.visit_map(TraceMap { - key_tracer: &mut tracer.key_tracer, - value_tracer: &mut tracer.value_tracer, - active: true, + self.0.ensure_map()?; + let Tracer::Map(tracer) = self.0 else { + unreachable!() + }; + visitor.visit_map(TraceMap { + key_tracer: &mut tracer.key_tracer, + value_tracer: &mut tracer.value_tracer, + active: true, + }) }) + .ctx(&self) } fn deserialize_struct>( @@ -221,16 +295,19 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { fields: &'static [&'static str], visitor: V, ) -> Result { - self.0.ensure_struct(fields, StructMode::Struct)?; - let Tracer::Struct(tracer) = self.0 else { - unreachable!() - }; - - visitor.visit_map(TraceStruct { - fields: &mut tracer.fields, - pos: 0, - names: fields, + try_(|| { + self.0.ensure_struct(fields, StructMode::Struct)?; + let Tracer::Struct(tracer) = self.0 else { + unreachable!() + }; + + visitor.visit_map(TraceStruct { + fields: &mut tracer.fields, + pos: 0, + names: fields, + }) }) + .ctx(&self) } fn deserialize_enum>( @@ -239,39 +316,45 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { variants: &'static [&'static str], visitor: V, ) -> Result { - self.0.ensure_union(variants)?; - let Tracer::Union(tracer) = self.0 else { - unreachable!(); - }; - - let idx = tracer - .variants - .iter() - .position(|opt| !opt.as_ref().unwrap().tracer.is_complete()) - .unwrap_or_default(); - if idx >= tracer.variants.len() { - fail!("invalid variant index"); - } + try_(|| { + self.0.ensure_union(variants)?; + let Tracer::Union(tracer) = self.0 else { + unreachable!(); + }; + + let idx = tracer + .variants + .iter() + .position(|opt| !opt.as_ref().unwrap().tracer.is_complete()) + .unwrap_or_default(); + if idx >= tracer.variants.len() { + fail!("invalid variant index"); + } - let Some(variant) = tracer.variants[idx].as_mut() else { - fail!("invalid state"); - }; + let Some(variant) = tracer.variants[idx].as_mut() else { + fail!("invalid state"); + }; - let res = visitor.visit_enum(TraceEnum { - tracer: &mut variant.tracer, - pos: idx, - variant: &variant.name, - })?; - Ok(res) + let res = visitor.visit_enum(TraceEnum { + tracer: &mut variant.tracer, + pos: idx, + variant: &variant.name, + })?; + Ok(res) + }) + .ctx(&self) } fn deserialize_identifier>(self, visitor: V) -> Result { - self.deserialize_str(visitor) + try_(|| TraceAny(&mut *self.0).deserialize_str(visitor)).ctx(&self) } fn deserialize_ignored_any>(self, visitor: V) -> Result { - // TODO: is this correct? - visitor.visit_unit() + try_(|| { + // TODO: is this correct? + visitor.visit_unit() + }) + .ctx(&self) } } diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index ed9f620a..21b3090d 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -1,11 +1,11 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; use crate::internal::{ arrow::{DataType, Field, UnionMode}, - error::{fail, Result}, + error::{fail, set_default, Context, Result}, schema::{ DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy, TracingMode, TracingOptions, STRATEGY_KEY, @@ -516,6 +516,41 @@ impl Tracer { } } +impl Context for Tracer { + fn annotate(&self, annotations: &mut BTreeMap) { + match self { + Tracer::Unknown(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "Unknown"); + } + Tracer::Primitive(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "Primitive"); + } + Tracer::List(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "List"); + } + Tracer::Map(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "Map"); + } + Tracer::Struct(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "Struct"); + } + Tracer::Tuple(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "Tuple"); + } + Tracer::Union(tracer) => { + set_default(annotations, "path", &tracer.path); + set_default(annotations, "tracer_type", "Union"); + } + } + } +} + fn coerce_primitive_type( prev: (&DataType, bool, Option<&Strategy>), curr: (DataType, Option), diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index d1c9cca5..b1a3e4c1 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1,2 +1,3 @@ -mod deserializers; -mod push_validity; +mod deserializers; +mod push_validity; +mod trace_from_type; diff --git a/serde_arrow/src/test/error_messages/trace_from_type.rs b/serde_arrow/src/test/error_messages/trace_from_type.rs new file mode 100644 index 00000000..17f8bef2 --- /dev/null +++ b/serde_arrow/src/test/error_messages/trace_from_type.rs @@ -0,0 +1,18 @@ +use serde_json::Value; + +use crate::{ + internal::{ + schema::{SchemaLike, SerdeArrowSchema}, + testing::assert_error_contains, + utils::Item, + }, + schema::TracingOptions, +}; + +#[test] +fn example() { + // NOTE: Value cannot be traced with from_type, as it is not self-describing + let res = SerdeArrowSchema::from_type::>>(TracingOptions::default()); + assert_error_contains(&res, "path: \"$.item.element\""); + assert_error_contains(&res, "tracer_type: \"Unknown\""); +} From a6f2751a60f0b108f3413116bee4b8e321edc445 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 16:03:26 +0200 Subject: [PATCH 167/178] Start to implement error context in from_samples --- .../src/internal/schema/from_samples/mod.rs | 186 ++++++++++++++---- serde_arrow/src/test/error_messages/mod.rs | 3 +- .../test/error_messages/trace_from_samples.rs | 9 + 3 files changed, 161 insertions(+), 37 deletions(-) create mode 100644 serde_arrow/src/test/error_messages/trace_from_samples.rs diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 52ec026c..daaa9a2c 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -3,13 +3,13 @@ mod chrono; #[cfg(test)] mod test_error_messages; -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use serde::{ser::Impossible, Serialize}; use crate::internal::{ arrow::DataType, - error::{fail, Error, Result}, + error::{fail, try_, Context, ContextSupport, Error, Result}, schema::{Strategy, TracingMode, TracingOptions}, }; @@ -38,14 +38,17 @@ mod impl_outer_sequence_serializer { use super::*; macro_rules! unimplemented_fn { - ($name:ident $($args:tt)* ) => { - fn $name $($args)* { - fail!("Cannot trace non-sequences with `from_samples`. Consider wrapping the argument in an array."); - } + ($ctx:ident ) => { + fail!(in $ctx, "Cannot trace non-sequences with `from_samples`. Consider wrapping the argument in an array.") }; } - #[rustfmt::skip] + impl<'a> Context for OuterSequenceSerializer<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) + } + } + impl<'a> serde::ser::Serializer for OuterSequenceSerializer<'a> { type Ok = (); type Error = Error; @@ -61,7 +64,13 @@ mod impl_outer_sequence_serializer { Ok(self) } - fn serialize_tuple_variant(self, _: &'static str, _: u32, _: &'static str, _: usize) -> Result { + fn serialize_tuple_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: usize, + ) -> Result { Ok(self) } @@ -70,31 +79,130 @@ mod impl_outer_sequence_serializer { type SerializeStructVariant = Impossible; type SerializeTupleStruct = Impossible; - unimplemented_fn!(serialize_bool(self, _: bool) -> Result); - unimplemented_fn!(serialize_i8(self, _: i8) -> Result); - unimplemented_fn!(serialize_i16(self, _: i16) -> Result); - unimplemented_fn!(serialize_i32(self, _: i32) -> Result); - unimplemented_fn!(serialize_i64(self, _: i64) -> Result); - unimplemented_fn!(serialize_u8(self, _: u8) -> Result); - unimplemented_fn!(serialize_u16(self, _: u16) -> Result); - unimplemented_fn!(serialize_u32(self, _: u32) -> Result); - unimplemented_fn!(serialize_u64(self, _: u64) -> Result); - unimplemented_fn!(serialize_f32(self, _: f32) -> Result); - unimplemented_fn!(serialize_f64(self, _: f64) -> Result); - unimplemented_fn!(serialize_char(self, _: char) -> Result); - unimplemented_fn!(serialize_unit(self) -> Result); - unimplemented_fn!(serialize_str(self, _: &str) -> Result); - unimplemented_fn!(serialize_bytes(self, _: &[u8]) -> Result); - unimplemented_fn!(serialize_none(self) -> Result); - unimplemented_fn!(serialize_map(self, _: Option) -> Result); - unimplemented_fn!(serialize_struct(self, _: &'static str, _: usize) -> Result); - unimplemented_fn!(serialize_struct_variant(self, _: &'static str, _: u32, _: &'static str, _: usize) -> Result); - unimplemented_fn!(serialize_tuple_struct(self, _: &'static str, _: usize) -> Result); - unimplemented_fn!(serialize_unit_struct(self, _: &'static str) -> Result); - unimplemented_fn!(serialize_unit_variant(self, _: &'static str, _: u32, _: &'static str) -> Result); - unimplemented_fn!(serialize_some(self, _: &T) -> Result); - unimplemented_fn!(serialize_newtype_struct(self, _: &'static str, _: &T) -> Result); - unimplemented_fn!(serialize_newtype_variant(self, _: &'static str, _: u32, _: &'static str, _: &T) -> Result); + fn serialize_bool(self, _: bool) -> Result { + unimplemented_fn!(self) + } + + fn serialize_i8(self, _: i8) -> Result { + unimplemented_fn!(self) + } + + fn serialize_i16(self, _: i16) -> Result { + unimplemented_fn!(self) + } + + fn serialize_i32(self, _: i32) -> Result { + unimplemented_fn!(self) + } + + fn serialize_i64(self, _: i64) -> Result { + unimplemented_fn!(self) + } + + fn serialize_u8(self, _: u8) -> Result { + unimplemented_fn!(self) + } + + fn serialize_u16(self, _: u16) -> Result { + unimplemented_fn!(self) + } + + fn serialize_u32(self, _: u32) -> Result { + unimplemented_fn!(self) + } + + fn serialize_u64(self, _: u64) -> Result { + unimplemented_fn!(self) + } + + fn serialize_f32(self, _: f32) -> Result { + unimplemented_fn!(self) + } + + fn serialize_f64(self, _: f64) -> Result { + unimplemented_fn!(self) + } + + fn serialize_char(self, _: char) -> Result { + unimplemented_fn!(self) + } + + fn serialize_unit(self) -> Result { + unimplemented_fn!(self) + } + + fn serialize_str(self, _: &str) -> Result { + unimplemented_fn!(self) + } + + fn serialize_bytes(self, _: &[u8]) -> Result { + unimplemented_fn!(self) + } + + fn serialize_none(self) -> Result { + unimplemented_fn!(self) + } + + fn serialize_map(self, _: Option) -> Result { + unimplemented_fn!(self) + } + + fn serialize_struct(self, _: &'static str, _: usize) -> Result { + unimplemented_fn!(self) + } + + fn serialize_struct_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: usize, + ) -> Result { + unimplemented_fn!(self) + } + + fn serialize_tuple_struct( + self, + _: &'static str, + _: usize, + ) -> Result { + unimplemented_fn!(self) + } + + fn serialize_unit_struct(self, _: &'static str) -> Result { + unimplemented_fn!(self) + } + + fn serialize_unit_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + ) -> Result { + unimplemented_fn!(self) + } + + fn serialize_some(self, _: &T) -> Result { + unimplemented_fn!(self) + } + + fn serialize_newtype_struct( + self, + _: &'static str, + _: &T, + ) -> Result { + unimplemented_fn!(self) + } + + fn serialize_newtype_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: &T, + ) -> Result { + unimplemented_fn!(self) + } } } @@ -103,7 +211,7 @@ impl<'a> serde::ser::SerializeSeq for OuterSequenceSerializer<'a> { type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> { - value.serialize(TracerSerializer(&mut *self.0)) + try_(|| value.serialize(TracerSerializer(&mut *self.0))).ctx(self) } fn end(self) -> Result { @@ -116,7 +224,7 @@ impl<'a> serde::ser::SerializeTuple for OuterSequenceSerializer<'a> { type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> { - value.serialize(TracerSerializer(&mut *self.0)) + try_(|| value.serialize(TracerSerializer(&mut *self.0))).ctx(self) } fn end(self) -> Result { @@ -129,7 +237,7 @@ impl<'a> serde::ser::SerializeTupleVariant for OuterSequenceSerializer<'a> { type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> { - value.serialize(TracerSerializer(&mut *self.0)) + try_(|| value.serialize(TracerSerializer(&mut *self.0))).ctx(self) } fn end(self) -> Result { @@ -139,6 +247,12 @@ impl<'a> serde::ser::SerializeTupleVariant for OuterSequenceSerializer<'a> { struct TracerSerializer<'a>(&'a mut Tracer); +impl<'a> Context for TracerSerializer<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) + } +} + impl<'a> TracerSerializer<'a> { fn ensure_union_variant( self, diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index b1a3e4c1..14378381 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1,3 +1,4 @@ mod deserializers; mod push_validity; -mod trace_from_type; +mod trace_from_samples; +mod trace_from_type; \ No newline at end of file diff --git a/serde_arrow/src/test/error_messages/trace_from_samples.rs b/serde_arrow/src/test/error_messages/trace_from_samples.rs new file mode 100644 index 00000000..95d1e21d --- /dev/null +++ b/serde_arrow/src/test/error_messages/trace_from_samples.rs @@ -0,0 +1,9 @@ +use crate::{internal::testing::assert_error_contains, schema::{SchemaLike, SerdeArrowSchema, TracingOptions}}; + + +#[test] +fn non_sequence() { + let res = SerdeArrowSchema::from_samples(&42, TracingOptions::default()); + assert_error_contains(&res, "Cannot trace non-sequences with `from_samples`"); + assert_error_contains(&res, "path: \"$\""); +} From c49f23960406541dfab249d0f583f80a723ea2e9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 16:34:42 +0200 Subject: [PATCH 168/178] Implement error context for from_samples --- .../src/internal/schema/from_samples/mod.rs | 382 +++++++++++------- .../from_samples/test_error_messages.rs | 4 +- .../schema/from_type/test_error_messages.rs | 2 +- serde_arrow/src/internal/schema/tracer.rs | 80 ++-- serde_arrow/src/test/error_messages/mod.rs | 2 +- .../test/error_messages/trace_from_samples.rs | 14 +- 6 files changed, 311 insertions(+), 173 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index daaa9a2c..a03b1aab 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -39,7 +39,7 @@ mod impl_outer_sequence_serializer { macro_rules! unimplemented_fn { ($ctx:ident ) => { - fail!(in $ctx, "Cannot trace non-sequences with `from_samples`. Consider wrapping the argument in an array.") + fail!(in $ctx, "Cannot trace non-sequences with `from_samples`: consider wrapping the argument in an array") }; } @@ -82,75 +82,75 @@ mod impl_outer_sequence_serializer { fn serialize_bool(self, _: bool) -> Result { unimplemented_fn!(self) } - + fn serialize_i8(self, _: i8) -> Result { unimplemented_fn!(self) } - + fn serialize_i16(self, _: i16) -> Result { unimplemented_fn!(self) } - + fn serialize_i32(self, _: i32) -> Result { unimplemented_fn!(self) } - + fn serialize_i64(self, _: i64) -> Result { unimplemented_fn!(self) } - + fn serialize_u8(self, _: u8) -> Result { unimplemented_fn!(self) } - + fn serialize_u16(self, _: u16) -> Result { unimplemented_fn!(self) } - + fn serialize_u32(self, _: u32) -> Result { unimplemented_fn!(self) } - + fn serialize_u64(self, _: u64) -> Result { unimplemented_fn!(self) } - + fn serialize_f32(self, _: f32) -> Result { unimplemented_fn!(self) } - + fn serialize_f64(self, _: f64) -> Result { unimplemented_fn!(self) } - + fn serialize_char(self, _: char) -> Result { unimplemented_fn!(self) } - + fn serialize_unit(self) -> Result { unimplemented_fn!(self) } - + fn serialize_str(self, _: &str) -> Result { unimplemented_fn!(self) } - + fn serialize_bytes(self, _: &[u8]) -> Result { unimplemented_fn!(self) } - + fn serialize_none(self) -> Result { unimplemented_fn!(self) } - + fn serialize_map(self, _: Option) -> Result { unimplemented_fn!(self) } - + fn serialize_struct(self, _: &'static str, _: usize) -> Result { unimplemented_fn!(self) } - + fn serialize_struct_variant( self, _: &'static str, @@ -160,7 +160,7 @@ mod impl_outer_sequence_serializer { ) -> Result { unimplemented_fn!(self) } - + fn serialize_tuple_struct( self, _: &'static str, @@ -168,11 +168,11 @@ mod impl_outer_sequence_serializer { ) -> Result { unimplemented_fn!(self) } - + fn serialize_unit_struct(self, _: &'static str) -> Result { unimplemented_fn!(self) } - + fn serialize_unit_variant( self, _: &'static str, @@ -181,11 +181,11 @@ mod impl_outer_sequence_serializer { ) -> Result { unimplemented_fn!(self) } - + fn serialize_some(self, _: &T) -> Result { unimplemented_fn!(self) } - + fn serialize_newtype_struct( self, _: &'static str, @@ -285,86 +285,95 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { type SerializeTupleVariant = TupleSerializer<'a>; fn serialize_bool(self, _: bool) -> Result { - self.0.ensure_primitive(DataType::Boolean) + try_(|| self.0.ensure_primitive(DataType::Boolean)).ctx(&self) } fn serialize_i8(self, _: i8) -> Result { - self.0.ensure_number(DataType::Int8) + try_(|| self.0.ensure_number(DataType::Int8)).ctx(&self) } fn serialize_i16(self, _: i16) -> Result { - self.0.ensure_number(DataType::Int16) + try_(|| self.0.ensure_number(DataType::Int16)).ctx(&self) } fn serialize_i32(self, _: i32) -> Result { - self.0.ensure_number(DataType::Int32) + try_(|| self.0.ensure_number(DataType::Int32)).ctx(&self) } fn serialize_i64(self, _: i64) -> Result { - self.0.ensure_number(DataType::Int64) + try_(|| self.0.ensure_number(DataType::Int64)).ctx(&self) } fn serialize_u8(self, _: u8) -> Result { - self.0.ensure_number(DataType::UInt8) + try_(|| self.0.ensure_number(DataType::UInt8)).ctx(&self) } fn serialize_u16(self, _: u16) -> Result { - self.0.ensure_number(DataType::UInt16) + try_(|| self.0.ensure_number(DataType::UInt16)).ctx(&self) } fn serialize_u32(self, _: u32) -> Result { - self.0.ensure_number(DataType::UInt32) + try_(|| self.0.ensure_number(DataType::UInt32)).ctx(&self) } fn serialize_u64(self, _: u64) -> Result { - self.0.ensure_number(DataType::UInt64) + try_(|| self.0.ensure_number(DataType::UInt64)).ctx(&self) } fn serialize_f32(self, _: f32) -> Result { - self.0.ensure_number(DataType::Float32) + try_(|| self.0.ensure_number(DataType::Float32)).ctx(&self) } fn serialize_f64(self, _: f64) -> Result { - self.0.ensure_number(DataType::Float64) + try_(|| self.0.ensure_number(DataType::Float64)).ctx(&self) } fn serialize_char(self, _: char) -> Result { - self.0.ensure_primitive(DataType::UInt32) + try_(|| self.0.ensure_primitive(DataType::UInt32)).ctx(&self) } fn serialize_unit(self) -> Result { - self.0.ensure_primitive(DataType::Null) + try_(|| self.0.ensure_primitive(DataType::Null)).ctx(&self) } fn serialize_str(self, s: &str) -> Result { - let guess_dates = self.0.get_options().guess_dates; - if guess_dates && chrono::matches_naive_datetime(s) { - self.0 - .ensure_utf8(DataType::Date64, Some(Strategy::NaiveStrAsDate64)) - } else if guess_dates && chrono::matches_utc_datetime(s) { - self.0 - .ensure_utf8(DataType::Date64, Some(Strategy::UtcStrAsDate64)) - } else { - self.0.ensure_utf8(DataType::LargeUtf8, None) - } + try_(|| { + let guess_dates = self.0.get_options().guess_dates; + if guess_dates && chrono::matches_naive_datetime(s) { + self.0 + .ensure_utf8(DataType::Date64, Some(Strategy::NaiveStrAsDate64)) + } else if guess_dates && chrono::matches_utc_datetime(s) { + self.0 + .ensure_utf8(DataType::Date64, Some(Strategy::UtcStrAsDate64)) + } else { + self.0.ensure_utf8(DataType::LargeUtf8, None) + } + }) + .ctx(&self) } fn serialize_bytes(self, _: &[u8]) -> Result { - self.0.ensure_primitive(DataType::LargeBinary) + try_(|| self.0.ensure_primitive(DataType::LargeBinary)).ctx(&self) } fn serialize_none(self) -> Result { - self.0.mark_nullable(); - Ok(()) + try_(|| { + self.0.mark_nullable(); + Ok(()) + }) + .ctx(&self) } fn serialize_some(self, value: &T) -> Result { - self.0.mark_nullable(); - value.serialize(self) + try_(|| { + self.0.mark_nullable(); + value.serialize(TracerSerializer(&mut *self.0)) + }) + .ctx(&self) } fn serialize_unit_struct(self, _: &'static str) -> Result { - self.serialize_unit() + try_(|| TracerSerializer(&mut *self.0).serialize_unit()).ctx(&self) } fn serialize_newtype_struct( @@ -372,47 +381,71 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { _: &'static str, value: &T, ) -> Result { - value.serialize(self) + try_(|| value.serialize(TracerSerializer(&mut *self.0))).ctx(&self) } fn serialize_map(self, _: Option) -> Result { - if self.0.get_options().map_as_struct { - self.0.ensure_struct::<&str>(&[], StructMode::Map)?; - let Tracer::Struct(tracer) = self.0 else { - unreachable!(); - }; - Ok(MapSerializer::AsStruct(tracer, None)) - } else { - self.0.ensure_map()?; - let Tracer::Map(tracer) = self.0 else { - unreachable!(); - }; - Ok(MapSerializer::AsMap(tracer)) - } + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(move || { + if self.0.get_options().map_as_struct { + self.0.ensure_struct::<&str>(&[], StructMode::Map)?; + let Tracer::Struct(tracer) = self.0 else { + unreachable!(); + }; + Ok(MapSerializer::AsStruct(tracer, None)) + } else { + self.0.ensure_map()?; + let Tracer::Map(tracer) = self.0 else { + unreachable!(); + }; + Ok(MapSerializer::AsMap(tracer)) + } + }) + .ctx(&ctx) } fn serialize_seq(self, _: Option) -> Result { - self.0.ensure_list()?; - let Tracer::List(tracer) = self.0 else { - unreachable!(); - }; - Ok(ListSerializer(tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(move || { + self.0.ensure_list()?; + let Tracer::List(tracer) = self.0 else { + unreachable!(); + }; + Ok(ListSerializer(tracer)) + }) + .ctx(&ctx) } fn serialize_struct(self, _: &'static str, _: usize) -> Result { - self.0.ensure_struct::<&str>(&[], StructMode::Struct)?; - let Tracer::Struct(tracer) = self.0 else { - unreachable!(); - }; - Ok(StructSerializer(tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(move || { + self.0.ensure_struct::<&str>(&[], StructMode::Struct)?; + let Tracer::Struct(tracer) = self.0 else { + unreachable!(); + }; + Ok(StructSerializer(tracer)) + }) + .ctx(&ctx) } fn serialize_tuple(self, len: usize) -> Result { - self.0.ensure_tuple(len)?; - let Tracer::Tuple(tracer) = self.0 else { - unreachable!(); - }; - Ok(TupleSerializer::new(tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(move || { + self.0.ensure_tuple(len)?; + let Tracer::Tuple(tracer) = self.0 else { + unreachable!(); + }; + Ok(TupleSerializer::new(tracer)) + }) + .ctx(&ctx) } fn serialize_tuple_struct( @@ -420,11 +453,17 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { _: &'static str, len: usize, ) -> Result { - self.0.ensure_tuple(len)?; - let Tracer::Tuple(tracer) = self.0 else { - unreachable!(); - }; - Ok(TupleSerializer::new(tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(move || { + self.0.ensure_tuple(len)?; + let Tracer::Tuple(tracer) = self.0 else { + unreachable!(); + }; + Ok(TupleSerializer::new(tracer)) + }) + .ctx(&ctx) } fn serialize_unit_variant( @@ -433,8 +472,14 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { variant_index: u32, variant_name: &'static str, ) -> Result { - let variant = self.ensure_union_variant(variant_name, variant_index)?; - variant.tracer.ensure_primitive(DataType::Null) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(|| { + let variant = self.ensure_union_variant(variant_name, variant_index)?; + variant.tracer.ensure_primitive(DataType::Null) + }) + .ctx(&ctx) } fn serialize_newtype_variant( @@ -444,8 +489,14 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { variant_name: &'static str, value: &T, ) -> Result { - let variant = self.ensure_union_variant(variant_name, variant_index)?; - value.serialize(TracerSerializer(&mut variant.tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(|| { + let variant = self.ensure_union_variant(variant_name, variant_index)?; + value.serialize(TracerSerializer(&mut variant.tracer)) + }) + .ctx(&ctx) } fn serialize_struct_variant( @@ -455,14 +506,20 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { variant_name: &'static str, _: usize, ) -> Result { - let variant = self.ensure_union_variant(variant_name, variant_index)?; - variant - .tracer - .ensure_struct::<&str>(&[], StructMode::Struct)?; - let Tracer::Struct(tracer) = &mut variant.tracer else { - unreachable!(); - }; - Ok(StructSerializer(tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(|| { + let variant = self.ensure_union_variant(variant_name, variant_index)?; + variant + .tracer + .ensure_struct::<&str>(&[], StructMode::Struct)?; + let Tracer::Struct(tracer) = &mut variant.tracer else { + unreachable!(); + }; + Ok(StructSerializer(tracer)) + }) + .ctx(&ctx) } fn serialize_tuple_variant( @@ -472,17 +529,29 @@ impl<'a> serde::ser::Serializer for TracerSerializer<'a> { variant_name: &'static str, len: usize, ) -> Result { - let variant = self.ensure_union_variant(variant_name, variant_index)?; - variant.tracer.ensure_tuple(len)?; - let Tracer::Tuple(tracer) = &mut variant.tracer else { - unreachable!(); - }; - Ok(TupleSerializer::new(tracer)) + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + + try_(|| { + let variant = self.ensure_union_variant(variant_name, variant_index)?; + variant.tracer.ensure_tuple(len)?; + let Tracer::Tuple(tracer) = &mut variant.tracer else { + unreachable!(); + }; + Ok(TupleSerializer::new(tracer)) + }) + .ctx(&ctx) } } struct StructSerializer<'a>(&'a mut StructTracer); +impl<'a> Context for StructSerializer<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) + } +} + impl<'a> serde::ser::SerializeStruct for StructSerializer<'a> { type Ok = (); type Error = Error; @@ -492,15 +561,18 @@ impl<'a> serde::ser::SerializeStruct for StructSerializer<'a> { key: &'static str, value: &T, ) -> Result<()> { - let field_idx = self.0.ensure_field(key)?; - let Some(field_tracer) = self.0.get_field_tracer_mut(field_idx) else { - unreachable!(); - }; - value.serialize(TracerSerializer(field_tracer)) + try_(|| { + let field_idx = self.0.ensure_field(key)?; + let Some(field_tracer) = self.0.get_field_tracer_mut(field_idx) else { + unreachable!(); + }; + value.serialize(TracerSerializer(field_tracer)) + }) + .ctx(self) } fn end(self) -> Result { - self.0.end() + try_(|| self.0.end()).ctx(&self) } } @@ -513,26 +585,35 @@ impl<'a> serde::ser::SerializeStructVariant for StructSerializer<'a> { key: &'static str, value: &T, ) -> Result<()> { - let field_idx = self.0.ensure_field(key)?; - let Some(field_tracer) = self.0.get_field_tracer_mut(field_idx) else { - unreachable!(); - }; - value.serialize(TracerSerializer(field_tracer)) + try_(|| { + let field_idx = self.0.ensure_field(key)?; + let Some(field_tracer) = self.0.get_field_tracer_mut(field_idx) else { + unreachable!(); + }; + value.serialize(TracerSerializer(field_tracer)) + }) + .ctx(self) } fn end(self) -> Result { - self.0.end() + try_(|| self.0.end()).ctx(&self) } } struct ListSerializer<'a>(&'a mut ListTracer); +impl<'a> Context for ListSerializer<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) + } +} + impl<'a> serde::ser::SerializeSeq for ListSerializer<'a> { type Ok = (); type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> { - value.serialize(TracerSerializer(&mut self.0.item_tracer)) + try_(|| value.serialize(TracerSerializer(&mut self.0.item_tracer))).ctx(self) } fn end(self) -> Result { @@ -542,6 +623,12 @@ impl<'a> serde::ser::SerializeSeq for ListSerializer<'a> { struct TupleSerializer<'a>(&'a mut TupleTracer, usize); +impl<'a> Context for TupleSerializer<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + self.0.annotate(annotations) + } +} + impl<'a> TupleSerializer<'a> { fn new(tracer: &'a mut TupleTracer) -> Self { Self(tracer, 0) @@ -553,10 +640,13 @@ impl<'a> serde::ser::SerializeTuple for TupleSerializer<'a> { type Error = Error; fn serialize_element(&mut self, value: &T) -> Result<()> { - let pos = self.1; - value.serialize(TracerSerializer(self.0.field_tracer(pos)))?; - self.1 += 1; - Ok(()) + try_(|| { + let pos = self.1; + value.serialize(TracerSerializer(self.0.field_tracer(pos)))?; + self.1 += 1; + Ok(()) + }) + .ctx(self) } fn end(self) -> Result { @@ -569,10 +659,13 @@ impl<'a> serde::ser::SerializeTupleStruct for TupleSerializer<'a> { type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> { - let pos = self.1; - value.serialize(TracerSerializer(self.0.field_tracer(pos)))?; - self.1 += 1; - Ok(()) + try_(|| { + let pos = self.1; + value.serialize(TracerSerializer(self.0.field_tracer(pos)))?; + self.1 += 1; + Ok(()) + }) + .ctx(self) } fn end(self) -> Result { @@ -585,10 +678,13 @@ impl<'a> serde::ser::SerializeTupleVariant for TupleSerializer<'a> { type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<()> where { - let pos = self.1; - value.serialize(TracerSerializer(self.0.field_tracer(pos)))?; - self.1 += 1; - Ok(()) + try_(|| { + let pos = self.1; + value.serialize(TracerSerializer(self.0.field_tracer(pos)))?; + self.1 += 1; + Ok(()) + }) + .ctx(self) } fn end(self) -> Result { @@ -601,22 +697,32 @@ enum MapSerializer<'a> { AsMap(&'a mut MapTracer), } +impl<'a> Context for MapSerializer<'a> { + fn annotate(&self, annotations: &mut BTreeMap) { + match self { + Self::AsStruct(tracer, _) => tracer.annotate(annotations), + Self::AsMap(tracer) => tracer.annotate(annotations), + } + } +} + impl<'a> serde::ser::SerializeMap for MapSerializer<'a> { type Ok = (); type Error = Error; fn serialize_key(&mut self, key: &T) -> Result<()> { - match self { + try_(|| match self { Self::AsStruct(_, next_key) => { *next_key = Some(key.serialize(SerializeToString)?); Ok(()) } Self::AsMap(tracer) => key.serialize(TracerSerializer(&mut tracer.key_tracer)), - } + }) + .ctx(self) } fn serialize_value(&mut self, value: &T) -> Result<()> { - match self { + try_(|| match self { Self::AsStruct(tracer, next_key) => { let Some(next_key) = next_key.take() else { fail!("invalid operations"); @@ -628,14 +734,16 @@ impl<'a> serde::ser::SerializeMap for MapSerializer<'a> { value.serialize(TracerSerializer(field_tracer)) } Self::AsMap(tracer) => value.serialize(TracerSerializer(&mut tracer.value_tracer)), - } + }) + .ctx(self) } - fn end(self) -> Result { - match self { + fn end(mut self) -> Result { + try_(|| match &mut self { Self::AsStruct(tracer, _) => tracer.end(), Self::AsMap(_) => Ok(()), - } + }) + .ctx(&self) } } diff --git a/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs b/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs index f9411eba..5a94ce55 100644 --- a/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_samples/test_error_messages.rs @@ -31,8 +31,8 @@ fn outer_sequence_issue_97() { }; let res = SerdeArrowSchema::from_samples(&b, TracingOptions::default()); - assert_error_contains(&res, "Cannot trace non-sequences with `from_samples`."); - assert_error_contains(&res, "Consider wrapping the argument in an array."); + assert_error_contains(&res, "Cannot trace non-sequences with `from_samples`"); + assert_error_contains(&res, "consider wrapping the argument in an array"); } #[test] diff --git a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs index f2fc5ac6..9e4d5cea 100644 --- a/serde_arrow/src/internal/schema/from_type/test_error_messages.rs +++ b/serde_arrow/src/internal/schema/from_type/test_error_messages.rs @@ -39,7 +39,7 @@ fn map_as_struct() { TracingOptions::default().map_as_struct(true), ); assert_error_contains(&res, "Cannot trace maps as structs with `from_type`"); - assert_error_contains(&res, "Consider using `from_samples`."); + assert_error_contains(&res, "Consider using `from_samples`"); } #[test] diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 21b3090d..3c18723d 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -518,36 +518,7 @@ impl Tracer { impl Context for Tracer { fn annotate(&self, annotations: &mut BTreeMap) { - match self { - Tracer::Unknown(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "Unknown"); - } - Tracer::Primitive(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "Primitive"); - } - Tracer::List(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "List"); - } - Tracer::Map(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "Map"); - } - Tracer::Struct(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "Struct"); - } - Tracer::Tuple(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "Tuple"); - } - Tracer::Union(tracer) => { - set_default(annotations, "path", &tracer.path); - set_default(annotations, "tracer_type", "Union"); - } - } + dispatch_tracer!(self, tracer => tracer.annotate(annotations)) } } @@ -596,6 +567,13 @@ pub struct UnknownTracer { pub nullable: bool, } +impl Context for UnknownTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "Unknown"); + } +} + impl UnknownTracer { pub fn new(name: String, path: String, options: Arc) -> Self { Self { @@ -645,6 +623,13 @@ pub struct MapTracer { pub value_tracer: Box, } +impl Context for MapTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "Map"); + } +} + impl MapTracer { pub fn get_path(&self) -> &str { &self.path @@ -693,6 +678,13 @@ pub struct ListTracer { pub item_tracer: Box, } +impl Context for ListTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "List"); + } +} + impl ListTracer { pub fn get_path(&self) -> &str { &self.path @@ -729,6 +721,13 @@ pub struct TupleTracer { pub field_tracers: Vec, } +impl Context for TupleTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "Tuple"); + } +} + impl TupleTracer { pub fn get_path(&self) -> &str { &self.path @@ -807,6 +806,13 @@ pub enum StructMode { Map, } +impl Context for StructTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "Struct"); + } +} + impl StructTracer { pub fn get_field_tracer_mut(&mut self, idx: usize) -> Option<&mut Tracer> { Some(&mut self.fields.get_mut(idx)?.tracer) @@ -920,6 +926,13 @@ impl UnionVariant { } } +impl Context for UnionTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "Union"); + } +} + impl UnionTracer { pub fn ensure_variant + AsRef>( &mut self, @@ -1025,6 +1038,13 @@ pub struct PrimitiveTracer { pub item_type: DataType, } +impl Context for PrimitiveTracer { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "path", &self.path); + set_default(annotations, "tracer_type", "Primitive"); + } +} + impl PrimitiveTracer { pub fn finish(&mut self) -> Result<()> { Ok(()) diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index 14378381..6b98bb21 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1,4 +1,4 @@ mod deserializers; mod push_validity; mod trace_from_samples; -mod trace_from_type; \ No newline at end of file +mod trace_from_type; diff --git a/serde_arrow/src/test/error_messages/trace_from_samples.rs b/serde_arrow/src/test/error_messages/trace_from_samples.rs index 95d1e21d..9ad309d4 100644 --- a/serde_arrow/src/test/error_messages/trace_from_samples.rs +++ b/serde_arrow/src/test/error_messages/trace_from_samples.rs @@ -1,5 +1,8 @@ -use crate::{internal::testing::assert_error_contains, schema::{SchemaLike, SerdeArrowSchema, TracingOptions}}; - +use crate::{ + internal::testing::assert_error_contains, + schema::{SchemaLike, SerdeArrowSchema, TracingOptions}, + utils::Item, +}; #[test] fn non_sequence() { @@ -7,3 +10,10 @@ fn non_sequence() { assert_error_contains(&res, "Cannot trace non-sequences with `from_samples`"); assert_error_contains(&res, "path: \"$\""); } + +#[test] +fn incompatible_primitives() { + let res = + SerdeArrowSchema::from_samples(&(Item(42_u32), Item("foo bar")), TracingOptions::default()); + assert_error_contains(&res, "path: \"$.item\""); +} From 78d1a6ddf1ca05716136d113744ccbc55291ac62 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 16:37:12 +0200 Subject: [PATCH 169/178] Update changelog --- Changes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/Changes.md b/Changes.md index aca50d64..34e4159d 100644 --- a/Changes.md +++ b/Changes.md @@ -10,6 +10,7 @@ New features - Add `Binary` / `LargeBinary` support for `arrow2` - Add support to serialize / deserialize `bool` from integer arrays - Add a helper to construct `Bool8` arrays +- Include the path of the field that caused an error in the error message API changes From abc7909c4e0102f2ce91a2e6cf22441defa70599 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 16:45:40 +0200 Subject: [PATCH 170/178] Fix capitalization of the error messages --- .../src/internal/schema/from_samples/mod.rs | 2 +- serde_arrow/src/internal/schema/from_type/mod.rs | 4 ++-- serde_arrow/src/internal/schema/tracer.rs | 16 ++++++++-------- .../impls/issue_90_type_tracing.rs | 4 +++- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index a03b1aab..6914e40d 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -725,7 +725,7 @@ impl<'a> serde::ser::SerializeMap for MapSerializer<'a> { try_(|| match self { Self::AsStruct(tracer, next_key) => { let Some(next_key) = next_key.take() else { - fail!("invalid operations"); + fail!("Invalid call to serialization methods: serialize_value called without prior call to serialize_key"); }; let field_idx = tracer.ensure_field(&next_key)?; let Some(field_tracer) = tracer.get_field_tracer_mut(field_idx) else { diff --git a/serde_arrow/src/internal/schema/from_type/mod.rs b/serde_arrow/src/internal/schema/from_type/mod.rs index 24b51d46..3d449312 100644 --- a/serde_arrow/src/internal/schema/from_type/mod.rs +++ b/serde_arrow/src/internal/schema/from_type/mod.rs @@ -328,11 +328,11 @@ impl<'de, 'a> serde::de::Deserializer<'de> for TraceAny<'a> { .position(|opt| !opt.as_ref().unwrap().tracer.is_complete()) .unwrap_or_default(); if idx >= tracer.variants.len() { - fail!("invalid variant index"); + fail!("Invalid variant index"); } let Some(variant) = tracer.variants[idx].as_mut() else { - fail!("invalid state"); + fail!("Invalid state"); }; let res = visitor.visit_enum(TraceEnum { diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 3c18723d..88c8b30b 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -15,7 +15,7 @@ use crate::internal::{ // TODO: allow to customize const MAX_TYPE_DEPTH: usize = 20; const RECURSIVE_TYPE_WARNING: &str = - "too deeply nested type detected. Recursive types are not supported in schema tracing"; + "Too deeply nested type detected: recursive types are not supported in schema tracing"; fn default_dictionary_field(name: &str, nullable: bool) -> Field { Field { @@ -177,7 +177,7 @@ impl Tracer { pub fn check(&self) -> Result<()> { if dispatch_tracer!(self, tracer => tracer.name != "$") { - fail!("check must be called on the root tracer"); + fail!("Check must be called on the root tracer"); } let options = self.get_options(); self.check_overwrites(&options.overwrites) @@ -314,7 +314,7 @@ impl Tracer { // TODO: check fields are equal Self::Struct(_tracer) => {} _ => fail!( - "mismatched types, previous {:?}, current struct", + "Mismatched types: previous {:?}, current struct", self.get_type() ), } @@ -348,7 +348,7 @@ impl Tracer { // TODO: check fields are equal Self::Tuple(_tracer) => {} _ => fail!( - "mismatched types, previous {:?}, current struct", + "Mismatched types, previous {:?}, current struct", self.get_type() ), } @@ -386,7 +386,7 @@ impl Tracer { // TODO: check fields are equal or fill missing fields Self::Union(_tracer) => {} _ => fail!( - "mismatched types, previous {:?}, current union", + "Mismatched types: previous {:?}, current union", self.get_type() ), } @@ -415,7 +415,7 @@ impl Tracer { } Self::List(_tracer) => {} _ => fail!( - "mismatched types, previous {:?}, current list", + "Mismatched types: previous {:?}, current list", self.get_type() ), } @@ -449,7 +449,7 @@ impl Tracer { } Self::Map(_tracer) => {} _ => fail!( - "mismatched types, previous {:?}, current list", + "Mismatched types: previous {:?}, current list", self.get_type() ), } @@ -821,7 +821,7 @@ impl StructTracer { pub fn ensure_field(&mut self, key: &str) -> Result { if let Some(&field_idx) = self.index.get(key) { let Some(field) = self.fields.get_mut(field_idx) else { - fail!("invalid state"); + fail!("Invalid state: no tracer found for field with name {key}"); }; field.last_seen_in_sample = self.seen_samples; diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index 69c0bf6a..3170e5a8 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -236,5 +236,7 @@ fn unsupported_recursive_types() { } let res = Tracer::from_type::(TracingOptions::default()); - assert_error_contains(&res, "too deeply nested type detected"); + assert_error_contains(&res, "Too deeply nested type detected"); + // NOTE: do not check the complete path, it depends on the recursion limit + assert_error_contains(&res, "path: \"$.left.left.left.left.left.left"); } From 411fc935a5604e823eeaf765eead1843db5f5360 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 16:58:23 +0200 Subject: [PATCH 171/178] Add hint for coerce_numbers when tracing different number formats --- serde_arrow/src/internal/schema/tracer.rs | 88 ++++++++++++++++--- .../test/error_messages/trace_from_samples.rs | 9 ++ .../test/error_messages/trace_from_type.rs | 16 ++++ .../impls/issue_90_type_tracing.rs | 16 ---- 4 files changed, 103 insertions(+), 26 deletions(-) diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 88c8b30b..9a7ecc47 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -533,32 +533,100 @@ fn coerce_primitive_type( }; let res = match (prev, curr) { - ((prev_ty, nullable, prev_st), (curr_ty, curr_st)) if prev_ty == &curr_ty && prev_st == curr_st.as_ref() => (curr_ty, nullable, curr_st), + ((prev_ty, nullable, prev_st), (curr_ty, curr_st)) + if prev_ty == &curr_ty && prev_st == curr_st.as_ref() => + { + (curr_ty, nullable, curr_st) + } ((Null, _, _), (curr_ty, curr_st)) => (curr_ty, true, curr_st), ((prev_ty, _, prev_st), (Null, _)) => (prev_ty.clone(), true, prev_st.cloned()), // unsigned x unsigned -> u64 - ((UInt8 | UInt16 | UInt32 | UInt64, nullable, _), (UInt8 | UInt16 | UInt32 | UInt64, _,)) if options.coerce_numbers => (UInt64, nullable, None), + ( + (UInt8 | UInt16 | UInt32 | UInt64, nullable, _), + (UInt8 | UInt16 | UInt32 | UInt64, _), + ) if options.coerce_numbers => (UInt64, nullable, None), // signed x signed -> i64 - ((Int8 | Int16 | Int32 | Int64, nullable, _), (Int8 | Int16 | Int32 | Int64, _)) if options.coerce_numbers => (Int64, nullable, None), + ((Int8 | Int16 | Int32 | Int64, nullable, _), (Int8 | Int16 | Int32 | Int64, _)) + if options.coerce_numbers => + { + (Int64, nullable, None) + } // signed x unsigned -> i64 - ((Int8 | Int16 | Int32 | Int64, nullable, _), (UInt8 | UInt16 | UInt32 | UInt64, _)) if options.coerce_numbers => (Int64, nullable, None), + ((Int8 | Int16 | Int32 | Int64, nullable, _), (UInt8 | UInt16 | UInt32 | UInt64, _)) + if options.coerce_numbers => + { + (Int64, nullable, None) + } // unsigned x signed -> i64 - ((UInt8 | UInt16 | UInt32 | UInt64, nullable, _), (Int8 | Int16 | Int32 | Int64, _)) if options.coerce_numbers => (Int64, nullable, None), + ((UInt8 | UInt16 | UInt32 | UInt64, nullable, _), (Int8 | Int16 | Int32 | Int64, _)) + if options.coerce_numbers => + { + (Int64, nullable, None) + } // float x float -> f64 - ((Float32 | Float64, nullable, _), (Float32 | Float64, _)) if options.coerce_numbers=> (Float64, nullable, None), + ((Float32 | Float64, nullable, _), (Float32 | Float64, _)) if options.coerce_numbers => { + (Float64, nullable, None) + } // int x float -> f64 - ((Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, nullable, _), (Float32 | Float64, _)) if options.coerce_numbers => (Float64, nullable, None), + ( + (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, nullable, _), + (Float32 | Float64, _), + ) if options.coerce_numbers => (Float64, nullable, None), // float x int -> f64 - ((Float32 | Float64, nullable, _), (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, _)) if options.coerce_numbers => (Float64, nullable, None), + ( + (Float32 | Float64, nullable, _), + (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, _), + ) if options.coerce_numbers => (Float64, nullable, None), // incompatible formats, coerce to string ((Date64, nullable, _), (LargeUtf8, _)) => (LargeUtf8, nullable, None), ((LargeUtf8, nullable, _), (Date64, _)) => (LargeUtf8, nullable, None), - ((Date64, nullable, prev_st), (Date64, curr_st)) if prev_st != curr_st.as_ref() => (LargeUtf8, nullable, None), - ((prev_ty, _, prev_st), (curr_ty, curr_st)) => fail!("Cannot accept event {curr_ty:?} with strategy {curr_st:?} for tracer of primitive type {prev_ty:?} with strategy {prev_st:?}"), + ((Date64, nullable, prev_st), (Date64, curr_st)) if prev_st != curr_st.as_ref() => { + (LargeUtf8, nullable, None) + } + ((prev_ty, _, prev_st), (curr_ty, curr_st)) => { + let extra = if is_numeric(prev_ty) && is_numeric(&curr_ty) { + ": consider setting `coerce_numbers` to `true` to coerce different numeric types." + } else { + "" + }; + fail!( + "Cannot accept {curr_ty:?} {curr_st} for tracer of primitive type {prev_ty:?} {prev_st}{extra}", + curr_st = OptionalStrategyDisplay(curr_st.as_ref()), + prev_st = OptionalStrategyDisplay(prev_st), + ) + } }; Ok(res) } +struct OptionalStrategyDisplay<'a>(Option<&'a Strategy>); + +impl<'a> std::fmt::Display for OptionalStrategyDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + None => Ok(()), + Some(strategy) => write!(f, " with strategy {strategy}"), + } + } +} + +fn is_numeric(dt: &DataType) -> bool { + matches!( + dt, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + ) +} + #[derive(Debug, PartialEq, Clone)] pub struct UnknownTracer { pub name: String, diff --git a/serde_arrow/src/test/error_messages/trace_from_samples.rs b/serde_arrow/src/test/error_messages/trace_from_samples.rs index 9ad309d4..29369808 100644 --- a/serde_arrow/src/test/error_messages/trace_from_samples.rs +++ b/serde_arrow/src/test/error_messages/trace_from_samples.rs @@ -17,3 +17,12 @@ fn incompatible_primitives() { SerdeArrowSchema::from_samples(&(Item(42_u32), Item("foo bar")), TracingOptions::default()); assert_error_contains(&res, "path: \"$.item\""); } + +#[test] +fn number_coercion() { + let res = SerdeArrowSchema::from_samples(&(&32.0_f32, 42_u64), TracingOptions::default()); + assert_error_contains( + &res, + "consider setting `coerce_numbers` to `true` to coerce different numeric types.", + ); +} diff --git a/serde_arrow/src/test/error_messages/trace_from_type.rs b/serde_arrow/src/test/error_messages/trace_from_type.rs index 17f8bef2..acc13021 100644 --- a/serde_arrow/src/test/error_messages/trace_from_type.rs +++ b/serde_arrow/src/test/error_messages/trace_from_type.rs @@ -1,3 +1,4 @@ +use serde::Deserialize; use serde_json::Value; use crate::{ @@ -16,3 +17,18 @@ fn example() { assert_error_contains(&res, "path: \"$.item.element\""); assert_error_contains(&res, "tracer_type: \"Unknown\""); } + +#[test] +fn unsupported_recursive_types() { + #[allow(unused)] + #[derive(Deserialize)] + struct Tree { + left: Option>, + right: Option>, + } + + let res = SerdeArrowSchema::from_type::(TracingOptions::default()); + assert_error_contains(&res, "Too deeply nested type detected"); + // NOTE: do not check the complete path, it depends on the recursion limit + assert_error_contains(&res, "path: \"$.left.left.left.left.left.left"); +} diff --git a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs index 3170e5a8..8bbf13f0 100644 --- a/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs +++ b/serde_arrow/src/test_with_arrow/impls/issue_90_type_tracing.rs @@ -6,7 +6,6 @@ use serde_json::json; use crate::internal::{ arrow::{DataType, Field, UnionMode}, schema::{tracer::Tracer, transmute_field, Strategy, TracingOptions, STRATEGY_KEY}, - testing::assert_error_contains, utils::Item, }; @@ -225,18 +224,3 @@ fn trace_map() { ); assert_eq!(actual, expected); } - -#[test] -fn unsupported_recursive_types() { - #[allow(unused)] - #[derive(Deserialize)] - struct Tree { - left: Option>, - right: Option>, - } - - let res = Tracer::from_type::(TracingOptions::default()); - assert_error_contains(&res, "Too deeply nested type detected"); - // NOTE: do not check the complete path, it depends on the recursion limit - assert_error_contains(&res, "path: \"$.left.left.left.left.left.left"); -} From d199d64da4097ab044abe6fb12103b3f152b9707 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 17:20:05 +0200 Subject: [PATCH 172/178] Implement the display reprenstations of anyhow --- serde_arrow/src/internal/error.rs | 32 ++++++++++++++------- serde_arrow/src/test/error_messages/misc.rs | 10 +++++++ serde_arrow/src/test/error_messages/mod.rs | 1 + 3 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 serde_arrow/src/test/error_messages/misc.rs diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 35b11f30..c8e88f9e 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -72,13 +72,18 @@ pub type Result = std::result::Result; /// Common errors during `serde_arrow`'s usage /// -/// At the moment only a generic string error is supported, but it is planned to -/// offer concrete types to match against. +/// At the moment only a generic string error is supported, but it is planned to offer concrete +/// types to match against. /// -/// The error carries a backtrace if `RUST_BACKTRACE=1`, see [`std::backtrace`] -/// for details. This backtrace is included when printing the error. If the -/// error is caused by another error, that error can be retrieved with -/// [`source()`][std::error::Error::source]. +/// The error carries a backtrace if `RUST_BACKTRACE=1`, see [`std::backtrace`] for details. This +/// backtrace is included when printing the error. If the error is caused by another error, that +/// error can be retrieved with [`source()`][std::error::Error::source]. +/// +/// # Display representation +/// +/// This error type follows anyhow's display representation: when printed with display format (`{}`) +/// (or converted to string) the error does not include a backtrace. Use the debug format (`{:?}`) +/// to include the backtrace information. /// #[derive(PartialEq)] #[non_exhaustive] @@ -150,7 +155,13 @@ impl std::cmp::PartialEq for CustomErrorImpl { impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "<{self}>") + write!( + f, + "Error: {msg}{annotations}\n{bt}", + msg = self.message(), + annotations = AnnotationsDisplay(self.annotations()), + bt = BacktraceDisplay(self.backtrace()), + ) } } @@ -158,10 +169,9 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Error: {msg}{annotations}\n{bt}", + "Error: {msg}{annotations}", msg = self.message(), annotations = AnnotationsDisplay(self.annotations()), - bt = BacktraceDisplay(self.backtrace()), ) } } @@ -194,8 +204,8 @@ impl<'a> std::fmt::Display for BacktraceDisplay<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.0.status() { BacktraceStatus::Captured => write!(f, "Backtrace:\n{bt}", bt=self.0), - BacktraceStatus::Disabled => write!(f, "No backtrace captured. Set the `RUST_BACKTRACE=1` env variable to enable."), - _ => write!(f, "No backtrace captured. Most likely backtraces are not supported on the current platform."), + BacktraceStatus::Disabled => write!(f, "Backtrace not captured; set the `RUST_BACKTRACE=1` env variable to enable"), + _ => write!(f, "Backtrace not captured: most likely backtraces are not supported on the current platform"), } } } diff --git a/serde_arrow/src/test/error_messages/misc.rs b/serde_arrow/src/test/error_messages/misc.rs new file mode 100644 index 00000000..93b00e26 --- /dev/null +++ b/serde_arrow/src/test/error_messages/misc.rs @@ -0,0 +1,10 @@ +use crate::internal::error::Error; + +#[test] +fn backtrace_on_debug() { + let err = Error::custom(String::from("foo bar")); + + // NOTE: the exact message depends on the ability of Rust to capture a backtrace + assert_eq!(format!("{}", err).contains("Backtrace"), false); + assert_eq!(format!("{:?}", err).contains("Backtrace"), true); +} diff --git a/serde_arrow/src/test/error_messages/mod.rs b/serde_arrow/src/test/error_messages/mod.rs index 6b98bb21..cafb0e5b 100644 --- a/serde_arrow/src/test/error_messages/mod.rs +++ b/serde_arrow/src/test/error_messages/mod.rs @@ -1,4 +1,5 @@ mod deserializers; +mod misc; mod push_validity; mod trace_from_samples; mod trace_from_type; From 607d4617a7147a699abee152a47e638e94b166f6 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 17:20:50 +0200 Subject: [PATCH 173/178] Update changelog --- Changes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/Changes.md b/Changes.md index 34e4159d..17d602c8 100644 --- a/Changes.md +++ b/Changes.md @@ -11,6 +11,7 @@ New features - Add support to serialize / deserialize `bool` from integer arrays - Add a helper to construct `Bool8` arrays - Include the path of the field that caused an error in the error message +- Include backtrace information only for the debug representations of errors API changes From 16fe8f6970d24a60c0f87d7e4621e8bd0a4f6eff Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 18:47:30 +0200 Subject: [PATCH 174/178] Remove Ref + ?Sized from from_value and from_samples --- serde_arrow/src/arrow2_impl/schema.rs | 7 ++----- serde_arrow/src/arrow_impl/schema.rs | 20 ++++++++----------- .../src/internal/schema/from_samples/mod.rs | 5 +---- serde_arrow/src/internal/schema/mod.rs | 8 ++++---- serde_arrow/src/internal/utils/value.rs | 2 +- 5 files changed, 16 insertions(+), 26 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/schema.rs b/serde_arrow/src/arrow2_impl/schema.rs index a2acc88e..18808774 100644 --- a/serde_arrow/src/arrow2_impl/schema.rs +++ b/serde_arrow/src/arrow2_impl/schema.rs @@ -43,7 +43,7 @@ impl Sealed for Vec {} /// Schema support for `Vec` (*requires one of the /// `arrow2-*` features*) impl SchemaLike for Vec { - fn from_value(value: &T) -> Result { + fn from_value(value: T) -> Result { SerdeArrowSchema::from_value(value)?.try_into() } @@ -51,10 +51,7 @@ impl SchemaLike for Vec { SerdeArrowSchema::from_type::(options)?.try_into() } - fn from_samples( - samples: &T, - options: TracingOptions, - ) -> Result { + fn from_samples(samples: T, options: TracingOptions) -> Result { SerdeArrowSchema::from_samples(samples, options)?.try_into() } } diff --git a/serde_arrow/src/arrow_impl/schema.rs b/serde_arrow/src/arrow_impl/schema.rs index 5bf45f3d..a692bad9 100644 --- a/serde_arrow/src/arrow_impl/schema.rs +++ b/serde_arrow/src/arrow_impl/schema.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use serde::{Deserialize, Serialize}; + use crate::{ _impl::arrow::datatypes::{ DataType as ArrowDataType, Field as ArrowField, FieldRef, TimeUnit as ArrowTimeUnit, @@ -76,18 +78,15 @@ impl Sealed for Vec {} /// Schema support for `Vec` (*requires one of the /// `arrow-*` features*) impl SchemaLike for Vec { - fn from_value(value: &T) -> Result { + fn from_value(value: T) -> Result { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de>>(options: TracingOptions) -> Result { + fn from_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } - fn from_samples( - samples: &T, - options: TracingOptions, - ) -> Result { + fn from_samples(samples: T, options: TracingOptions) -> Result { SerdeArrowSchema::from_samples(samples, options)?.try_into() } } @@ -97,18 +96,15 @@ impl Sealed for Vec {} /// Schema support for `Vec` (*requires one of the /// `arrow-*` features*) impl SchemaLike for Vec { - fn from_value(value: &T) -> Result { + fn from_value(value: T) -> Result { SerdeArrowSchema::from_value(value)?.try_into() } - fn from_type<'de, T: serde::Deserialize<'de>>(options: TracingOptions) -> Result { + fn from_type<'de, T: Deserialize<'de>>(options: TracingOptions) -> Result { SerdeArrowSchema::from_type::(options)?.try_into() } - fn from_samples( - samples: &T, - options: TracingOptions, - ) -> Result { + fn from_samples(samples: T, options: TracingOptions) -> Result { SerdeArrowSchema::from_samples(samples, options)?.try_into() } } diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 6914e40d..18ef607e 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -18,10 +18,7 @@ use super::tracer::{ }; impl Tracer { - pub fn from_samples( - samples: &T, - options: TracingOptions, - ) -> Result { + pub fn from_samples(samples: T, options: TracingOptions) -> Result { let options = options.tracing_mode(TracingMode::FromSamples); let mut tracer = Tracer::new(String::from("$"), String::from("$"), Arc::new(options)); samples.serialize(OuterSequenceSerializer(&mut tracer))?; diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 09157a88..b196b7e9 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -133,7 +133,7 @@ pub trait SchemaLike: Sized + Sealed { /// fields, named `"key"` of integer type and named `"value"` of string /// type /// - fn from_value(value: &T) -> Result; + fn from_value(value: T) -> Result; /// Determine the schema from the given record type. See [`TracingOptions`] for customization /// options. @@ -271,7 +271,7 @@ pub trait SchemaLike: Sized + Sealed { /// # #[cfg(not(has_arrow))] /// # fn main() { } /// ``` - fn from_samples(samples: &T, options: TracingOptions) -> Result; + fn from_samples(samples: T, options: TracingOptions) -> Result; } /// A collection of fields as understood by `serde_arrow` @@ -286,7 +286,7 @@ pub struct SerdeArrowSchema { impl Sealed for SerdeArrowSchema {} impl SchemaLike for SerdeArrowSchema { - fn from_value(value: &T) -> Result { + fn from_value(value: T) -> Result { value::transmute(value) } @@ -294,7 +294,7 @@ impl SchemaLike for SerdeArrowSchema { Tracer::from_type::(options)?.to_schema() } - fn from_samples(samples: &T, options: TracingOptions) -> Result { + fn from_samples(samples: T, options: TracingOptions) -> Result { Tracer::from_samples(samples, options)?.to_schema() } } diff --git a/serde_arrow/src/internal/utils/value.rs b/serde_arrow/src/internal/utils/value.rs index 2dc7c1b5..68037dfe 100644 --- a/serde_arrow/src/internal/utils/value.rs +++ b/serde_arrow/src/internal/utils/value.rs @@ -82,7 +82,7 @@ impl std::hash::Hash for HashF64 { } } -pub fn transmute(value: &S) -> Result { +pub fn transmute(value: S) -> Result { let value = value.serialize(ValueSerializer)?; T::deserialize(ValueDeserializer::new(&value)) } From 5d94db090bf5da99469f9ee1d34757eca236f763 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 18:49:38 +0200 Subject: [PATCH 175/178] Remove Ref + ?Sized from ArrayBuilder::push / ArrayBuilder::extend --- serde_arrow/src/internal/array_builder.rs | 4 ++-- .../src/internal/serialization/outer_sequence_builder.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/serde_arrow/src/internal/array_builder.rs b/serde_arrow/src/internal/array_builder.rs index 6b07eec8..27ea7958 100644 --- a/serde_arrow/src/internal/array_builder.rs +++ b/serde_arrow/src/internal/array_builder.rs @@ -73,13 +73,13 @@ impl std::fmt::Debug for ArrayBuilder { impl ArrayBuilder { /// Add a single record to the arrays /// - pub fn push(&mut self, item: &T) -> Result<()> { + pub fn push(&mut self, item: T) -> Result<()> { self.builder.push(item) } /// Add multiple records to the arrays /// - pub fn extend(&mut self, items: &T) -> Result<()> { + pub fn extend(&mut self, items: T) -> Result<()> { self.builder.extend(items) } } diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index f727ef85..58bbf1ba 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -45,13 +45,13 @@ impl OuterSequenceBuilder { } /// Extend the builder with a sequence of items - pub fn extend(&mut self, value: &T) -> Result<()> { + pub fn extend(&mut self, value: T) -> Result<()> { value.serialize(Mut(self)) } /// Push a single item into the builder - pub fn push(&mut self, value: &T) -> Result<()> { - self.element(value) + pub fn push(&mut self, value: T) -> Result<()> { + self.element(&value) } } From b79f86dffa2d345b3e220b3379ff780b251a1539 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 18:53:40 +0200 Subject: [PATCH 176/178] Remove Ref + ?Sized for to_arrow, to_record_batch, to_arrow2 --- serde_arrow/src/arrow2_impl/api.rs | 5 +---- serde_arrow/src/arrow_impl/api.rs | 7 ++----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index 6127d956..0c2a4aa2 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -52,10 +52,7 @@ use crate::{ /// # } /// ``` /// -pub fn to_arrow2(fields: &[ArrowField], items: &T) -> Result>> -where - T: Serialize + ?Sized, -{ +pub fn to_arrow2(fields: &[ArrowField], items: T) -> Result>> { let builder = ArrayBuilder::from_arrow2(fields)?; items .serialize(Serializer::new(builder))? diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index 92644fa7..cc91c084 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -57,7 +57,7 @@ use super::type_support::fields_from_field_refs; /// # } /// ``` /// -pub fn to_arrow(fields: &[FieldRef], items: &T) -> Result> { +pub fn to_arrow(fields: &[FieldRef], items: T) -> Result> { let builder = ArrayBuilder::from_arrow(fields)?; items .serialize(Serializer::new(builder))? @@ -138,10 +138,7 @@ where /// # Ok(()) /// # } /// ``` -pub fn to_record_batch( - fields: &[FieldRef], - items: &T, -) -> Result { +pub fn to_record_batch(fields: &[FieldRef], items: &T) -> Result { let builder = ArrayBuilder::from_arrow(fields)?; items .serialize(Serializer::new(builder))? From 16bbcc93bc4b9c699cfbf7cc392c4fb5816f43c4 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Mon, 16 Sep 2024 19:20:10 +0200 Subject: [PATCH 177/178] Add support for FixedSizeBinary / FixedSizeList for arrow2 --- Changes.md | 2 +- serde_arrow/src/arrow2_impl/array.rs | 51 +++++++++++++++++-- .../src/test_with_arrow/impls/bytes.rs | 3 -- .../test_with_arrow/impls/fixed_size_list.rs | 3 -- 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/Changes.md b/Changes.md index 17d602c8..9ac4f1a6 100644 --- a/Changes.md +++ b/Changes.md @@ -7,7 +7,7 @@ Refactor the underlying implementation to prepare for further development New features - Add `arrow=53` support -- Add `Binary` / `LargeBinary` support for `arrow2` +- Add `Binary`, `LargeBinary`, `FixedSizeBinary(n)`, `FixedSizeList(n)` support for `arrow2` - Add support to serialize / deserialize `bool` from integer arrays - Add a helper to construct `Bool8` arrays - Include the path of the field that caused an error in the error message diff --git a/serde_arrow/src/arrow2_impl/array.rs b/serde_arrow/src/arrow2_impl/array.rs index 6a32fd5e..19e50e4f 100644 --- a/serde_arrow/src/arrow2_impl/array.rs +++ b/serde_arrow/src/arrow2_impl/array.rs @@ -3,8 +3,9 @@ use std::borrow::Cow; use crate::{ _impl::arrow2::{ array::{ - Array as A2Array, BinaryArray, BooleanArray, DictionaryArray, DictionaryKey, ListArray, - MapArray, NullArray, PrimitiveArray, StructArray, UnionArray, Utf8Array, + Array as A2Array, BinaryArray, BooleanArray, DictionaryArray, DictionaryKey, + FixedSizeBinaryArray, FixedSizeListArray, ListArray, MapArray, NullArray, + PrimitiveArray, StructArray, UnionArray, Utf8Array, }, bitmap::Bitmap, buffer::Buffer, @@ -14,7 +15,8 @@ use crate::{ internal::{ arrow::{ Array, ArrayView, BitsWithOffset, BooleanArrayView, BytesArrayView, DecimalArrayView, - DenseUnionArrayView, DictionaryArrayView, FieldMeta, ListArrayView, NullArrayView, + DenseUnionArrayView, DictionaryArrayView, FieldMeta, FixedSizeBinaryArrayView, + FixedSizeListArrayView, ListArrayView, NullArrayView, PrimitiveArray as InternalPrimitiveArray, PrimitiveArrayView, StructArrayView, TimeArrayView, TimestampArrayView, }, @@ -161,8 +163,27 @@ impl TryFrom for ArrayRef { Some(arr.offsets.into()), )?)) } - A::FixedSizeList(_) => fail!("FixedSizeList is not supported by arrow2"), - A::FixedSizeBinary(_) => fail!("FixedSizeBinary is not supported by arrow2"), + A::FixedSizeList(arr) => { + let child: ArrayRef = (*arr.element).try_into()?; + let child_field = field_from_array_and_meta(child.as_ref(), arr.meta); + let data_type = T::FixedSizeList(Box::new(child_field), arr.n.try_into()?); + let validity = arr.validity.map(|v| Bitmap::from_u8_vec(v, arr.len)); + + Ok(Box::new(FixedSizeListArray::try_new( + data_type, child, validity, + )?)) + } + A::FixedSizeBinary(arr) => { + let n = usize::try_from(arr.n)?; + let len = arr.data.len() / n; + let validity = arr.validity.map(|v| Bitmap::from_u8_vec(v, len)); + + Ok(Box::new(FixedSizeBinaryArray::try_new( + T::FixedSizeBinary(n), + Buffer::from(arr.data), + validity, + )?)) + } } } } @@ -387,6 +408,26 @@ impl<'a> TryFrom<&'a dyn A2Array> for ArrayView<'a> { offsets: offsets.as_slice(), fields, })) + } else if let Some(array) = any.downcast_ref::() { + let T::FixedSizeList(field, _) = array.data_type() else { + fail!("Invalid type: expected FixedSizeList"); + }; + + let child_view: ArrayView<'_> = array.values().as_ref().try_into()?; + + Ok(V::FixedSizeList(FixedSizeListArrayView { + len: array.len(), + n: array.size().try_into()?, + validity: bits_with_offset_from_bitmap(array.validity()), + meta: meta_from_field(field.as_ref().try_into()?), + element: Box::new(child_view), + })) + } else if let Some(array) = any.downcast_ref::() { + Ok(V::FixedSizeBinary(FixedSizeBinaryArrayView { + n: array.size().try_into()?, + validity: bits_with_offset_from_bitmap(array.validity()), + data: array.values().as_slice(), + })) } else { fail!( "Cannot convert array with data type {:?} into an array view", diff --git a/serde_arrow/src/test_with_arrow/impls/bytes.rs b/serde_arrow/src/test_with_arrow/impls/bytes.rs index c6e6d930..9578098d 100644 --- a/serde_arrow/src/test_with_arrow/impls/bytes.rs +++ b/serde_arrow/src/test_with_arrow/impls/bytes.rs @@ -157,7 +157,6 @@ mod fixed_size_binary { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "FixedSizeBinary(3)"}])) .serialize(&items) .deserialize(&items); @@ -172,7 +171,6 @@ mod fixed_size_binary { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "FixedSizeBinary(3)"}])) .serialize(&items) .deserialize(&items); @@ -187,7 +185,6 @@ mod fixed_size_binary { ]; Test::new() - .skip_arrow2() .with_schema(json!([{"name": "item", "data_type": "FixedSizeBinary(3)"}])) .serialize(&items) .deserialize_borrowed(&items); diff --git a/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs b/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs index a714e8ee..51969b8e 100644 --- a/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs +++ b/serde_arrow/src/test_with_arrow/impls/fixed_size_list.rs @@ -12,7 +12,6 @@ fn example() { let items = [Item(vec![0_u8, 1]), Item(vec![2, 3]), Item(vec![4, 5])]; Test::new() - .skip_arrow2() .with_schema(json!([{ "name": "item", "data_type": "FixedSizeList(2)", @@ -27,7 +26,6 @@ fn example_nullable_no_nulls() { let items = [Item(vec![0_u16, 1]), Item(vec![2, 3]), Item(vec![4, 5])]; Test::new() - .skip_arrow2() .with_schema(json!([{ "name": "item", "data_type": "FixedSizeList(2)", @@ -48,7 +46,6 @@ fn example_nullable_with_nulls() { ]; Test::new() - .skip_arrow2() .with_schema(json!([{ "name": "item", "data_type": "FixedSizeList(2)", From 78ebdbfb08f1f157ba143d71471c38c970c490d9 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 22 Sep 2024 21:32:46 +0200 Subject: [PATCH 178/178] Add ArrayBuilder::to_arrays --- serde_arrow/src/arrow2_impl/api.rs | 9 ++++----- serde_arrow/src/arrow_impl/api.rs | 12 +++++------- serde_arrow/src/arrow_impl/array.rs | 14 +++++++++++--- serde_arrow/src/internal/array_builder.rs | 10 +++++++++- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/serde_arrow/src/arrow2_impl/api.rs b/serde_arrow/src/arrow2_impl/api.rs index 0c2a4aa2..41cad44b 100644 --- a/serde_arrow/src/arrow2_impl/api.rs +++ b/serde_arrow/src/arrow2_impl/api.rs @@ -107,11 +107,10 @@ impl crate::internal::array_builder::ArrayBuilder { /// Construct `arrow2` arrays and reset the builder (*requires one of the /// `arrow2-*` features*) pub fn to_arrow2(&mut self) -> Result>> { - let mut arrays = Vec::new(); - for field in self.builder.take_records()? { - arrays.push(field.into_array()?.try_into()?); - } - Ok(arrays) + self.to_arrays()? + .into_iter() + .map(Box::::try_from) + .collect() } } diff --git a/serde_arrow/src/arrow_impl/api.rs b/serde_arrow/src/arrow_impl/api.rs index cc91c084..cf298f27 100644 --- a/serde_arrow/src/arrow_impl/api.rs +++ b/serde_arrow/src/arrow_impl/api.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use crate::{ _impl::arrow::{ - array::{make_array, Array, ArrayRef, RecordBatch}, + array::{Array, ArrayRef, RecordBatch}, datatypes::{FieldRef, Schema}, }, internal::{ @@ -186,12 +186,10 @@ impl crate::internal::array_builder::ArrayBuilder { /// Construct `arrow` arrays and reset the builder (*requires one of the /// `arrow-*` features*) pub fn to_arrow(&mut self) -> Result> { - let mut arrays = Vec::new(); - for field in self.builder.take_records()? { - let data = field.into_array()?.try_into()?; - arrays.push(make_array(data)); - } - Ok(arrays) + self.to_arrays()? + .into_iter() + .map(ArrayRef::try_from) + .collect() } /// Construct a [`RecordBatch`] and reset the builder (*requires one of the diff --git a/serde_arrow/src/arrow_impl/array.rs b/serde_arrow/src/arrow_impl/array.rs index a256ca3e..ae0a2c4d 100644 --- a/serde_arrow/src/arrow_impl/array.rs +++ b/serde_arrow/src/arrow_impl/array.rs @@ -6,9 +6,9 @@ use half::f16; use crate::{ _impl::arrow::{ array::{ - Array, ArrayData, BooleanArray, DictionaryArray, FixedSizeBinaryArray, - FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, MapArray, - NullArray, PrimitiveArray, StructArray, UnionArray, + make_array, Array, ArrayData, ArrayRef, BooleanArray, DictionaryArray, + FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, + GenericStringArray, MapArray, NullArray, PrimitiveArray, StructArray, UnionArray, }, buffer::{Buffer, ScalarBuffer}, datatypes::{ @@ -34,6 +34,14 @@ use crate::{ }, }; +impl TryFrom for ArrayRef { + type Error = Error; + + fn try_from(value: crate::internal::arrow::Array) -> Result { + Ok(make_array(ArrayData::try_from(value)?)) + } +} + impl TryFrom for ArrayData { type Error = Error; diff --git a/serde_arrow/src/internal/array_builder.rs b/serde_arrow/src/internal/array_builder.rs index 27ea7958..869c9694 100644 --- a/serde_arrow/src/internal/array_builder.rs +++ b/serde_arrow/src/internal/array_builder.rs @@ -1,7 +1,7 @@ use serde::Serialize; use crate::internal::{ - error::Result, schema::SerdeArrowSchema, serialization::OuterSequenceBuilder, + arrow::Array, error::Result, schema::SerdeArrowSchema, serialization::OuterSequenceBuilder, }; /// Construct arrays by pushing individual records @@ -82,6 +82,14 @@ impl ArrayBuilder { pub fn extend(&mut self, items: T) -> Result<()> { self.builder.extend(items) } + + pub(crate) fn to_arrays(&mut self) -> Result> { + let mut arrays = Vec::new(); + for field in self.builder.take_records()? { + arrays.push(field.into_array()?); + } + Ok(arrays) + } } impl std::convert::AsRef for ArrayBuilder {