Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate coercion code in datafusion_expr::type_coercion and submodules #3728

Merged
merged 6 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
698 changes: 7 additions & 691 deletions datafusion/expr/src/aggregate_function.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
// under the License.

use super::Expr;
use crate::binary_rule::binary_operator_data_type;
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::binary_operator_data_type;
use crate::{aggregate_function, function, window_function};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Function module contains typing and signature for built-in and user defined functions.

use crate::nullif::SUPPORTED_NULLIF_TYPES;
use crate::type_coercion::data_types;
use crate::type_coercion::functions::data_types;
use crate::ColumnarValue;
use crate::{
array_expressions, conditional_expressions, struct_expressions, Accumulator,
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
mod accumulator;
pub mod aggregate_function;
pub mod array_expressions;
pub mod binary_rule;
mod built_in_function;
mod columnar_value;
pub mod conditional_expressions;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

//! This module provides a builder for creating LogicalPlans

use crate::binary_rule::comparison_coercion;
use crate::expr_rewriter::{
coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::type_coercion::binary::comparison_coercion;
use crate::utils::{
columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
};
Expand Down
266 changes: 34 additions & 232 deletions datafusion/expr/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,247 +15,49 @@
// specific language governing permissions and limitations
// under the License.

//! Type coercion rules for functions with multiple valid signatures
//! Type coercion rules for DataFusion
//!
//! Coercion is performed automatically by DataFusion when the types
//! of arguments passed to a function do not exacty match the types
//! required by that function. In this case, DataFusion will attempt to
//! *coerce* the arguments to types accepted by the function by
//! inserting CAST operations.
//! of arguments passed to a function or needed by operators do not
//! exacty match the types required by that function / operator. In
//! this case, DataFusion will attempt to *coerce* the arguments to
//! types accepted by the function by inserting CAST operations.
//!
//! CAST operations added by coercion are lossless and never discard
//! information. For example coercion from i32 -> i64 might be
//! information.
//!
//! For example coercion from i32 -> i64 might be
//! performed because all valid i32 values can be represented using an
//! i64. However, i64 -> i32 is never performed as there are i64
//! values which can not be represented by i32 values.
//!

use crate::{Signature, TypeSignature};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this code into Moved to type_coercion/function.rsas it is related to function coercion, despite thedata_types` generic name

use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{DataFusionError, Result};

/// Returns the data types that each argument must be coerced to match
/// `signature`.
///
/// See the module level documentation for more detail on coercion.
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
if current_types.is_empty() {
return Ok(vec![]);
}
let valid_types = get_valid_types(&signature.type_signature, current_types)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}

// none possible -> Error
Err(DataFusionError::Plan(format!(
"Coercion from {:?} to the signature {:?} failed.",
current_types, &signature.type_signature
)))
use arrow::datatypes::DataType;

/// Determine if a DataType is signed numeric or not
pub fn is_signed_numeric(dt: &DataType) -> bool {
matches!(
dt,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
)
}

fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
// one entry with the same len as current_types, whose type is `current_types[0]`.
vec![current_types
.iter()
.map(|_| current_types[0].clone())
.collect()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::Any(number) => {
if current_types.len() != *number {
return Err(DataFusionError::Plan(format!(
"The function expected {} arguments but received {}",
number,
current_types.len()
)));
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
TypeSignature::OneOf(types) => types
.iter()
.filter_map(|t| get_valid_types(t, current_types).ok())
.flatten()
.collect::<Vec<_>>(),
};

Ok(valid_types)
}

/// Try to coerce current_types into valid_types.
fn maybe_data_types(
valid_types: &[DataType],
current_types: &[DataType],
) -> Option<Vec<DataType>> {
if valid_types.len() != current_types.len() {
return None;
}

let mut new_type = Vec::with_capacity(valid_types.len());
for (i, valid_type) in valid_types.iter().enumerate() {
let current_type = &current_types[i];

if current_type == valid_type {
new_type.push(current_type.clone())
} else {
// attempt to coerce
if can_coerce_from(valid_type, current_type) {
new_type.push(valid_type.clone())
} else {
// not possible
return None;
}
}
}
Some(new_type)
/// Determine if a DataType is numeric or not
pub fn is_numeric(dt: &DataType) -> bool {
is_signed_numeric(dt)
|| matches!(
dt,
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64
)
}

/// Return true if a value of type `type_from` can be coerced
/// (losslessly converted) into a value of `type_to`
///
/// See the module level documentation for more detail on coercion.
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;
// Null can convert to most of types
match type_into {
Int8 => matches!(type_from, Null | Int8),
Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8),
Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16),
Int64 => matches!(
type_from,
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
),
UInt8 => matches!(type_from, Null | UInt8),
UInt16 => matches!(type_from, Null | UInt8 | UInt16),
UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32),
UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64),
Float32 => matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
),
Float64 => matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
| Decimal128(_, _)
),
Timestamp(TimeUnit::Nanosecond, None) => {
matches!(type_from, Null | Timestamp(_, None))
}
Utf8 | LargeUtf8 => true,
Null => can_cast_types(type_from, type_into),
_ => false,
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;

#[test]
fn test_maybe_data_types() {
// this vec contains: arg1, arg2, expected result
let cases = vec![
// 2 entries, same values
(
vec![DataType::UInt8, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt8, DataType::UInt16]),
),
// 2 entries, can coerse values
(
vec![DataType::UInt16, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt16, DataType::UInt16]),
),
// 0 entries, all good
(vec![], vec![], Some(vec![])),
// 2 entries, can't coerce
(
vec![DataType::Boolean, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
None,
),
// u32 -> u16 is possible
(
vec![DataType::Boolean, DataType::UInt32],
vec![DataType::Boolean, DataType::UInt16],
Some(vec![DataType::Boolean, DataType::UInt32]),
),
];

for case in cases {
assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
}
}

#[test]
fn test_get_valid_types_one_of() -> Result<()> {
let signature =
TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);

let invalid_types = get_valid_types(
&signature,
&[DataType::Int32, DataType::Int32, DataType::Int32],
)?;
assert_eq!(invalid_types.len(), 0);

let args = vec![DataType::Int32, DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);

let args = vec![DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);

Ok(())
}
}
pub mod aggregates;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the code for various coercion lives in a module that helps identify what it is for

pub mod binary;
pub mod functions;
pub mod other;
Loading