diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index a9dc0bd58f15..96f611e2b7b4 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -19,8 +19,8 @@ use arrow::array::ArrayRef; use arrow::array::{ - BooleanArray, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, + BooleanArray, Decimal128Array, Decimal256Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; @@ -701,6 +701,18 @@ macro_rules! make_try_abs_function { }}; } +macro_rules! make_decimal_abs_function { + ($ARRAY_TYPE:ident) => {{ + |args: &[ArrayRef]| { + let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE); + let res: $ARRAY_TYPE = array + .unary(|x| x.wrapping_abs()) + .with_data_type(args[0].data_type().clone()); + Ok(Arc::new(res) as ArrayRef) + } + }}; +} + /// Abs SQL function /// Return different implementations based on input datatype to reduce branches during execution pub(super) fn create_abs_function( @@ -723,15 +735,9 @@ pub(super) fn create_abs_function( | DataType::UInt32 | DataType::UInt64 => Ok(|args: &[ArrayRef]| Ok(args[0].clone())), - // Decimal should keep the same precision and scale by using `with_data_type()`. - // https://github.com/apache/arrow-rs/issues/4644 - DataType::Decimal128(_, _) => Ok(|args: &[ArrayRef]| { - let array = downcast_arg!(&args[0], "abs arg", Decimal128Array); - let res: Decimal128Array = array - .unary(i128::abs) - .with_data_type(args[0].data_type().clone()); - Ok(Arc::new(res) as ArrayRef) - }), + // Decimal types + DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), + DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), other => not_impl_err!("Unsupported data type {other:?} for function abs"), } diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index a3ee307f4940..ee1e345f946a 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -395,7 +395,7 @@ NaN NaN # abs: return type query TT rowsort -SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_float limit 1 +SELECT arrow_typeof(abs(c1)), arrow_typeof(abs(c2)) FROM test_nullable_float limit 1 ---- Float32 Float64 @@ -466,34 +466,48 @@ drop table test_non_nullable_float statement ok CREATE TABLE test_nullable_decimal( - c1 DECIMAL(10, 2), - c2 DECIMAL(38, 10) - ) AS VALUES (0, 0), (NULL, NULL); - -query RR + c1 DECIMAL(10, 2), /* Decimal128 */ + c2 DECIMAL(38, 10), /* Decimal128 with max precision */ + c3 DECIMAL(40, 2), /* Decimal256 */ + c4 DECIMAL(76, 10) /* Decimal256 with max precision */ + ) AS VALUES + (0, 0, 0, 0), + (NULL, NULL, NULL, NULL); + +query RRRR INSERT into test_nullable_decimal values - (-99999999.99, '-9999999999999999999999999999.9999999999'), - (99999999.99, '9999999999999999999999999999.9999999999'); + ( + -99999999.99, + '-9999999999999999999999999999.9999999999', + '-99999999999999999999999999999999999999.99', + '-999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ), + ( + 99999999.99, + '9999999999999999999999999999.9999999999', + '99999999999999999999999999999999999999.99', + '999999999999999999999999999999999999999999999999999999999999999999.9999999999' + ) ---- 2 -query R rowsort +query R SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NULL; ---- NULL -query R rowsort +query R SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NULL; ---- NULL -query R rowsort +query R SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NULL; ---- NULL -query R rowsort +query R SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; ---- 0 @@ -507,19 +521,24 @@ query error DataFusion error: Arrow error: Divide by zero error SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL; # abs: return type -query TT rowsort -SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_decimal limit 1 +query TTTT +SELECT + arrow_typeof(abs(c1)), + arrow_typeof(abs(c2)), + arrow_typeof(abs(c3)), + arrow_typeof(abs(c4)) +FROM test_nullable_decimal limit 1 ---- -Decimal128(10, 2) Decimal128(38, 10) +Decimal128(10, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10) -# abs: Decimal128 -query RR rowsort -SELECT abs(c1), abs(c2) FROM test_nullable_decimal +# abs: decimals +query RRRR rowsort +SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal ---- -0 0 -99999999.99 9999999999999999999999999999.9999999999 -99999999.99 9999999999999999999999999999.9999999999 -NULL NULL +0 0 0 0 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999 +NULL NULL NULL NULL statement ok drop table test_nullable_decimal