Skip to content

Commit

Permalink
Overflow in negate operator (apache#11084)
Browse files Browse the repository at this point in the history
* Do checked negative op instead of unchecked

* add tests for checking if overflow error occurs

* add context to negating complexer ScalarValues

* put format! call to create error message in closure

* seperate test case for f16 that should panic with not implemented
  • Loading branch information
LorrensP-2158466 authored and findepi committed Jul 16, 2024
1 parent 027ded3 commit b65fc50
Showing 1 changed file with 160 additions and 17 deletions.
177 changes: 160 additions & 17 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;

use crate::arrow_datafusion_err;
use crate::cast::{
as_decimal128_array, as_decimal256_array, as_dictionary_array,
as_fixed_size_binary_array, as_fixed_size_list_array,
Expand Down Expand Up @@ -1168,6 +1169,13 @@ impl ScalarValue {

/// Calculate arithmetic negation for a scalar value
pub fn arithmetic_negate(&self) -> Result<Self> {
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
v: T,
ctx: impl Fn() -> String,
) -> Result<T> {
v.neg_checked()
.map_err(|e| arrow_datafusion_err!(e).context(ctx()))
}
match self {
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
Expand All @@ -1177,40 +1185,91 @@ impl ScalarValue {
| ScalarValue::Float64(None) => Ok(self.clone()),
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))),
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))),
ScalarValue::IntervalYearMonth(Some(v)) => {
Ok(ScalarValue::IntervalYearMonth(Some(-v)))
}
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))),
ScalarValue::IntervalYearMonth(Some(v)) => Ok(
ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || {
format!("In negation of IntervalYearMonth({v})")
})?)),
),
ScalarValue::IntervalDayTime(Some(v)) => {
let (days, ms) = IntervalDayTimeType::to_parts(*v);
let val = IntervalDayTimeType::make_value(-days, -ms);
let val = IntervalDayTimeType::make_value(
neg_checked_with_ctx(days, || {
format!("In negation of days {days} in IntervalDayTime")
})?,
neg_checked_with_ctx(ms, || {
format!("In negation of milliseconds {ms} in IntervalDayTime")
})?,
);
Ok(ScalarValue::IntervalDayTime(Some(val)))
}
ScalarValue::IntervalMonthDayNano(Some(v)) => {
let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v);
let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos);
let val = IntervalMonthDayNanoType::make_value(
neg_checked_with_ctx(months, || {
format!("In negation of months {months} of IntervalMonthDayNano")
})?,
neg_checked_with_ctx(days, || {
format!("In negation of days {days} of IntervalMonthDayNano")
})?,
neg_checked_with_ctx(nanos, || {
format!("In negation of nanos {nanos} of IntervalMonthDayNano")
})?,
);
Ok(ScalarValue::IntervalMonthDayNano(Some(val)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
Ok(ScalarValue::Decimal128(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of Decimal128({v}, {precision}, {scale})")
})?),
*precision,
*scale,
))
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal256(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of Decimal256({v}, {precision}, {scale})")
})?),
*precision,
*scale,
))
}
ScalarValue::Decimal256(Some(v), precision, scale) => Ok(
ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale),
),
ScalarValue::TimestampSecond(Some(v), tz) => {
Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampSecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampNanosecond(Some(v), tz) => {
Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampNanosecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampNanoSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampMicrosecond(Some(v), tz) => {
Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampMicrosecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampMicroSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampMillisecond(Some(v), tz) => {
Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampMillisecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampMilliSecond({v})")
})?),
tz.clone(),
))
}
value => _internal_err!(
"Can not run arithmetic negative on scalar value {value:?}"
Expand Down Expand Up @@ -3501,6 +3560,7 @@ mod tests {
use crate::assert_batches_eq;
use arrow::buffer::OffsetBuffer;
use arrow::compute::{is_null, kernels};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use arrow_buffer::Buffer;
use arrow_schema::Fields;
Expand Down Expand Up @@ -5494,6 +5554,89 @@ mod tests {
Ok(())
}

#[test]
#[allow(arithmetic_overflow)] // we want to test them
fn test_scalar_negative_overflows() -> Result<()> {
macro_rules! test_overflow_on_value {
($($val:expr),* $(,)?) => {$(
{
let value: ScalarValue = $val;
let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}");
let root_err = err.find_root();
match root_err{
DataFusionError::ArrowError(
ArrowError::ComputeError(_),
_,
) => {}
_ => return Err(err),
};
}
)*};
}
test_overflow_on_value!(
// the integers
i8::MIN.into(),
i16::MIN.into(),
i32::MIN.into(),
i64::MIN.into(),
// for decimals, only value needs to be tested
ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?,
ScalarValue::Decimal256(Some(i256::MIN), 20, 5),
// interval, check all possible values
ScalarValue::IntervalYearMonth(Some(i32::MIN)),
ScalarValue::new_interval_dt(i32::MIN, 999),
ScalarValue::new_interval_dt(1, i32::MIN),
ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456),
ScalarValue::new_interval_mdn(12, i32::MIN, 123_456),
ScalarValue::new_interval_mdn(12, 15, i64::MIN),
// tz doesn't matter when negating
ScalarValue::TimestampSecond(Some(i64::MIN), None),
ScalarValue::TimestampMillisecond(Some(i64::MIN), None),
ScalarValue::TimestampMicrosecond(Some(i64::MIN), None),
ScalarValue::TimestampNanosecond(Some(i64::MIN), None),
);

let float_cases = [
(
ScalarValue::Float16(Some(f16::MIN)),
ScalarValue::Float16(Some(f16::MAX)),
),
(
ScalarValue::Float16(Some(f16::MAX)),
ScalarValue::Float16(Some(f16::MIN)),
),
(f32::MIN.into(), f32::MAX.into()),
(f32::MAX.into(), f32::MIN.into()),
(f64::MIN.into(), f64::MAX.into()),
(f64::MAX.into(), f64::MIN.into()),
];
// skip float 16 because they aren't supported
for (test, expected) in float_cases.into_iter().skip(2) {
assert_eq!(test.arithmetic_negate()?, expected);
}
Ok(())
}

#[test]
#[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")]
fn f16_test_overflow() {
// TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
let cases = [
(
ScalarValue::Float16(Some(f16::MIN)),
ScalarValue::Float16(Some(f16::MAX)),
),
(
ScalarValue::Float16(Some(f16::MAX)),
ScalarValue::Float16(Some(f16::MIN)),
),
];

for (test, expected) in cases {
assert_eq!(test.arithmetic_negate().unwrap(), expected);
}
}

macro_rules! expect_operation_error {
($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => {
#[test]
Expand Down

0 comments on commit b65fc50

Please sign in to comment.