Skip to content

Commit

Permalink
feat: support Decimal256 for the abs function (#7904)
Browse files Browse the repository at this point in the history
* feat: support Decimal256 for the abs function

* Remove useless comment

* use wrapping_abs
  • Loading branch information
jonahgao authored Oct 23, 2023
1 parent ae85a67 commit eee790f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
28 changes: 17 additions & 11 deletions datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
use arrow::array::ArrayRef;
use arrow::array::{
BooleanArray, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array,
BooleanArray, Decimal128Array, Decimal256Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array,
};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
Expand Down Expand Up @@ -701,6 +701,18 @@ macro_rules! make_try_abs_function {
}};
}

macro_rules! make_decimal_abs_function {
($ARRAY_TYPE:ident) => {{
|args: &[ArrayRef]| {
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
let res: $ARRAY_TYPE = array
.unary(|x| x.wrapping_abs())
.with_data_type(args[0].data_type().clone());
Ok(Arc::new(res) as ArrayRef)
}
}};
}

/// Abs SQL function
/// Return different implementations based on input datatype to reduce branches during execution
pub(super) fn create_abs_function(
Expand All @@ -723,15 +735,9 @@ pub(super) fn create_abs_function(
| DataType::UInt32
| DataType::UInt64 => Ok(|args: &[ArrayRef]| Ok(args[0].clone())),

// Decimal should keep the same precision and scale by using `with_data_type()`.
// https://github.com/apache/arrow-rs/issues/4644
DataType::Decimal128(_, _) => Ok(|args: &[ArrayRef]| {
let array = downcast_arg!(&args[0], "abs arg", Decimal128Array);
let res: Decimal128Array = array
.unary(i128::abs)
.with_data_type(args[0].data_type().clone());
Ok(Arc::new(res) as ArrayRef)
}),
// Decimal types
DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)),
DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)),

other => not_impl_err!("Unsupported data type {other:?} for function abs"),
}
Expand Down
63 changes: 41 additions & 22 deletions datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ NaN NaN

# abs: return type
query TT rowsort
SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_float limit 1
SELECT arrow_typeof(abs(c1)), arrow_typeof(abs(c2)) FROM test_nullable_float limit 1
----
Float32 Float64

Expand Down Expand Up @@ -466,34 +466,48 @@ drop table test_non_nullable_float

statement ok
CREATE TABLE test_nullable_decimal(
c1 DECIMAL(10, 2),
c2 DECIMAL(38, 10)
) AS VALUES (0, 0), (NULL, NULL);

query RR
c1 DECIMAL(10, 2), /* Decimal128 */
c2 DECIMAL(38, 10), /* Decimal128 with max precision */
c3 DECIMAL(40, 2), /* Decimal256 */
c4 DECIMAL(76, 10) /* Decimal256 with max precision */
) AS VALUES
(0, 0, 0, 0),
(NULL, NULL, NULL, NULL);

query RRRR
INSERT into test_nullable_decimal values
(-99999999.99, '-9999999999999999999999999999.9999999999'),
(99999999.99, '9999999999999999999999999999.9999999999');
(
-99999999.99,
'-9999999999999999999999999999.9999999999',
'-99999999999999999999999999999999999999.99',
'-999999999999999999999999999999999999999999999999999999999999999999.9999999999'
),
(
99999999.99,
'9999999999999999999999999999.9999999999',
'99999999999999999999999999999999999999.99',
'999999999999999999999999999999999999999999999999999999999999999999.9999999999'
)
----
2


query R rowsort
query R
SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NULL;
----
NULL

query R rowsort
query R
SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NULL;
----
NULL

query R rowsort
query R
SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NULL;
----
NULL

query R rowsort
query R
SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
----
0
Expand All @@ -507,19 +521,24 @@ query error DataFusion error: Arrow error: Divide by zero error
SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;

# abs: return type
query TT rowsort
SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_decimal limit 1
query TTTT
SELECT
arrow_typeof(abs(c1)),
arrow_typeof(abs(c2)),
arrow_typeof(abs(c3)),
arrow_typeof(abs(c4))
FROM test_nullable_decimal limit 1
----
Decimal128(10, 2) Decimal128(38, 10)
Decimal128(10, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10)

# abs: Decimal128
query RR rowsort
SELECT abs(c1), abs(c2) FROM test_nullable_decimal
# abs: decimals
query RRRR rowsort
SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal
----
0 0
99999999.99 9999999999999999999999999999.9999999999
99999999.99 9999999999999999999999999999.9999999999
NULL NULL
0 0 0 0
99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999
99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999
NULL NULL NULL NULL

statement ok
drop table test_nullable_decimal
Expand Down

0 comments on commit eee790f

Please sign in to comment.