Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(function): Support generic Array access elements by index #5244

Merged
merged 1 commit into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/datavalues/src/array_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ impl ArrayValue {
pub fn new(values: Vec<DataValue>) -> Self {
Self { values }
}

pub fn inner_type(&self) -> Option<DataTypeImpl> {
if let Some(value) = self.values.get(0) {
return Some(value.max_data_type());
}
None
}
}

impl From<DataValue> for ArrayValue {
Expand Down
16 changes: 16 additions & 0 deletions common/datavalues/src/columns/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ impl SeriesFrom<Vec<Option<VariantValue>>, Vec<Option<VariantValue>>> for Series
}
}

impl SeriesFrom<Vec<ArrayValue>, Vec<ArrayValue>> for Series {
fn from_data(vals: Vec<ArrayValue>) -> ColumnRef {
let inner_data_type = match vals.iter().find(|&x| x.inner_type().is_some()) {
Some(array_value) => array_value.inner_type().unwrap(),
None => Int64Type::new_impl(),
};
let mut builder = MutableArrayColumn::with_capacity_meta(vals.len(), ColumnMeta::Array {
data_type: inner_data_type,
});
for val in vals {
builder.append_value(val);
}
builder.finish().arc()
}
}

macro_rules! impl_from_option_iterator {
([], $( { $S: ident} ),*) => {
$(
Expand Down
51 changes: 51 additions & 0 deletions common/datavalues/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ macro_rules! for_all_scalar_types {
{ f64 },
{ bool },
{ Vu8 },
{ ArrayValue },
{ VariantValue }
}
};
Expand Down Expand Up @@ -278,6 +279,56 @@ macro_rules! with_match_primitive_types_error {
}};
}

#[macro_export]
macro_rules! with_match_integer_type_id {
($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}

match $key_type {
TypeID::Int8 => __with_ty__! { i8 },
TypeID::Int16 => __with_ty__! { i16 },
TypeID::Int32 => __with_ty__! { i32 },
TypeID::Int64 => __with_ty__! { i64 },
TypeID::UInt8 => __with_ty__! { u8 },
TypeID::UInt16 => __with_ty__! { u16 },
TypeID::UInt32 => __with_ty__! { u32 },
TypeID::UInt64 => __with_ty__! { u64 },

_ => $nbody,
}
}};
}

#[macro_export]
macro_rules! with_match_integer_types_error {
($key_type:expr, | $_:tt $T:ident | $body:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}

match $key_type {
TypeID::Int8 => __with_ty__! { i8 },
TypeID::Int16 => __with_ty__! { i16 },
TypeID::Int32 => __with_ty__! { i32 },
TypeID::Int64 => __with_ty__! { i64 },
TypeID::UInt8 => __with_ty__! { u8 },
TypeID::UInt16 => __with_ty__! { u16 },
TypeID::UInt32 => __with_ty__! { u32 },
TypeID::UInt64 => __with_ty__! { u64 },
v => Err(ErrorCode::BadDataValueType(format!(
"Ops is not support on datatype: {:?}",
v
))),
}
}};
}

#[macro_export]
macro_rules! with_match_date_type_error {
($key_type:expr, | $_:tt $T:ident | $body:tt) => {{
Expand Down
6 changes: 6 additions & 0 deletions common/exception/src/exception_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ impl From<std::num::ParseFloatError> for ErrorCode {
}
}

impl From<std::num::TryFromIntError> for ErrorCode {
fn from(error: std::num::TryFromIntError) -> Self {
ErrorCode::from_std_error(error)
}
}

impl From<common_arrow::arrow::error::ArrowError> for ErrorCode {
fn from(error: common_arrow::arrow::error::ArrowError) -> Self {
ErrorCode::from_std_error(error)
Expand Down
143 changes: 143 additions & 0 deletions common/functions/src/scalars/semi_structureds/array_get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright 2022 Datafuse Labs.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;

use common_datavalues::prelude::*;
use common_datavalues::with_match_integer_types_error;
use common_datavalues::with_match_scalar_types_error;
use common_exception::ErrorCode;
use common_exception::Result;

use crate::scalars::Function;
use crate::scalars::FunctionContext;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;

#[derive(Clone)]
pub struct ArrayGetFunction {
array_type: ArrayType,
display_name: String,
}

impl ArrayGetFunction {
pub fn try_create(display_name: &str, args: &[&DataTypeImpl]) -> Result<Box<dyn Function>> {
let data_type = args[0];
let path_type = args[1];

if !data_type.data_type_id().is_array() || !path_type.data_type_id().is_integer() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if index is negative?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, we don't support index by a negative number, a error will return.

mysql> select get(v, -1) from arr;
ERROR 1105 (HY000): Code: 1010, displayText = Unexpected type:Int64 to get u64 number (while in processor thread 0).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have not seen as_u64 method.

return Err(ErrorCode::IllegalDataType(format!(
"Invalid argument types for function '{}': ({:?}, {:?})",
display_name.to_uppercase(),
data_type.data_type_id(),
path_type.data_type_id()
)));
}

let array_type: ArrayType = data_type.clone().try_into()?;
Ok(Box::new(ArrayGetFunction {
array_type,
display_name: display_name.to_string(),
}))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(2))
}
}

impl Function for ArrayGetFunction {
fn name(&self) -> &str {
&*self.display_name
}

fn return_type(&self) -> DataTypeImpl {
NullableType::new_impl(self.array_type.inner_type().clone())
}

fn eval(
&self,
_func_ctx: FunctionContext,
columns: &ColumnsWithField,
input_rows: usize,
) -> Result<ColumnRef> {
let array_column: &ArrayColumn = if columns[0].column().is_const() {
let const_column: &ConstColumn = Series::check_get(columns[0].column())?;
Series::check_get(const_column.inner())?
} else {
Series::check_get(columns[0].column())?
};

let inner_type = self.array_type.inner_type().data_type_id();
let index_type = columns[1].data_type().data_type_id();
with_match_scalar_types_error!(inner_type.to_physical_type(), |$T1| {
with_match_integer_types_error!(index_type, |$T2| {
let inner_column: &<$T1 as Scalar>::ColumnType = Series::check_get(array_column.values())?;
let mut builder = NullableColumnBuilder::<$T1>::with_capacity(input_rows);
if columns[0].column().is_const() {
let index_column: &PrimitiveColumn<$T2> = if columns[1].column().is_const() {
let const_column: &ConstColumn = Series::check_get(columns[1].column())?;
Series::check_get(const_column.inner())?
} else {
Series::check_get(columns[1].column())?
};
let len = array_column.size_at_index(0);
for (i, index) in index_column.iter().enumerate() {
let index = usize::try_from(*index)?;
let _ = check_index(index, len)?;
builder.append(inner_column.get_data(index), true);
}
} else if columns[1].column().is_const() {
let index_column: &ConstColumn = Series::check_get(columns[1].column())?;
let index = index_column.get(0).as_u64()? as usize;
let mut offset = 0;
for i in 0..input_rows {
let len = array_column.size_at_index(i);
let _ = check_index(index, len)?;
builder.append(inner_column.get_data(offset + index), true);
offset += len;
}
} else {
let index_column: &PrimitiveColumn<$T2> = Series::check_get(columns[1].column())?;
let mut offset = 0;
for (i, index) in index_column.iter().enumerate() {
let index = usize::try_from(*index)?;
let len = array_column.size_at_index(i);
let _ = check_index(index, len)?;
builder.append(inner_column.get_data(offset + index), true);
offset += len;
}
}
Ok(builder.build(input_rows))
})
})
}
}

impl fmt::Display for ArrayGetFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name.to_uppercase())
}
}

fn check_index(index: usize, len: usize) -> Result<()> {
if index >= len {
return Err(ErrorCode::BadArguments(format!(
"Index out of array column bounds: the len is {} but the index is {}",
len, index
)));
}
Ok(())
}
5 changes: 5 additions & 0 deletions common/functions/src/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use sqlparser::tokenizer::Tokenizer;

use crate::scalars::semi_structureds::array_get::ArrayGetFunction;
use crate::scalars::Function;
use crate::scalars::FunctionContext;
use crate::scalars::FunctionDescription;
Expand All @@ -43,6 +44,10 @@ impl<const BY_PATH: bool, const IGNORE_CASE: bool> GetFunctionImpl<BY_PATH, IGNO
let data_type = args[0];
let path_type = args[1];

if data_type.data_type_id().is_array() {
return ArrayGetFunction::try_create(display_name, args);
}

if (IGNORE_CASE
&& (!data_type.data_type_id().is_variant_or_object()
|| !path_type.data_type_id().is_string()))
Expand Down
2 changes: 2 additions & 0 deletions common/functions/src/scalars/semi_structureds/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod array_get;
mod array_length;
mod check_json;
mod get;
Expand All @@ -20,6 +21,7 @@ mod length;
mod parse_json;
mod semi_structured;

pub use array_get::ArrayGetFunction;
pub use array_length::ArrayLengthFunction;
pub use check_json::CheckJsonFunction;
pub use get::GetFunction;
Expand Down
72 changes: 72 additions & 0 deletions common/functions/tests/it/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,75 @@ fn test_get_path_function() -> Result<()> {

test_scalar_functions("get_path", &tests)
}

#[test]
fn test_array_get_function() -> Result<()> {
let tests = vec![
ScalarFunctionTest {
name: "array_get_int64",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec![0_u32, 0, 0]),
],
expect: Series::from_data(vec![Some(1_i64), Some(4_i64), Some(7_i64)]),
error: "",
},
ScalarFunctionTest {
name: "array_get_string",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![
"a1".as_bytes().into(),
"a2".as_bytes().into(),
"a3".as_bytes().into(),
]),
ArrayValue::new(vec![
"b1".as_bytes().into(),
"b2".as_bytes().into(),
"b3".as_bytes().into(),
]),
ArrayValue::new(vec![
"c1".as_bytes().into(),
"c2".as_bytes().into(),
"c3".as_bytes().into(),
]),
]),
Series::from_data(vec![0_u32, 0, 0]),
],
expect: Series::from_data(vec![Some("a1"), Some("b1"), Some("c1")]),
error: "",
},
ScalarFunctionTest {
name: "array_get_out_of_bounds",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec![3_u32, 3, 3]),
],
expect: Series::from_data(vec![None::<&str>]),
error: "Index out of array column bounds: the len is 3 but the index is 3",
},
ScalarFunctionTest {
name: "array_get_error_type",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec!["a", "a", "a"]),
],
expect: Series::from_data(vec![None::<&str>]),
error: "Invalid argument types for function 'GET': (Array, String)",
},
];

test_scalar_functions("get", &tests)
}
2 changes: 1 addition & 1 deletion common/planners/src/plan_expression_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl ExpressionChain {

let arg_types2: Vec<&DataTypeImpl> = arg_types.iter().collect();

let func_name = "get_path";
let func_name = "get";
let func = FunctionFactory::instance().get(func_name, &arg_types2)?;
let return_type = func.return_type();

Expand Down
2 changes: 1 addition & 1 deletion common/planners/src/plan_expression_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ impl ExpressionVisitor for ExpressionDataTypeVisitor {
self.stack.push(inner_type.clone());
Ok(self)
}
Expression::MapAccess { args, .. } => self.visit_function("get_path", args.len()),
Expression::MapAccess { args, .. } => self.visit_function("get", args.len()),
Expression::Alias(_, _) | Expression::Sort { .. } => Ok(self),
}
}
Expand Down
Loading