Skip to content

Commit

Permalink
Use LogicalType for TypeSignature Numeric and String, Coercible (
Browse files Browse the repository at this point in the history
…#13240)

* use logical type for signature

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt & clippy

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* numeric

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix numeric

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* deprecate coercible

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* introduce numeric and numeric string

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix doc

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add back coercible

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rename

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm numeric string signature

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* typo

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* improve doc and err msg

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Nov 6, 2024
1 parent 345117b commit 6686e03
Show file tree
Hide file tree
Showing 21 changed files with 199 additions and 117 deletions.
6 changes: 6 additions & 0 deletions datafusion/common/src/types/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl fmt::Debug for dyn LogicalType {
}
}

impl std::fmt::Display for dyn LogicalType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}

impl PartialEq for dyn LogicalType {
fn eq(&self, other: &Self) -> bool {
self.signature().eq(&other.signature())
Expand Down
45 changes: 41 additions & 4 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::compute::can_cast_types;
use arrow_schema::{
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
};
use std::sync::Arc;
use std::{fmt::Display, sync::Arc};

/// Representation of a type that DataFusion can handle natively. It is a subset
/// of the physical variants in Arrow's native [`DataType`].
Expand Down Expand Up @@ -183,6 +183,12 @@ pub enum NativeType {
Map(LogicalFieldRef),
}

impl Display for NativeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NativeType::{self:?}")
}
}

impl LogicalType for NativeType {
fn native(&self) -> &NativeType {
self
Expand Down Expand Up @@ -348,6 +354,12 @@ impl LogicalType for NativeType {
// mapping solutions to provide backwards compatibility while transitioning from
// the purely physical system to a logical / physical system.

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
}
}

impl From<DataType> for NativeType {
fn from(value: DataType) -> Self {
use NativeType::*;
Expand Down Expand Up @@ -392,8 +404,33 @@ impl From<DataType> for NativeType {
}
}

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
impl NativeType {
#[inline]
pub fn is_numeric(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal(_, _)
)
}

#[inline]
pub fn is_integer(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
)
}
}
16 changes: 11 additions & 5 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! and return types of functions in DataFusion.
use arrow::datatypes::DataType;
use datafusion_common::types::LogicalTypeRef;

/// Constant that is used as a placeholder for any valid timezone.
/// This is used where a function can accept a timestamp type with any
Expand Down Expand Up @@ -106,10 +107,10 @@ pub enum TypeSignature {
/// Exact number of arguments of an exact type
Exact(Vec<DataType>),
/// The number of arguments that can be coerced to in order
/// For example, `Coercible(vec![DataType::Float64])` accepts
/// For example, `Coercible(vec![logical_float64()])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
Coercible(Vec<DataType>),
Coercible(Vec<LogicalTypeRef>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
Expand All @@ -123,7 +124,9 @@ pub enum TypeSignature {
/// Specifies Signatures for array functions
ArraySignature(ArrayFunctionSignature),
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
/// See [`NativeType::is_numeric`] to know which type is considered numeric
///
/// [`NativeType::is_numeric`]: datafusion_common
Numeric(usize),
/// Fixed number of arguments of all the same string types.
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
Expand Down Expand Up @@ -201,7 +204,10 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({num})")]
}
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
Expand Down Expand Up @@ -322,7 +328,7 @@ impl Signature {
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
Expand Down
149 changes: 94 additions & 55 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::{LogicalType, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
Expand Down Expand Up @@ -395,40 +396,56 @@ fn get_valid_types(
}
}

fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
if length < 1 {
return plan_err!(
"The signature expected at least one argument but received {expected_length}"
);
}

if length != expected_length {
return plan_err!(
"The signature expected {length} arguments but received {expected_length}"
);
}

Ok(())
}

let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
function_length_check(current_types.len(), *number)?;

let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();
if logical_data_type == NativeType::String {
new_types.push(data_type.to_owned());
} else if logical_data_type == NativeType::Null {
// TODO: Switch to Utf8View if all the string functions supports Utf8View
new_types.push(DataType::Utf8);
} else {
return plan_err!(
"The signature expected NativeType::String but received {logical_data_type}"
);
}
}

fn coercion_rule(
// Find the common string type for the given types
fn find_common_type(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
(DataType::Null, data_type) | (data_type, DataType::Null) => {
coercion_rule(data_type, &DataType::Utf8)
}
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
coercion_rule(lhs, rhs)
find_common_type(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
| (other, DataType::Dictionary(_, v)) => find_common_type(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
Expand All @@ -444,15 +461,13 @@ fn get_valid_types(
}

// Length checked above, safe to unwrap
let mut coerced_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
coerced_type = coercion_rule(&coerced_type, t)?;
let mut coerced_type = new_types.first().unwrap().to_owned();
for t in new_types.iter().skip(1) {
coerced_type = find_common_type(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if data_type.is_null() {
DataType::Utf8
} else if let DataType::Dictionary(_, v) = data_type {
if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
Expand All @@ -462,22 +477,22 @@ fn get_valid_types(
vec![vec![base_type_or_default_type(&coerced_type); *number]]
}
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}
function_length_check(current_types.len(), *number)?;

let mut valid_type = current_types.first().unwrap().clone();
// Find common numeric type amongs given types except string
let mut valid_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
let logical_data_type: NativeType = t.into();
if logical_data_type == NativeType::Null {
continue;
}

if !logical_data_type.is_numeric() {
return plan_err!(
"The signature expected NativeType::Numeric but received {logical_data_type}"
);
}

if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
valid_type = coerced_type;
} else {
Expand All @@ -489,31 +504,55 @@ fn get_valid_types(
}
}

let logical_data_type: NativeType = valid_type.clone().into();
// Fallback to default type if we don't know which type to coerced to
// f64 is chosen since most of the math functions utilize Signature::numeric,
// and their default type is double precision
if logical_data_type == NativeType::Null {
valid_type = DataType::Float64;
}

vec![vec![valid_type; *number]]
}
TypeSignature::Coercible(target_types) => {
if target_types.is_empty() {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if target_types.len() != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
target_types.len(),
current_types.len()
);
function_length_check(current_types.len(), target_types.len())?;

// Aim to keep this logic as SIMPLE as possible!
// Make sure the corresponding test is covered
// If this function becomes COMPLEX, create another new signature!
fn can_coerce_to(
logical_type: &NativeType,
target_type: &NativeType,
) -> bool {
if logical_type == target_type {
return true;
}

if logical_type == &NativeType::Null {
return true;
}

if target_type.is_integer() && logical_type.is_integer() {
return true;
}

false
}

for (data_type, target_type) in current_types.iter().zip(target_types.iter())
let mut new_types = Vec::with_capacity(current_types.len());
for (current_type, target_type) in
current_types.iter().zip(target_types.iter())
{
if !can_cast_types(data_type, target_type) {
return plan_err!("{data_type} is not coercible to {target_type}");
let logical_type: NativeType = current_type.into();
let target_logical_type = target_type.native();
if can_coerce_to(&logical_type, target_logical_type) {
let target_type =
target_logical_type.default_cast_for(current_type)?;
new_types.push(target_type);
}
}

vec![target_types.to_owned()]
vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
Expand Down
24 changes: 4 additions & 20 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr,
ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility,
Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature,
SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::utils::get_sort_options;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
Expand Down Expand Up @@ -79,15 +79,7 @@ impl Default for FirstValue {
impl FirstValue {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Numeric(1),
TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
signature: Signature::any(1, Volatility::Immutable),
requirement_satisfied: false,
}
}
Expand Down Expand Up @@ -406,15 +398,7 @@ impl Default for LastValue {
impl LastValue {
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Numeric(1),
TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
signature: Signature::any(1, Volatility::Immutable),
requirement_satisfied: false,
}
}
Expand Down
Loading

0 comments on commit 6686e03

Please sign in to comment.