-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the schema mismatch between logical and physical for aggregate function, add AggregateUDFImpl::is_null
#11989
Changes from 4 commits
aed01f0
cbfefc6
b3fc2c8
20d0a5f
1132686
611092e
e732adc
ab38a5a
1d299eb
19a1ac7
984ced7
9b75540
6361bc4
794ce12
cb63514
9c12566
a42654c
e45d1bb
83ce363
3519e75
da30827
356faa8
043c332
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -670,6 +670,10 @@ impl DefaultPhysicalPlanner { | |
let input_exec = children.one()?; | ||
let physical_input_schema = input_exec.schema(); | ||
let logical_input_schema = input.as_ref().schema(); | ||
let physical_input_schema_from_logical: Arc<Schema> = | ||
logical_input_schema.as_ref().clone().into(); | ||
|
||
debug_assert_eq!(physical_input_schema_from_logical, physical_input_schema, "Physical input schema should be the same as the one converted from logical input schema. Please file an issue or send the PR"); | ||
|
||
let groups = self.create_grouping_physical_expr( | ||
group_expr, | ||
|
@@ -1548,7 +1552,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( | |
e: &Expr, | ||
name: Option<String>, | ||
logical_input_schema: &DFSchema, | ||
_physical_input_schema: &Schema, | ||
physical_input_schema: &Schema, | ||
execution_props: &ExecutionProps, | ||
) -> Result<AggregateExprWithOptionalArgs> { | ||
match e { | ||
|
@@ -1599,11 +1603,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( | |
let ordering_reqs: Vec<PhysicalSortExpr> = | ||
physical_sort_exprs.clone().unwrap_or(vec![]); | ||
|
||
let schema: Schema = logical_input_schema.clone().into(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. workaround cleanup |
||
let agg_expr = | ||
AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) | ||
.order_by(ordering_reqs.to_vec()) | ||
.schema(Arc::new(schema)) | ||
.schema(Arc::new(physical_input_schema.to_owned())) | ||
.alias(name) | ||
.with_ignore_nulls(ignore_nulls) | ||
.with_distinct(*distinct) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -328,10 +328,45 @@ impl ExprSchemable for Expr { | |
Ok(true) | ||
} | ||
} | ||
Expr::WindowFunction(WindowFunction { fun, .. }) => { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change required for this PR or is it a "drive by" improvement? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Required |
||
match fun { | ||
WindowFunctionDefinition::BuiltInWindowFunction(func) => { | ||
if func.name() == "ROW_NUMBER" | ||
|| func.name() == "RANK" | ||
|| func.name() == "NTILE" | ||
|| func.name() == "CUME_DIST" | ||
{ | ||
Ok(false) | ||
} else { | ||
Ok(true) | ||
} | ||
} | ||
WindowFunctionDefinition::AggregateUDF(func) => { | ||
// TODO: UDF should be able to customize nullability | ||
if func.name() == "count" { | ||
// TODO: there is issue unsolved for count with window, should return false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not so familiar with window function yet, leave it as TODO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we can file a ticket to track this -- ideally it would eventually be part of the window function definition itself rather than relying on names |
||
Ok(true) | ||
} else { | ||
Ok(true) | ||
} | ||
} | ||
_ => Ok(true), | ||
} | ||
} | ||
Expr::ScalarFunction(ScalarFunction { func, args }) => { | ||
// If all the element in coalesce is non-null, the result is non-null | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably add an API to ScalarUDFImpl to signal its null/non-nullness (as a follow on PR) instead of hard coding this function name
|
||
if func.name() == "coalesce" | ||
&& args | ||
.iter() | ||
.all(|e| !e.nullable(input_schema).ok().unwrap_or(true)) | ||
{ | ||
return Ok(false); | ||
} | ||
|
||
Ok(true) | ||
} | ||
Expr::ScalarVariable(_, _) | ||
| Expr::TryCast { .. } | ||
| Expr::ScalarFunction(..) | ||
| Expr::WindowFunction { .. } | ||
| Expr::Unnest(_) | ||
| Expr::Placeholder(_) => Ok(true), | ||
Expr::IsNull(_) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -196,6 +196,10 @@ impl AggregateUDF { | |
self.inner.state_fields(args) | ||
} | ||
|
||
pub fn fields(&self, args: StateFieldsArgs) -> Result<Field> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we document this function and what it is for (also in AggregateUdfImpl)? Also, the name is strange to me -- it is |
||
self.inner.field(args) | ||
} | ||
|
||
/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. | ||
pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { | ||
self.inner.groups_accumulator_supported(args) | ||
|
@@ -383,6 +387,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { | |
.collect()) | ||
} | ||
|
||
fn field(&self, args: StateFieldsArgs) -> Result<Field> { | ||
Ok(Field::new(args.name, args.return_type.clone(), true)) | ||
} | ||
|
||
/// If the aggregate expression has a specialized | ||
/// [`GroupsAccumulator`] implementation. If this returns true, | ||
/// `[Self::create_groups_accumulator]` will be called. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -171,6 +171,9 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> { | |
fn get_minmax_desc(&self) -> Option<(Field, bool)> { | ||
None | ||
} | ||
|
||
/// Get function's name, for example `count(x)` returns `count` | ||
fn func_name(&self) -> &str; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a reason this isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to identify function (i.e. count), there is |
||
} | ||
|
||
/// Stores the physical expressions used inside the `AggregateExpr`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
|
||
//! Optimizer rule for type validation and coercion | ||
|
||
use std::collections::HashMap; | ||
use std::sync::Arc; | ||
|
||
use itertools::izip; | ||
|
@@ -821,9 +820,18 @@ fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> { | |
.iter() | ||
.map(|f| f.is_nullable()) | ||
.collect::<Vec<_>>(); | ||
let mut union_field_meta = base_schema | ||
.fields() | ||
.iter() | ||
.map(|f| f.metadata().clone()) | ||
.collect::<Vec<_>>(); | ||
|
||
let mut metadata = base_schema.metadata().clone(); | ||
|
||
for (i, plan) in inputs.iter().enumerate().skip(1) { | ||
let plan_schema = plan.schema(); | ||
metadata.extend(plan_schema.metadata().clone()); | ||
|
||
if plan_schema.fields().len() != base_schema.fields().len() { | ||
return plan_err!( | ||
"Union schemas have different number of fields: \ | ||
|
@@ -833,39 +841,47 @@ fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> { | |
plan_schema.fields().len() | ||
); | ||
} | ||
// coerce data type and nullablity for each field | ||
for (union_datatype, union_nullable, plan_field) in izip!( | ||
union_datatypes.iter_mut(), | ||
union_nullabilities.iter_mut(), | ||
plan_schema.fields() | ||
) { | ||
let coerced_type = | ||
comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else( | ||
|| { | ||
plan_datafusion_err!( | ||
"Incompatible inputs for Union: Previous inputs were \ | ||
of type {}, but got incompatible type {} on column '{}'", | ||
union_datatype, | ||
plan_field.data_type(), | ||
plan_field.name() | ||
) | ||
}, | ||
)?; | ||
*union_datatype = coerced_type; | ||
*union_nullable = *union_nullable || plan_field.is_nullable(); | ||
|
||
// Safety: Length is checked | ||
unsafe { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this unsafe block is unecessary -- this isn't a performance critical piece of code. I think |
||
// coerce data type and nullablity for each field | ||
for (i, plan_field) in plan_schema.fields().iter().enumerate() { | ||
let union_datatype = union_datatypes.get_unchecked_mut(i); | ||
let union_nullable = union_nullabilities.get_unchecked_mut(i); | ||
let union_field_map = union_field_meta.get_unchecked_mut(i); | ||
|
||
let coerced_type = | ||
comparison_coercion(union_datatype, plan_field.data_type()) | ||
.ok_or_else(|| { | ||
plan_datafusion_err!( | ||
"Incompatible inputs for Union: Previous inputs were \ | ||
of type {}, but got incompatible type {} on column '{}'", | ||
union_datatype, | ||
plan_field.data_type(), | ||
plan_field.name() | ||
) | ||
})?; | ||
|
||
*union_datatype = coerced_type; | ||
*union_nullable = *union_nullable || plan_field.is_nullable(); | ||
union_field_map.extend(plan_field.metadata().clone()); | ||
} | ||
} | ||
} | ||
let union_qualified_fields = izip!( | ||
base_schema.iter(), | ||
union_datatypes.into_iter(), | ||
union_nullabilities | ||
union_nullabilities, | ||
union_field_meta.into_iter() | ||
) | ||
.map(|((qualifier, field), datatype, nullable)| { | ||
let field = Arc::new(Field::new(field.name().clone(), datatype, nullable)); | ||
(qualifier.cloned(), field) | ||
.map(|((qualifier, field), datatype, nullable, metadata)| { | ||
let mut field = Field::new(field.name().clone(), datatype, nullable); | ||
field.set_metadata(metadata); | ||
(qualifier.cloned(), field.into()) | ||
}) | ||
.collect::<Vec<_>>(); | ||
DFSchema::new_with_metadata(union_qualified_fields, HashMap::new()) | ||
|
||
DFSchema::new_with_metadata(union_qualified_fields, metadata) | ||
} | ||
|
||
/// See `<https://github.com/apache/datafusion/pull/2108>` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,14 @@ impl WindowExpr for PlainAggregateWindowExpr { | |
} | ||
|
||
fn field(&self) -> Result<Field> { | ||
// TODO: Fix window function to always return non-null for count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this comment -- can we please file a ticket to track it (and add the ticket reference to the comments)? |
||
if let Ok(name) = self.func_name() { | ||
if name == "count" { | ||
let field = self.aggregate.field()?; | ||
return Ok(field.with_nullable(true)); | ||
} | ||
} | ||
|
||
self.aggregate.field() | ||
} | ||
|
||
|
@@ -157,6 +165,10 @@ impl WindowExpr for PlainAggregateWindowExpr { | |
fn uses_bounded_memory(&self) -> bool { | ||
!self.window_frame.end_bound.is_unbounded() | ||
} | ||
|
||
fn func_name(&self) -> Result<&str> { | ||
Ok(self.aggregate.func_name()) | ||
} | ||
} | ||
|
||
impl AggregateWindowExpr for PlainAggregateWindowExpr { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ use arrow::compute::SortOptions; | |
use arrow::datatypes::Field; | ||
use arrow::record_batch::RecordBatch; | ||
use datafusion_common::utils::evaluate_partition_ranges; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_common::{not_impl_err, Result, ScalarValue}; | ||
use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; | ||
use datafusion_expr::WindowFrame; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move them back to the top? |
||
|
@@ -97,6 +97,10 @@ impl BuiltInWindowExpr { | |
} | ||
|
||
impl WindowExpr for BuiltInWindowExpr { | ||
fn func_name(&self) -> Result<&str> { | ||
not_impl_err!("function name not determined") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why wouldn't we implement func_name for a built in window function 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason is because I don't need it -- for name checking in |
||
} | ||
|
||
/// Return a reference to Any that can be used for downcasting | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main goal of the change is to ensure they are the same. And, we pass
physical_input_schema
through the function that require input's schema.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Did you consider making this function return an
internal_error
rather than debug_assert ?If we are concerned about breaking existing tests, we could add a config setting like
datafusion.optimizer.skip_failed_rules
to let users bypass the checkThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The objective here is to ensure that the logical schema from
ExprSchemable
and the physical schema fromExecutionPlan.schema()
are equivalent. if they are not, it indicates a potential schema mismatch issue. This is also why you can see the code change in this PR are mostly fixing schema related things and they are all required thus I don't think we should let user bypass the check 🤔If we encounter inconsistent schemas, it raises an important question: Which schema should we use?
It looks good to me