diff --git a/common/datavalues/src/array_value.rs b/common/datavalues/src/array_value.rs index 037d69caeafa8..0168bf9b02590 100644 --- a/common/datavalues/src/array_value.rs +++ b/common/datavalues/src/array_value.rs @@ -31,6 +31,13 @@ impl ArrayValue { pub fn new(values: Vec) -> Self { Self { values } } + + pub fn inner_type(&self) -> Option { + if let Some(value) = self.values.get(0) { + return Some(value.max_data_type()); + } + None + } } impl From for ArrayValue { diff --git a/common/datavalues/src/columns/series.rs b/common/datavalues/src/columns/series.rs index 28961ae948841..180a8ca9d83b8 100644 --- a/common/datavalues/src/columns/series.rs +++ b/common/datavalues/src/columns/series.rs @@ -221,6 +221,22 @@ impl SeriesFrom>, Vec>> for Series } } +impl SeriesFrom, Vec> for Series { + fn from_data(vals: Vec) -> ColumnRef { + let inner_data_type = match vals.iter().find(|&x| x.inner_type().is_some()) { + Some(array_value) => array_value.inner_type().unwrap(), + None => Int64Type::new_impl(), + }; + let mut builder = MutableArrayColumn::with_capacity_meta(vals.len(), ColumnMeta::Array { + data_type: inner_data_type, + }); + for val in vals { + builder.append_value(val); + } + builder.finish().arc() + } +} + macro_rules! impl_from_option_iterator { ([], $( { $S: ident} ),*) => { $( diff --git a/common/datavalues/src/macros.rs b/common/datavalues/src/macros.rs index 369d0ff6943fa..c5db6d58e25f9 100644 --- a/common/datavalues/src/macros.rs +++ b/common/datavalues/src/macros.rs @@ -29,6 +29,7 @@ macro_rules! for_all_scalar_types { { f64 }, { bool }, { Vu8 }, + { ArrayValue }, { VariantValue } } }; diff --git a/common/functions/src/scalars/semi_structureds/array_get.rs b/common/functions/src/scalars/semi_structureds/array_get.rs new file mode 100644 index 0000000000000..5d8726a3cc127 --- /dev/null +++ b/common/functions/src/scalars/semi_structureds/array_get.rs @@ -0,0 +1,139 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +use common_datavalues::prelude::*; +use common_datavalues::with_match_scalar_types_error; +use common_exception::ErrorCode; +use common_exception::Result; + +use crate::scalars::Function; +use crate::scalars::FunctionContext; +use crate::scalars::FunctionDescription; +use crate::scalars::FunctionFeatures; + +#[derive(Clone)] +pub struct ArrayGetFunction { + array_type: ArrayType, + display_name: String, +} + +impl ArrayGetFunction { + pub fn try_create(display_name: &str, args: &[&DataTypeImpl]) -> Result> { + let data_type = args[0]; + let path_type = args[1]; + + if !data_type.data_type_id().is_array() || !path_type.data_type_id().is_integer() { + return Err(ErrorCode::IllegalDataType(format!( + "Invalid argument types for function '{}': ({:?}, {:?})", + display_name.to_uppercase(), + data_type.data_type_id(), + path_type.data_type_id() + ))); + } + + let array_type: ArrayType = data_type.clone().try_into()?; + Ok(Box::new(ArrayGetFunction { + array_type, + display_name: display_name.to_string(), + })) + } + + pub fn desc() -> FunctionDescription { + FunctionDescription::creator(Box::new(Self::try_create)) + .features(FunctionFeatures::default().deterministic().num_arguments(2)) + } +} + +impl Function for ArrayGetFunction { + fn name(&self) -> &str { + &*self.display_name + } + + fn return_type(&self) -> DataTypeImpl { + NullableType::new_impl(self.array_type.inner_type().clone()) + } + + fn eval( + &self, + _func_ctx: FunctionContext, + columns: &ColumnsWithField, + input_rows: usize, + ) -> Result { + let array_column: &ArrayColumn = if columns[0].column().is_const() { + let const_column: &ConstColumn = Series::check_get(columns[0].column())?; + Series::check_get(const_column.inner())? + } else { + Series::check_get(columns[0].column())? + }; + + let inner_type = self.array_type.inner_type().data_type_id(); + with_match_scalar_types_error!(inner_type.to_physical_type(), |$T| { + let inner_column: &<$T as Scalar>::ColumnType = Series::check_get(array_column.values())?; + let mut builder = NullableColumnBuilder::<$T>::with_capacity(input_rows); + if columns[0].column().is_const() { + let index_column: &PrimitiveColumn = if columns[1].column().is_const() { + let const_column: &ConstColumn = Series::check_get(columns[1].column())?; + Series::check_get(const_column.inner())? + } else { + Series::check_get(columns[1].column())? + }; + let len = array_column.size_at_index(0); + for i in 0..input_rows { + let index = index_column.get(i).as_u64()? as usize; + let _ = check_index(index, len)?; + builder.append(inner_column.get_data(index), true); + } + } else if columns[1].column().is_const() { + let index_column: &ConstColumn = Series::check_get(columns[1].column())?; + let index = index_column.get(0).as_u64()? as usize; + let mut offset = 0; + for i in 0..input_rows { + let len = array_column.size_at_index(i); + let _ = check_index(index, len)?; + builder.append(inner_column.get_data(offset + index), true); + offset += len; + } + } else { + let index_column: &PrimitiveColumn = Series::check_get(columns[1].column())?; + let mut offset = 0; + for i in 0..input_rows { + let index = index_column.get(i).as_u64()? as usize; + let len = array_column.size_at_index(i); + let _ = check_index(index, len)?; + builder.append(inner_column.get_data(offset + index), true); + offset += len; + } + } + Ok(builder.build(input_rows)) + }) + } +} + +impl fmt::Display for ArrayGetFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.display_name.to_uppercase()) + } +} + +fn check_index(index: usize, len: usize) -> Result<()> { + if index >= len { + return Err(ErrorCode::BadArguments(format!( + "Index out of array column bounds: the len is {} but the index is {}", + len, index + ))); + } + Ok(()) +} diff --git a/common/functions/src/scalars/semi_structureds/get.rs b/common/functions/src/scalars/semi_structureds/get.rs index 6b9fe21b7556b..d8bd66344dde1 100644 --- a/common/functions/src/scalars/semi_structureds/get.rs +++ b/common/functions/src/scalars/semi_structureds/get.rs @@ -22,6 +22,7 @@ use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; use sqlparser::tokenizer::Tokenizer; +use crate::scalars::semi_structureds::array_get::ArrayGetFunction; use crate::scalars::Function; use crate::scalars::FunctionContext; use crate::scalars::FunctionDescription; @@ -43,6 +44,10 @@ impl GetFunctionImpl Result<()> { test_scalar_functions("get_path", &tests) } + +#[test] +fn test_array_get_function() -> Result<()> { + let tests = vec![ + ScalarFunctionTest { + name: "array_get_int64", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec![0_u32, 0, 0]), + ], + expect: Series::from_data(vec![Some(1_i64), Some(4_i64), Some(7_i64)]), + error: "", + }, + ScalarFunctionTest { + name: "array_get_string", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![ + "a1".as_bytes().into(), + "a2".as_bytes().into(), + "a3".as_bytes().into(), + ]), + ArrayValue::new(vec![ + "b1".as_bytes().into(), + "b2".as_bytes().into(), + "b3".as_bytes().into(), + ]), + ArrayValue::new(vec![ + "c1".as_bytes().into(), + "c2".as_bytes().into(), + "c3".as_bytes().into(), + ]), + ]), + Series::from_data(vec![0_u32, 0, 0]), + ], + expect: Series::from_data(vec![Some("a1"), Some("b1"), Some("c1")]), + error: "", + }, + ScalarFunctionTest { + name: "array_get_out_of_bounds", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec![3_u32, 3, 3]), + ], + expect: Series::from_data(vec![None::<&str>]), + error: "Index out of array column bounds: the len is 3 but the index is 3", + }, + ScalarFunctionTest { + name: "array_get_error_type", + columns: vec![ + Series::from_data(vec![ + ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]), + ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]), + ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]), + ]), + Series::from_data(vec!["a", "a", "a"]), + ], + expect: Series::from_data(vec![None::<&str>]), + error: "Invalid argument types for function 'GET': (Array, String)", + }, + ]; + + test_scalar_functions("get", &tests) +} diff --git a/common/planners/src/plan_expression_chain.rs b/common/planners/src/plan_expression_chain.rs index 14aaad519fa9f..8640d74a217cd 100644 --- a/common/planners/src/plan_expression_chain.rs +++ b/common/planners/src/plan_expression_chain.rs @@ -228,7 +228,7 @@ impl ExpressionChain { let arg_types2: Vec<&DataTypeImpl> = arg_types.iter().collect(); - let func_name = "get_path"; + let func_name = "get"; let func = FunctionFactory::instance().get(func_name, &arg_types2)?; let return_type = func.return_type(); diff --git a/common/planners/src/plan_expression_common.rs b/common/planners/src/plan_expression_common.rs index 6ed808029a7a7..2ef5efb340539 100644 --- a/common/planners/src/plan_expression_common.rs +++ b/common/planners/src/plan_expression_common.rs @@ -496,7 +496,7 @@ impl ExpressionVisitor for ExpressionDataTypeVisitor { self.stack.push(inner_type.clone()); Ok(self) } - Expression::MapAccess { args, .. } => self.visit_function("get_path", args.len()), + Expression::MapAccess { args, .. } => self.visit_function("get", args.len()), Expression::Alias(_, _) | Expression::Sort { .. } => Ok(self), } } diff --git a/query/src/sql/statements/analyzer_expr.rs b/query/src/sql/statements/analyzer_expr.rs index aac2c4cfff420..f096d90a21486 100644 --- a/query/src/sql/statements/analyzer_expr.rs +++ b/query/src/sql/statements/analyzer_expr.rs @@ -385,37 +385,39 @@ impl ExpressionAnalyzer { "MapAccess operator must be one children.", )), Some(inner_expr) => { - let path_name: String = keys - .iter() - .enumerate() - .map(|(i, k)| match k { - k @ Value::Number(_, _) => format!("[{}]", k), - Value::SingleQuotedString(s) => format!("[\"{}\"]", s), - Value::ColonString(s) => { - let key = if i == 0 { - s.to_string() - } else { - format!(":{}", s) - }; - key - } - Value::PeriodString(s) => format!(".{}", s), - _ => format!("[{}]", k), - }) - .collect(); - - let name = match keys[0] { - Value::ColonString(_) => format!("{}:{}", inner_expr.column_name(), path_name), - _ => format!("{}{}", inner_expr.column_name(), path_name), + let name = match &keys[0] { + k @ Value::Number(_, _) => format!("{}[{}]", inner_expr.column_name(), k), + Value::SingleQuotedString(s) => { + format!("{}['{}']", inner_expr.column_name(), s) + } + Value::DoubleQuotedString(s) => { + format!("{}[\"{}\"]", inner_expr.column_name(), s) + } + Value::ColonString(s) => format!("{}:{}", inner_expr.column_name(), s), + Value::PeriodString(s) => format!("{}.{}", inner_expr.column_name(), s), + _ => format!("{}[{}]", inner_expr.column_name(), keys[0]), }; - let path = - Expression::create_literal(DataValue::String(path_name.as_bytes().to_vec())); - let arguments = vec![inner_expr, path]; + + let path_expr = match &keys[0] { + Value::Number(value, _) => Expression::create_literal( + DataValue::try_from_literal(value, None).unwrap(), + ), + Value::SingleQuotedString(s) + | Value::DoubleQuotedString(s) + | Value::ColonString(s) + | Value::PeriodString(s) => Expression::create_literal(s.as_bytes().into()), + _ => Expression::create_literal(keys[0].to_string().as_bytes().into()), + }; + let arguments = vec![inner_expr, path_expr]; args.push(Expression::MapAccess { name, args: arguments, }); + // convert map access v[0][1] to function get(get(v, 0), 1) + if keys.len() > 1 { + self.analyze_map_access(&keys[1..], args)?; + } Ok(()) } } diff --git a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result index ec30b1ac96480..3863ce8435fc1 100644 --- a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result +++ b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.result @@ -48,3 +48,8 @@ NULL 2 2 1 NULL 2 2 +==get from array table== +1 10 +2 50 +1 20 +2 60 diff --git a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql index 5ea281df80644..fb49876362aa0 100644 --- a/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql +++ b/tests/suites/0_stateless/02_function/02_0051_function_semi_structureds_get.sql @@ -42,6 +42,10 @@ CREATE TABLE IF NOT EXISTS t3(id Int null, str String null) Engine = Memory; insert into t3 values(1, '[1,2,3,["a","b","c"]]'), (2, '{"a":1,"b":{"c":2}}'); +CREATE TABLE IF NOT EXISTS t4(id Int null, arr Array(Int64) null) Engine = Memory; + +insert into t4 values(1, [10,20,30,40]), (2, [50,60,70,80]); + select '==get from table=='; select get(arr, 0) from t1; select get(arr, 'a') from t1; @@ -68,4 +72,10 @@ select id, json_extract_path_text(str, '["a"]') from t3; select id, json_extract_path_text(str, 'b.c') from t3; select id, json_extract_path_text(str, '["b"]["c"]') from t3; +select '==get from array table=='; +select id, get(arr, 0) from t4; +select id, get(arr, 1) from t4; +select id, get(arr, 4) from t4; -- {ErrorCode 1006} +select id, get(arr, 'a') from t4; -- {ErrorCode 1007} + DROP DATABASE db1; diff --git a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result index a55250f9d3a53..8c73e9b5723a1 100644 --- a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result +++ b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.result @@ -1,47 +1,77 @@ ==Array(UInt8)== 1 [1, 2, 3] 2 [254, 255] +1 2 +254 255 ==Array(UInt16)== 1 [1, 2, 3] 2 [65534, 65535] +1 2 +65534 65535 ==Array(UInt32)== 1 [1, 2, 3] 2 [4294967294, 4294967295] +1 2 +4294967294 4294967295 ==Array(UInt64)== 1 [1, 2, 3] 2 [18446744073709551614, 18446744073709551615] +1 2 +18446744073709551614 18446744073709551615 ==Array(Int8)== 1 [1, 2, 3] 2 [-128, 127] +1 2 +-128 127 ==Array(Int16)== 1 [1, 2, 3] 2 [-32768, 32767] +1 2 +-32768 32767 ==Array(Int32)== 1 [1, 2, 3] 2 [-2147483648, 2147483647] +1 2 +-2147483648 2147483647 ==Array(Int64)== 1 [1, 2, 3] 2 [-9223372036854775808, 9223372036854775807] +1 2 +-9223372036854775808 9223372036854775807 ==Array(Float32)== 1 [1.100000023841858, 1.2000000476837158, 1.2999999523162842] 2 [-1.100000023841858, -1.2000000476837158, -1.2999999523162842] +1.100000023841858 1.2000000476837158 +-1.100000023841858 -1.2000000476837158 ==Array(Float64)== 1 [1.1, 1.2, 1.3] 2 [-1.1, -1.2, -1.3] +1.1 1.2 +-1.1 -1.2 ==Array(Boolean)== 1 [1, 1] 2 [0, 0] 3 [1, 0] 4 [0, 1] +1 1 +0 0 +1 0 +0 1 ==Array(Date)== 1 ['2021-01-01', '2022-01-01'] 2 ['1990-12-01', '2030-01-12'] +2021-01-01 2022-01-01 +1990-12-01 2030-01-12 ==Array(Timestamp)== 1 ['2021-01-01 01:01:01', '2022-01-01 01:01:01'] 2 ['1990-12-01 10:11:12', '2030-01-12 22:00:00'] +2021-01-01 01:01:01.000000 2022-01-01 01:01:01.000000 +1990-12-01 10:11:12.000000 2030-01-12 22:00:00.000000 ==Array(String)== 1 ['aa', 'bb'] 2 ['cc', 'dd'] +aa bb +cc dd ==Array(String) Nullable== 1 ['aa', 'bb'] 2 ['cc', 'dd'] diff --git a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql index 20cd6de6f0ad1..0e0a1b6f3891e 100644 --- a/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql +++ b/tests/suites/0_stateless/03_dml/03_0023_insert_into_array.sql @@ -9,6 +9,7 @@ CREATE TABLE IF NOT EXISTS t1(id Int, arr Array(UInt8)) Engine = Memory; INSERT INTO t1 (id, arr) VALUES(1, [1,2,3]), (2, [254,255]); select * from t1; +select arr[0], arr[1] from t1; select '==Array(UInt16)=='; @@ -17,6 +18,7 @@ CREATE TABLE IF NOT EXISTS t2(id Int, arr Array(UInt16)) Engine = Memory; INSERT INTO t2 (id, arr) VALUES(1, [1,2,3]), (2, [65534,65535]); select * from t2; +select arr[0], arr[1] from t2; select '==Array(UInt32)=='; @@ -25,6 +27,7 @@ CREATE TABLE IF NOT EXISTS t3(id Int, arr Array(UInt32)) Engine = Memory; INSERT INTO t3 (id, arr) VALUES(1, [1,2,3]), (2, [4294967294,4294967295]); select * from t3; +select arr[0], arr[1] from t3; select '==Array(UInt64)=='; @@ -33,6 +36,7 @@ CREATE TABLE IF NOT EXISTS t4(id Int, arr Array(UInt64)) Engine = Memory; INSERT INTO t4 (id, arr) VALUES(1, [1,2,3]), (2, [18446744073709551614,18446744073709551615]); select * from t4; +select arr[0], arr[1] from t4; select '==Array(Int8)=='; @@ -41,6 +45,7 @@ CREATE TABLE IF NOT EXISTS t5(id Int, arr Array(Int8)) Engine = Memory; INSERT INTO t5 (id, arr) VALUES(1, [1,2,3]), (2, [-128,127]); select * from t5; +select arr[0], arr[1] from t5; select '==Array(Int16)=='; @@ -49,6 +54,7 @@ CREATE TABLE IF NOT EXISTS t6(id Int, arr Array(Int16)) Engine = Memory; INSERT INTO t6 (id, arr) VALUES(1, [1,2,3]), (2, [-32768,32767]); select * from t6; +select arr[0], arr[1] from t6; select '==Array(Int32)=='; @@ -57,6 +63,7 @@ CREATE TABLE IF NOT EXISTS t7(id Int, arr Array(Int32)) Engine = Memory; INSERT INTO t7 (id, arr) VALUES(1, [1,2,3]), (2, [-2147483648,2147483647]); select * from t7; +select arr[0], arr[1] from t7; select '==Array(Int64)=='; @@ -65,6 +72,7 @@ CREATE TABLE IF NOT EXISTS t8(id Int, arr Array(Int64)) Engine = Memory; INSERT INTO t8 (id, arr) VALUES(1, [1,2,3]), (2, [-9223372036854775808,9223372036854775807]); select * from t8; +select arr[0], arr[1] from t8; select '==Array(Float32)=='; @@ -73,6 +81,7 @@ CREATE TABLE IF NOT EXISTS t9(id Int, arr Array(Float32)) Engine = Memory; INSERT INTO t9 (id, arr) VALUES(1, [1.1,1.2,1.3]), (2, [-1.1,-1.2,-1.3]); select * from t9; +select arr[0], arr[1] from t9; select '==Array(Float64)=='; @@ -81,6 +90,7 @@ CREATE TABLE IF NOT EXISTS t10(id Int, arr Array(Float64)) Engine = Memory; INSERT INTO t10 (id, arr) VALUES(1, [1.1,1.2,1.3]), (2, [-1.1,-1.2,-1.3]); select * from t10; +select arr[0], arr[1] from t10; select '==Array(Boolean)=='; @@ -89,6 +99,7 @@ CREATE TABLE IF NOT EXISTS t11(id Int, arr Array(Bool)) Engine = Memory; INSERT INTO t11 (id, arr) VALUES(1, [true, true]), (2, [false, false]), (3, [true, false]), (4, [false, true]); select * from t11; +select arr[0], arr[1] from t11; select '==Array(Date)=='; @@ -98,6 +109,7 @@ INSERT INTO t12 (id, arr) VALUES(1, ['2021-01-01', '2022-01-01']), (2, ['1990-12 INSERT INTO t12 (id, arr) VALUES(3, ['1000000-01-01', '2000000-01-01']); -- {ErrorCode 1010} select * from t12; +select arr[0], arr[1] from t12; select '==Array(Timestamp)=='; @@ -107,6 +119,7 @@ INSERT INTO t13 (id, arr) VALUES(1, ['2021-01-01 01:01:01', '2022-01-01 01:01:01 INSERT INTO t13 (id, arr) VALUES(3, ['1000000-01-01 01:01:01', '2000000-01-01 01:01:01']); -- {ErrorCode 1010} select * from t13; +select arr[0], arr[1] from t13; select '==Array(String)=='; @@ -115,6 +128,7 @@ CREATE TABLE IF NOT EXISTS t14(id Int, arr Array(String)) Engine = Memory; INSERT INTO t14 (id, arr) VALUES(1, ['aa', 'bb']), (2, ['cc', 'dd']); select * from t14; +select arr[0], arr[1] from t14; select '==Array(String) Nullable=='; @@ -123,5 +137,6 @@ CREATE TABLE IF NOT EXISTS t15(id Int, arr Array(String) Null) Engine = Memory; INSERT INTO t15 (id, arr) VALUES(1, ['aa', 'bb']), (2, ['cc', 'dd']), (3, null), (4, ['ee', 'ff']); select * from t15; +select arr[0], arr[1] from t15; -- {ErrorCode 1006} DROP DATABASE db1;