Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add customizable equality and hash functions to UDFs #252

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> {
.await;
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })"
"Arrow error: Invalid argument error: Nested comparison: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) (hint: use make_comparator instead)"
);
Ok(())
}
80 changes: 75 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
};
use arrow_schema::Schema;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
Expand All @@ -45,8 +50,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
LogicalPlanBuilder, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -377,6 +382,56 @@ async fn test_groups_accumulator() -> Result<()> {
Ok(())
}

#[ignore]
#[tokio::test]
async fn test_parameterized_aggregate_udf() -> Result<()> {
Comment on lines +385 to +387
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works on main but not here. Unfortunately something else must be broken too 😢

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed in apache#10473

let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 1,
});
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 2,
});

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.aggregate(
[col("text")],
[
udf1.call(vec![col("text")]).alias("a"),
udf2.call(vec![col("text")]).alias("b"),
],
)?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+------+---+---+",
"| text | a | b |",
"+------+---+---+",
"| foo | 1 | 2 |",
"+------+---+---+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down Expand Up @@ -735,6 +790,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
}

impl Accumulator for TestGroupsAccumulator {
Expand Down
128 changes: 125 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
// under the License.

use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use parking_lot::Mutex;
use regex::Regex;
use sqlparser::ast::Ident;

use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
Expand All @@ -37,8 +46,6 @@ use datafusion_expr::{
Volatility,
};
use datafusion_functions_array::range::range_udf;
use parking_lot::Mutex;
use sqlparser::ast::Ident;

/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
Expand Down Expand Up @@ -1010,6 +1017,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<(
Ok(())
}

#[derive(Debug)]
struct MyRegexUdf {
signature: Signature,
regex: Regex,
}

impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
regex: Regex::new(pattern).expect("regex"),
}
}

fn matches(&self, value: Option<&str>) -> Option<bool> {
Some(self.regex.is_match(value?))
}
}

impl ScalarUDFImpl for MyRegexUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"regex_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if matches!(args, [DataType::Utf8]) {
Ok(DataType::Boolean)
} else {
plan_err!("regex_udf only accepts a Utf8 argument")
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
self.matches(value.as_deref()),
)))
}
[ColumnarValue::Array(values)] => {
let mut builder = BooleanBuilder::with_capacity(values.len());
for value in values.as_string::<i32>() {
builder.append_option(self.matches(value))
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
self.regex.as_str() == other.regex.as_str()
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.regex.as_str().hash(hasher);
hasher.finish()
}
}

#[tokio::test]
async fn test_parameterized_scalar_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.filter(
foo_udf
.call(vec![col("text")])
.and(bar_udf.call(vec![col("text")])),
)?
.filter(col("text").is_not_null())?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+--------+",
"| text |",
"+--------+",
"| foobar |",
"| barfoo |",
"+--------+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
71 changes: 58 additions & 13 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::vec;

use arrow::datatypes::{DataType, Field};

use datafusion_common::{exec_err, not_impl_err, Result};

use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
Expand All @@ -25,12 +35,6 @@ use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -70,20 +74,19 @@ pub struct AggregateUDF {

impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref())
}
}

impl Eq for AggregateUDF {}

impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for AggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

impl std::fmt::Display for AggregateUDF {
impl fmt::Display for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
Expand Down Expand Up @@ -276,7 +279,7 @@ where
/// #[derive(Debug, Clone)]
/// struct GeoMeanUdf {
/// signature: Signature
/// };
/// }
///
/// impl GeoMeanUdf {
/// fn new() -> Self {
Expand Down Expand Up @@ -503,6 +506,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}

/// Return true if this aggregate UDF is equal to the other.
///
/// Allows customizing the equality of aggregate UDFs.
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
///
/// - reflexive: `a.equals(a)`;
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
///
/// By default, compares [`Self::name`] and [`Self::signature`].
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Returns a hash value for this aggregate UDF.
///
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
///
/// By default, hashes [`Self::name`] and [`Self::signature`].
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

pub enum ReversedUDAF {
Expand Down Expand Up @@ -558,6 +588,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
Loading
Loading