diff --git a/common/datavalues/src/array_value.rs b/common/datavalues/src/array_value.rs index 037d69caeafa8..0168bf9b02590 100644 --- a/common/datavalues/src/array_value.rs +++ b/common/datavalues/src/array_value.rs @@ -31,6 +31,13 @@ impl ArrayValue { pub fn new(values: Vec) -> Self { Self { values } } + + pub fn inner_type(&self) -> Option { + if let Some(value) = self.values.get(0) { + return Some(value.max_data_type()); + } + None + } } impl From for ArrayValue { diff --git a/common/datavalues/src/columns/series.rs b/common/datavalues/src/columns/series.rs index 28961ae948841..180a8ca9d83b8 100644 --- a/common/datavalues/src/columns/series.rs +++ b/common/datavalues/src/columns/series.rs @@ -221,6 +221,22 @@ impl SeriesFrom>, Vec>> for Series } } +impl SeriesFrom, Vec> for Series { + fn from_data(vals: Vec) -> ColumnRef { + let inner_data_type = match vals.iter().find(|&x| x.inner_type().is_some()) { + Some(array_value) => array_value.inner_type().unwrap(), + None => Int64Type::new_impl(), + }; + let mut builder = MutableArrayColumn::with_capacity_meta(vals.len(), ColumnMeta::Array { + data_type: inner_data_type, + }); + for val in vals { + builder.append_value(val); + } + builder.finish().arc() + } +} + macro_rules! impl_from_option_iterator { ([], $( { $S: ident} ),*) => { $( diff --git a/common/datavalues/src/macros.rs b/common/datavalues/src/macros.rs index 369d0ff6943fa..c5db6d58e25f9 100644 --- a/common/datavalues/src/macros.rs +++ b/common/datavalues/src/macros.rs @@ -29,6 +29,7 @@ macro_rules! for_all_scalar_types { { f64 }, { bool }, { Vu8 }, + { ArrayValue }, { VariantValue } } }; diff --git a/common/functions/src/scalars/semi_structureds/array_get.rs b/common/functions/src/scalars/semi_structureds/array_get.rs new file mode 100644 index 0000000000000..6a665bac7e92a --- /dev/null +++ b/common/functions/src/scalars/semi_structureds/array_get.rs @@ -0,0 +1,148 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +use common_datavalues::prelude::*; +use common_datavalues::with_match_scalar_types_error; +use common_exception::ErrorCode; +use common_exception::Result; + +use crate::scalars::semi_structureds::get::build_path_keys; +use crate::scalars::semi_structureds::get::parse_path_keys; +use crate::scalars::Function; +use crate::scalars::FunctionContext; +use crate::scalars::FunctionDescription; +use crate::scalars::FunctionFeatures; + +pub type ArrayGetFunction = ArrayGetFunctionImpl; + +pub type ArrayGetPathFunction = ArrayGetFunctionImpl; + +#[derive(Clone)] +pub struct ArrayGetFunctionImpl { + array_type: ArrayType, + display_name: String, +} + +impl ArrayGetFunctionImpl { + pub fn try_create(display_name: &str, args: &[&DataTypeImpl]) -> Result> { + let data_type = args[0]; + let path_type = args[1]; + + if !data_type.data_type_id().is_array() + || (BY_PATH && !path_type.data_type_id().is_string()) + || (!BY_PATH && !path_type.data_type_id().is_unsigned_integer()) + { + return Err(ErrorCode::IllegalDataType(format!( + "Invalid argument types for function '{}': ({:?}, {:?})", + display_name.to_uppercase(), + data_type.data_type_id(), + path_type.data_type_id() + ))); + } + + let array_type = if let DataTypeImpl::Array(array_type) = data_type { + array_type + } else { + unreachable!() + }; + + Ok(Box::new(ArrayGetFunctionImpl:: { + array_type: array_type.clone(), + display_name: display_name.to_string(), + })) + } + + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create)) + .features(FunctionFeatures::default().deterministic().num_arguments(2)) + } +} + +impl Function for ArrayGetFunctionImpl { + fn name(&self) -> &str { + &*self.display_name + } + + fn return_type(&self) -> DataTypeImpl { + // TODO(b41sh): Support multi-dimensional array access + NullableType::new_impl(self.array_type.inner_type().clone()) + } + + fn eval( + &self, + _func_ctx: FunctionContext, + columns: &ColumnsWithField, + input_rows: usize, + ) -> Result { + let path_keys = if BY_PATH { + parse_path_keys(columns[1].column())? + } else { + build_path_keys(columns[1].column())? + }; + + let array_column: &ArrayColumn = if columns[0].column().is_const() { + let const_column: &ConstColumn = Series::check_get(columns[0].column())?; + Series::check_get(const_column.inner())? + } else { + Series::check_get(columns[0].column())? + }; + + let inner_type = self.array_type.inner_type().data_type_id(); + with_match_scalar_types_error!(inner_type.to_physical_type(), |$T| { + let inner_column: &<$T as Scalar>::ColumnType = Series::check_get(array_column.values())?; + let mut builder = NullableColumnBuilder::<$T>::with_capacity(input_rows); + + for path_key in path_keys.iter() { + // TODO(b41sh): Support multi-dimensional array access + if path_key.is_empty() || path_key.len() > 1 { + return Err(ErrorCode::BadArguments(format!( + "Array column don't support accessed by path: {:?}", + path_key + ))); + } + let key = &path_key[0]; + if let DataValue::UInt64(k) = key { + let index = *k as usize; + let mut offset = 0; + for row in 0..array_column.len() { + let len = array_column.size_at_index(row); + if index >= len { + return Err(ErrorCode::BadArguments(format!( + "Index out of array column bounds: the len is {} but the index is {}", + len, index + ))); + } else { + builder.append(inner_column.get_data(offset + index), true); + } + offset += len; + } + } else { + return Err(ErrorCode::IllegalDataType(format!( + "Array column only support accessed by index, but got {:#?}", + key + ))); + } + } + Ok(builder.build(input_rows)) + }) + } +} + +impl fmt::Display for ArrayGetFunctionImpl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.display_name.to_uppercase()) + } +} diff --git a/common/functions/src/scalars/semi_structureds/get.rs b/common/functions/src/scalars/semi_structureds/get.rs index 6b9fe21b7556b..62c56fd9d81c1 100644 --- a/common/functions/src/scalars/semi_structureds/get.rs +++ b/common/functions/src/scalars/semi_structureds/get.rs @@ -22,6 +22,8 @@ use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; use sqlparser::tokenizer::Tokenizer; +use crate::scalars::semi_structureds::array_get::ArrayGetFunction; +use crate::scalars::semi_structureds::array_get::ArrayGetPathFunction; use crate::scalars::Function; use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; @@ -43,6 +45,14 @@ impl GetFunctionImpl Result<()> { test_scalar_functions("get_path", &tests) } + +#[test] +fn test_array_get_function() -> Result<()> { + let tests = vec![ + ScalarFunctionTest { + name: "array_get_int64", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec![0_u32, 1_u32]), + ], + expect: Series::from_data(vec![ + Some(1_i64), + Some(4_i64), + Some(7_i64), + Some(2_i64), + Some(5_i64), + Some(8_i64), + ]), + error: "", + }, + ScalarFunctionTest { + name: "array_get_string", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![ + "a1".as_bytes().into(), + "a2".as_bytes().into(), + "a3".as_bytes().into(), + ]), + ArrayValue::new(vec![ + "b1".as_bytes().into(), + "b2".as_bytes().into(), + "b3".as_bytes().into(), + ]), + ArrayValue::new(vec![ + "c1".as_bytes().into(), + "c2".as_bytes().into(), + "c3".as_bytes().into(), + ]), + ]), + Series::from_data(vec![0_u32, 1_u32]), + ], + expect: Series::from_data(vec![ + Some("a1"), + Some("b1"), + Some("c1"), + Some("a2"), + Some("b2"), + Some("c2"), + ]), + error: "", + }, + ScalarFunctionTest { + name: "array_get_out_of_bounds", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec![3_u32]), + ], + expect: Series::from_data(vec![None::<&str>]), + error: "Index out of array column bounds: the len is 3 but the index is 3", + }, + ScalarFunctionTest { + name: "array_get_error_type", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec!["a", "b"]), + ], + expect: Series::from_data(vec![None::<&str>]), + error: "Invalid argument types for function 'GET': (Array, String)", + }, + ]; + + test_scalar_functions("get", &tests) +} + +#[test] +fn test_array_get_path_function() -> Result<()> { + let tests = vec![ + ScalarFunctionTest { + name: "array_get_path_int64", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec!["[0]", "[1]"]), + ], + expect: Series::from_data(vec![ + Some(1_i64), + Some(4_i64), + Some(7_i64), + Some(2_i64), + Some(5_i64), + Some(8_i64), + ]), + error: "", + }, + ScalarFunctionTest { + name: "array_get_path_string", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![ + "a1".as_bytes().into(), + "a2".as_bytes().into(), + "a3".as_bytes().into(), + ]), + ArrayValue::new(vec![ + "b1".as_bytes().into(), + "b2".as_bytes().into(), + "b3".as_bytes().into(), + ]), + ArrayValue::new(vec![ + "c1".as_bytes().into(), + "c2".as_bytes().into(), + "c3".as_bytes().into(), + ]), + ]), + Series::from_data(vec!["[0]", "[1]"]), + ], + expect: Series::from_data(vec![ + Some("a1"), + Some("b1"), + Some("c1"), + Some("a2"), + Some("b2"), + Some("c2"), + ]), + error: "", + }, + ScalarFunctionTest { + name: "array_get_path_out_of_bounds", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec!["[3]"]), + ], + expect: Series::from_data(vec![None::<&str>]), + error: "Index out of array column bounds: the len is 3 but the index is 3", + }, + ScalarFunctionTest { + name: "array_get_path_error_type", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec![0_i32, 1_i32]), + ], + expect: Series::from_data(vec![None::<&str>]), + error: "Invalid argument types for function 'GET_PATH': (Array, Int32)", + }, + ]; + + test_scalar_functions("get_path", &tests) +} diff --git a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result index ec30b1ac96480..8e133e8a94aca 100644 --- a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result +++ b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result @@ -48,3 +48,12 @@ NULL 2 2 1 NULL 2 2 +==get from array table== +1 10 +2 50 +1 20 +2 60 +1 10 +2 50 +1 20 +2 60 diff --git a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql index 5ea281df80644..deb4a618870f2 100644 --- a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql +++ b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql @@ -42,6 +42,10 @@ CREATE TABLE IF NOT EXISTS t3(id Int null, str String null) Engine = Memory; insert into t3 values(1, '[1,2,3,["a","b","c"]]'), (2, '{"a":1,"b":{"c":2}}'); +CREATE TABLE IF NOT EXISTS t4(id Int null, arr Array(Int64) null) Engine = Memory; + +insert into t4 values(1, [10,20,30,40]), (2, [50,60,70,80]); + select '==get from table=='; select get(arr, 0) from t1; select get(arr, 'a') from t1; @@ -68,4 +72,14 @@ select id, json_extract_path_text(str, '["a"]') from t3; select id, json_extract_path_text(str, 'b.c') from t3; select id, json_extract_path_text(str, '["b"]["c"]') from t3; +select '==get from array table=='; +select id, get(arr, 0) from t4; +select id, get(arr, 1) from t4; +select id, get(arr, 4) from t4; -- {ErrorCode 1006} +select id, get(arr, 'a') from t4; -- {ErrorCode 1007} +select id, get_path(arr, '[0]') from t4; +select id, get_path(arr, '[1]') from t4; +select id, get_path(arr, '[4]') from t4; -- {ErrorCode 1006} +select id, get_path(arr, 'a') from t4; -- {ErrorCode 1007} + DROP DATABASE db1; diff --git a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result index a55250f9d3a53..8c73e9b5723a1 100644 --- a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result +++ b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result @@ -1,47 +1,77 @@ ==Array(UInt8)== 1 [1, 2, 3] 2 [254, 255] +1 2 +254 255 ==Array(UInt16)== 1 [1, 2, 3] 2 [65534, 65535] +1 2 +65534 65535 ==Array(UInt32)== 1 [1, 2, 3] 2 [4294967294, 4294967295] +1 2 +4294967294 4294967295 ==Array(UInt64)== 1 [1, 2, 3] 2 [18446744073709551614, 18446744073709551615] +1 2 +18446744073709551614 18446744073709551615 ==Array(Int8)== 1 [1, 2, 3] 2 [-128, 127] +1 2 +-128 127 ==Array(Int16)== 1 [1, 2, 3] 2 [-32768, 32767] +1 2 +-32768 32767 ==Array(Int32)== 1 [1, 2, 3] 2 [-2147483648, 2147483647] +1 2 +-2147483648 2147483647 ==Array(Int64)== 1 [1, 2, 3] 2 [-9223372036854775808, 9223372036854775807] +1 2 +-9223372036854775808 9223372036854775807 ==Array(Float32)== 1 [1.100000023841858, 1.2000000476837158, 1.2999999523162842] 2 [-1.100000023841858, -1.2000000476837158, -1.2999999523162842] +1.100000023841858 1.2000000476837158 +-1.100000023841858 -1.2000000476837158 ==Array(Float64)== 1 [1.1, 1.2, 1.3] 2 [-1.1, -1.2, -1.3] +1.1 1.2 +-1.1 -1.2 ==Array(Boolean)== 1 [1, 1] 2 [0, 0] 3 [1, 0] 4 [0, 1] +1 1 +0 0 +1 0 +0 1 ==Array(Date)== 1 ['2021-01-01', '2022-01-01'] 2 ['1990-12-01', '2030-01-12'] +2021-01-01 2022-01-01 +1990-12-01 2030-01-12 ==Array(Timestamp)== 1 ['2021-01-01 01:01:01', '2022-01-01 01:01:01'] 2 ['1990-12-01 10:11:12', '2030-01-12 22:00:00'] +2021-01-01 01:01:01.000000 2022-01-01 01:01:01.000000 +1990-12-01 10:11:12.000000 2030-01-12 22:00:00.000000 ==Array(String)== 1 ['aa', 'bb'] 2 ['cc', 'dd'] +aa bb +cc dd ==Array(String) Nullable== 1 ['aa', 'bb'] 2 ['cc', 'dd'] diff --git a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql index 20cd6de6f0ad1..0e0a1b6f3891e 100644 --- a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql +++ b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql @@ -9,6 +9,7 @@ CREATE TABLE IF NOT EXISTS t1(id Int, arr Array(UInt8)) Engine = Memory; INSERT INTO t1 (id, arr) VALUES(1, [1,2,3]), (2, [254,255]); select * from t1; +select arr[0], arr[1] from t1; select '==Array(UInt16)=='; @@ -17,6 +18,7 @@ CREATE TABLE IF NOT EXISTS t2(id Int, arr Array(UInt16)) Engine = Memory; INSERT INTO t2 (id, arr) VALUES(1, [1,2,3]), (2, [65534,65535]); select * from t2; +select arr[0], arr[1] from t2; select '==Array(UInt32)=='; @@ -25,6 +27,7 @@ CREATE TABLE IF NOT EXISTS t3(id Int, arr Array(UInt32)) Engine = Memory; INSERT INTO t3 (id, arr) VALUES(1, [1,2,3]), (2, [4294967294,4294967295]); select * from t3; +select arr[0], arr[1] from t3; select '==Array(UInt64)=='; @@ -33,6 +36,7 @@ CREATE TABLE IF NOT EXISTS t4(id Int, arr Array(UInt64)) Engine = Memory; INSERT INTO t4 (id, arr) VALUES(1, [1,2,3]), (2, [18446744073709551614,18446744073709551615]); select * from t4; +select arr[0], arr[1] from t4; select '==Array(Int8)=='; @@ -41,6 +45,7 @@ CREATE TABLE IF NOT EXISTS t5(id Int, arr Array(Int8)) Engine = Memory; INSERT INTO t5 (id, arr) VALUES(1, [1,2,3]), (2, [-128,127]); select * from t5; +select arr[0], arr[1] from t5; select '==Array(Int16)=='; @@ -49,6 +54,7 @@ CREATE TABLE IF NOT EXISTS t6(id Int, arr Array(Int16)) Engine = Memory; INSERT INTO t6 (id, arr) VALUES(1, [1,2,3]), (2, [-32768,32767]); select * from t6; +select arr[0], arr[1] from t6; select '==Array(Int32)=='; @@ -57,6 +63,7 @@ CREATE TABLE IF NOT EXISTS t7(id Int, arr Array(Int32)) Engine = Memory; INSERT INTO t7 (id, arr) VALUES(1, [1,2,3]), (2, [-2147483648,2147483647]); select * from t7; +select arr[0], arr[1] from t7; select '==Array(Int64)=='; @@ -65,6 +72,7 @@ CREATE TABLE IF NOT EXISTS t8(id Int, arr Array(Int64)) Engine = Memory; INSERT INTO t8 (id, arr) VALUES(1, [1,2,3]), (2, [-9223372036854775808,9223372036854775807]); select * from t8; +select arr[0], arr[1] from t8; select '==Array(Float32)=='; @@ -73,6 +81,7 @@ CREATE TABLE IF NOT EXISTS t9(id Int, arr Array(Float32)) Engine = Memory; INSERT INTO t9 (id, arr) VALUES(1, [1.1,1.2,1.3]), (2, [-1.1,-1.2,-1.3]); select * from t9; +select arr[0], arr[1] from t9; select '==Array(Float64)=='; @@ -81,6 +90,7 @@ CREATE TABLE IF NOT EXISTS t10(id Int, arr Array(Float64)) Engine = Memory; INSERT INTO t10 (id, arr) VALUES(1, [1.1,1.2,1.3]), (2, [-1.1,-1.2,-1.3]); select * from t10; +select arr[0], arr[1] from t10; select '==Array(Boolean)=='; @@ -89,6 +99,7 @@ CREATE TABLE IF NOT EXISTS t11(id Int, arr Array(Bool)) Engine = Memory; INSERT INTO t11 (id, arr) VALUES(1, [true, true]), (2, [false, false]), (3, [true, false]), (4, [false, true]); select * from t11; +select arr[0], arr[1] from t11; select '==Array(Date)=='; @@ -98,6 +109,7 @@ INSERT INTO t12 (id, arr) VALUES(1, ['2021-01-01', '2022-01-01']), (2, ['1990-12 INSERT INTO t12 (id, arr) VALUES(3, ['1000000-01-01', '2000000-01-01']); -- {ErrorCode 1010} select * from t12; +select arr[0], arr[1] from t12; select '==Array(Timestamp)=='; @@ -107,6 +119,7 @@ INSERT INTO t13 (id, arr) VALUES(1, ['2021-01-01 01:01:01', '2022-01-01 01:01:01 INSERT INTO t13 (id, arr) VALUES(3, ['1000000-01-01 01:01:01', '2000000-01-01 01:01:01']); -- {ErrorCode 1010} select * from t13; +select arr[0], arr[1] from t13; select '==Array(String)=='; @@ -115,6 +128,7 @@ CREATE TABLE IF NOT EXISTS t14(id Int, arr Array(String)) Engine = Memory; INSERT INTO t14 (id, arr) VALUES(1, ['aa', 'bb']), (2, ['cc', 'dd']); select * from t14; +select arr[0], arr[1] from t14; select '==Array(String) Nullable=='; @@ -123,5 +137,6 @@ CREATE TABLE IF NOT EXISTS t15(id Int, arr Array(String) Null) Engine = Memory; INSERT INTO t15 (id, arr) VALUES(1, ['aa', 'bb']), (2, ['cc', 'dd']), (3, null), (4, ['ee', 'ff']); select * from t15; +select arr[0], arr[1] from t15; -- {ErrorCode 1006} DROP DATABASE db1;