-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support GroupsAccumulator accumulator for udaf #8892
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,29 +16,35 @@ | |||||||||
// under the License. | ||||||||||
|
||||||||||
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; | ||||||||||
use datafusion_physical_expr::NullState; | ||||||||||
use std::{any::Any, sync::Arc}; | ||||||||||
|
||||||||||
use arrow::{ | ||||||||||
array::{ArrayRef, Float32Array}, | ||||||||||
array::{ | ||||||||||
ArrayRef, AsArray, Float32Array, PrimitiveArray, PrimitiveBuilder, UInt32Array, | ||||||||||
}, | ||||||||||
datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}, | ||||||||||
record_batch::RecordBatch, | ||||||||||
}; | ||||||||||
use datafusion::error::Result; | ||||||||||
use datafusion::prelude::*; | ||||||||||
use datafusion_common::{cast::as_float64_array, ScalarValue}; | ||||||||||
use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; | ||||||||||
use datafusion_expr::{ | ||||||||||
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, | ||||||||||
}; | ||||||||||
|
||||||||||
/// This example shows how to use the full AggregateUDFImpl API to implement a user | ||||||||||
/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements | ||||||||||
/// a function `accumulator` that returns the `Accumulator` instance. | ||||||||||
/// | ||||||||||
/// To do so, we must implement the `AggregateUDFImpl` trait. | ||||||||||
#[derive(Debug, Clone)] | ||||||||||
struct GeoMeanUdf { | ||||||||||
struct GeoMeanUdaf { | ||||||||||
signature: Signature, | ||||||||||
} | ||||||||||
|
||||||||||
impl GeoMeanUdf { | ||||||||||
/// Create a new instance of the GeoMeanUdf struct | ||||||||||
impl GeoMeanUdaf { | ||||||||||
/// Create a new instance of the GeoMeanUdaf struct | ||||||||||
fn new() -> Self { | ||||||||||
Self { | ||||||||||
signature: Signature::exact( | ||||||||||
|
@@ -52,7 +58,7 @@ impl GeoMeanUdf { | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
impl AggregateUDFImpl for GeoMeanUdf { | ||||||||||
impl AggregateUDFImpl for GeoMeanUdaf { | ||||||||||
/// We implement as_any so that we can downcast the AggregateUDFImpl trait object | ||||||||||
fn as_any(&self) -> &dyn Any { | ||||||||||
self | ||||||||||
|
@@ -74,6 +80,11 @@ impl AggregateUDFImpl for GeoMeanUdf { | |||||||||
} | ||||||||||
|
||||||||||
/// This is the accumulator factory; DataFusion uses it to create new accumulators. | ||||||||||
/// | ||||||||||
/// This is the accumulator factory for row wise accumulation; Even when `GroupsAccumulator` | ||||||||||
/// is supported, DataFusion will use this row oriented | ||||||||||
/// accumulator when the aggregate function is used as a window function | ||||||||||
/// or when there are only aggregates (no GROUP BY columns) in the plan. | ||||||||||
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> { | ||||||||||
Ok(Box::new(GeometricMean::new())) | ||||||||||
} | ||||||||||
|
@@ -82,6 +93,16 @@ impl AggregateUDFImpl for GeoMeanUdf { | |||||||||
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> { | ||||||||||
Ok(vec![DataType::Float64, DataType::UInt32]) | ||||||||||
} | ||||||||||
|
||||||||||
/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` | ||||||||||
/// which is used for cases when there are grouping columns in the query | ||||||||||
fn groups_accumulator_supported(&self) -> bool { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be good to add some context annotating this function for the example:
Suggested change
|
||||||||||
true | ||||||||||
} | ||||||||||
|
||||||||||
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> { | ||||||||||
Ok(Box::new(GeometricMeanGroupsAccumulator::new())) | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. | ||||||||||
|
@@ -173,16 +194,25 @@ fn create_context() -> Result<SessionContext> { | |||||||||
use datafusion::arrow::datatypes::{Field, Schema}; | ||||||||||
use datafusion::datasource::MemTable; | ||||||||||
// define a schema. | ||||||||||
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); | ||||||||||
let schema = Arc::new(Schema::new(vec![ | ||||||||||
Field::new("a", DataType::Float32, false), | ||||||||||
Field::new("b", DataType::Float32, false), | ||||||||||
])); | ||||||||||
|
||||||||||
// define data in two partitions | ||||||||||
let batch1 = RecordBatch::try_new( | ||||||||||
schema.clone(), | ||||||||||
vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], | ||||||||||
vec![ | ||||||||||
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), | ||||||||||
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), | ||||||||||
], | ||||||||||
)?; | ||||||||||
let batch2 = RecordBatch::try_new( | ||||||||||
schema.clone(), | ||||||||||
vec![Arc::new(Float32Array::from(vec![64.0]))], | ||||||||||
vec![ | ||||||||||
Arc::new(Float32Array::from(vec![64.0])), | ||||||||||
Arc::new(Float32Array::from(vec![2.0])), | ||||||||||
], | ||||||||||
)?; | ||||||||||
|
||||||||||
// declare a new context. In spark API, this corresponds to a new spark SQLsession | ||||||||||
|
@@ -194,15 +224,183 @@ fn create_context() -> Result<SessionContext> { | |||||||||
Ok(ctx) | ||||||||||
} | ||||||||||
|
||||||||||
// Define a `GroupsAccumulator` for GeometricMean | ||||||||||
/// which handles accumulator state for multiple groups at once. | ||||||||||
/// This API is significantly more complicated than `Accumulator`, which manages | ||||||||||
/// the state for a single group, but for queries with a large number of groups | ||||||||||
/// can be significantly faster. See the `GroupsAccumulator` documentation for | ||||||||||
/// more information. | ||||||||||
struct GeometricMeanGroupsAccumulator { | ||||||||||
/// The type of the internal sum | ||||||||||
prod_data_type: DataType, | ||||||||||
|
||||||||||
/// The type of the returned sum | ||||||||||
return_data_type: DataType, | ||||||||||
|
||||||||||
/// Count per group (use u32 to make UInt32Array) | ||||||||||
counts: Vec<u32>, | ||||||||||
|
||||||||||
/// product per group, stored as the native type (not `ScalarValue`) | ||||||||||
prods: Vec<f64>, | ||||||||||
|
||||||||||
/// Track nulls in the input / filters | ||||||||||
null_state: NullState, | ||||||||||
} | ||||||||||
|
||||||||||
impl GeometricMeanGroupsAccumulator { | ||||||||||
fn new() -> Self { | ||||||||||
Self { | ||||||||||
prod_data_type: DataType::Float64, | ||||||||||
return_data_type: DataType::Float64, | ||||||||||
counts: vec![], | ||||||||||
prods: vec![], | ||||||||||
null_state: NullState::new(), | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
impl GroupsAccumulator for GeometricMeanGroupsAccumulator { | ||||||||||
/// Updates the accumulator state given input. DataFusion provides `group_indices`, | ||||||||||
/// the groups that each row in `values` belongs to as well as an optional filter of which rows passed. | ||||||||||
fn update_batch( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
&mut self, | ||||||||||
values: &[ArrayRef], | ||||||||||
group_indices: &[usize], | ||||||||||
opt_filter: Option<&arrow::array::BooleanArray>, | ||||||||||
total_num_groups: usize, | ||||||||||
) -> Result<()> { | ||||||||||
assert_eq!(values.len(), 1, "single argument to update_batch"); | ||||||||||
let values = values[0].as_primitive::<Float64Type>(); | ||||||||||
|
||||||||||
// increment counts, update sums | ||||||||||
self.counts.resize(total_num_groups, 0); | ||||||||||
self.prods.resize(total_num_groups, 1.0); | ||||||||||
// Use the `NullState` structure to generate specialized code for null / non null input elements | ||||||||||
self.null_state.accumulate( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
group_indices, | ||||||||||
values, | ||||||||||
opt_filter, | ||||||||||
total_num_groups, | ||||||||||
|group_index, new_value| { | ||||||||||
let prod = &mut self.prods[group_index]; | ||||||||||
*prod = prod.mul_wrapping(new_value); | ||||||||||
|
||||||||||
self.counts[group_index] += 1; | ||||||||||
}, | ||||||||||
); | ||||||||||
|
||||||||||
Ok(()) | ||||||||||
} | ||||||||||
|
||||||||||
/// Merge the results from previous invocations of `evaluate` into this accumulator's state | ||||||||||
fn merge_batch( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
&mut self, | ||||||||||
values: &[ArrayRef], | ||||||||||
group_indices: &[usize], | ||||||||||
opt_filter: Option<&arrow::array::BooleanArray>, | ||||||||||
total_num_groups: usize, | ||||||||||
) -> Result<()> { | ||||||||||
assert_eq!(values.len(), 2, "two arguments to merge_batch"); | ||||||||||
// first batch is counts, second is partial sums | ||||||||||
let partial_prods = values[0].as_primitive::<Float64Type>(); | ||||||||||
let partial_counts = values[1].as_primitive::<UInt32Type>(); | ||||||||||
// update counts with partial counts | ||||||||||
self.counts.resize(total_num_groups, 0); | ||||||||||
self.null_state.accumulate( | ||||||||||
group_indices, | ||||||||||
partial_counts, | ||||||||||
opt_filter, | ||||||||||
total_num_groups, | ||||||||||
|group_index, partial_count| { | ||||||||||
self.counts[group_index] += partial_count; | ||||||||||
}, | ||||||||||
); | ||||||||||
|
||||||||||
// update prods | ||||||||||
self.prods.resize(total_num_groups, 1.0); | ||||||||||
self.null_state.accumulate( | ||||||||||
group_indices, | ||||||||||
partial_prods, | ||||||||||
opt_filter, | ||||||||||
total_num_groups, | ||||||||||
|group_index, new_value: <Float64Type as ArrowPrimitiveType>::Native| { | ||||||||||
let prod = &mut self.prods[group_index]; | ||||||||||
*prod = prod.mul_wrapping(new_value); | ||||||||||
}, | ||||||||||
); | ||||||||||
|
||||||||||
Ok(()) | ||||||||||
} | ||||||||||
|
||||||||||
/// Generate output, as specififed by `emit_to` and update the intermediate state | ||||||||||
fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<ArrayRef> { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
let counts = emit_to.take_needed(&mut self.counts); | ||||||||||
let prods = emit_to.take_needed(&mut self.prods); | ||||||||||
let nulls = self.null_state.build(emit_to); | ||||||||||
|
||||||||||
assert_eq!(nulls.len(), prods.len()); | ||||||||||
assert_eq!(counts.len(), prods.len()); | ||||||||||
|
||||||||||
// don't evaluate geometric mean with null inputs to avoid errors on null values | ||||||||||
|
||||||||||
let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 { | ||||||||||
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len()); | ||||||||||
let iter = prods.into_iter().zip(counts).zip(nulls.iter()); | ||||||||||
|
||||||||||
for ((prod, count), is_valid) in iter { | ||||||||||
if is_valid { | ||||||||||
builder.append_value(prod.powf(1.0 / count as f64)) | ||||||||||
} else { | ||||||||||
builder.append_null(); | ||||||||||
} | ||||||||||
} | ||||||||||
builder.finish() | ||||||||||
} else { | ||||||||||
let geo_mean: Vec<<Float64Type as ArrowPrimitiveType>::Native> = prods | ||||||||||
.into_iter() | ||||||||||
.zip(counts) | ||||||||||
.map(|(prod, count)| prod.powf(1.0 / count as f64)) | ||||||||||
.collect::<Vec<_>>(); | ||||||||||
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy | ||||||||||
.with_data_type(self.return_data_type.clone()) | ||||||||||
}; | ||||||||||
|
||||||||||
Ok(Arc::new(array)) | ||||||||||
} | ||||||||||
|
||||||||||
// return arrays for counts and prods | ||||||||||
fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result<Vec<ArrayRef>> { | ||||||||||
let nulls = self.null_state.build(emit_to); | ||||||||||
let nulls = Some(nulls); | ||||||||||
|
||||||||||
let counts = emit_to.take_needed(&mut self.counts); | ||||||||||
let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy | ||||||||||
|
||||||||||
let prods = emit_to.take_needed(&mut self.prods); | ||||||||||
let prods = PrimitiveArray::<Float64Type>::new(prods.into(), nulls) // zero copy | ||||||||||
.with_data_type(self.prod_data_type.clone()); | ||||||||||
|
||||||||||
Ok(vec![ | ||||||||||
Arc::new(prods) as ArrayRef, | ||||||||||
Arc::new(counts) as ArrayRef, | ||||||||||
]) | ||||||||||
} | ||||||||||
|
||||||||||
fn size(&self) -> usize { | ||||||||||
self.counts.capacity() * std::mem::size_of::<u32>() | ||||||||||
+ self.prods.capacity() * std::mem::size_of::<Float64Type>() | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
#[tokio::main] | ||||||||||
async fn main() -> Result<()> { | ||||||||||
let ctx = create_context()?; | ||||||||||
|
||||||||||
// create the AggregateUDF | ||||||||||
let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); | ||||||||||
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new()); | ||||||||||
ctx.register_udaf(geometric_mean.clone()); | ||||||||||
|
||||||||||
let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; | ||||||||||
let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?; | ||||||||||
sql_df.show().await?; | ||||||||||
|
||||||||||
// get a DataFrame from the context | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend we add a note to
accumulator()
above about when this is used. Now that I write this maybe we should also put some of this information on the docstrings forAggregateUDF::groups_accumulator