-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consolidate coercion code in datafusion_expr::type_coercion
and submodules
#3728
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,247 +15,49 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
//! Type coercion rules for functions with multiple valid signatures | ||
//! Type coercion rules for DataFusion | ||
//! | ||
//! Coercion is performed automatically by DataFusion when the types | ||
//! of arguments passed to a function do not exacty match the types | ||
//! required by that function. In this case, DataFusion will attempt to | ||
//! *coerce* the arguments to types accepted by the function by | ||
//! inserting CAST operations. | ||
//! of arguments passed to a function or needed by operators do not | ||
//! exacty match the types required by that function / operator. In | ||
//! this case, DataFusion will attempt to *coerce* the arguments to | ||
//! types accepted by the function by inserting CAST operations. | ||
//! | ||
//! CAST operations added by coercion are lossless and never discard | ||
//! information. For example coercion from i32 -> i64 might be | ||
//! information. | ||
//! | ||
//! For example coercion from i32 -> i64 might be | ||
//! performed because all valid i32 values can be represented using an | ||
//! i64. However, i64 -> i32 is never performed as there are i64 | ||
//! values which can not be represented by i32 values. | ||
//! | ||
|
||
use crate::{Signature, TypeSignature}; | ||
use arrow::{ | ||
compute::can_cast_types, | ||
datatypes::{DataType, TimeUnit}, | ||
}; | ||
use datafusion_common::{DataFusionError, Result}; | ||
|
||
/// Returns the data types that each argument must be coerced to match | ||
/// `signature`. | ||
/// | ||
/// See the module level documentation for more detail on coercion. | ||
pub fn data_types( | ||
current_types: &[DataType], | ||
signature: &Signature, | ||
) -> Result<Vec<DataType>> { | ||
if current_types.is_empty() { | ||
return Ok(vec![]); | ||
} | ||
let valid_types = get_valid_types(&signature.type_signature, current_types)?; | ||
|
||
if valid_types | ||
.iter() | ||
.any(|data_type| data_type == current_types) | ||
{ | ||
return Ok(current_types.to_vec()); | ||
} | ||
|
||
for valid_types in valid_types { | ||
if let Some(types) = maybe_data_types(&valid_types, current_types) { | ||
return Ok(types); | ||
} | ||
} | ||
|
||
// none possible -> Error | ||
Err(DataFusionError::Plan(format!( | ||
"Coercion from {:?} to the signature {:?} failed.", | ||
current_types, &signature.type_signature | ||
))) | ||
use arrow::datatypes::DataType; | ||
|
||
/// Determine if a DataType is signed numeric or not | ||
pub fn is_signed_numeric(dt: &DataType) -> bool { | ||
matches!( | ||
dt, | ||
DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::Float16 | ||
| DataType::Float32 | ||
| DataType::Float64 | ||
| DataType::Decimal128(_, _) | ||
) | ||
} | ||
|
||
fn get_valid_types( | ||
signature: &TypeSignature, | ||
current_types: &[DataType], | ||
) -> Result<Vec<Vec<DataType>>> { | ||
let valid_types = match signature { | ||
TypeSignature::Variadic(valid_types) => valid_types | ||
.iter() | ||
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) | ||
.collect(), | ||
TypeSignature::Uniform(number, valid_types) => valid_types | ||
.iter() | ||
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) | ||
.collect(), | ||
TypeSignature::VariadicEqual => { | ||
// one entry with the same len as current_types, whose type is `current_types[0]`. | ||
vec![current_types | ||
.iter() | ||
.map(|_| current_types[0].clone()) | ||
.collect()] | ||
} | ||
TypeSignature::Exact(valid_types) => vec![valid_types.clone()], | ||
TypeSignature::Any(number) => { | ||
if current_types.len() != *number { | ||
return Err(DataFusionError::Plan(format!( | ||
"The function expected {} arguments but received {}", | ||
number, | ||
current_types.len() | ||
))); | ||
} | ||
vec![(0..*number).map(|i| current_types[i].clone()).collect()] | ||
} | ||
TypeSignature::OneOf(types) => types | ||
.iter() | ||
.filter_map(|t| get_valid_types(t, current_types).ok()) | ||
.flatten() | ||
.collect::<Vec<_>>(), | ||
}; | ||
|
||
Ok(valid_types) | ||
} | ||
|
||
/// Try to coerce current_types into valid_types. | ||
fn maybe_data_types( | ||
valid_types: &[DataType], | ||
current_types: &[DataType], | ||
) -> Option<Vec<DataType>> { | ||
if valid_types.len() != current_types.len() { | ||
return None; | ||
} | ||
|
||
let mut new_type = Vec::with_capacity(valid_types.len()); | ||
for (i, valid_type) in valid_types.iter().enumerate() { | ||
let current_type = ¤t_types[i]; | ||
|
||
if current_type == valid_type { | ||
new_type.push(current_type.clone()) | ||
} else { | ||
// attempt to coerce | ||
if can_coerce_from(valid_type, current_type) { | ||
new_type.push(valid_type.clone()) | ||
} else { | ||
// not possible | ||
return None; | ||
} | ||
} | ||
} | ||
Some(new_type) | ||
/// Determine if a DataType is numeric or not | ||
pub fn is_numeric(dt: &DataType) -> bool { | ||
is_signed_numeric(dt) | ||
|| matches!( | ||
dt, | ||
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 | ||
) | ||
} | ||
|
||
/// Return true if a value of type `type_from` can be coerced | ||
/// (losslessly converted) into a value of `type_to` | ||
/// | ||
/// See the module level documentation for more detail on coercion. | ||
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { | ||
use self::DataType::*; | ||
// Null can convert to most of types | ||
match type_into { | ||
Int8 => matches!(type_from, Null | Int8), | ||
Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8), | ||
Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16), | ||
Int64 => matches!( | ||
type_from, | ||
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | ||
), | ||
UInt8 => matches!(type_from, Null | UInt8), | ||
UInt16 => matches!(type_from, Null | UInt8 | UInt16), | ||
UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32), | ||
UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64), | ||
Float32 => matches!( | ||
type_from, | ||
Null | Int8 | ||
| Int16 | ||
| Int32 | ||
| Int64 | ||
| UInt8 | ||
| UInt16 | ||
| UInt32 | ||
| UInt64 | ||
| Float32 | ||
), | ||
Float64 => matches!( | ||
type_from, | ||
Null | Int8 | ||
| Int16 | ||
| Int32 | ||
| Int64 | ||
| UInt8 | ||
| UInt16 | ||
| UInt32 | ||
| UInt64 | ||
| Float32 | ||
| Float64 | ||
| Decimal128(_, _) | ||
), | ||
Timestamp(TimeUnit::Nanosecond, None) => { | ||
matches!(type_from, Null | Timestamp(_, None)) | ||
} | ||
Utf8 | LargeUtf8 => true, | ||
Null => can_cast_types(type_from, type_into), | ||
_ => false, | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use arrow::datatypes::DataType; | ||
|
||
#[test] | ||
fn test_maybe_data_types() { | ||
// this vec contains: arg1, arg2, expected result | ||
let cases = vec![ | ||
// 2 entries, same values | ||
( | ||
vec![DataType::UInt8, DataType::UInt16], | ||
vec![DataType::UInt8, DataType::UInt16], | ||
Some(vec![DataType::UInt8, DataType::UInt16]), | ||
), | ||
// 2 entries, can coerse values | ||
( | ||
vec![DataType::UInt16, DataType::UInt16], | ||
vec![DataType::UInt8, DataType::UInt16], | ||
Some(vec![DataType::UInt16, DataType::UInt16]), | ||
), | ||
// 0 entries, all good | ||
(vec![], vec![], Some(vec![])), | ||
// 2 entries, can't coerce | ||
( | ||
vec![DataType::Boolean, DataType::UInt16], | ||
vec![DataType::UInt8, DataType::UInt16], | ||
None, | ||
), | ||
// u32 -> u16 is possible | ||
( | ||
vec![DataType::Boolean, DataType::UInt32], | ||
vec![DataType::Boolean, DataType::UInt16], | ||
Some(vec![DataType::Boolean, DataType::UInt32]), | ||
), | ||
]; | ||
|
||
for case in cases { | ||
assert_eq!(maybe_data_types(&case.0, &case.1), case.2) | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_get_valid_types_one_of() -> Result<()> { | ||
let signature = | ||
TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]); | ||
|
||
let invalid_types = get_valid_types( | ||
&signature, | ||
&[DataType::Int32, DataType::Int32, DataType::Int32], | ||
)?; | ||
assert_eq!(invalid_types.len(), 0); | ||
|
||
let args = vec![DataType::Int32, DataType::Int32]; | ||
let valid_types = get_valid_types(&signature, &args)?; | ||
assert_eq!(valid_types.len(), 1); | ||
assert_eq!(valid_types[0], args); | ||
|
||
let args = vec![DataType::Int32]; | ||
let valid_types = get_valid_types(&signature, &args)?; | ||
assert_eq!(valid_types.len(), 1); | ||
assert_eq!(valid_types[0], args); | ||
|
||
Ok(()) | ||
} | ||
} | ||
pub mod aggregates; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now the code for various coercion lives in a module that helps identify what it is for |
||
pub mod binary; | ||
pub mod functions; | ||
pub mod other; |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this code into
Moved to
type_coercion/function.rsas it is related to function coercion, despite the
data_types` generic name