Skip to content

Commit

Permalink
Consolidate coercion code in datafusion_expr::type_coercion and sub…
Browse files Browse the repository at this point in the history
…modules (#3728)

* Move function coercion to its own module

* Move binary into type coercion

* fmt

* More updates

* consolidate some more

* Move aggregates
  • Loading branch information
alamb authored Oct 6, 2022
1 parent 23682f6 commit 7c5c2e5
Show file tree
Hide file tree
Showing 20 changed files with 1,076 additions and 1,002 deletions.
698 changes: 7 additions & 691 deletions datafusion/expr/src/aggregate_function.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
// under the License.

use super::Expr;
use crate::binary_rule::binary_operator_data_type;
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::binary_operator_data_type;
use crate::{aggregate_function, function, window_function};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Function module contains typing and signature for built-in and user defined functions.
use crate::nullif::SUPPORTED_NULLIF_TYPES;
use crate::type_coercion::data_types;
use crate::type_coercion::functions::data_types;
use crate::ColumnarValue;
use crate::{
array_expressions, conditional_expressions, struct_expressions, Accumulator,
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
mod accumulator;
pub mod aggregate_function;
pub mod array_expressions;
pub mod binary_rule;
mod built_in_function;
mod columnar_value;
pub mod conditional_expressions;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

//! This module provides a builder for creating LogicalPlans
use crate::binary_rule::comparison_coercion;
use crate::expr_rewriter::{
coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::type_coercion::binary::comparison_coercion;
use crate::utils::{
columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
};
Expand Down
266 changes: 34 additions & 232 deletions datafusion/expr/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,247 +15,49 @@
// specific language governing permissions and limitations
// under the License.

//! Type coercion rules for functions with multiple valid signatures
//! Type coercion rules for DataFusion
//!
//! Coercion is performed automatically by DataFusion when the types
//! of arguments passed to a function do not exacty match the types
//! required by that function. In this case, DataFusion will attempt to
//! *coerce* the arguments to types accepted by the function by
//! inserting CAST operations.
//! of arguments passed to a function or needed by operators do not
//! exacty match the types required by that function / operator. In
//! this case, DataFusion will attempt to *coerce* the arguments to
//! types accepted by the function by inserting CAST operations.
//!
//! CAST operations added by coercion are lossless and never discard
//! information. For example coercion from i32 -> i64 might be
//! information.
//!
//! For example coercion from i32 -> i64 might be
//! performed because all valid i32 values can be represented using an
//! i64. However, i64 -> i32 is never performed as there are i64
//! values which can not be represented by i32 values.
//!
use crate::{Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{DataFusionError, Result};

/// Returns the data types that each argument must be coerced to match
/// `signature`.
///
/// See the module level documentation for more detail on coercion.
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
if current_types.is_empty() {
return Ok(vec![]);
}
let valid_types = get_valid_types(&signature.type_signature, current_types)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}
// none possible -> Error
Err(DataFusionError::Plan(format!(
"Coercion from {:?} to the signature {:?} failed.",
current_types, &signature.type_signature
)))
use arrow::datatypes::DataType;

/// Determine if a DataType is signed numeric or not
pub fn is_signed_numeric(dt: &DataType) -> bool {
matches!(
dt,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
)
}

fn get_valid_types(
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
// one entry with the same len as current_types, whose type is `current_types[0]`.
vec![current_types
.iter()
.map(|_| current_types[0].clone())
.collect()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::Any(number) => {
if current_types.len() != *number {
return Err(DataFusionError::Plan(format!(
"The function expected {} arguments but received {}",
number,
current_types.len()
)));
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
TypeSignature::OneOf(types) => types
.iter()
.filter_map(|t| get_valid_types(t, current_types).ok())
.flatten()
.collect::<Vec<_>>(),
};

Ok(valid_types)
}

/// Try to coerce current_types into valid_types.
fn maybe_data_types(
valid_types: &[DataType],
current_types: &[DataType],
) -> Option<Vec<DataType>> {
if valid_types.len() != current_types.len() {
return None;
}

let mut new_type = Vec::with_capacity(valid_types.len());
for (i, valid_type) in valid_types.iter().enumerate() {
let current_type = &current_types[i];

if current_type == valid_type {
new_type.push(current_type.clone())
} else {
// attempt to coerce
if can_coerce_from(valid_type, current_type) {
new_type.push(valid_type.clone())
} else {
// not possible
return None;
}
}
}
Some(new_type)
/// Determine if a DataType is numeric or not
pub fn is_numeric(dt: &DataType) -> bool {
is_signed_numeric(dt)
|| matches!(
dt,
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64
)
}

/// Return true if a value of type `type_from` can be coerced
/// (losslessly converted) into a value of `type_to`
///
/// See the module level documentation for more detail on coercion.
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;
// Null can convert to most of types
match type_into {
Int8 => matches!(type_from, Null | Int8),
Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8),
Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16),
Int64 => matches!(
type_from,
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
),
UInt8 => matches!(type_from, Null | UInt8),
UInt16 => matches!(type_from, Null | UInt8 | UInt16),
UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32),
UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64),
Float32 => matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
),
Float64 => matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
| Decimal128(_, _)
),
Timestamp(TimeUnit::Nanosecond, None) => {
matches!(type_from, Null | Timestamp(_, None))
}
Utf8 | LargeUtf8 => true,
Null => can_cast_types(type_from, type_into),
_ => false,
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;

#[test]
fn test_maybe_data_types() {
// this vec contains: arg1, arg2, expected result
let cases = vec![
// 2 entries, same values
(
vec![DataType::UInt8, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt8, DataType::UInt16]),
),
// 2 entries, can coerse values
(
vec![DataType::UInt16, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt16, DataType::UInt16]),
),
// 0 entries, all good
(vec![], vec![], Some(vec![])),
// 2 entries, can't coerce
(
vec![DataType::Boolean, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
None,
),
// u32 -> u16 is possible
(
vec![DataType::Boolean, DataType::UInt32],
vec![DataType::Boolean, DataType::UInt16],
Some(vec![DataType::Boolean, DataType::UInt32]),
),
];

for case in cases {
assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
}
}

#[test]
fn test_get_valid_types_one_of() -> Result<()> {
let signature =
TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);

let invalid_types = get_valid_types(
&signature,
&[DataType::Int32, DataType::Int32, DataType::Int32],
)?;
assert_eq!(invalid_types.len(), 0);

let args = vec![DataType::Int32, DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);

let args = vec![DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);

Ok(())
}
}
pub mod aggregates;
pub mod binary;
pub mod functions;
pub mod other;
Loading

0 comments on commit 7c5c2e5

Please sign in to comment.