Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite array_ndims to fix List(Null) handling #8320

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, LargeListArray, ListArray};
use arrow_schema::DataType;
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -402,6 +403,37 @@ pub fn arrays_into_list_array(
))
}

/// Get the base type of a data type.
///
/// Example
/// ```
/// use arrow::datatypes::{DataType, Field};
/// use datafusion_common::utils::base_type;
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
/// assert_eq!(base_type(&data_type), DataType::Int32);
///
/// let data_type = DataType::Int32;
/// assert_eq!(base_type(&data_type), DataType::Int32);
/// ```
pub fn base_type(data_type: &DataType) -> DataType {
if let DataType::List(field) = data_type {
base_type(field.data_type())
} else {
data_type.to_owned()
}
}

/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
if let DataType::List(field) = data_type {
1 + list_ndims(field.data_type())
} else {
0
}
}

/// An extension trait for smart pointers. Provides an interface to get a
/// raw pointer to the data (with metadata stripped away).
///
Expand Down
76 changes: 27 additions & 49 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_common::cast::{
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
as_null_array, as_string_array,
};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::utils::{array_into_list_array, list_ndims};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
DataFusionError, Result,
Expand Down Expand Up @@ -103,6 +103,7 @@ fn compare_element_to_list(
) -> Result<BooleanArray> {
let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row = arrow::compute::take(element_array, &indices, None)?;

// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let res = match element_array_row.data_type() {
Expand Down Expand Up @@ -176,35 +177,6 @@ fn compute_array_length(
}
}

/// Returns the dimension of the array
fn compute_array_ndims(arr: Option<ArrayRef>) -> Result<Option<u64>> {
Ok(compute_array_ndims_with_datatype(arr)?.0)
}

/// Returns the dimension and the datatype of elements of the array
fn compute_array_ndims_with_datatype(
arr: Option<ArrayRef>,
) -> Result<(Option<u64>, DataType)> {
let mut res: u64 = 1;
let mut value = match arr {
Some(arr) => arr,
None => return Ok((None, DataType::Null)),
};
if value.is_empty() {
return Ok((None, DataType::Null));
}

loop {
match value.data_type() {
DataType::List(..) => {
value = downcast_arg!(value, ListArray).value(0);
res += 1;
}
data_type => return Ok((Some(res), data_type.clone())),
}
}
}

/// Returns the length of each array dimension
fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>> {
let mut value = match arr {
Expand Down Expand Up @@ -825,10 +797,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
let args_ndim = args
.iter()
.map(|arg| compute_array_ndims(Some(arg.to_owned())))
.collect::<Result<Vec<_>>>()?
.into_iter()
.map(|x| x.unwrap_or(0))
.map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
.collect::<Vec<_>>();
let max_ndim = args_ndim.iter().max().unwrap_or(&0);

Expand Down Expand Up @@ -919,18 +888,19 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
Arc::new(compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);

Ok(Arc::new(list_arr))
}

/// Array_concat/Array_cat SQL function
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut new_args = vec![];
for arg in args {
let (ndim, lower_data_type) =
compute_array_ndims_with_datatype(Some(arg.clone()))?;
if ndim.is_none() || ndim == Some(1) {
return not_impl_err!("Array is not type '{lower_data_type:?}'.");
} else if !lower_data_type.equals_datatype(&DataType::Null) {
let ndim = list_ndims(arg.data_type());
let base_type = datafusion_common::utils::base_type(arg.data_type());
if ndim == 0 {
return not_impl_err!("Array is not type '{base_type:?}'.");
} else if !base_type.eq(&DataType::Null) {
new_args.push(arg.clone());
}
}
Expand Down Expand Up @@ -1765,14 +1735,22 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_ndims SQL function
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
if let Some(list_array) = args[0].as_list_opt::<i32>() {
let ndims = datafusion_common::utils::list_ndims(list_array.data_type());

let result = list_array
.iter()
.map(compute_array_ndims)
.collect::<Result<UInt64Array>>()?;
let mut data = vec![];
for arr in list_array.iter() {
if arr.is_some() {
data.push(Some(ndims))
} else {
data.push(None)
}
}

Ok(Arc::new(result) as ArrayRef)
Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
} else {
Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef)
}
}

/// Array_has SQL function
Expand Down Expand Up @@ -2034,10 +2012,10 @@ mod tests {
.unwrap();

let expected = as_list_array(&array2d_1).unwrap();
let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap();
let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
assert_eq!(
compute_array_ndims(Some(res[0].clone())).unwrap(),
datafusion_common::utils::list_ndims(res[0].data_type()),
expected_dim
);

Expand All @@ -2047,10 +2025,10 @@ mod tests {
align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap();

let expected = as_list_array(&array3d_1).unwrap();
let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap();
let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
assert_eq!(
compute_array_ndims(Some(res[0].clone())).unwrap(),
datafusion_common::utils::list_ndims(res[0].data_type()),
expected_dim
);
}
Expand Down
42 changes: 38 additions & 4 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2479,10 +2479,44 @@ NULL [3] [4]
## array_ndims (aliases: `list_ndims`)

# array_ndims scalar function #1

query III
select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]]));
select
array_ndims(1),
array_ndims(null),
array_ndims([2, 3]);
----
1 2 5
0 0 1

statement ok
CREATE TABLE array_ndims_table
AS VALUES
(1, [1, 2, 3], [[7]], [[[[[10]]]]]),
(2, [4, 5], [[8]], [[[[[10]]]]]),
(null, [6], [[9]], [[[[[10]]]]]),
(3, [6], [[9]], [[[[[10]]]]])
;

query IIII
select
array_ndims(column1),
array_ndims(column2),
array_ndims(column3),
array_ndims(column4)
from array_ndims_table;
----
0 1 2 5
0 1 2 5
0 1 2 5
0 1 2 5

statement ok
drop table array_ndims_table;

query I
select array_ndims(arrow_cast([null], 'List(List(List(Int64)))'));
----
3

# array_ndims scalar function #2
query II
Expand All @@ -2494,7 +2528,7 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_
query II
select array_ndims(make_array()), array_ndims(make_array(make_array()))
----
NULL 2
1 2

# list_ndims scalar function #4 (function alias `array_ndims`)
query III
Expand All @@ -2505,7 +2539,7 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])),
query II
select array_ndims(make_array()), array_ndims(make_array(make_array()))
----
NULL 2
1 2

# array_ndims with columns
query III
Expand Down