diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 4f32f4c17776..cee1a57bc6d9 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -17,11 +17,10 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, + OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::exec_err; use datafusion_common::Result; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -71,17 +70,7 @@ impl ScalarUDFImpl for CharacterLengthFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(character_length::, vec![])(args) - } - DataType::LargeUtf8 => { - make_scalar_function(character_length::, vec![])(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function character_length") - } - } + make_scalar_function(character_length, vec![])(args) } fn aliases(&self) -> &[String] { @@ -92,15 +81,32 @@ impl ScalarUDFImpl for CharacterLengthFunc { /// Returns number of characters in the string. /// character_length('josé') = 4 /// The implementation counts UTF-8 code points to count the number of characters -fn character_length(args: &[ArrayRef]) -> Result +fn character_length(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + character_length_general::(string_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + character_length_general::(string_array) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + character_length_general::(string_array) + } + _ => unreachable!(), + } +} + +fn character_length_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor>( + array: V, +) -> Result where T::Native: OffsetSizeTrait, { - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() + let iter = ArrayIter::new(array); + let result = iter .map(|string| { string.map(|string: &str| { T::Native::from_usize(string.chars().count()) @@ -116,55 +122,54 @@ where mod tests { use crate::unicode::character_length::CharacterLengthFunc; use crate::utils::test::test_function; - use arrow::array::{Array, Int32Array}; - use arrow::datatypes::DataType::Int32; + use arrow::array::{Array, Int32Array, Int64Array}; + use arrow::datatypes::DataType::{Int32, Int64}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + macro_rules! test_character_length { + ($INPUT:expr, $EXPECTED:expr) => { + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i64, + Int64, + Int64Array + ); + + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + #[test] fn test_functions() -> Result<()> { #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("chars") - )))], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("josé") - )))], - Ok(Some(4)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("") - )))], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); + { + test_character_length!(Some(String::from("chars")), Ok(Some(5))); + test_character_length!(Some(String::from("josé")), Ok(Some(4))); + // test long strings (more than 12 bytes for StringView) + test_character_length!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16))); + test_character_length!(Some(String::from("")), Ok(Some(0))); + test_character_length!(None, Ok(None)); + } + #[cfg(not(feature = "unicode_expressions"))] test_function!( CharacterLengthFunc::new(),