Skip to content

Commit

Permalink
feat(query): Support st_collect aggregate function (#16996)
Browse files Browse the repository at this point in the history
feat(query): Support st_collect aggregate function
  • Loading branch information
b41sh authored Dec 4, 2024
1 parent 2b4782d commit 0343b57
Show file tree
Hide file tree
Showing 15 changed files with 1,187 additions and 20 deletions.
11 changes: 7 additions & 4 deletions src/query/functions/src/aggregates/aggregate_array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::alloc::Layout;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::sync::Arc;

use borsh::BorshDeserialize;
Expand Down Expand Up @@ -101,8 +102,9 @@ where
match inner_type.remove_nullable() {
DataType::Decimal(decimal_type) => {
let size = decimal_type.size();
for value in &self.values {
let val = T::upcast_scalar(value.clone());
let values = mem::take(&mut self.values);
for value in values.into_iter() {
let val = T::upcast_scalar(value);
let decimal_val = val.as_decimal().unwrap();
let new_val = match decimal_val {
DecimalScalar::Decimal128(v, _) => {
Expand All @@ -116,8 +118,9 @@ where
}
}
_ => {
for value in &self.values {
let val = T::upcast_scalar(value.clone());
let values = mem::take(&mut self.values);
for value in values.into_iter() {
let val = T::upcast_scalar(value);
inner_builder.push(val.as_ref());
}
}
Expand Down
22 changes: 13 additions & 9 deletions src/query/functions/src/aggregates/aggregate_function_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ use super::AggregateFunctionOrNullAdaptor;
use crate::aggregates::AggregateFunctionRef;
use crate::aggregates::Aggregators;

// The NULL value in the those function needs to be handled separately.
const NEED_NULL_AGGREGATE_FUNCTIONS: [&str; 7] = [
"array_agg",
"list",
"json_array_agg",
"json_object_agg",
"group_array_moving_avg",
"group_array_moving_sum",
"st_collect",
];

const STATE_SUFFIX: &str = "_state";

pub type AggregateFunctionCreator =
Expand Down Expand Up @@ -172,15 +183,8 @@ impl AggregateFunctionFactory {
) -> Result<AggregateFunctionRef> {
let name = name.as_ref();
let mut features = AggregateFunctionFeatures::default();
// The NULL value in the array_agg function needs to be added to the returned array column,
// so handled separately.
if name == "array_agg"
|| name == "list"
|| name == "json_array_agg"
|| name == "json_object_agg"
|| name == "group_array_moving_avg"
|| name == "group_array_moving_sum"
{

if NEED_NULL_AGGREGATE_FUNCTIONS.contains(&name) {
let agg = self.get_impl(name, params, arguments, &mut features)?;
return Ok(agg);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::alloc::Layout;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::sync::Arc;

use borsh::BorshDeserialize;
Expand Down Expand Up @@ -102,8 +103,9 @@ where
fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()> {
let tz = TimeZone::UTC;
let mut items = Vec::with_capacity(self.values.len());
for value in &self.values {
let v = T::upcast_scalar(value.clone());
let values = mem::take(&mut self.values);
for value in values.into_iter() {
let v = T::upcast_scalar(value);
// NULL values are omitted from the output.
if v == Scalar::Null {
continue;
Expand Down
12 changes: 7 additions & 5 deletions src/query/functions/src/aggregates/aggregate_json_object_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::alloc::Layout;
use std::collections::BTreeMap;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::sync::Arc;

use borsh::BorshDeserialize;
Expand Down Expand Up @@ -166,19 +167,20 @@ where

fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()> {
let tz = TimeZone::UTC;
let mut kvs = Vec::with_capacity(self.kvs.len());
for (key, value) in &self.kvs {
let v = V::upcast_scalar(value.clone());
let mut values = Vec::with_capacity(self.kvs.len());
let kvs = mem::take(&mut self.kvs);
for (key, value) in kvs.into_iter() {
let v = V::upcast_scalar(value);
// NULL values are omitted from the output.
if v == Scalar::Null {
continue;
}
let mut val = vec![];
cast_scalar_to_variant(v.as_ref(), &tz, &mut val);
kvs.push((key, val));
values.push((key, val));
}
let mut data = vec![];
jsonb::build_object(kvs.iter().map(|(k, v)| (k, &v[..])), &mut data).unwrap();
jsonb::build_object(values.iter().map(|(k, v)| (k, &v[..])), &mut data).unwrap();

let object_value = Scalar::Variant(data);
builder.push(object_value.as_ref());
Expand Down
Loading

0 comments on commit 0343b57

Please sign in to comment.