Skip to content

Commit

Permalink
feat: collect to list (non-windowed) (primitive/strings/booleans) (#569)
Browse files Browse the repository at this point in the history
Adds the `collect` function for non-windowed, primitive/string/boolean
types.

This doesn't support `LargeUtf8` because the other string aggregations
that use the `create_typed_evaluator` macro don't support generic offset
size, so that will come in a follow up.

---------

Co-authored-by: Kevin J Nguyen <kevin.nguyen@datastax.com>
  • Loading branch information
jordanrfrazier and kevinjnguyen authored Jul 31, 2023
1 parent f58306e commit 0f33802
Show file tree
Hide file tree
Showing 13 changed files with 816 additions and 21 deletions.
70 changes: 49 additions & 21 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,31 +522,28 @@ pub fn add_to_dfg(
dfg.bind("$condition_input", args[0].inner().clone());

let window = &expr.unwrap().args()[1];
let (condition, duration) = match window.op() {
ExprOp::Call(window_name) => {
flatten_window_args(window_name, window, dfg, data_context, diagnostics)?
}
ExprOp::Literal(v) if v.inner() == &LiteralValue::Null => {
// Unwindowed aggregations just use nulls
let null_arg = dfg.add_literal(LiteralValue::Null.to_scalar()?)?;
let null_arg = Located::new(
add_literal(
dfg,
null_arg,
FenlType::Concrete(DataType::Null),
window.location().clone(),
)?,
window.location().clone(),
);

(null_arg.clone(), null_arg)
}
unexpected => anyhow::bail!("expected window, found {:?}", unexpected),
};
let (condition, duration) =
flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?;

dfg.exit_env();
// [agg_input, condition, duration]
vec![args[0].clone(), condition, duration]
} else if function.name() == "collect" {
// The collect function contains a window, but does not follow the same signature
// pattern as aggregations, so it requires a different flattening strategy.
//
// TODO: Flattening the window arguments is hacky and confusing. We should instead
// incorporate the tick directly into the function containing the window.
dfg.enter_env();
dfg.bind("$condition_input", args[1].inner().clone());

let window = &expr.unwrap().args()[2];
let (condition, duration) =
flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?;

dfg.exit_env();
// [max, input, condition, duration]
vec![args[0].clone(), args[1].clone(), condition, duration]
} else if function.name() == "when" || function.name() == "if" {
dfg.enter_env();
dfg.bind("$condition_input", args[1].inner().clone());
Expand Down Expand Up @@ -605,6 +602,37 @@ pub fn add_to_dfg(
}
}

#[allow(clippy::type_complexity)]
fn flatten_window_args_if_needed(
window: &Located<Box<ResolvedExpr>>,
dfg: &mut Dfg,
data_context: &mut DataContext,
diagnostics: &mut DiagnosticCollector<'_>,
) -> anyhow::Result<(Located<Arc<AstDfg>>, Located<Arc<AstDfg>>)> {
let (condition, duration) = match window.op() {
ExprOp::Call(window_name) => {
flatten_window_args(window_name, window, dfg, data_context, diagnostics)?
}
ExprOp::Literal(v) if v.inner() == &LiteralValue::Null => {
// Unwindowed aggregations just use nulls
let null_arg = dfg.add_literal(LiteralValue::Null.to_scalar()?)?;
let null_arg = Located::new(
add_literal(
dfg,
null_arg,
FenlType::Concrete(DataType::Null),
window.location().clone(),
)?,
window.location().clone(),
);

(null_arg.clone(), null_arg)
}
unexpected => anyhow::bail!("expected window, found {:?}", unexpected),
};
Ok((condition, duration))
}

// Verify that the arguments are compatibly partitioned.
fn verify_same_partitioning(
data_context: &DataContext,
Expand Down
5 changes: 5 additions & 0 deletions crates/sparrow-compiler/src/functions/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@ pub(super) fn register(registry: &mut Registry) {
.register("index<T: any>(i: i64, list: list<T>) -> T")
.with_implementation(Implementation::Instruction(InstOp::Index))
.set_internal();

registry
.register("collect<T: any>(const max: i64, input: T, window: window = null) -> list<T>")
.with_implementation(Implementation::Instruction(InstOp::Collect))
.set_internal();
}
10 changes: 10 additions & 0 deletions crates/sparrow-instructions/src/evaluators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ fn create_simple_evaluator(
create_number_evaluator!(&info.args[0].data_type, ClampEvaluator, info)
}
InstOp::Coalesce => CoalesceEvaluator::try_new(info),
InstOp::Collect => {
create_typed_evaluator!(
&info.args[1].data_type,
CollectPrimitiveEvaluator,
CollectMapEvaluator,
CollectBooleanEvaluator,
CollectStringEvaluator,
info
)
}
InstOp::CountIf => CountIfEvaluator::try_new(info),
InstOp::DayOfMonth => DayOfMonthEvaluator::try_new(info),
InstOp::DayOfMonth0 => DayOfMonth0Evaluator::try_new(info),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Tokens representing keys for compute storage.

mod boolean_accum_token;
mod collect_token;
mod count_accum_token;
pub mod lag_token;
mod map_accum_token;
Expand All @@ -12,6 +13,7 @@ mod two_stacks_primitive_accum_token;
mod two_stacks_string_accum_token;

pub use boolean_accum_token::*;
pub use collect_token::*;
pub use count_accum_token::*;
pub use map_accum_token::*;
pub use primitive_accum_token::*;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::VecDeque;

use crate::{ComputeStore, StateToken, StoreKey};

/// State token used for the lag operator.
#[derive(Default, Debug)]
pub struct CollectToken<T>
where
T: Clone,
T: Serialize + DeserializeOwned,
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned,
{
state: Vec<VecDeque<Option<T>>>,
}

impl<T> CollectToken<T>
where
T: Clone,
T: Serialize + DeserializeOwned,
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned,
{
pub fn resize(&mut self, len: usize) {
if len >= self.state.len() {
self.state.resize(len + 1, VecDeque::new());
}
}

pub fn add_value(&mut self, max: usize, index: usize, input: Option<T>) {
self.state[index].push_back(input);
if self.state[index].len() > max {
self.state[index].pop_front();
}
}

pub fn state(&self, index: usize) -> &VecDeque<Option<T>> {
&self.state[index]
}
}

impl<T> StateToken for CollectToken<T>
where
T: Clone,
T: Serialize + DeserializeOwned,
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned,
{
fn restore(&mut self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> {
if let Some(state) = store.get(key)? {
self.state = state;
} else {
self.state.clear();
}

Ok(())
}

fn store(&self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> {
store.put(key, &self.state)
}
}
9 changes: 9 additions & 0 deletions crates/sparrow-instructions/src/evaluators/list.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
mod collect_boolean;
mod collect_map;
mod collect_primitive;
mod collect_string;
mod index;

pub(super) use collect_boolean::*;
pub(super) use collect_map::*;
pub(super) use collect_primitive::*;
pub(super) use collect_string::*;
pub(super) use index::*;
108 changes: 108 additions & 0 deletions crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::{CollectToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateToken, StaticInfo};
use arrow::array::{ArrayRef, AsArray, BooleanBuilder, ListBuilder};
use arrow::datatypes::DataType;
use itertools::izip;
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_plan::ValueRef;
use std::sync::Arc;

/// Evaluator for the `collect` instruction.
///
/// Collects a stream of values into a List. A list is produced
/// for each input value received, growing up to a maximum size.
///
/// If the list is empty, an empty list is returned (rather than `null`).
#[derive(Debug)]
pub struct CollectBooleanEvaluator {
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
/// value pushed to the back.
max: usize,
input: ValueRef,
tick: ValueRef,
duration: ValueRef,
/// Contains the buffer of values for each entity
token: CollectToken<bool>,
}

impl EvaluatorFactory for CollectBooleanEvaluator {
fn try_new(info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
let input_type = info.args[1].data_type();
let result_type = info.result_type;
match result_type {
DataType::List(t) => anyhow::ensure!(t.data_type() == input_type),
other => anyhow::bail!("expected list result type, saw {:?}", other),
};

let max = match info.args[0].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v <= 0 => {
anyhow::bail!("unexpected value of `max` -- must be > 0")
}
Some(ScalarValue::Int64(Some(v))) => *v as usize,
// If a user specifies `max = null`, we use usize::MAX value as a way
// to have an "unlimited" buffer.
Some(ScalarValue::Int64(None)) => usize::MAX,
Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other),
None => anyhow::bail!("expected literal value for max parameter"),
};

let (_, input, tick, duration) = info.unpack_arguments()?;
Ok(Box::new(Self {
max,
input,
tick,
duration,
token: CollectToken::default(),
}))
}
}

impl Evaluator for CollectBooleanEvaluator {
fn evaluate(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> {
match (self.tick.is_literal_null(), self.duration.is_literal_null()) {
(true, true) => self.evaluate_non_windowed(info),
(true, false) => unimplemented!("since window aggregation unsupported"),
(false, false) => panic!("sliding window aggregation should use other evaluator"),
(_, _) => anyhow::bail!("saw invalid combination of tick and duration"),
}
}

fn state_token(&self) -> Option<&dyn StateToken> {
Some(&self.token)
}

fn state_token_mut(&mut self) -> Option<&mut dyn StateToken> {
Some(&mut self.token)
}
}

impl CollectBooleanEvaluator {
fn ensure_entity_capacity(&mut self, len: usize) {
self.token.resize(len)
}

fn evaluate_non_windowed(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> {
let input = info.value(&self.input)?.array_ref()?;
let key_capacity = info.grouping().num_groups();
let entity_indices = info.grouping().group_indices();
assert_eq!(entity_indices.len(), input.len());

self.ensure_entity_capacity(key_capacity);

let input = input.as_boolean();
let builder = BooleanBuilder::new();
let mut list_builder = ListBuilder::new(builder);

izip!(entity_indices.values(), input).for_each(|(entity_index, input)| {
let entity_index = *entity_index as usize;

self.token.add_value(self.max, entity_index, input);
let cur_list = self.token.state(entity_index);

list_builder.append_value(cur_list.iter().copied());
});

Ok(Arc::new(list_builder.finish()))
}
}
31 changes: 31 additions & 0 deletions crates/sparrow-instructions/src/evaluators/list/collect_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::{Evaluator, EvaluatorFactory, RuntimeInfo, StaticInfo};
use arrow::array::ArrayRef;
use sparrow_plan::ValueRef;

/// Evaluator for the `collect` instruction.
///
/// Collect collects a stream of values into a List<T>. A list is produced
/// for each input value received, growing up to a maximum size.
#[derive(Debug)]
pub struct CollectMapEvaluator {
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
/// value pushed to the back.
_max: i64,
_input: ValueRef,
_tick: ValueRef,
_duration: ValueRef,
}

impl EvaluatorFactory for CollectMapEvaluator {
fn try_new(_info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
unimplemented!("map collect evaluator is unsupported")
}
}

impl Evaluator for CollectMapEvaluator {
fn evaluate(&mut self, _info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> {
unimplemented!("map collect evaluator is unsupported")
}
}
Loading

0 comments on commit 0f33802

Please sign in to comment.