Skip to content

Commit

Permalink
support array access elements
Browse files Browse the repository at this point in the history
  • Loading branch information
b41sh committed May 10, 2022
1 parent ed635e9 commit 22786cf
Show file tree
Hide file tree
Showing 14 changed files with 341 additions and 27 deletions.
7 changes: 7 additions & 0 deletions common/datavalues/src/array_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ impl ArrayValue {
pub fn new(values: Vec<DataValue>) -> Self {
Self { values }
}

pub fn inner_type(&self) -> Option<DataTypeImpl> {
if let Some(value) = self.values.get(0) {
return Some(value.max_data_type());
}
None
}
}

impl From<DataValue> for ArrayValue {
Expand Down
16 changes: 16 additions & 0 deletions common/datavalues/src/columns/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ impl SeriesFrom<Vec<Option<VariantValue>>, Vec<Option<VariantValue>>> for Series
}
}

impl SeriesFrom<Vec<ArrayValue>, Vec<ArrayValue>> for Series {
fn from_data(vals: Vec<ArrayValue>) -> ColumnRef {
let inner_data_type = match vals.iter().find(|&x| x.inner_type().is_some()) {
Some(array_value) => array_value.inner_type().unwrap(),
None => Int64Type::new_impl(),
};
let mut builder = MutableArrayColumn::with_capacity_meta(vals.len(), ColumnMeta::Array {
data_type: inner_data_type,
});
for val in vals {
builder.append_value(val);
}
builder.finish().arc()
}
}

macro_rules! impl_from_option_iterator {
([], $( { $S: ident} ),*) => {
$(
Expand Down
1 change: 1 addition & 0 deletions common/datavalues/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ macro_rules! for_all_scalar_types {
{ f64 },
{ bool },
{ Vu8 },
{ ArrayValue },
{ VariantValue }
}
};
Expand Down
135 changes: 135 additions & 0 deletions common/functions/src/scalars/semi_structureds/array_get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright 2022 Datafuse Labs.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;

use common_datavalues::prelude::*;
use common_datavalues::with_match_scalar_types_error;
use common_exception::ErrorCode;
use common_exception::Result;

use crate::scalars::Function;
use crate::scalars::FunctionContext;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;

#[derive(Clone)]
pub struct ArrayGetFunction {
array_type: ArrayType,
display_name: String,
}

impl ArrayGetFunction {
pub fn try_create(display_name: &str, args: &[&DataTypeImpl]) -> Result<Box<dyn Function>> {
let data_type = args[0];
let path_type = args[1];

if !data_type.data_type_id().is_array() || !path_type.data_type_id().is_integer() {
return Err(ErrorCode::IllegalDataType(format!(
"Invalid argument types for function '{}': ({:?}, {:?})",
display_name.to_uppercase(),
data_type.data_type_id(),
path_type.data_type_id()
)));
}

let array_type: ArrayType = data_type.clone().try_into()?;
Ok(Box::new(ArrayGetFunction {
array_type,
display_name: display_name.to_string(),
}))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(2))
}
}

impl Function for ArrayGetFunction {
fn name(&self) -> &str {
&*self.display_name
}

fn return_type(&self) -> DataTypeImpl {
NullableType::new_impl(self.array_type.inner_type().clone())
}

fn eval(
&self,
_func_ctx: FunctionContext,
columns: &ColumnsWithField,
input_rows: usize,
) -> Result<ColumnRef> {
let indexes = build_path_indexes(columns[1].column())?;

let array_column: &ArrayColumn = if columns[0].column().is_const() {
let const_column: &ConstColumn = Series::check_get(columns[0].column())?;
Series::check_get(const_column.inner())?
} else {
Series::check_get(columns[0].column())?
};

let inner_type = self.array_type.inner_type().data_type_id();
with_match_scalar_types_error!(inner_type.to_physical_type(), |$T| {
let inner_column: &<$T as Scalar>::ColumnType = Series::check_get(array_column.values())?;
let mut builder = NullableColumnBuilder::<$T>::with_capacity(input_rows);

for index in indexes.iter() {
let mut offset = 0;
for row in 0..array_column.len() {
let len = array_column.size_at_index(row);
if *index >= len {
return Err(ErrorCode::BadArguments(format!(
"Index out of array column bounds: the len is {} but the index is {}",
len, index
)));
} else {
builder.append(inner_column.get_data(offset + index), true);
}
offset += len;
}
}
Ok(builder.build(input_rows))
})
}
}

impl fmt::Display for ArrayGetFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name.to_uppercase())
}
}

fn build_path_indexes(column: &ColumnRef) -> Result<Vec<usize>> {
if column.is_const() {
let const_column: &ConstColumn = Series::check_get(column)?;
return build_path_indexes(const_column.inner());
}

let mut path_indexes: Vec<usize> = Vec::with_capacity(column.len());
for i in 0..column.len() {
let val = column.get(i);
match val.as_u64() {
Ok(index) => path_indexes.push(index as usize),
Err(_) => {
return Err(ErrorCode::IllegalDataType(format!(
"Array column only support accessed by index, but got {}",
val
)))
}
}
}
Ok(path_indexes)
}
5 changes: 5 additions & 0 deletions common/functions/src/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use sqlparser::tokenizer::Tokenizer;

use crate::scalars::semi_structureds::array_get::ArrayGetFunction;
use crate::scalars::Function;
use crate::scalars::FunctionContext;
use crate::scalars::FunctionDescription;
Expand All @@ -43,6 +44,10 @@ impl<const BY_PATH: bool, const IGNORE_CASE: bool> GetFunctionImpl<BY_PATH, IGNO
let data_type = args[0];
let path_type = args[1];

if data_type.data_type_id().is_array() {
return ArrayGetFunction::try_create(display_name, args);
}

if (IGNORE_CASE
&& (!data_type.data_type_id().is_variant_or_object()
|| !path_type.data_type_id().is_string()))
Expand Down
2 changes: 2 additions & 0 deletions common/functions/src/scalars/semi_structureds/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod array_get;
mod check_json;
mod get;
mod json_extract_path_text;
mod parse_json;
mod semi_structured;

pub use array_get::ArrayGetFunction;
pub use check_json::CheckJsonFunction;
pub use get::GetFunction;
pub use get::GetIgnoreCaseFunction;
Expand Down
86 changes: 86 additions & 0 deletions common/functions/tests/it/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,89 @@ fn test_get_path_function() -> Result<()> {

test_scalar_functions("get_path", &tests)
}

#[test]
fn test_array_get_function() -> Result<()> {
let tests = vec![
ScalarFunctionTest {
name: "array_get_int64",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec![0_u32, 1_u32]),
],
expect: Series::from_data(vec![
Some(1_i64),
Some(4_i64),
Some(7_i64),
Some(2_i64),
Some(5_i64),
Some(8_i64),
]),
error: "",
},
ScalarFunctionTest {
name: "array_get_string",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![
"a1".as_bytes().into(),
"a2".as_bytes().into(),
"a3".as_bytes().into(),
]),
ArrayValue::new(vec![
"b1".as_bytes().into(),
"b2".as_bytes().into(),
"b3".as_bytes().into(),
]),
ArrayValue::new(vec![
"c1".as_bytes().into(),
"c2".as_bytes().into(),
"c3".as_bytes().into(),
]),
]),
Series::from_data(vec![0_u32, 1_u32]),
],
expect: Series::from_data(vec![
Some("a1"),
Some("b1"),
Some("c1"),
Some("a2"),
Some("b2"),
Some("c2"),
]),
error: "",
},
ScalarFunctionTest {
name: "array_get_out_of_bounds",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec![3_u32]),
],
expect: Series::from_data(vec![None::<&str>]),
error: "Index out of array column bounds: the len is 3 but the index is 3",
},
ScalarFunctionTest {
name: "array_get_error_type",
columns: vec![
Series::from_data(vec![
ArrayValue::new(vec![1_i64.into(), 2_i64.into(), 3_i64.into()]),
ArrayValue::new(vec![4_i64.into(), 5_i64.into(), 6_i64.into()]),
ArrayValue::new(vec![7_i64.into(), 8_i64.into(), 9_i64.into()]),
]),
Series::from_data(vec!["a", "b"]),
],
expect: Series::from_data(vec![None::<&str>]),
error: "Invalid argument types for function 'GET': (Array, String)",
},
];

test_scalar_functions("get", &tests)
}
2 changes: 1 addition & 1 deletion common/planners/src/plan_expression_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl ExpressionChain {

let arg_types2: Vec<&DataTypeImpl> = arg_types.iter().collect();

let func_name = "get_path";
let func_name = "get";
let func = FunctionFactory::instance().get(func_name, &arg_types2)?;
let return_type = func.return_type();

Expand Down
2 changes: 1 addition & 1 deletion common/planners/src/plan_expression_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ impl ExpressionVisitor for ExpressionDataTypeVisitor {
self.stack.push(inner_type.clone());
Ok(self)
}
Expression::MapAccess { args, .. } => self.visit_function("get_path", args.len()),
Expression::MapAccess { args, .. } => self.visit_function("get", args.len()),
Expression::Alias(_, _) | Expression::Sort { .. } => Ok(self),
}
}
Expand Down
52 changes: 27 additions & 25 deletions query/src/sql/statements/analyzer_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,37 +385,39 @@ impl ExpressionAnalyzer {
"MapAccess operator must be one children.",
)),
Some(inner_expr) => {
let path_name: String = keys
.iter()
.enumerate()
.map(|(i, k)| match k {
k @ Value::Number(_, _) => format!("[{}]", k),
Value::SingleQuotedString(s) => format!("[\"{}\"]", s),
Value::ColonString(s) => {
let key = if i == 0 {
s.to_string()
} else {
format!(":{}", s)
};
key
}
Value::PeriodString(s) => format!(".{}", s),
_ => format!("[{}]", k),
})
.collect();

let name = match keys[0] {
Value::ColonString(_) => format!("{}:{}", inner_expr.column_name(), path_name),
_ => format!("{}{}", inner_expr.column_name(), path_name),
let name = match &keys[0] {
k @ Value::Number(_, _) => format!("{}[{}]", inner_expr.column_name(), k),
Value::SingleQuotedString(s) => {
format!("{}['{}']", inner_expr.column_name(), s)
}
Value::DoubleQuotedString(s) => {
format!("{}[\"{}\"]", inner_expr.column_name(), s)
}
Value::ColonString(s) => format!("{}:{}", inner_expr.column_name(), s),
Value::PeriodString(s) => format!("{}.{}", inner_expr.column_name(), s),
_ => format!("{}[{}]", inner_expr.column_name(), keys[0]),
};
let path =
Expression::create_literal(DataValue::String(path_name.as_bytes().to_vec()));
let arguments = vec![inner_expr, path];

let path_expr = match &keys[0] {
Value::Number(value, _) => Expression::create_literal(
DataValue::try_from_literal(value, None).unwrap(),
),
Value::SingleQuotedString(s)
| Value::DoubleQuotedString(s)
| Value::ColonString(s)
| Value::PeriodString(s) => Expression::create_literal(s.as_bytes().into()),
_ => Expression::create_literal(keys[0].to_string().as_bytes().into()),
};
let arguments = vec![inner_expr, path_expr];

args.push(Expression::MapAccess {
name,
args: arguments,
});
// convert map access v[0][1] to function get(get(v, 0), 1)
if keys.len() > 1 {
self.analyze_map_access(&keys[1..], args)?;
}
Ok(())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ NULL
2 2
1 NULL
2 2
==get from array table==
1 10
2 50
1 20
2 60
Loading

0 comments on commit 22786cf

Please sign in to comment.