Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: collect to list (non-windowed) (primitive/strings/booleans) #569

Merged
merged 13 commits into from
Jul 31, 2023
70 changes: 49 additions & 21 deletions crates/sparrow-compiler/src/ast_to_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,31 +522,28 @@ pub fn add_to_dfg(
dfg.bind("$condition_input", args[0].inner().clone());

let window = &expr.unwrap().args()[1];
let (condition, duration) = match window.op() {
ExprOp::Call(window_name) => {
flatten_window_args(window_name, window, dfg, data_context, diagnostics)?
}
ExprOp::Literal(v) if v.inner() == &LiteralValue::Null => {
// Unwindowed aggregations just use nulls
let null_arg = dfg.add_literal(LiteralValue::Null.to_scalar()?)?;
let null_arg = Located::new(
add_literal(
dfg,
null_arg,
FenlType::Concrete(DataType::Null),
window.location().clone(),
)?,
window.location().clone(),
);

(null_arg.clone(), null_arg)
}
unexpected => anyhow::bail!("expected window, found {:?}", unexpected),
};
let (condition, duration) =
flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?;
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved

dfg.exit_env();
// [agg_input, condition, duration]
vec![args[0].clone(), condition, duration]
} else if function.name() == "collect" {
// The collect function contains a window, but does not follow the same signature
// pattern as aggregations, so it requires a different flattening strategy.
//
// TODO: Flattening the window arguments is hacky and confusing. We should instead
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
// incorporate the tick directly into the function containing the window.
dfg.enter_env();
dfg.bind("$condition_input", args[1].inner().clone());

let window = &expr.unwrap().args()[2];
let (condition, duration) =
flatten_window_args_if_needed(window, dfg, data_context, diagnostics)?;

dfg.exit_env();
// [max, input, condition, duration]
vec![args[0].clone(), args[1].clone(), condition, duration]
} else if function.name() == "when" || function.name() == "if" {
dfg.enter_env();
dfg.bind("$condition_input", args[1].inner().clone());
Expand Down Expand Up @@ -605,6 +602,37 @@ pub fn add_to_dfg(
}
}

#[allow(clippy::type_complexity)]
fn flatten_window_args_if_needed(
window: &Located<Box<ResolvedExpr>>,
dfg: &mut Dfg,
data_context: &mut DataContext,
diagnostics: &mut DiagnosticCollector<'_>,
) -> anyhow::Result<(Located<Arc<AstDfg>>, Located<Arc<AstDfg>>)> {
jordanrfrazier marked this conversation as resolved.
Show resolved Hide resolved
let (condition, duration) = match window.op() {
ExprOp::Call(window_name) => {
flatten_window_args(window_name, window, dfg, data_context, diagnostics)?
}
ExprOp::Literal(v) if v.inner() == &LiteralValue::Null => {
// Unwindowed aggregations just use nulls
let null_arg = dfg.add_literal(LiteralValue::Null.to_scalar()?)?;
let null_arg = Located::new(
add_literal(
dfg,
null_arg,
FenlType::Concrete(DataType::Null),
window.location().clone(),
)?,
window.location().clone(),
);

(null_arg.clone(), null_arg)
}
unexpected => anyhow::bail!("expected window, found {:?}", unexpected),
};
Ok((condition, duration))
}

// Verify that the arguments are compatibly partitioned.
fn verify_same_partitioning(
data_context: &DataContext,
Expand Down
5 changes: 5 additions & 0 deletions crates/sparrow-compiler/src/functions/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@ pub(super) fn register(registry: &mut Registry) {
.register("index<T: any>(i: i64, list: list<T>) -> T")
.with_implementation(Implementation::Instruction(InstOp::Index))
.set_internal();

registry
.register("collect<T: any>(const max: i64, input: T, window: window = null) -> list<T>")
.with_implementation(Implementation::Instruction(InstOp::Collect))
.set_internal();
}
10 changes: 10 additions & 0 deletions crates/sparrow-instructions/src/evaluators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ fn create_simple_evaluator(
create_number_evaluator!(&info.args[0].data_type, ClampEvaluator, info)
}
InstOp::Coalesce => CoalesceEvaluator::try_new(info),
InstOp::Collect => {
create_typed_evaluator!(
&info.args[1].data_type,
CollectPrimitiveEvaluator,
CollectMapEvaluator,
CollectBooleanEvaluator,
CollectStringEvaluator,
info
)
}
InstOp::CountIf => CountIfEvaluator::try_new(info),
InstOp::DayOfMonth => DayOfMonthEvaluator::try_new(info),
InstOp::DayOfMonth0 => DayOfMonth0Evaluator::try_new(info),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Tokens representing keys for compute storage.

mod boolean_accum_token;
mod collect_token;
mod count_accum_token;
pub mod lag_token;
mod map_accum_token;
Expand All @@ -12,6 +13,7 @@ mod two_stacks_primitive_accum_token;
mod two_stacks_string_accum_token;

pub use boolean_accum_token::*;
pub use collect_token::*;
pub use count_accum_token::*;
pub use map_accum_token::*;
pub use primitive_accum_token::*;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::VecDeque;

use crate::{ComputeStore, StateToken, StoreKey};

/// State token used for the lag operator.
#[derive(Default, Debug)]
pub struct CollectToken<T>
where
T: Clone,
T: Serialize + DeserializeOwned,
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned,
{
state: Vec<VecDeque<Option<T>>>,
}

impl<T> CollectToken<T>
where
T: Clone,
T: Serialize + DeserializeOwned,
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned,
{
pub fn resize(&mut self, len: usize) {
if len >= self.state.len() {
self.state.resize(len + 1, VecDeque::new());
}
}

pub fn add_value(&mut self, max: usize, index: usize, input: Option<T>) {
self.state[index].push_back(input);
if self.state[index].len() > max {
self.state[index].pop_front();
}
}

pub fn state(&self, index: usize) -> &VecDeque<Option<T>> {
&self.state[index]
}
}

impl<T> StateToken for CollectToken<T>
where
T: Clone,
T: Serialize + DeserializeOwned,
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned,
{
fn restore(&mut self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> {
if let Some(state) = store.get(key)? {
self.state = state;
} else {
self.state.clear();
}

Ok(())
}

fn store(&self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> {
store.put(key, &self.state)
}
}
9 changes: 9 additions & 0 deletions crates/sparrow-instructions/src/evaluators/list.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
mod collect_boolean;
mod collect_map;
mod collect_primitive;
mod collect_string;
mod index;

pub(super) use collect_boolean::*;
pub(super) use collect_map::*;
pub(super) use collect_primitive::*;
pub(super) use collect_string::*;
pub(super) use index::*;
108 changes: 108 additions & 0 deletions crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::{CollectToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateToken, StaticInfo};
use arrow::array::{ArrayRef, AsArray, BooleanBuilder, ListBuilder};
use arrow::datatypes::DataType;
use itertools::izip;
use sparrow_arrow::scalar_value::ScalarValue;
use sparrow_plan::ValueRef;
use std::sync::Arc;

/// Evaluator for the `collect` instruction.
///
/// Collects a stream of values into a List. A list is produced
/// for each input value received, growing up to a maximum size.
///
/// If the list is empty, an empty list is returned (rather than `null`).
#[derive(Debug)]
pub struct CollectBooleanEvaluator {
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
/// value pushed to the back.
max: usize,
input: ValueRef,
tick: ValueRef,
duration: ValueRef,
/// Contains the buffer of values for each entity
token: CollectToken<bool>,
}

impl EvaluatorFactory for CollectBooleanEvaluator {
fn try_new(info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
let input_type = info.args[1].data_type();
let result_type = info.result_type;
match result_type {
DataType::List(t) => anyhow::ensure!(t.data_type() == input_type),
other => anyhow::bail!("expected list result type, saw {:?}", other),
};

let max = match info.args[0].value_ref.literal_value() {
Some(ScalarValue::Int64(Some(v))) if *v <= 0 => {
anyhow::bail!("unexpected value of `max` -- must be > 0")
}
Some(ScalarValue::Int64(Some(v))) => *v as usize,
// If a user specifies `max = null`, we use usize::MAX value as a way
// to have an "unlimited" buffer.
Some(ScalarValue::Int64(None)) => usize::MAX,
Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other),
None => anyhow::bail!("expected literal value for max parameter"),
};

let (_, input, tick, duration) = info.unpack_arguments()?;
Ok(Box::new(Self {
max,
input,
tick,
duration,
token: CollectToken::default(),
}))
}
}

impl Evaluator for CollectBooleanEvaluator {
fn evaluate(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> {
match (self.tick.is_literal_null(), self.duration.is_literal_null()) {
(true, true) => self.evaluate_non_windowed(info),
(true, false) => unimplemented!("since window aggregation unsupported"),
(false, false) => panic!("sliding window aggregation should use other evaluator"),
(_, _) => anyhow::bail!("saw invalid combination of tick and duration"),
}
}

fn state_token(&self) -> Option<&dyn StateToken> {
Some(&self.token)
}

fn state_token_mut(&mut self) -> Option<&mut dyn StateToken> {
Some(&mut self.token)
}
}

impl CollectBooleanEvaluator {
fn ensure_entity_capacity(&mut self, len: usize) {
self.token.resize(len)
}

fn evaluate_non_windowed(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> {
let input = info.value(&self.input)?.array_ref()?;
let key_capacity = info.grouping().num_groups();
let entity_indices = info.grouping().group_indices();
assert_eq!(entity_indices.len(), input.len());

self.ensure_entity_capacity(key_capacity);

let input = input.as_boolean();
let builder = BooleanBuilder::new();
let mut list_builder = ListBuilder::new(builder);

izip!(entity_indices.values(), input).for_each(|(entity_index, input)| {
let entity_index = *entity_index as usize;

self.token.add_value(self.max, entity_index, input);
let cur_list = self.token.state(entity_index);

list_builder.append_value(cur_list.iter().copied());
});

Ok(Arc::new(list_builder.finish()))
}
}
31 changes: 31 additions & 0 deletions crates/sparrow-instructions/src/evaluators/list/collect_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::{Evaluator, EvaluatorFactory, RuntimeInfo, StaticInfo};
use arrow::array::ArrayRef;
use sparrow_plan::ValueRef;

/// Evaluator for the `collect` instruction.
///
/// Collect collects a stream of values into a List<T>. A list is produced
/// for each input value received, growing up to a maximum size.
#[derive(Debug)]
pub struct CollectMapEvaluator {
/// The max size of the buffer.
///
/// Once the max size is reached, the front will be popped and the new
/// value pushed to the back.
_max: i64,
_input: ValueRef,
_tick: ValueRef,
_duration: ValueRef,
}

impl EvaluatorFactory for CollectMapEvaluator {
fn try_new(_info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> {
unimplemented!("map collect evaluator is unsupported")
}
}

impl Evaluator for CollectMapEvaluator {
fn evaluate(&mut self, _info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> {
unimplemented!("map collect evaluator is unsupported")
}
}
Loading
Loading