Skip to content

Commit

Permalink
support array access elements
Browse files Browse the repository at this point in the history
  • Loading branch information
b41sh committed May 11, 2022
1 parent 01a7dbd commit 2ea8ba3
Show file tree
Hide file tree
Showing 15 changed files with 391 additions and 27 deletions.
7 changes: 7 additions & 0 deletions common/datavalues/src/array_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ impl ArrayValue {
pub fn new(values: Vec<DataValue>) -> Self {
Self { values }
}

pub fn inner_type(&self) -> Option<DataTypeImpl> {
if let Some(value) = self.values.get(0) {
return Some(value.max_data_type());
}
None
}
}

impl From<DataValue> for ArrayValue {
Expand Down
16 changes: 16 additions & 0 deletions common/datavalues/src/columns/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ impl SeriesFrom<Vec<Option<VariantValue>>, Vec<Option<VariantValue>>> for Series
}
}

impl SeriesFrom<Vec<ArrayValue>, Vec<ArrayValue>> for Series {
fn from_data(vals: Vec<ArrayValue>) -> ColumnRef {
let inner_data_type = match vals.iter().find(|&x| x.inner_type().is_some()) {
Some(array_value) => array_value.inner_type().unwrap(),
None => Int64Type::new_impl(),
};
let mut builder = MutableArrayColumn::with_capacity_meta(vals.len(), ColumnMeta::Array {
data_type: inner_data_type,
});
for val in vals {
builder.append_value(val);
}
builder.finish().arc()
}
}

macro_rules! impl_from_option_iterator {
([], $( { $S: ident} ),*) => {
$(
Expand Down
51 changes: 51 additions & 0 deletions common/datavalues/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ macro_rules! for_all_scalar_types {
{ f64 },
{ bool },
{ Vu8 },
{ ArrayValue },
{ VariantValue }
}
};
Expand Down Expand Up @@ -278,6 +279,56 @@ macro_rules! with_match_primitive_types_error {
}};
}

#[macro_export]
macro_rules! with_match_integer_type_id {
($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}

match $key_type {
TypeID::Int8 => __with_ty__! { i8 },
TypeID::Int16 => __with_ty__! { i16 },
TypeID::Int32 => __with_ty__! { i32 },
TypeID::Int64 => __with_ty__! { i64 },
TypeID::UInt8 => __with_ty__! { u8 },
TypeID::UInt16 => __with_ty__! { u16 },
TypeID::UInt32 => __with_ty__! { u32 },
TypeID::UInt64 => __with_ty__! { u64 },

_ => $nbody,
}
}};
}

#[macro_export]
macro_rules! with_match_integer_types_error {
($key_type:expr, | $_:tt $T:ident | $body:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}

match $key_type {
TypeID::Int8 => __with_ty__! { i8 },
TypeID::Int16 => __with_ty__! { i16 },
TypeID::Int32 => __with_ty__! { i32 },
TypeID::Int64 => __with_ty__! { i64 },
TypeID::UInt8 => __with_ty__! { u8 },
TypeID::UInt16 => __with_ty__! { u16 },
TypeID::UInt32 => __with_ty__! { u32 },
TypeID::UInt64 => __with_ty__! { u64 },
v => Err(ErrorCode::BadDataValueType(format!(
"Ops is not support on datatype: {:?}",
v
))),
}
}};
}

#[macro_export]
macro_rules! with_match_date_type_error {
($key_type:expr, | $_:tt $T:ident | $body:tt) => {{
Expand Down
6 changes: 6 additions & 0 deletions common/exception/src/exception_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ impl From<std::num::ParseFloatError> for ErrorCode {
}
}

impl From<std::num::TryFromIntError> for ErrorCode {
fn from(error: std::num::TryFromIntError) -> Self {
ErrorCode::from_std_error(error)
}
}

impl From<common_arrow::arrow::error::ArrowError> for ErrorCode {
fn from(error: common_arrow::arrow::error::ArrowError) -> Self {
ErrorCode::from_std_error(error)
Expand Down
143 changes: 143 additions & 0 deletions common/functions/src/scalars/semi_structureds/array_get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright 2022 Datafuse Labs.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;

use common_datavalues::prelude::*;
use common_datavalues::with_match_integer_types_error;
use common_datavalues::with_match_scalar_types_error;
use common_exception::ErrorCode;
use common_exception::Result;

use crate::scalars::Function;
use crate::scalars::FunctionContext;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;

#[derive(Clone)]
pub struct ArrayGetFunction {
array_type: ArrayType,
display_name: String,
}

impl ArrayGetFunction {
pub fn try_create(display_name: &str, args: &[&DataTypeImpl]) -> Result<Box<dyn Function>> {
let data_type = args[0];
let path_type = args[1];

if !data_type.data_type_id().is_array() || !path_type.data_type_id().is_integer() {
return Err(ErrorCode::IllegalDataType(format!(
"Invalid argument types for function '{}': ({:?}, {:?})",
display_name.to_uppercase(),
data_type.data_type_id(),
path_type.data_type_id()
)));
}

let array_type: ArrayType = data_type.clone().try_into()?;
Ok(Box::new(ArrayGetFunction {
array_type,
display_name: display_name.to_string(),
}))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(2))
}
}

impl Function for ArrayGetFunction {
fn name(&self) -> &str {
&*self.display_name
}

fn return_type(&self) -> DataTypeImpl {
NullableType::new_impl(self.array_type.inner_type().clone())
}

fn eval(
&self,
_func_ctx: FunctionContext,
columns: &ColumnsWithField,
input_rows: usize,
) -> Result<ColumnRef> {
let array_column: &ArrayColumn = if columns[0].column().is_const() {
let const_column: &ConstColumn = Series::check_get(columns[0].column())?;
Series::check_get(const_column.inner())?
} else {
Series::check_get(columns[0].column())?
};

let inner_type = self.array_type.inner_type().data_type_id();
let index_type = columns[1].data_type().data_type_id();
with_match_scalar_types_error!(inner_type.to_physical_type(), |$T1| {
with_match_integer_types_error!(index_type, |$T2| {
let inner_column: &<$T1 as Scalar>::ColumnType = Series::check_get(array_column.values())?;
let mut builder = NullableColumnBuilder::<$T1>::with_capacity(input_rows);
if columns[0].column().is_const() {
let index_column: &PrimitiveColumn<$T2> = if columns[1].column().is_const() {
let const_column: &ConstColumn = Series::check_get(columns[1].column())?;
Series::check_get(const_column.inner())?
} else {
Series::check_get(columns[1].column())?
};
let len = array_column.size_at_index(0);
for (i, index) in index_column.iter().enumerate() {
let index = usize::try_from(*index)?;
let _ = check_index(index, len)?;
builder.append(inner_column.get_data(index), true);
}
} else if columns[1].column().is_const() {
let index_column: &ConstColumn = Series::check_get(columns[1].column())?;
let index = index_column.get(0).as_u64()? as usize;
let mut offset = 0;
for i in 0..input_rows {
let len = array_column.size_at_index(i);
let _ = check_index(index, len)?;
builder.append(inner_column.get_data(offset + index), true);
offset += len;
}
} else {
let index_column: &PrimitiveColumn<$T2> = Series::check_get(columns[1].column())?;
let mut offset = 0;
for (i, index) in index_column.iter().enumerate() {
let index = usize::try_from(*index)?;
let len = array_column.size_at_index(i);
let _ = check_index(index, len)?;
builder.append(inner_column.get_data(offset + index), true);
offset += len;
}
}
Ok(builder.build(input_rows))
})
})
}
}

impl fmt::Display for ArrayGetFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name.to_uppercase())
}
}

fn check_index(index: usize, len: usize) -> Result<()> {
if index >= len {
return Err(ErrorCode::BadArguments(format!(
"Index out of array column bounds: the len is {} but the index is {}",
len, index
)));
}
Ok(())
}
5 changes: 5 additions & 0 deletions common/functions/src/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use sqlparser::tokenizer::Tokenizer;

use crate::scalars::semi_structureds::array_get::ArrayGetFunction;
use crate::scalars::Function;
use crate::scalars::FunctionContext;
use crate::scalars::FunctionDescription;
Expand All @@ -43,6 +44,10 @@ impl<const BY_PATH: bool, const IGNORE_CASE: bool> GetFunctionImpl<BY_PATH, IGNO
let data_type = args[0];
let path_type = args[1];

if data_type.data_type_id().is_array() {
return ArrayGetFunction::try_create(display_name, args);
}

if (IGNORE_CASE
&& (!data_type.data_type_id().is_variant_or_object()
|| !path_type.data_type_id().is_string()))
Expand Down
2 changes: 2 additions & 0 deletions common/functions/src/scalars/semi_structureds/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod array_get;
mod array_length;
mod check_json;
mod get;
Expand All @@ -20,6 +21,7 @@ mod length;
mod parse_json;
mod semi_structured;

pub use array_get::ArrayGetFunction;
pub use array_length::ArrayLengthFunction;
pub use check_json::CheckJsonFunction;
pub use get::GetFunction;
Expand Down
72 changes: 72 additions & 0 deletions common/functions/tests/it/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,75 @@ fn test_get_path_function() -> Result<()> {

test_scalar_functions("get_path", &tests)
}

#[test]
fn test_array_get_function() -> Result<()> {
let tests = vec![
ScalarFunctionTest {
name: "array_get_int64",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec![0_u32, 0, 0]),
],
expect: Series::from_data(vec![Some(1_i64), Some(4_i64), Some(7_i64)]),
error: "",
},
ScalarFunctionTest {
name: "array_get_string",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![
"a1".as_bytes().into(),
"a2".as_bytes().into(),
"a3".as_bytes().into(),
]),
ArrayValue::new(vec![
"b1".as_bytes().into(),
"b2".as_bytes().into(),
"b3".as_bytes().into(),
]),
ArrayValue::new(vec![
"c1".as_bytes().into(),
"c2".as_bytes().into(),
"c3".as_bytes().into(),
]),
]),
Series::from_data(vec![0_u32, 0, 0]),
],
expect: Series::from_data(vec![Some("a1"), Some("b1"), Some("c1")]),
error: "",
},
ScalarFunctionTest {
name: "array_get_out_of_bounds",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec![3_u32, 3, 3]),
],
expect: Series::from_data(vec![None::<&str>]),
error: "Index out of array column bounds: the len is 3 but the index is 3",
},
ScalarFunctionTest {
name: "array_get_error_type",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec!["a", "a", "a"]),
],
expect: Series::from_data(vec![None::<&str>]),
error: "Invalid argument types for function 'GET': (Array, String)",
},
];

test_scalar_functions("get", &tests)
}
2 changes: 1 addition & 1 deletion common/planners/src/plan_expression_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl ExpressionChain {

let arg_types2: Vec<&DataTypeImpl> = arg_types.iter().collect();

let func_name = "get_path";
let func_name = "get";
let func = FunctionFactory::instance().get(func_name, &arg_types2)?;
let return_type = func.return_type();

Expand Down
2 changes: 1 addition & 1 deletion common/planners/src/plan_expression_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ impl ExpressionVisitor for ExpressionDataTypeVisitor {
self.stack.push(inner_type.clone());
Ok(self)
}
Expression::MapAccess { args, .. } => self.visit_function("get_path", args.len()),
Expression::MapAccess { args, .. } => self.visit_function("get", args.len()),
Expression::Alias(_, _) | Expression::Sort { .. } => Ok(self),
}
}
Expand Down
Loading

0 comments on commit 2ea8ba3

Please sign in to comment.