Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the schema mismatch between logical and physical for aggregate function, add AggregateUDFImpl::is_null #11989

Merged
merged 23 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,10 @@ impl DefaultPhysicalPlanner {
let input_exec = children.one()?;
let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();
let physical_input_schema_from_logical: Arc<Schema> =
logical_input_schema.as_ref().clone().into();

debug_assert_eq!(physical_input_schema_from_logical, physical_input_schema, "Physical input schema should be the same as the one converted from logical input schema. Please file an issue or send the PR");
Copy link
Contributor Author

@jayzhan211 jayzhan211 Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main goal of the change is to ensure they are the same. And, we pass physical_input_schema through the function that require input's schema.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Did you consider making this function return an internal_error rather than debug_assert ?

If we are concerned about breaking existing tests, we could add a config setting like datafusion.optimizer.skip_failed_rules to let users bypass the check

Copy link
Contributor Author

@jayzhan211 jayzhan211 Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The objective here is to ensure that the logical schema from ExprSchemable and the physical schema from ExecutionPlan.schema() are equivalent. if they are not, it indicates a potential schema mismatch issue. This is also why you can see the code change in this PR are mostly fixing schema related things and they are all required thus I don't think we should let user bypass the check 🤔

If we encounter inconsistent schemas, it raises an important question: Which schema should we use?

Did you consider making this function return an internal_error rather than debug_assert ?

It looks good to me


let groups = self.create_grouping_physical_expr(
group_expr,
Expand Down Expand Up @@ -1548,7 +1552,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: Option<String>,
logical_input_schema: &DFSchema,
_physical_input_schema: &Schema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<AggregateExprWithOptionalArgs> {
match e {
Expand Down Expand Up @@ -1599,11 +1603,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);

let schema: Schema = logical_input_schema.clone().into();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workaround cleanup

let agg_expr =
AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec())
.order_by(ordering_reqs.to_vec())
.schema(Arc::new(schema))
.schema(Arc::new(physical_input_schema.to_owned()))
.alias(name)
.with_ignore_nulls(ignore_nulls)
.with_distinct(*distinct)
Expand Down
39 changes: 37 additions & 2 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,45 @@ impl ExprSchemable for Expr {
Ok(true)
}
}
Expr::WindowFunction(WindowFunction { fun, .. }) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required for this PR or is it a "drive by" improvement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Required

match fun {
WindowFunctionDefinition::BuiltInWindowFunction(func) => {
if func.name() == "ROW_NUMBER"
|| func.name() == "RANK"
|| func.name() == "NTILE"
|| func.name() == "CUME_DIST"
{
Ok(false)
} else {
Ok(true)
}
}
WindowFunctionDefinition::AggregateUDF(func) => {
// TODO: UDF should be able to customize nullability
if func.name() == "count" {
// TODO: there is issue unsolved for count with window, should return false
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so familiar with window function yet, leave it as TODO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can file a ticket to track this -- ideally it would eventually be part of the window function definition itself rather than relying on names

Ok(true)
} else {
Ok(true)
}
}
_ => Ok(true),
}
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
// If all the element in coalesce is non-null, the result is non-null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add an API to ScalarUDFImpl to signal its null/non-nullness (as a follow on PR) instead of hard coding this function name

     func.is_nullable(args)

if func.name() == "coalesce"
&& args
.iter()
.all(|e| !e.nullable(input_schema).ok().unwrap_or(true))
{
return Ok(false);
}

Ok(true)
}
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::Unnest(_)
| Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
Expand Down
12 changes: 7 additions & 5 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2015,10 +2015,9 @@ impl Projection {
/// produced by the projection operation. If the schema computation is successful,
/// the `Result` will contain the schema; otherwise, it will contain an error.
pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result<Arc<DFSchema>> {
let mut schema = DFSchema::new_with_metadata(
exprlist_to_fields(exprs, input)?,
input.schema().metadata().clone(),
)?;
let metadata = input.schema().metadata().clone();
let mut schema =
DFSchema::new_with_metadata(exprlist_to_fields(exprs, input)?, metadata)?;
schema = schema.with_functional_dependencies(calc_func_dependencies_for_project(
exprs, input,
)?)?;
Expand Down Expand Up @@ -2659,7 +2658,10 @@ impl Aggregate {

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);

let schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?;
let schema = DFSchema::new_with_metadata(
qualified_fields,
input.schema().metadata().clone(),
)?;

Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema))
}
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ impl AggregateUDF {
self.inner.state_fields(args)
}

pub fn fields(&self, args: StateFieldsArgs) -> Result<Field> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we document this function and what it is for (also in AggregateUdfImpl)?

Also, the name is strange to me -- it is fields but it returns a single Field and the corresponding function on AggregateUDFImpl is called field (no s) 🤔

self.inner.field(args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner.groups_accumulator_supported(args)
Expand Down Expand Up @@ -383,6 +387,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
.collect())
}

fn field(&self, args: StateFieldsArgs) -> Result<Field> {
Ok(Field::new(args.name, args.return_type.clone(), true))
}

/// If the aggregate expression has a specialized
/// [`GroupsAccumulator`] implementation. If this returns true,
/// `[Self::create_groups_accumulator]` will be called.
Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate-common/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
fn get_minmax_desc(&self) -> Option<(Field, bool)> {
None
}

/// Get function's name, for example `count(x)` returns `count`
fn func_name(&self) -> &str;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason this isn't name() ? func_name is fine, it just seems inconsistent with the rest of the code

Copy link
Contributor Author

@jayzhan211 jayzhan211 Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to identify function (i.e. count), there is name() already, but it includes arguments (i.e. count(x)), which is not I want.
Alternative way is introduce nullable() for AggregateUDF, so we don't need name checking. Maybe I should done it before this PR.

}

/// Stores the physical expressions used inside the `AggregateExpr`.
Expand Down
5 changes: 5 additions & 0 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ impl AggregateUDFImpl for Count {
}
}

fn field(&self, args: StateFieldsArgs) -> Result<Field> {
// count always return non-null value
Ok(Field::new(args.name, args.return_type.clone(), false))
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if !acc_args.is_distinct {
return Ok(Box::new(CountAccumulator::new()));
Expand Down
68 changes: 42 additions & 26 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! Optimizer rule for type validation and coercion

use std::collections::HashMap;
use std::sync::Arc;

use itertools::izip;
Expand Down Expand Up @@ -821,9 +820,18 @@ fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
.iter()
.map(|f| f.is_nullable())
.collect::<Vec<_>>();
let mut union_field_meta = base_schema
.fields()
.iter()
.map(|f| f.metadata().clone())
.collect::<Vec<_>>();

let mut metadata = base_schema.metadata().clone();

for (i, plan) in inputs.iter().enumerate().skip(1) {
let plan_schema = plan.schema();
metadata.extend(plan_schema.metadata().clone());

if plan_schema.fields().len() != base_schema.fields().len() {
return plan_err!(
"Union schemas have different number of fields: \
Expand All @@ -833,39 +841,47 @@ fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
plan_schema.fields().len()
);
}
// coerce data type and nullablity for each field
for (union_datatype, union_nullable, plan_field) in izip!(
union_datatypes.iter_mut(),
union_nullabilities.iter_mut(),
plan_schema.fields()
) {
let coerced_type =
comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
|| {
plan_datafusion_err!(
"Incompatible inputs for Union: Previous inputs were \
of type {}, but got incompatible type {} on column '{}'",
union_datatype,
plan_field.data_type(),
plan_field.name()
)
},
)?;
*union_datatype = coerced_type;
*union_nullable = *union_nullable || plan_field.is_nullable();

// Safety: Length is checked
unsafe {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this unsafe block is unecessary -- this isn't a performance critical piece of code. I think izip or just manuallly zipping three times would be better

// coerce data type and nullablity for each field
for (i, plan_field) in plan_schema.fields().iter().enumerate() {
let union_datatype = union_datatypes.get_unchecked_mut(i);
let union_nullable = union_nullabilities.get_unchecked_mut(i);
let union_field_map = union_field_meta.get_unchecked_mut(i);

let coerced_type =
comparison_coercion(union_datatype, plan_field.data_type())
.ok_or_else(|| {
plan_datafusion_err!(
"Incompatible inputs for Union: Previous inputs were \
of type {}, but got incompatible type {} on column '{}'",
union_datatype,
plan_field.data_type(),
plan_field.name()
)
})?;

*union_datatype = coerced_type;
*union_nullable = *union_nullable || plan_field.is_nullable();
union_field_map.extend(plan_field.metadata().clone());
}
}
}
let union_qualified_fields = izip!(
base_schema.iter(),
union_datatypes.into_iter(),
union_nullabilities
union_nullabilities,
union_field_meta.into_iter()
)
.map(|((qualifier, field), datatype, nullable)| {
let field = Arc::new(Field::new(field.name().clone(), datatype, nullable));
(qualifier.cloned(), field)
.map(|((qualifier, field), datatype, nullable, metadata)| {
let mut field = Field::new(field.name().clone(), datatype, nullable);
field.set_metadata(metadata);
(qualifier.cloned(), field.into())
})
.collect::<Vec<_>>();
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())

DFSchema::new_with_metadata(union_qualified_fields, metadata)
}

/// See `<https://github.com/apache/datafusion/pull/2108>`
Expand Down
14 changes: 13 additions & 1 deletion datafusion/physical-expr-functions-aggregate/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,15 @@ impl AggregateExpr for AggregateFunctionExpr {
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
let args = StateFieldsArgs {
name: &self.name,
input_types: &self.input_types,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
};

self.fun.fields(args)
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Expand Down Expand Up @@ -435,6 +443,10 @@ impl AggregateExpr for AggregateFunctionExpr {
.is_descending()
.and_then(|flag| self.field().ok().map(|f| (f, flag)))
}

fn func_name(&self) -> &str {
self.fun.name()
}
}

impl PartialEq<dyn Any> for AggregateFunctionExpr {
Expand Down
12 changes: 12 additions & 0 deletions datafusion/physical-expr/src/window/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ impl WindowExpr for PlainAggregateWindowExpr {
}

fn field(&self) -> Result<Field> {
// TODO: Fix window function to always return non-null for count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this comment -- can we please file a ticket to track it (and add the ticket reference to the comments)?

if let Ok(name) = self.func_name() {
if name == "count" {
let field = self.aggregate.field()?;
return Ok(field.with_nullable(true));
}
}

self.aggregate.field()
}

Expand Down Expand Up @@ -157,6 +165,10 @@ impl WindowExpr for PlainAggregateWindowExpr {
fn uses_bounded_memory(&self) -> bool {
!self.window_frame.end_bound.is_unbounded()
}

fn func_name(&self) -> Result<&str> {
Ok(self.aggregate.func_name())
}
}

impl AggregateWindowExpr for PlainAggregateWindowExpr {
Expand Down
6 changes: 5 additions & 1 deletion datafusion/physical-expr/src/window/built_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::evaluate_partition_ranges;
use datafusion_common::{Result, ScalarValue};
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_expr::window_state::{WindowAggState, WindowFrameContext};
use datafusion_expr::WindowFrame;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move them back to the top?

Expand Down Expand Up @@ -97,6 +97,10 @@ impl BuiltInWindowExpr {
}

impl WindowExpr for BuiltInWindowExpr {
fn func_name(&self) -> Result<&str> {
not_impl_err!("function name not determined")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why wouldn't we implement func_name for a built in window function 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is because I don't need it -- for name checking in nullable

}

/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
Expand Down
12 changes: 12 additions & 0 deletions datafusion/physical-expr/src/window/sliding_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ impl WindowExpr for SlidingAggregateWindowExpr {
}

fn field(&self) -> Result<Field> {
// TODO: Fix window function to always return non-null for count
if let Ok(name) = self.func_name() {
if name == "count" {
let field = self.aggregate.field()?;
return Ok(field.with_nullable(true));
}
}

self.aggregate.field()
}

Expand Down Expand Up @@ -166,6 +174,10 @@ impl WindowExpr for SlidingAggregateWindowExpr {
window_frame: Arc::clone(&self.window_frame),
}))
}

fn func_name(&self) -> Result<&str> {
Ok(self.aggregate.func_name())
}
}

impl AggregateWindowExpr for SlidingAggregateWindowExpr {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-expr/src/window/window_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ pub trait WindowExpr: Send + Sync + Debug {
) -> Option<Arc<dyn WindowExpr>> {
None
}

fn func_name(&self) -> Result<&str>;
}

/// Stores the physical expressions used inside the `WindowExpr`.
Expand Down
Loading