From 07316123a26ba0a7e0554885015284c4494b67e8 Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 08:18:46 +0200 Subject: [PATCH 1/3] Rewrite FieldLookup helper to ensure ArrayBuilder is Send + Sync --- serde_arrow/src/internal/array_builder.rs | 8 ++++++++ .../internal/serialization/struct_builder.rs | 18 ++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/internal/array_builder.rs b/serde_arrow/src/internal/array_builder.rs index 3e65a6e4..5b86e198 100644 --- a/serde_arrow/src/internal/array_builder.rs +++ b/serde_arrow/src/internal/array_builder.rs @@ -95,3 +95,11 @@ impl std::convert::AsMut for ArrayBuilder { self } } + + +const _: () = { + const fn assert_send_sync() { + () + } + assert_send_sync::() +}; diff --git a/serde_arrow/src/internal/serialization/struct_builder.rs b/serde_arrow/src/internal/serialization/struct_builder.rs index 68753f13..228f3b1f 100644 --- a/serde_arrow/src/internal/serialization/struct_builder.rs +++ b/serde_arrow/src/internal/serialization/struct_builder.rs @@ -205,12 +205,23 @@ impl SimpleSerializer for StructBuilder { } } +/// Optimize field lookups for static names #[derive(Debug, Clone)] pub struct FieldLookup { - pub cached_names: Vec>, + pub cached_names: Vec>, pub index: BTreeMap, } +/// A wrapper around a static field name that compares using ptr and length +#[derive(Debug, Clone)] +pub struct StaticFieldName(&'static str); + +impl std::cmp::PartialEq for StaticFieldName { + fn eq(&self, other: &Self) -> bool { + (self.0.as_ptr(), self.0.len()) == (other.0.as_ptr(), other.0.len()) + } +} + impl FieldLookup { pub fn new(field_names: Vec) -> Result { let mut index = BTreeMap::new(); @@ -234,13 +245,12 @@ impl FieldLookup { } pub fn lookup(&mut self, guess: usize, key: &'static str) -> Option { - let fast_key = (key.as_ptr(), key.len()); - if self.cached_names.get(guess) == Some(&Some(fast_key)) { + if self.cached_names.get(guess) == Some(&Some(StaticFieldName(key))) { Some(guess) } else { let &idx = self.index.get(key)?; if self.cached_names[idx].is_none() { - self.cached_names[idx] = Some(fast_key); + self.cached_names[idx] = Some(StaticFieldName(key)); } Some(idx) } From ac8bcafdd2f230cfdb84237c7bd60a9ae538102e Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 08:42:33 +0200 Subject: [PATCH 2/3] Assert that all public types implement Send + Sync --- serde_arrow/src/internal/array_builder.rs | 7 ++----- serde_arrow/src/internal/deserializer.rs | 5 +++++ serde_arrow/src/internal/error.rs | 6 ++++++ serde_arrow/src/internal/schema/extensions/mod.rs | 7 +++++++ serde_arrow/src/internal/schema/mod.rs | 8 ++++++++ serde_arrow/src/internal/serializer.rs | 5 +++++ 6 files changed, 33 insertions(+), 5 deletions(-) diff --git a/serde_arrow/src/internal/array_builder.rs b/serde_arrow/src/internal/array_builder.rs index 5b86e198..6b07eec8 100644 --- a/serde_arrow/src/internal/array_builder.rs +++ b/serde_arrow/src/internal/array_builder.rs @@ -96,10 +96,7 @@ impl std::convert::AsMut for ArrayBuilder { } } - const _: () = { - const fn assert_send_sync() { - () - } - assert_send_sync::() + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for ArrayBuilder {} }; diff --git a/serde_arrow/src/internal/deserializer.rs b/serde_arrow/src/internal/deserializer.rs index 723615b3..3fdaaa2c 100644 --- a/serde_arrow/src/internal/deserializer.rs +++ b/serde_arrow/src/internal/deserializer.rs @@ -194,3 +194,8 @@ impl<'de> serde::de::Deserializer<'de> for Deserializer<'de> { false } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl<'de> AssertSendSync for Deserializer<'de> {} +}; diff --git a/serde_arrow/src/internal/error.rs b/serde_arrow/src/internal/error.rs index 610cd1ab..66b22f58 100644 --- a/serde_arrow/src/internal/error.rs +++ b/serde_arrow/src/internal/error.rs @@ -211,3 +211,9 @@ impl From for PanicOnErrorError { panic!("{value}"); } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for Error {} + impl AssertSendSync for Result {} +}; diff --git a/serde_arrow/src/internal/schema/extensions/mod.rs b/serde_arrow/src/internal/schema/extensions/mod.rs index fa879d7c..9b11016e 100644 --- a/serde_arrow/src/internal/schema/extensions/mod.rs +++ b/serde_arrow/src/internal/schema/extensions/mod.rs @@ -6,3 +6,10 @@ mod variable_shape_tensor_field; pub use bool8_field::Bool8Field; pub use fixed_shape_tensor_field::FixedShapeTensorField; pub use variable_shape_tensor_field::VariableShapeTensorField; + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for Bool8Field {} + impl AssertSendSync for FixedShapeTensorField {} + impl AssertSendSync for VariableShapeTensorField {} +}; diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 8ff90182..7537b268 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -546,3 +546,11 @@ impl<'a> std::fmt::Display for DataTypeDisplay<'a> { } } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for SerdeArrowSchema {} + impl AssertSendSync for TracingOptions {} + impl AssertSendSync for Strategy {} + impl AssertSendSync for Overwrites {} +}; diff --git a/serde_arrow/src/internal/serializer.rs b/serde_arrow/src/internal/serializer.rs index add07604..70927cf2 100644 --- a/serde_arrow/src/internal/serializer.rs +++ b/serde_arrow/src/internal/serializer.rs @@ -281,3 +281,8 @@ impl> serde::ser::SerializeTupleVariant for CollectionSer Ok(Serializer(self.0)) } } + +const _: () = { + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for Serializer {} +}; From d99cbee8a52f144452627693a1ba42934e33a98c Mon Sep 17 00:00:00 2001 From: Christopher Prohm Date: Sun, 1 Sep 2024 08:47:08 +0200 Subject: [PATCH 3/3] Execute test workflow on PRs for develop branches --- .github/workflows/test.yml | 3 ++- x.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1a9a1cfc..a69e321e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,8 @@ "workflow_dispatch": {}, "pull_request": { "branches": [ - "main" + "main", + "develop-*" ], "types": [ "opened", diff --git a/x.py b/x.py index cbcfc488..df6d900e 100644 --- a/x.py +++ b/x.py @@ -35,7 +35,7 @@ "on": { "workflow_dispatch": {}, "pull_request": { - "branches": ["main"], + "branches": ["main", "develop-*"], "types": [ "opened", "edited",