From f11d7a0b14ab2c7ca3fefb923cd1aa9507db3dd9 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 25 Nov 2023 17:26:12 +0800 Subject: [PATCH 1/3] done Signed-off-by: jayzhan211 --- datafusion/common/src/utils.rs | 32 ++++++++ .../physical-expr/src/array_expressions.rs | 77 +++++++------------ datafusion/sqllogictest/test_files/array.slt | 37 ++++++++- 3 files changed, 93 insertions(+), 53 deletions(-) diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 12d4f516b4d0..7f2dc61c07bf 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -26,6 +26,7 @@ use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; use arrow_array::{Array, LargeListArray, ListArray}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -402,6 +403,37 @@ pub fn arrays_into_list_array( )) } +/// Get the base type of a data type. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::base_type; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// +/// let data_type = DataType::Int32; +/// assert_eq!(base_type(&data_type), DataType::Int32); +/// ``` +pub fn base_type(data_type: &DataType) -> DataType { + if let DataType::List(field) = data_type { + base_type(field.data_type()) + } else { + data_type.to_owned() + } +} + +/// Compute the number of dimensions in a list data type. +pub fn list_ndims(data_type: &DataType) -> u64 { + if let DataType::List(field) = data_type { + 1 + list_ndims(field.data_type()) + } else { + 0 + } +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 8968bcf2ea4e..0d799af5c259 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -32,7 +32,7 @@ use arrow_schema::FieldRef; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_list_array, as_string_array, }; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::{array_into_list_array, list_ndims}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, @@ -102,6 +102,7 @@ fn compare_element_to_list( ) -> Result { let indices = UInt32Array::from(vec![row_index as u32]); let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` let res = match element_array_row.data_type() { @@ -175,35 +176,6 @@ fn compute_array_length( } } -/// Returns the dimension of the array -fn compute_array_ndims(arr: Option) -> Result> { - Ok(compute_array_ndims_with_datatype(arr)?.0) -} - -/// Returns the dimension and the datatype of elements of the array -fn compute_array_ndims_with_datatype( - arr: Option, -) -> Result<(Option, DataType)> { - let mut res: u64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok((None, DataType::Null)), - }; - if value.is_empty() { - return Ok((None, DataType::Null)); - } - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res += 1; - } - data_type => return Ok((Some(res), data_type.clone())), - } - } -} - /// Returns the length of each array dimension fn compute_array_dims(arr: Option) -> Result>>> { let mut value = match arr { @@ -818,10 +790,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { fn align_array_dimensions(args: Vec) -> Result> { let args_ndim = args .iter() - .map(|arg| compute_array_ndims(Some(arg.to_owned()))) - .collect::>>()? - .into_iter() - .map(|x| x.unwrap_or(0)) + .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) .collect::>(); let max_ndim = args_ndim.iter().max().unwrap_or(&0); @@ -912,6 +881,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), ); + Ok(Arc::new(list_arr)) } @@ -919,11 +889,11 @@ fn concat_internal(args: &[ArrayRef]) -> Result { pub fn array_concat(args: &[ArrayRef]) -> Result { let mut new_args = vec![]; for arg in args { - let (ndim, lower_data_type) = - compute_array_ndims_with_datatype(Some(arg.clone()))?; - if ndim.is_none() || ndim == Some(1) { - return not_impl_err!("Array is not type '{lower_data_type:?}'."); - } else if !lower_data_type.equals_datatype(&DataType::Null) { + let ndim = list_ndims(arg.data_type()); + let base_type = datafusion_common::utils::base_type(arg.data_type()); + if ndim == 0 { + return not_impl_err!("Array is not type '{base_type:?}'."); + } else if !base_type.eq(&DataType::Null) { new_args.push(arg.clone()); } } @@ -1748,14 +1718,23 @@ pub fn array_dims(args: &[ArrayRef]) -> Result { /// Array_ndims SQL function pub fn array_ndims(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; + if let Some(list_array) = args[0].as_list_opt::() { + let ndims = datafusion_common::utils::list_ndims(list_array.data_type()); - let result = list_array - .iter() - .map(compute_array_ndims) - .collect::>()?; + let mut data = vec![]; + for arr in list_array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) + } + } - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + } else { + println!("args: {:?}", args); + Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) + } } /// Array_has SQL function @@ -2017,10 +1996,10 @@ mod tests { .unwrap(); let expected = as_list_array(&array2d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); @@ -2030,10 +2009,10 @@ mod tests { align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); - let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap(); + let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); assert_ne!(as_list_array(&res[0]).unwrap(), expected); assert_eq!( - compute_array_ndims(Some(res[0].clone())).unwrap(), + datafusion_common::utils::list_ndims(res[0].data_type()), expected_dim ); } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index d33555509e6c..143e61c789c4 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2473,10 +2473,39 @@ NULL [3] [4] ## array_ndims (aliases: `list_ndims`) # array_ndims scalar function #1 + query III -select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); +select + array_ndims(1), + array_ndims(null), + array_ndims([2, 3]); ---- -1 2 5 +0 0 1 + +statement ok +CREATE TABLE array_ndims_table +AS VALUES + (1, [1, 2, 3], [[7]], [[[[[10]]]]]), + (2, [4, 5], [[8]], [[[[[10]]]]]), + (null, [6], [[9]], [[[[[10]]]]]), + (3, [6], [[9]], [[[[[10]]]]]) +; + +query IIII +select + array_ndims(column1), + array_ndims(column2), + array_ndims(column3), + array_ndims(column4) +from array_ndims_table; +---- +0 1 2 5 +0 1 2 5 +0 1 2 5 +0 1 2 5 + +statement ok +drop table array_ndims_table; # array_ndims scalar function #2 query II @@ -2488,7 +2517,7 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 # list_ndims scalar function #4 (function alias `array_ndims`) query III @@ -2499,7 +2528,7 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])), query II select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- -NULL 2 +1 2 # array_ndims with columns query III From cd1e4a757729e3bb0d7d9c7f4d024fac2c54ddc4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 25 Nov 2023 17:30:04 +0800 Subject: [PATCH 2/3] add more test Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/array.slt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 143e61c789c4..2608becdce71 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2507,6 +2507,11 @@ from array_ndims_table; statement ok drop table array_ndims_table; +query I +select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); +---- +3 + # array_ndims scalar function #2 query II select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); From 22b2a509abff63243e09ed14e4263a6c4d8565ab Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 25 Nov 2023 17:32:37 +0800 Subject: [PATCH 3/3] cleanup Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/array_expressions.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 0d799af5c259..a373b25c960a 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1732,7 +1732,6 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) } else { - println!("args: {:?}", args); Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef) } }