diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5b727250b5863..667fa4688d7ae 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -29,6 +29,7 @@ use std::iter::repeat; use std::str::FromStr; use std::sync::Arc; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, @@ -1168,6 +1169,13 @@ impl ScalarValue { /// Calculate arithmetic negation for a scalar value pub fn arithmetic_negate(&self) -> Result { + fn neg_checked_with_ctx( + v: T, + ctx: impl Fn() -> String, + ) -> Result { + v.neg_checked() + .map_err(|e| arrow_datafusion_err!(e).context(ctx())) + } match self { ScalarValue::Int8(None) | ScalarValue::Int16(None) @@ -1177,40 +1185,91 @@ impl ScalarValue { | ScalarValue::Float64(None) => Ok(self.clone()), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), - ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))), - ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))), - ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))), - ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))), - ScalarValue::IntervalYearMonth(Some(v)) => { - Ok(ScalarValue::IntervalYearMonth(Some(-v))) - } + ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), + ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))), + ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))), + ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))), + ScalarValue::IntervalYearMonth(Some(v)) => Ok( + ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || { + format!("In negation of IntervalYearMonth({v})") + })?)), + ), ScalarValue::IntervalDayTime(Some(v)) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); - let val = IntervalDayTimeType::make_value(-days, -ms); + let val = IntervalDayTimeType::make_value( + neg_checked_with_ctx(days, || { + format!("In negation of days {days} in IntervalDayTime") + })?, + neg_checked_with_ctx(ms, || { + format!("In negation of milliseconds {ms} in IntervalDayTime") + })?, + ); Ok(ScalarValue::IntervalDayTime(Some(val))) } ScalarValue::IntervalMonthDayNano(Some(v)) => { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos); + let val = IntervalMonthDayNanoType::make_value( + neg_checked_with_ctx(months, || { + format!("In negation of months {months} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(days, || { + format!("In negation of days {days} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(nanos, || { + format!("In negation of nanos {nanos} of IntervalMonthDayNano") + })?, + ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } ScalarValue::Decimal128(Some(v), precision, scale) => { - Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) + Ok(ScalarValue::Decimal128( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal128({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal256( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal256({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) } - ScalarValue::Decimal256(Some(v), precision, scale) => Ok( - ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), - ), ScalarValue::TimestampSecond(Some(v), tz) => { - Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampSecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampNanosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampNanosecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampNanoSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampMicrosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampMicrosecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMicroSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampMillisecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampMillisecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMilliSecond({v})") + })?), + tz.clone(), + )) } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" @@ -3501,6 +3560,7 @@ mod tests { use crate::assert_batches_eq; use arrow::buffer::OffsetBuffer; use arrow::compute::{is_null, kernels}; + use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; use arrow_buffer::Buffer; use arrow_schema::Fields; @@ -5494,6 +5554,89 @@ mod tests { Ok(()) } + #[test] + #[allow(arithmetic_overflow)] // we want to test them + fn test_scalar_negative_overflows() -> Result<()> { + macro_rules! test_overflow_on_value { + ($($val:expr),* $(,)?) => {$( + { + let value: ScalarValue = $val; + let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}"); + let root_err = err.find_root(); + match root_err{ + DataFusionError::ArrowError( + ArrowError::ComputeError(_), + _, + ) => {} + _ => return Err(err), + }; + } + )*}; + } + test_overflow_on_value!( + // the integers + i8::MIN.into(), + i16::MIN.into(), + i32::MIN.into(), + i64::MIN.into(), + // for decimals, only value needs to be tested + ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?, + ScalarValue::Decimal256(Some(i256::MIN), 20, 5), + // interval, check all possible values + ScalarValue::IntervalYearMonth(Some(i32::MIN)), + ScalarValue::new_interval_dt(i32::MIN, 999), + ScalarValue::new_interval_dt(1, i32::MIN), + ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456), + ScalarValue::new_interval_mdn(12, i32::MIN, 123_456), + ScalarValue::new_interval_mdn(12, 15, i64::MIN), + // tz doesn't matter when negating + ScalarValue::TimestampSecond(Some(i64::MIN), None), + ScalarValue::TimestampMillisecond(Some(i64::MIN), None), + ScalarValue::TimestampMicrosecond(Some(i64::MIN), None), + ScalarValue::TimestampNanosecond(Some(i64::MIN), None), + ); + + let float_cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + (f32::MIN.into(), f32::MAX.into()), + (f32::MAX.into(), f32::MIN.into()), + (f64::MIN.into(), f64::MAX.into()), + (f64::MAX.into(), f64::MIN.into()), + ]; + // skip float 16 because they aren't supported + for (test, expected) in float_cases.into_iter().skip(2) { + assert_eq!(test.arithmetic_negate()?, expected); + } + Ok(()) + } + + #[test] + #[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")] + fn f16_test_overflow() { + // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case + let cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + ]; + + for (test, expected) in cases { + assert_eq!(test.arithmetic_negate().unwrap(), expected); + } + } + macro_rules! expect_operation_error { ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { #[test]