Skip to content

Commit

Permalink
Implement special min/max accumulator for Strings and Binary (10% fas…
Browse files Browse the repository at this point in the history
…ter for Clickbench Q28) (#12792)

* Implement special min/max accumulator for Strings: `MinMaxBytesAccumulator`

* fix bug

* fix msrv

* move code, handle filters

* simplify

* Add functional tests

* remove unecessary test

* improve docs

* improve docs

* cleanup

* improve comments

* fix diagram

* fix accounting

* Use correct type in memory accounting

* Add TODO comment
  • Loading branch information
alamb authored Oct 13, 2024
1 parent ebfc155 commit 646f40a
Show file tree
Hide file tree
Showing 5 changed files with 872 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl NullState {
///
/// When value_fn is called it also sets
///
/// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale
/// 1. `self.seen_values[group_index]` to true for all rows that had a non null value
pub fn accumulate<T, F>(
&mut self,
group_indices: &[usize],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@
// specific language governing permissions and limitations
// under the License.

//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls
//! [`set_nulls`], other utilities for working with nulls
use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray};
use arrow::array::{
Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray,
BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray,
};
use arrow::buffer::NullBuffer;
use arrow::datatypes::DataType;
use datafusion_common::{not_impl_err, Result};
use std::sync::Arc;

/// Sets the validity mask for a `PrimitiveArray` to `nulls`
/// replacing any existing null mask
///
/// See [`set_nulls_dyn`] for a version that works with `Array`
pub fn set_nulls<T: ArrowNumericType + Send>(
array: PrimitiveArray<T>,
nulls: Option<NullBuffer>,
Expand Down Expand Up @@ -91,3 +100,105 @@ pub fn filtered_null_mask(
let opt_filter = opt_filter.and_then(filter_to_nulls);
NullBuffer::union(opt_filter.as_ref(), input.nulls())
}

/// Applies optional filter to input, returning a new array of the same type
/// with the same data, but with any values that were filtered out set to null
pub fn apply_filter_as_nulls(
input: &dyn Array,
opt_filter: Option<&BooleanArray>,
) -> Result<ArrayRef> {
let nulls = filtered_null_mask(opt_filter, input);
set_nulls_dyn(input, nulls)
}

/// Replaces the nulls in the input array with the given `NullBuffer`
///
/// TODO: replace when upstreamed in arrow-rs: <https://github.com/apache/arrow-rs/issues/6528>
pub fn set_nulls_dyn(input: &dyn Array, nulls: Option<NullBuffer>) -> Result<ArrayRef> {
if let Some(nulls) = nulls.as_ref() {
assert_eq!(nulls.len(), input.len());
}

let output: ArrayRef = match input.data_type() {
DataType::Utf8 => {
let input = input.as_string::<i32>();
// safety: values / offsets came from a valid string array, so are valid utf8
// and we checked nulls has the same length as values
unsafe {
Arc::new(StringArray::new_unchecked(
input.offsets().clone(),
input.values().clone(),
nulls,
))
}
}
DataType::LargeUtf8 => {
let input = input.as_string::<i64>();
// safety: values / offsets came from a valid string array, so are valid utf8
// and we checked nulls has the same length as values
unsafe {
Arc::new(LargeStringArray::new_unchecked(
input.offsets().clone(),
input.values().clone(),
nulls,
))
}
}
DataType::Utf8View => {
let input = input.as_string_view();
// safety: values / views came from a valid string view array, so are valid utf8
// and we checked nulls has the same length as values
unsafe {
Arc::new(StringViewArray::new_unchecked(
input.views().clone(),
input.data_buffers().to_vec(),
nulls,
))
}
}

DataType::Binary => {
let input = input.as_binary::<i32>();
// safety: values / offsets came from a valid binary array
// and we checked nulls has the same length as values
unsafe {
Arc::new(BinaryArray::new_unchecked(
input.offsets().clone(),
input.values().clone(),
nulls,
))
}
}
DataType::LargeBinary => {
let input = input.as_binary::<i64>();
// safety: values / offsets came from a valid large binary array
// and we checked nulls has the same length as values
unsafe {
Arc::new(LargeBinaryArray::new_unchecked(
input.offsets().clone(),
input.values().clone(),
nulls,
))
}
}
DataType::BinaryView => {
let input = input.as_binary_view();
// safety: values / views came from a valid binary view array
// and we checked nulls has the same length as values
unsafe {
Arc::new(BinaryViewArray::new_unchecked(
input.views().clone(),
input.data_buffers().to_vec(),
nulls,
))
}
}
_ => {
return not_impl_err!("Applying nulls {:?}", input.data_type());
}
};
assert_eq!(input.len(), output.len());
assert_eq!(input.data_type(), output.data_type());

Ok(output)
}
123 changes: 69 additions & 54 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
mod min_max_bytes;

use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
Expand Down Expand Up @@ -50,6 +52,7 @@ use arrow::datatypes::{
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
};

use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
use datafusion_common::ScalarValue;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature,
Expand Down Expand Up @@ -104,7 +107,7 @@ impl Default for Max {
/// the specified [`ArrowPrimitiveType`].
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
macro_rules! instantiate_max_accumulator {
macro_rules! primitive_max_accumulator {
($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
Expand All @@ -123,7 +126,7 @@ macro_rules! instantiate_max_accumulator {
///
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
macro_rules! instantiate_min_accumulator {
macro_rules! primitive_min_accumulator {
($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
Expand Down Expand Up @@ -231,6 +234,12 @@ impl AggregateUDFImpl for Max {
| Time32(_)
| Time64(_)
| Timestamp(_, _)
| Utf8
| LargeUtf8
| Utf8View
| Binary
| LargeBinary
| BinaryView
)
}

Expand All @@ -242,58 +251,58 @@ impl AggregateUDFImpl for Max {
use TimeUnit::*;
let data_type = args.return_type;
match data_type {
Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type),
Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type),
Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type),
Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type),
UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type),
UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_max_accumulator!(data_type, f16, Float16Type)
primitive_max_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_max_accumulator!(data_type, f32, Float32Type)
primitive_max_accumulator!(data_type, f32, Float32Type)
}
Float64 => {
instantiate_max_accumulator!(data_type, f64, Float64Type)
primitive_max_accumulator!(data_type, f64, Float64Type)
}
Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type),
Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type),
Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
Time32(Second) => {
instantiate_max_accumulator!(data_type, i32, Time32SecondType)
primitive_max_accumulator!(data_type, i32, Time32SecondType)
}
Time32(Millisecond) => {
instantiate_max_accumulator!(data_type, i32, Time32MillisecondType)
primitive_max_accumulator!(data_type, i32, Time32MillisecondType)
}
Time64(Microsecond) => {
instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType)
primitive_max_accumulator!(data_type, i64, Time64MicrosecondType)
}
Time64(Nanosecond) => {
instantiate_max_accumulator!(data_type, i64, Time64NanosecondType)
primitive_max_accumulator!(data_type, i64, Time64NanosecondType)
}
Timestamp(Second, _) => {
instantiate_max_accumulator!(data_type, i64, TimestampSecondType)
primitive_max_accumulator!(data_type, i64, TimestampSecondType)
}
Timestamp(Millisecond, _) => {
instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType)
primitive_max_accumulator!(data_type, i64, TimestampMillisecondType)
}
Timestamp(Microsecond, _) => {
instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType)
primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType)
}
Timestamp(Nanosecond, _) => {
instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType)
primitive_max_accumulator!(data_type, i64, TimestampNanosecondType)
}
Decimal128(_, _) => {
instantiate_max_accumulator!(data_type, i128, Decimal128Type)
primitive_max_accumulator!(data_type, i128, Decimal128Type)
}
Decimal256(_, _) => {
instantiate_max_accumulator!(data_type, i256, Decimal256Type)
primitive_max_accumulator!(data_type, i256, Decimal256Type)
}
Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
}

// It would be nice to have a fast implementation for Strings as well
// https://github.com/apache/datafusion/issues/6906

// This is only reached if groups_accumulator_supported is out of sync
_ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
Expand Down Expand Up @@ -1057,6 +1066,12 @@ impl AggregateUDFImpl for Min {
| Time32(_)
| Time64(_)
| Timestamp(_, _)
| Utf8
| LargeUtf8
| Utf8View
| Binary
| LargeBinary
| BinaryView
)
}

Expand All @@ -1068,58 +1083,58 @@ impl AggregateUDFImpl for Min {
use TimeUnit::*;
let data_type = args.return_type;
match data_type {
Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type),
Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type),
Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type),
Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type),
UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type),
UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
Float16 => {
instantiate_min_accumulator!(data_type, f16, Float16Type)
primitive_min_accumulator!(data_type, f16, Float16Type)
}
Float32 => {
instantiate_min_accumulator!(data_type, f32, Float32Type)
primitive_min_accumulator!(data_type, f32, Float32Type)
}
Float64 => {
instantiate_min_accumulator!(data_type, f64, Float64Type)
primitive_min_accumulator!(data_type, f64, Float64Type)
}
Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type),
Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type),
Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
Time32(Second) => {
instantiate_min_accumulator!(data_type, i32, Time32SecondType)
primitive_min_accumulator!(data_type, i32, Time32SecondType)
}
Time32(Millisecond) => {
instantiate_min_accumulator!(data_type, i32, Time32MillisecondType)
primitive_min_accumulator!(data_type, i32, Time32MillisecondType)
}
Time64(Microsecond) => {
instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType)
primitive_min_accumulator!(data_type, i64, Time64MicrosecondType)
}
Time64(Nanosecond) => {
instantiate_min_accumulator!(data_type, i64, Time64NanosecondType)
primitive_min_accumulator!(data_type, i64, Time64NanosecondType)
}
Timestamp(Second, _) => {
instantiate_min_accumulator!(data_type, i64, TimestampSecondType)
primitive_min_accumulator!(data_type, i64, TimestampSecondType)
}
Timestamp(Millisecond, _) => {
instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType)
primitive_min_accumulator!(data_type, i64, TimestampMillisecondType)
}
Timestamp(Microsecond, _) => {
instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType)
primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType)
}
Timestamp(Nanosecond, _) => {
instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType)
primitive_min_accumulator!(data_type, i64, TimestampNanosecondType)
}
Decimal128(_, _) => {
instantiate_min_accumulator!(data_type, i128, Decimal128Type)
primitive_min_accumulator!(data_type, i128, Decimal128Type)
}
Decimal256(_, _) => {
instantiate_min_accumulator!(data_type, i256, Decimal256Type)
primitive_min_accumulator!(data_type, i256, Decimal256Type)
}
Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
}

// It would be nice to have a fast implementation for Strings as well
// https://github.com/apache/datafusion/issues/6906

// This is only reached if groups_accumulator_supported is out of sync
_ => internal_err!("GroupsAccumulator not supported for min({})", data_type),
Expand Down
Loading

0 comments on commit 646f40a

Please sign in to comment.