From f6f63b97cda46b5edabd95466cb4aae05a13a889 Mon Sep 17 00:00:00 2001 From: Lorrens Pantelis <100197010+LorrensP-2158466@users.noreply.github.com> Date: Wed, 26 Jun 2024 19:39:59 +0200 Subject: [PATCH] Fix overflow in factorial (#11134) * fix: Return overflow error instead of panicking * refactor: change the way we do checked multiplication and the error returned * cleaner order of operations on factorial results * test: add test for factorial overflow --- datafusion/functions/src/math/factorial.rs | 44 +++++++++++---------- datafusion/sqllogictest/test_files/math.slt | 4 ++ 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index dc481da79069..74ad2c738a93 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, Int64Array}; +use arrow::{ + array::{ArrayRef, Int64Array}, + error::ArrowError, +}; use std::any::Any; use std::sync::Arc; @@ -23,7 +26,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -67,28 +70,27 @@ impl ScalarUDFImpl for FactorialFunc { } } -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Int64Array, - { |value: i64| { (1..=value).product() } } - )) as ArrayRef), + DataType::Int64 => { + let arg = downcast_arg!((&args[0]), "value", Int64Array); + Ok(arg + .iter() + .map(|a| match a { + Some(a) => (2..=a) + .try_fold(1i64, i64::checked_mul) + .ok_or_else(|| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Overflow happened on FACTORIAL({a})" + ))) + }) + .map(Some), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function factorial."), } } diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index a2ce3834a87a..19de6560c26f 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -592,3 +592,7 @@ select lcm(-9223372036854775808, -9223372036854775808); query error DataFusion error: Arrow error: Compute error: Overflow happened on: 2107754225 \^ 1221660777 select power(2107754225, 1221660777); + +# factorial overflow +query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\) +select FACTORIAL(350943270);