From 9f2d985bc61fd007412ab23de826c38b01efd86a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Dahlb=C3=A6k?= <30782351+dahlbaek@users.noreply.github.com> Date: Sun, 9 Apr 2023 01:38:53 +0200 Subject: [PATCH] Issue #3: Fix test failures - Fix GroupByKey test by using KV instead of tuple. - Fix data race when writing to SERIALIZED_FNS. - Enable ensure_assert_fails test which was presumably previously causing spurious failures due to data race. --- sdks/rust/src/elem_types/kv/mod.rs | 4 ++-- sdks/rust/src/internals/serialize.rs | 23 ++++++++++++++++------- sdks/rust/src/tests/primitives_test.rs | 3 --- sdks/rust/src/worker/operators.rs | 3 ++- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/sdks/rust/src/elem_types/kv/mod.rs b/sdks/rust/src/elem_types/kv/mod.rs index 152fcf9193ec9..f3b546170b54f 100644 --- a/sdks/rust/src/elem_types/kv/mod.rs +++ b/sdks/rust/src/elem_types/kv/mod.rs @@ -2,8 +2,8 @@ use std::{cmp, fmt}; #[derive(Clone, PartialEq, Eq, Debug)] pub struct KV { - k: K, - v: V, + pub k: K, + pub v: V, } impl KV diff --git a/sdks/rust/src/internals/serialize.rs b/sdks/rust/src/internals/serialize.rs index e62152402294d..bda1bfc52b9e7 100644 --- a/sdks/rust/src/internals/serialize.rs +++ b/sdks/rust/src/internals/serialize.rs @@ -8,12 +8,15 @@ use std::sync::Mutex; use once_cell::sync::Lazy; +use crate::elem_types::kv::KV; + static SERIALIZED_FNS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); pub fn serialize_fn(obj: Box) -> String { - let name = format!("object{}", SERIALIZED_FNS.lock().unwrap().len()); - SERIALIZED_FNS.lock().unwrap().insert(name.to_string(), obj); + let mut serialized_fns = SERIALIZED_FNS.lock().unwrap(); + let name = format!("object{}", serialized_fns.len()); + serialized_fns.insert(name.to_string(), obj); name } @@ -83,7 +86,7 @@ pub fn to_generic_dofn_dyn + 'static>( } pub trait KeyExtractor: Sync + Send { - fn extract(&self, kv: &dyn Any) -> (String, Box); + fn extract(&self, kv: &dyn Any) -> KV>; fn recombine( &self, key: &str, @@ -104,9 +107,12 @@ impl Default for TypedKeyExtractor { } impl KeyExtractor for TypedKeyExtractor { - fn extract(&self, kv: &dyn Any) -> (String, Box) { - let typed_kv = kv.downcast_ref::<(String, V)>().unwrap(); - (typed_kv.0.clone(), Box::new(typed_kv.1.clone())) + fn extract(&self, kv: &dyn Any) -> KV> { + let typed_kv = kv.downcast_ref::>().unwrap(); + KV { + k: typed_kv.k.clone(), + v: Box::new(typed_kv.v.clone()), + } } fn recombine( &self, @@ -117,6 +123,9 @@ impl KeyExtractor for TypedKeyExtractor { for untyped_value in values.iter() { typed_values.push(untyped_value.downcast_ref::().unwrap().clone()); } - Box::new((key.to_string(), typed_values)) + Box::new(KV { + k: key.to_string(), + v: typed_values, + }) } } diff --git a/sdks/rust/src/tests/primitives_test.rs b/sdks/rust/src/tests/primitives_test.rs index 1acf3f736427b..313a41cfa33c2 100644 --- a/sdks/rust/src/tests/primitives_test.rs +++ b/sdks/rust/src/tests/primitives_test.rs @@ -39,8 +39,6 @@ mod tests { .await; } - // TODO: Enabling these tests seem to cause random failures in other - // tests in this file. #[tokio::test] #[should_panic] // This tests that AssertEqualUnordered is actually doing its job. @@ -54,7 +52,6 @@ mod tests { } #[tokio::test] - #[ignore] #[should_panic] async fn ensure_assert_fails_on_empty() { DirectRunner::new() diff --git a/sdks/rust/src/worker/operators.rs b/sdks/rust/src/worker/operators.rs index a0fe4a4a6ef34..2a8866f6fd3c6 100644 --- a/sdks/rust/src/worker/operators.rs +++ b/sdks/rust/src/worker/operators.rs @@ -27,6 +27,7 @@ use std::sync::{Arc, Mutex}; use once_cell::sync::Lazy; use serde_json; +use crate::elem_types::kv::KV; use crate::internals::serialize; use crate::internals::urns; use crate::proto::beam_api::fn_execution::ProcessBundleDescriptor; @@ -500,7 +501,7 @@ impl OperatorI for GroupByKeyWithinBundleOperator { fn process(&self, element: &WindowedValue) { // TODO: assumes global window let untyped_value: &dyn Any = &*element.value; - let (key, value) = self.key_extractor.extract(untyped_value); + let KV { k: key, v: value } = self.key_extractor.extract(untyped_value); let mut grouped_values = self.grouped_values.lock().unwrap(); if !grouped_values.contains_key(&key) { grouped_values.insert(key.clone(), Box::default());