From 0f33802e4a04da216ff791ae7b784d9943204f67 Mon Sep 17 00:00:00 2001 From: jordanrfrazier <122494242+jordanrfrazier@users.noreply.github.com> Date: Mon, 31 Jul 2023 13:24:46 -0700 Subject: [PATCH] feat: collect to list (non-windowed) (primitive/strings/booleans) (#569) Adds the `collect` function for non-windowed, primitive/string/boolean types. This doesn't support `LargeUtf8` because the other string aggregations that use the `create_typed_evaluator` macro don't support generic offset size, so that will come in a follow up. --------- Co-authored-by: Kevin J Nguyen --- crates/sparrow-compiler/src/ast_to_dfg.rs | 70 +++-- .../src/functions/collection.rs | 5 + crates/sparrow-instructions/src/evaluators.rs | 10 + .../src/evaluators/aggregation/token.rs | 2 + .../aggregation/token/collect_token.rs | 61 ++++ .../src/evaluators/list.rs | 9 + .../src/evaluators/list/collect_boolean.rs | 108 +++++++ .../src/evaluators/list/collect_map.rs | 31 ++ .../src/evaluators/list/collect_primitive.rs | 131 ++++++++ .../src/evaluators/list/collect_string.rs | 109 +++++++ .../sparrow-main/tests/e2e/collect_tests.rs | 295 ++++++++++++++++++ crates/sparrow-main/tests/e2e/main.rs | 1 + crates/sparrow-plan/src/inst.rs | 5 + 13 files changed, 816 insertions(+), 21 deletions(-) create mode 100644 crates/sparrow-instructions/src/evaluators/aggregation/token/collect_token.rs create mode 100644 crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs create mode 100644 crates/sparrow-instructions/src/evaluators/list/collect_map.rs create mode 100644 crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs create mode 100644 crates/sparrow-instructions/src/evaluators/list/collect_string.rs create mode 100644 crates/sparrow-main/tests/e2e/collect_tests.rs diff --git a/crates/sparrow-compiler/src/ast_to_dfg.rs b/crates/sparrow-compiler/src/ast_to_dfg.rs index a615cbdbd..56994e4e9 100644 --- a/crates/sparrow-compiler/src/ast_to_dfg.rs +++ b/crates/sparrow-compiler/src/ast_to_dfg.rs @@ -522,31 +522,28 @@ pub fn add_to_dfg( dfg.bind("$condition_input", args[0].inner().clone()); let window = &expr.unwrap().args()[1]; - let (condition, duration) = match window.op() { - ExprOp::Call(window_name) => { - flatten_window_args(window_name, window, dfg, data_context, diagnostics)? - } - ExprOp::Literal(v) if v.inner() == &LiteralValue::Null => { - // Unwindowed aggregations just use nulls - let null_arg = dfg.add_literal(LiteralValue::Null.to_scalar()?)?; - let null_arg = Located::new( - add_literal( - dfg, - null_arg, - FenlType::Concrete(DataType::Null), - window.location().clone(), - )?, - window.location().clone(), - ); - - (null_arg.clone(), null_arg) - } - unexpected => anyhow::bail!("expected window, found {:?}", unexpected), - }; + let (condition, duration) = + flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?; dfg.exit_env(); // [agg_input, condition, duration] vec![args[0].clone(), condition, duration] + } else if function.name() == "collect" { + // The collect function contains a window, but does not follow the same signature + // pattern as aggregations, so it requires a different flattening strategy. + // + // TODO: Flattening the window arguments is hacky and confusing. We should instead + // incorporate the tick directly into the function containing the window. + dfg.enter_env(); + dfg.bind("$condition_input", args[1].inner().clone()); + + let window = &expr.unwrap().args()[2]; + let (condition, duration) = + flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?; + + dfg.exit_env(); + // [max, input, condition, duration] + vec![args[0].clone(), args[1].clone(), condition, duration] } else if function.name() == "when" || function.name() == "if" { dfg.enter_env(); dfg.bind("$condition_input", args[1].inner().clone()); @@ -605,6 +602,37 @@ pub fn add_to_dfg( } } +#[allow(clippy::type_complexity)] +fn flatten_window_args_if_needed( + window: &Located>, + dfg: &mut Dfg, + data_context: &mut DataContext, + diagnostics: &mut DiagnosticCollector<'_>, +) -> anyhow::Result<(Located>, Located>)> { + let (condition, duration) = match window.op() { + ExprOp::Call(window_name) => { + flatten_window_args(window_name, window, dfg, data_context, diagnostics)? + } + ExprOp::Literal(v) if v.inner() == &LiteralValue::Null => { + // Unwindowed aggregations just use nulls + let null_arg = dfg.add_literal(LiteralValue::Null.to_scalar()?)?; + let null_arg = Located::new( + add_literal( + dfg, + null_arg, + FenlType::Concrete(DataType::Null), + window.location().clone(), + )?, + window.location().clone(), + ); + + (null_arg.clone(), null_arg) + } + unexpected => anyhow::bail!("expected window, found {:?}", unexpected), + }; + Ok((condition, duration)) +} + // Verify that the arguments are compatibly partitioned. fn verify_same_partitioning( data_context: &DataContext, diff --git a/crates/sparrow-compiler/src/functions/collection.rs b/crates/sparrow-compiler/src/functions/collection.rs index 78f003c09..f28129e9a 100644 --- a/crates/sparrow-compiler/src/functions/collection.rs +++ b/crates/sparrow-compiler/src/functions/collection.rs @@ -12,4 +12,9 @@ pub(super) fn register(registry: &mut Registry) { .register("index(i: i64, list: list) -> T") .with_implementation(Implementation::Instruction(InstOp::Index)) .set_internal(); + + registry + .register("collect(const max: i64, input: T, window: window = null) -> list") + .with_implementation(Implementation::Instruction(InstOp::Collect)) + .set_internal(); } diff --git a/crates/sparrow-instructions/src/evaluators.rs b/crates/sparrow-instructions/src/evaluators.rs index 770e923cf..5341ef6ee 100644 --- a/crates/sparrow-instructions/src/evaluators.rs +++ b/crates/sparrow-instructions/src/evaluators.rs @@ -179,6 +179,16 @@ fn create_simple_evaluator( create_number_evaluator!(&info.args[0].data_type, ClampEvaluator, info) } InstOp::Coalesce => CoalesceEvaluator::try_new(info), + InstOp::Collect => { + create_typed_evaluator!( + &info.args[1].data_type, + CollectPrimitiveEvaluator, + CollectMapEvaluator, + CollectBooleanEvaluator, + CollectStringEvaluator, + info + ) + } InstOp::CountIf => CountIfEvaluator::try_new(info), InstOp::DayOfMonth => DayOfMonthEvaluator::try_new(info), InstOp::DayOfMonth0 => DayOfMonth0Evaluator::try_new(info), diff --git a/crates/sparrow-instructions/src/evaluators/aggregation/token.rs b/crates/sparrow-instructions/src/evaluators/aggregation/token.rs index 766028640..72ff04303 100644 --- a/crates/sparrow-instructions/src/evaluators/aggregation/token.rs +++ b/crates/sparrow-instructions/src/evaluators/aggregation/token.rs @@ -1,6 +1,7 @@ //! Tokens representing keys for compute storage. mod boolean_accum_token; +mod collect_token; mod count_accum_token; pub mod lag_token; mod map_accum_token; @@ -12,6 +13,7 @@ mod two_stacks_primitive_accum_token; mod two_stacks_string_accum_token; pub use boolean_accum_token::*; +pub use collect_token::*; pub use count_accum_token::*; pub use map_accum_token::*; pub use primitive_accum_token::*; diff --git a/crates/sparrow-instructions/src/evaluators/aggregation/token/collect_token.rs b/crates/sparrow-instructions/src/evaluators/aggregation/token/collect_token.rs new file mode 100644 index 000000000..ad0dcb73d --- /dev/null +++ b/crates/sparrow-instructions/src/evaluators/aggregation/token/collect_token.rs @@ -0,0 +1,61 @@ +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::collections::VecDeque; + +use crate::{ComputeStore, StateToken, StoreKey}; + +/// State token used for the lag operator. +#[derive(Default, Debug)] +pub struct CollectToken +where + T: Clone, + T: Serialize + DeserializeOwned, + Vec>>: Serialize + DeserializeOwned, +{ + state: Vec>>, +} + +impl CollectToken +where + T: Clone, + T: Serialize + DeserializeOwned, + Vec>>: Serialize + DeserializeOwned, +{ + pub fn resize(&mut self, len: usize) { + if len >= self.state.len() { + self.state.resize(len + 1, VecDeque::new()); + } + } + + pub fn add_value(&mut self, max: usize, index: usize, input: Option) { + self.state[index].push_back(input); + if self.state[index].len() > max { + self.state[index].pop_front(); + } + } + + pub fn state(&self, index: usize) -> &VecDeque> { + &self.state[index] + } +} + +impl StateToken for CollectToken +where + T: Clone, + T: Serialize + DeserializeOwned, + Vec>>: Serialize + DeserializeOwned, +{ + fn restore(&mut self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> { + if let Some(state) = store.get(key)? { + self.state = state; + } else { + self.state.clear(); + } + + Ok(()) + } + + fn store(&self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> { + store.put(key, &self.state) + } +} diff --git a/crates/sparrow-instructions/src/evaluators/list.rs b/crates/sparrow-instructions/src/evaluators/list.rs index c1c37f4f9..971f1790b 100644 --- a/crates/sparrow-instructions/src/evaluators/list.rs +++ b/crates/sparrow-instructions/src/evaluators/list.rs @@ -1,2 +1,11 @@ +mod collect_boolean; +mod collect_map; +mod collect_primitive; +mod collect_string; mod index; + +pub(super) use collect_boolean::*; +pub(super) use collect_map::*; +pub(super) use collect_primitive::*; +pub(super) use collect_string::*; pub(super) use index::*; diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs b/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs new file mode 100644 index 000000000..7e0945dd5 --- /dev/null +++ b/crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs @@ -0,0 +1,108 @@ +use crate::{CollectToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateToken, StaticInfo}; +use arrow::array::{ArrayRef, AsArray, BooleanBuilder, ListBuilder}; +use arrow::datatypes::DataType; +use itertools::izip; +use sparrow_arrow::scalar_value::ScalarValue; +use sparrow_plan::ValueRef; +use std::sync::Arc; + +/// Evaluator for the `collect` instruction. +/// +/// Collects a stream of values into a List. A list is produced +/// for each input value received, growing up to a maximum size. +/// +/// If the list is empty, an empty list is returned (rather than `null`). +#[derive(Debug)] +pub struct CollectBooleanEvaluator { + /// The max size of the buffer. + /// + /// Once the max size is reached, the front will be popped and the new + /// value pushed to the back. + max: usize, + input: ValueRef, + tick: ValueRef, + duration: ValueRef, + /// Contains the buffer of values for each entity + token: CollectToken, +} + +impl EvaluatorFactory for CollectBooleanEvaluator { + fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { + let input_type = info.args[1].data_type(); + let result_type = info.result_type; + match result_type { + DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), + other => anyhow::bail!("expected list result type, saw {:?}", other), + }; + + let max = match info.args[0].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { + anyhow::bail!("unexpected value of `max` -- must be > 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `max = null`, we use usize::MAX value as a way + // to have an "unlimited" buffer. + Some(ScalarValue::Int64(None)) => usize::MAX, + Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for max parameter"), + }; + + let (_, input, tick, duration) = info.unpack_arguments()?; + Ok(Box::new(Self { + max, + input, + tick, + duration, + token: CollectToken::default(), + })) + } +} + +impl Evaluator for CollectBooleanEvaluator { + fn evaluate(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result { + match (self.tick.is_literal_null(), self.duration.is_literal_null()) { + (true, true) => self.evaluate_non_windowed(info), + (true, false) => unimplemented!("since window aggregation unsupported"), + (false, false) => panic!("sliding window aggregation should use other evaluator"), + (_, _) => anyhow::bail!("saw invalid combination of tick and duration"), + } + } + + fn state_token(&self) -> Option<&dyn StateToken> { + Some(&self.token) + } + + fn state_token_mut(&mut self) -> Option<&mut dyn StateToken> { + Some(&mut self.token) + } +} + +impl CollectBooleanEvaluator { + fn ensure_entity_capacity(&mut self, len: usize) { + self.token.resize(len) + } + + fn evaluate_non_windowed(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result { + let input = info.value(&self.input)?.array_ref()?; + let key_capacity = info.grouping().num_groups(); + let entity_indices = info.grouping().group_indices(); + assert_eq!(entity_indices.len(), input.len()); + + self.ensure_entity_capacity(key_capacity); + + let input = input.as_boolean(); + let builder = BooleanBuilder::new(); + let mut list_builder = ListBuilder::new(builder); + + izip!(entity_indices.values(), input).for_each(|(entity_index, input)| { + let entity_index = *entity_index as usize; + + self.token.add_value(self.max, entity_index, input); + let cur_list = self.token.state(entity_index); + + list_builder.append_value(cur_list.iter().copied()); + }); + + Ok(Arc::new(list_builder.finish())) + } +} diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_map.rs b/crates/sparrow-instructions/src/evaluators/list/collect_map.rs new file mode 100644 index 000000000..a0ecea95b --- /dev/null +++ b/crates/sparrow-instructions/src/evaluators/list/collect_map.rs @@ -0,0 +1,31 @@ +use crate::{Evaluator, EvaluatorFactory, RuntimeInfo, StaticInfo}; +use arrow::array::ArrayRef; +use sparrow_plan::ValueRef; + +/// Evaluator for the `collect` instruction. +/// +/// Collect collects a stream of values into a List. A list is produced +/// for each input value received, growing up to a maximum size. +#[derive(Debug)] +pub struct CollectMapEvaluator { + /// The max size of the buffer. + /// + /// Once the max size is reached, the front will be popped and the new + /// value pushed to the back. + _max: i64, + _input: ValueRef, + _tick: ValueRef, + _duration: ValueRef, +} + +impl EvaluatorFactory for CollectMapEvaluator { + fn try_new(_info: StaticInfo<'_>) -> anyhow::Result> { + unimplemented!("map collect evaluator is unsupported") + } +} + +impl Evaluator for CollectMapEvaluator { + fn evaluate(&mut self, _info: &dyn RuntimeInfo) -> anyhow::Result { + unimplemented!("map collect evaluator is unsupported") + } +} diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs b/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs new file mode 100644 index 000000000..3dcf31b57 --- /dev/null +++ b/crates/sparrow-instructions/src/evaluators/list/collect_primitive.rs @@ -0,0 +1,131 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, ListBuilder, PrimitiveBuilder}; +use arrow::datatypes::{ArrowPrimitiveType, DataType}; + +use itertools::izip; +use serde::de::DeserializeOwned; +use serde::Serialize; +use sparrow_arrow::downcast::downcast_primitive_array; +use sparrow_arrow::scalar_value::ScalarValue; + +use sparrow_plan::ValueRef; + +use crate::{CollectToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateToken, StaticInfo}; + +/// Evaluator for the `collect` instruction. +/// +/// Collect collects a stream of values into a List. A list is produced +/// for each input value received, growing up to a maximum size. +/// +/// If the list is empty, an empty list is returned (rather than `null`). +#[derive(Debug)] +pub struct CollectPrimitiveEvaluator +where + T: ArrowPrimitiveType, + T::Native: Serialize + DeserializeOwned + Copy, +{ + /// The max size of the buffer. + /// + /// Once the max size is reached, the front will be popped and the new + /// value pushed to the back. + max: usize, + input: ValueRef, + tick: ValueRef, + duration: ValueRef, + /// Contains the buffer of values for each entity + token: CollectToken, +} + +impl EvaluatorFactory for CollectPrimitiveEvaluator +where + T: ArrowPrimitiveType + Send + Sync, + T::Native: Serialize + DeserializeOwned + Copy, +{ + fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { + let input_type = info.args[1].data_type(); + let result_type = info.result_type; + match result_type { + DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), + other => anyhow::bail!("expected list result type, saw {:?}", other), + }; + + let max = match info.args[0].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { + anyhow::bail!("unexpected value of `max` -- must be > 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `max = null`, we use usize::MAX value as a way + // to have an "unlimited" buffer. + Some(ScalarValue::Int64(None)) => usize::MAX, + Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for max parameter"), + }; + + let (_, input, tick, duration) = info.unpack_arguments()?; + Ok(Box::new(Self { + max, + input, + tick, + duration, + token: CollectToken::default(), + })) + } +} + +impl Evaluator for CollectPrimitiveEvaluator +where + T: ArrowPrimitiveType + Send + Sync, + T::Native: Serialize + DeserializeOwned + Copy, +{ + fn evaluate(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result { + match (self.tick.is_literal_null(), self.duration.is_literal_null()) { + (true, true) => self.evaluate_non_windowed(info), + (true, false) => unimplemented!("since window aggregation unsupported"), + (false, false) => panic!("sliding window aggregation should use other evaluator"), + (_, _) => anyhow::bail!("saw invalid combination of tick and duration"), + } + } + + fn state_token(&self) -> Option<&dyn StateToken> { + Some(&self.token) + } + + fn state_token_mut(&mut self) -> Option<&mut dyn StateToken> { + Some(&mut self.token) + } +} + +impl CollectPrimitiveEvaluator +where + T: ArrowPrimitiveType + Send + Sync, + T::Native: Serialize + DeserializeOwned + Copy, +{ + fn ensure_entity_capacity(&mut self, len: usize) { + self.token.resize(len) + } + + fn evaluate_non_windowed(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result { + let input = info.value(&self.input)?.array_ref()?; + let key_capacity = info.grouping().num_groups(); + let entity_indices = info.grouping().group_indices(); + assert_eq!(entity_indices.len(), input.len()); + + self.ensure_entity_capacity(key_capacity); + + let input = downcast_primitive_array::(input.as_ref())?; + let builder = PrimitiveBuilder::::new(); + let mut list_builder = ListBuilder::new(builder); + + izip!(entity_indices.values(), input).for_each(|(entity_index, input)| { + let entity_index = *entity_index as usize; + + self.token.add_value(self.max, entity_index, input); + let cur_list = self.token.state(entity_index); + + list_builder.append_value(cur_list.iter().copied()); + }); + + Ok(Arc::new(list_builder.finish())) + } +} diff --git a/crates/sparrow-instructions/src/evaluators/list/collect_string.rs b/crates/sparrow-instructions/src/evaluators/list/collect_string.rs new file mode 100644 index 000000000..b60e0aa34 --- /dev/null +++ b/crates/sparrow-instructions/src/evaluators/list/collect_string.rs @@ -0,0 +1,109 @@ +use crate::{CollectToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateToken, StaticInfo}; +use arrow::array::{ArrayRef, AsArray, ListBuilder, StringBuilder}; +use arrow::datatypes::DataType; +use itertools::izip; +use sparrow_arrow::scalar_value::ScalarValue; +use sparrow_plan::ValueRef; +use std::sync::Arc; + +/// Evaluator for the `collect` instruction. +/// +/// Collects a stream of values into a List. A list is produced +/// for each input value received, growing up to a maximum size. +/// +/// If the list is empty, an empty list is returned (rather than `null`). +#[derive(Debug)] +pub struct CollectStringEvaluator { + /// The max size of the buffer. + /// + /// Once the max size is reached, the front will be popped and the new + /// value pushed to the back. + max: usize, + input: ValueRef, + tick: ValueRef, + duration: ValueRef, + /// Contains the buffer of values for each entity + token: CollectToken, +} + +impl EvaluatorFactory for CollectStringEvaluator { + fn try_new(info: StaticInfo<'_>) -> anyhow::Result> { + let input_type = info.args[1].data_type(); + let result_type = info.result_type; + match result_type { + DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), + other => anyhow::bail!("expected list result type, saw {:?}", other), + }; + + let max = match info.args[0].value_ref.literal_value() { + Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { + anyhow::bail!("unexpected value of `max` -- must be > 0") + } + Some(ScalarValue::Int64(Some(v))) => *v as usize, + // If a user specifies `max = null`, we use usize::MAX value as a way + // to have an "unlimited" buffer. + Some(ScalarValue::Int64(None)) => usize::MAX, + Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other), + None => anyhow::bail!("expected literal value for max parameter"), + }; + + let (_, input, tick, duration) = info.unpack_arguments()?; + Ok(Box::new(Self { + max, + input, + tick, + duration, + token: CollectToken::default(), + })) + } +} + +impl Evaluator for CollectStringEvaluator { + fn evaluate(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result { + match (self.tick.is_literal_null(), self.duration.is_literal_null()) { + (true, true) => self.evaluate_non_windowed(info), + (true, false) => unimplemented!("since window aggregation unsupported"), + (false, false) => panic!("sliding window aggregation should use other evaluator"), + (_, _) => anyhow::bail!("saw invalid combination of tick and duration"), + } + } + + fn state_token(&self) -> Option<&dyn StateToken> { + Some(&self.token) + } + + fn state_token_mut(&mut self) -> Option<&mut dyn StateToken> { + Some(&mut self.token) + } +} + +impl CollectStringEvaluator { + fn ensure_entity_capacity(&mut self, len: usize) { + self.token.resize(len) + } + + fn evaluate_non_windowed(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result { + let input = info.value(&self.input)?.array_ref()?; + let key_capacity = info.grouping().num_groups(); + let entity_indices = info.grouping().group_indices(); + assert_eq!(entity_indices.len(), input.len()); + + self.ensure_entity_capacity(key_capacity); + + let input = input.as_string::(); + let builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(builder); + + izip!(entity_indices.values(), input).for_each(|(entity_index, input)| { + let entity_index = *entity_index as usize; + + self.token + .add_value(self.max, entity_index, input.map(|s| s.to_owned())); + let cur_list = self.token.state(entity_index); + + list_builder.append_value(cur_list.clone()); + }); + + Ok(Arc::new(list_builder.finish())) + } +} diff --git a/crates/sparrow-main/tests/e2e/collect_tests.rs b/crates/sparrow-main/tests/e2e/collect_tests.rs new file mode 100644 index 000000000..f1ffafa08 --- /dev/null +++ b/crates/sparrow-main/tests/e2e/collect_tests.rs @@ -0,0 +1,295 @@ +//! e2e tests for collect function + +use indoc::indoc; +use sparrow_api::kaskada::v1alpha::TableConfig; +use uuid::Uuid; + +use crate::{fixture::DataFixture, QueryFixture}; + +pub(crate) async fn collect_data_fixture() -> DataFixture { + DataFixture::new() + .with_table_from_csv( + TableConfig::new_with_table_source( + "Collect", + &Uuid::new_v4(), + "time", + Some("subsort"), + "key", + "", + ), + indoc! {" + time,subsort,key,s,n,b,index + 1996-12-19T16:39:57-08:00,0,A,hEllo,0,true,0 + 1996-12-19T16:40:57-08:00,0,A,hi,2,false,1 + 1996-12-19T16:41:57-08:00,0,A,hey,9,,2 + 1996-12-19T16:42:57-08:00,0,A,ay,-1,true,1 + 1996-12-19T16:43:57-08:00,0,A,hIlo,10,true, + 1996-12-20T16:40:57-08:00,0,B,h,5,false,0 + 1996-12-20T16:41:57-08:00,0,B,he,-2,,1 + 1996-12-20T16:42:57-08:00,0,B,,,true,2 + 1996-12-20T16:43:57-08:00,0,B,hel,2,false,1 + 1996-12-20T16:44:57-08:00,0,B,,,true,1 + 1996-12-21T16:44:57-08:00,0,C,g,1,true,2 + 1996-12-21T16:45:57-08:00,0,C,go,2,true,0 + 1996-12-21T16:46:57-08:00,0,C,goo,3,true, + 1996-12-21T16:47:57-08:00,0,C,good,4,true,1 + "}, + ) + .await + .unwrap() +} + +#[tokio::test] +async fn test_collect_with_null_max() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(max = null) | index(0), f2: Collect.b | collect(max = null) | index(0), f3: Collect.s | collect(max = null) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1,f2,f3 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0,true,hEllo + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,0,true,hEllo + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,0,true,hEllo + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,0,true,hEllo + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A,0,true,hEllo + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,5,false,h + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,5,false,h + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,5,false,h + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,5,false,h + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,5,false,h + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C,1,true,g + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,1,true,g + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C,1,true,g + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,1,true,g + "###); +} + +#[tokio::test] +async fn test_collect_to_list_i64() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C,1 + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,1 + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C,1 + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,1 + "###); +} + +#[tokio::test] +async fn test_collect_to_list_i64_dynamic() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,2 + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,9 + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,2 + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,-2 + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,-2 + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,-2 + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,1 + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,2 + "###); +} + +#[tokio::test] +async fn test_collect_to_small_list_i64() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.n | collect(2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,0 + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,2 + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,-1 + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,5 + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,-2 + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,2 + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,1 + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,4 + "###); +} + +#[tokio::test] +async fn test_collect_to_list_string() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C,g + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,g + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C,g + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,g + "###); +} + +#[tokio::test] +async fn test_collect_to_list_string_dynamic() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hi + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,hey + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,hi + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,he + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,he + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,he + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,g + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,go + "###); +} + +#[tokio::test] +async fn test_collect_to_small_list_string() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,hEllo + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,hi + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,ay + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,h + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,he + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,hel + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,g + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,good + "###); +} + +#[tokio::test] +async fn test_collect_to_list_boolean() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(10) | index(0) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C,true + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,true + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C,true + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,true + "###); +} + +#[tokio::test] +async fn test_collect_to_list_boolean_dynamic() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(10) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,false + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,false + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B,true + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,true + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,true + "###); +} + +#[tokio::test] +async fn test_collect_to_small_list_boolean() { + insta::assert_snapshot!(QueryFixture::new("{ f1: Collect.b | collect(2) | index(Collect.index) }").run_to_csv(&collect_data_fixture().await).await.unwrap(), @r###" + _time,_subsort,_key_hash,_key,f1 + 1996-12-20T00:39:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:40:57.000000000,9223372036854775808,12960666915911099378,A,false + 1996-12-20T00:41:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-20T00:42:57.000000000,9223372036854775808,12960666915911099378,A,true + 1996-12-20T00:43:57.000000000,9223372036854775808,12960666915911099378,A, + 1996-12-21T00:40:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:41:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:42:57.000000000,9223372036854775808,2867199309159137213,B, + 1996-12-21T00:43:57.000000000,9223372036854775808,2867199309159137213,B,false + 1996-12-21T00:44:57.000000000,9223372036854775808,2867199309159137213,B,true + 1996-12-22T00:44:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:45:57.000000000,9223372036854775808,2521269998124177631,C,true + 1996-12-22T00:46:57.000000000,9223372036854775808,2521269998124177631,C, + 1996-12-22T00:47:57.000000000,9223372036854775808,2521269998124177631,C,true + "###); +} + +#[tokio::test] +async fn test_require_literal_max() { + // TODO: We should figure out how to not report the second error -- type variables with + // error propagation needs some fixing. + insta::assert_yaml_snapshot!(QueryFixture::new("{ f1: Collect.s | collect(Collect.index) | index(1) }") + .run_to_csv(&collect_data_fixture().await).await.unwrap_err(), @r###" + --- + code: Client specified an invalid argument + message: 2 errors in Fenl statements; see diagnostics + fenl_diagnostics: + - severity: error + code: E0014 + message: Invalid non-constant argument + formatted: + - "error[E0014]: Invalid non-constant argument" + - " --> Query:1:27" + - " |" + - "1 | { f1: Collect.s | collect(Collect.index) | index(1) }" + - " | ^^^^^^^^^^^^^ Argument 'max' to 'collect' must be constant, but was not" + - "" + - "" + - severity: error + code: E0010 + message: Invalid argument type(s) + formatted: + - "error[E0010]: Invalid argument type(s)" + - " --> Query:1:44" + - " |" + - "1 | { f1: Collect.s | collect(Collect.index) | index(1) }" + - " | ^^^^^ Invalid types for parameter 'list' in call to 'index'" + - " |" + - " --> internal:1:1" + - " |" + - 1 | $input + - " | ------ Actual type: error" + - " |" + - " --> built-in signature 'index(i: i64, list: list) -> T':1:29" + - " |" + - "1 | index(i: i64, list: list) -> T" + - " | ------- Expected type: list" + - "" + - "" + "###); +} diff --git a/crates/sparrow-main/tests/e2e/main.rs b/crates/sparrow-main/tests/e2e/main.rs index 84ecde202..97f25c56a 100644 --- a/crates/sparrow-main/tests/e2e/main.rs +++ b/crates/sparrow-main/tests/e2e/main.rs @@ -19,6 +19,7 @@ mod aggregation_tests; mod basic_error_tests; mod cast_tests; mod coalesce_tests; +mod collect_tests; mod comparison_tests; mod decoration_tests; mod entity_key_output_tests; diff --git a/crates/sparrow-plan/src/inst.rs b/crates/sparrow-plan/src/inst.rs index 5fcd8ad53..6d370a78c 100644 --- a/crates/sparrow-plan/src/inst.rs +++ b/crates/sparrow-plan/src/inst.rs @@ -60,6 +60,11 @@ pub enum InstOp { Clamp, #[strum(props(signature = "coalesce(values+: T) -> T"))] Coalesce, + #[strum(props( + dfg_signature = "collect(max: i64, input: T, window: window = null) -> list", + plan_signature = "collect(max: i64, input: T, ticks: bool = null, slide_duration: i64 = null) -> list" + ))] + Collect, #[strum(props( dfg_signature = "count_if(input: T, window: window = null) -> u32", plan_signature = "count_if(input: T, ticks: bool = null, slide_duration: i64 = null) -> \