Skip to content

Commit

Permalink
use FactoryCreatorWithTypes for semi_structureds functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyass committed Apr 10, 2022
1 parent ffda6ec commit 6f831c2
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 183 deletions.
24 changes: 18 additions & 6 deletions common/functions/src/scalars/function_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl Function for FunctionAdapter {
.collect::<Vec<_>>();

let col = self.eval(&columns, 1)?;
let col = if col.is_const() && col.len() == 1 {
let col = if col.is_const() && col.len() != input_rows {
col.replicate(&[input_rows])
} else if col.is_null() {
NullColumn::new(input_rows).arc()
Expand Down Expand Up @@ -171,7 +171,7 @@ impl Function for FunctionAdapter {
if is_all_null {
// If only null, return null directly.
let args = columns.iter().map(|v| v.data_type()).collect::<Vec<_>>();
let inner_type = inner.return_type(args.as_slice())?;
let inner_type = remove_nullable(&inner.return_type(args.as_slice())?);
return Ok(Arc::new(NullableColumn::new(
inner_type
.create_constant_column(&inner_type.default_value(), input_rows)?,
Expand Down Expand Up @@ -206,12 +206,24 @@ impl Function for FunctionAdapter {
});

let col = if col.is_nullable() {
let nullable_column: &NullableColumn = Series::check_get(&col)?;
NullableColumn::new(nullable_column.inner().clone(), validity)
// Constant(Nullable(column))?
if col.is_const() {
let c: &ConstColumn = unsafe { Series::static_cast(&col) };
let nullable_column: &NullableColumn =
unsafe { Series::static_cast(c.inner()) };
ConstColumn::new(
NullableColumn::new(nullable_column.inner().clone(), validity).arc(),
input_rows,
)
.arc()
} else {
let nullable_column: &NullableColumn = Series::check_get(&col)?;
NullableColumn::new(nullable_column.inner().clone(), validity).arc()
}
} else {
NullableColumn::new(col, validity)
NullableColumn::new(col, validity).arc()
};
return Ok(Arc::new(col));
return Ok(col);
}
}

Expand Down
37 changes: 8 additions & 29 deletions common/functions/src/scalars/others/inet_aton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,23 @@ pub type InetAtonFunction = InetAtonFunctionImpl<false>;
#[derive(Clone)]
pub struct InetAtonFunctionImpl<const SUPPRESS_PARSE_ERROR: bool> {
display_name: String,
result_type: DataTypePtr,
}

impl<const SUPPRESS_PARSE_ERROR: bool> InetAtonFunctionImpl<SUPPRESS_PARSE_ERROR> {
pub fn try_create(
display_name: &str,
args: &[&common_datavalues::DataTypePtr],
) -> Result<Box<dyn Function>> {
let result_type = if SUPPRESS_PARSE_ERROR {
let input_type = remove_nullable(args[0]);
match input_type.data_type_id() {
TypeID::Null => NullType::arc(),
// For invalid input, we suppress parse error and return null. So the return type must be nullable.
TypeID::String => NullableType::arc(UInt32Type::arc()),
_ => {
return Err(ErrorCode::IllegalDataType(format!(
"Expected string or null type, but got {}",
args[0].name()
)))
}
}
} else {
assert_string(args[0])?;
UInt32Type::arc()
};
assert_string(args[0])?;

Ok(Box::new(InetAtonFunctionImpl::<SUPPRESS_PARSE_ERROR> {
display_name: display_name.to_string(),
result_type,
}))
}

pub fn desc() -> TypedFunctionDescription {
let mut features = FunctionFeatures::default().deterministic().num_arguments(1);
if SUPPRESS_PARSE_ERROR {
features = features.disable_passthrough_null()
}
TypedFunctionDescription::creator(Box::new(Self::try_create)).features(features)
TypedFunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(1))
}
}

Expand All @@ -81,14 +60,14 @@ impl<const SUPPRESS_PARSE_ERROR: bool> Function for InetAtonFunctionImpl<SUPPRES
}

fn return_type(&self, _args: &[&DataTypePtr]) -> Result<DataTypePtr> {
Ok(self.result_type.clone())
if SUPPRESS_PARSE_ERROR {
Ok(NullableType::arc(UInt32Type::arc()))
} else {
Ok(UInt32Type::arc())
}
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
if columns[0].column().data_type_id() == TypeID::Null {
return NullType::arc().create_constant_column(&DataValue::Null, input_rows);
}

let viewer = Vu8::try_create_viewer(columns[0].column())?;
let viewer_iter = viewer.iter();

Expand Down
33 changes: 8 additions & 25 deletions common/functions/src/scalars/others/inet_ntoa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,20 @@ pub type InetNtoaFunction = InetNtoaFunctionImpl<false>;
#[derive(Clone)]
pub struct InetNtoaFunctionImpl<const SUPPRESS_CAST_ERROR: bool> {
display_name: String,
result_type: DataTypePtr,
}

impl<const SUPPRESS_CAST_ERROR: bool> InetNtoaFunctionImpl<SUPPRESS_CAST_ERROR> {
pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result<Box<dyn Function>> {
let result_type = if SUPPRESS_CAST_ERROR {
if args[0].data_type_id() == TypeID::Null {
NullType::arc()
} else {
let input_type = remove_nullable(args[0]);
assert_numeric(&input_type)?;
// For invalid input, the function should return null. So the return type must be nullable.
NullableType::arc(StringType::arc())
}
} else {
assert_numeric(args[0])?;
StringType::arc()
};
assert_numeric(args[0])?;

Ok(Box::new(InetNtoaFunctionImpl::<SUPPRESS_CAST_ERROR> {
display_name: display_name.to_string(),
result_type,
}))
}

pub fn desc() -> TypedFunctionDescription {
let mut features = FunctionFeatures::default().deterministic().num_arguments(1);
if SUPPRESS_CAST_ERROR {
features = features.disable_passthrough_null()
}
TypedFunctionDescription::creator(Box::new(Self::try_create)).features(features)
TypedFunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(1))
}
}

Expand All @@ -78,14 +61,14 @@ impl<const SUPPRESS_CAST_ERROR: bool> Function for InetNtoaFunctionImpl<SUPPRESS
}

fn return_type(&self, _args: &[&DataTypePtr]) -> Result<DataTypePtr> {
Ok(self.result_type.clone())
if SUPPRESS_CAST_ERROR {
Ok(NullableType::arc(StringType::arc()))
} else {
Ok(StringType::arc())
}
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
if columns[0].column().data_type_id() == TypeID::Null {
return NullType::arc().create_constant_column(&DataValue::Null, input_rows);
}

if SUPPRESS_CAST_ERROR {
let cast_to: DataTypePtr = Arc::new(NullableType::create(UInt32Type::arc()));
let cast_options = CastOptions {
Expand Down
45 changes: 9 additions & 36 deletions common/functions/src/scalars/semi_structureds/check_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,30 @@
use std::fmt;
use std::sync::Arc;

use common_arrow::arrow::bitmap::Bitmap;
use common_datavalues::prelude::*;
use common_exception::ErrorCode;
use common_exception::Result;
use serde_json::Value as JsonValue;

use crate::scalars::Function;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;
use crate::scalars::TypedFunctionDescription;

#[derive(Clone)]
pub struct CheckJsonFunction {
display_name: String,
}

impl CheckJsonFunction {
pub fn try_create(display_name: &str) -> Result<Box<dyn Function>> {
pub fn try_create(display_name: &str, _args: &[&DataTypePtr]) -> Result<Box<dyn Function>> {
Ok(Box::new(CheckJsonFunction {
display_name: display_name.to_string(),
}))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create)).features(
FunctionFeatures::default()
.deterministic()
.monotonicity()
.num_arguments(1),
)
pub fn desc() -> TypedFunctionDescription {
TypedFunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(1))
}
}

Expand All @@ -52,28 +47,13 @@ impl Function for CheckJsonFunction {
&*self.display_name
}

fn return_type(&self, args: &[&DataTypePtr]) -> Result<DataTypePtr> {
if args[0].data_type_id() == TypeID::Null {
return Ok(NullType::arc());
}

fn return_type(&self, _args: &[&DataTypePtr]) -> Result<DataTypePtr> {
Ok(Arc::new(NullableType::create(StringType::arc())))
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
let data_type = remove_nullable(columns[0].field().data_type());
let mut column = columns[0].column();
let mut _all_null = false;
let mut source_valids: Option<&Bitmap> = None;
if column.is_nullable() {
(_all_null, source_valids) = column.validity();
let nullable_column: &NullableColumn = Series::check_get(column)?;
column = nullable_column.inner();
}

if data_type.data_type_id() == TypeID::Null {
return NullType::arc().create_constant_column(&DataValue::Null, input_rows);
}
let data_type = columns[0].field().data_type();
let column = columns[0].column();

let mut builder = NullableColumnBuilder::<Vu8>::with_capacity(input_rows);

Expand All @@ -83,14 +63,7 @@ impl Function for CheckJsonFunction {
}
} else if data_type.data_type_id() == TypeID::String {
let c: &StringColumn = Series::check_get(column)?;
for (i, v) in c.iter().enumerate() {
if let Some(source_valids) = source_valids {
if !source_valids.get_bit(i) {
builder.append_null();
continue;
}
}

for v in c.iter() {
match std::str::from_utf8(v) {
Ok(v) => match serde_json::from_str::<JsonValue>(v) {
Ok(_v) => builder.append_null(),
Expand Down
47 changes: 23 additions & 24 deletions common/functions/src/scalars/semi_structureds/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::fmt;
use std::sync::Arc;

use common_datavalues::prelude::*;
use common_exception::ErrorCode;
Expand All @@ -25,8 +24,8 @@ use sqlparser::parser::Parser;
use sqlparser::tokenizer::Tokenizer;

use crate::scalars::Function;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;
use crate::scalars::TypedFunctionDescription;

pub type GetFunction = GetFunctionImpl<false, false>;

Expand All @@ -40,26 +39,7 @@ pub struct GetFunctionImpl<const BY_PATH: bool, const IGNORE_CASE: bool> {
}

impl<const BY_PATH: bool, const IGNORE_CASE: bool> GetFunctionImpl<BY_PATH, IGNORE_CASE> {
pub fn try_create(display_name: &str) -> Result<Box<dyn Function>> {
Ok(Box::new(GetFunctionImpl::<BY_PATH, IGNORE_CASE> {
display_name: display_name.to_string(),
}))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(2))
}
}

impl<const BY_PATH: bool, const IGNORE_CASE: bool> Function
for GetFunctionImpl<BY_PATH, IGNORE_CASE>
{
fn name(&self) -> &str {
&*self.display_name
}

fn return_type(&self, args: &[&DataTypePtr]) -> Result<DataTypePtr> {
pub fn try_create(display_name: &str, args: &[&DataTypePtr]) -> Result<Box<dyn Function>> {
let data_type = args[0];
let path_type = args[1];

Expand All @@ -75,13 +55,32 @@ impl<const BY_PATH: bool, const IGNORE_CASE: bool> Function
{
return Err(ErrorCode::IllegalDataType(format!(
"Invalid argument types for function '{}': ({:?}, {:?})",
self.display_name.to_uppercase(),
display_name.to_uppercase(),
data_type,
path_type
)));
}

Ok(Arc::new(NullableType::create(VariantType::arc())))
Ok(Box::new(GetFunctionImpl::<BY_PATH, IGNORE_CASE> {
display_name: display_name.to_string(),
}))
}

pub fn desc() -> TypedFunctionDescription {
TypedFunctionDescription::creator(Box::new(Self::try_create))
.features(FunctionFeatures::default().deterministic().num_arguments(2))
}
}

impl<const BY_PATH: bool, const IGNORE_CASE: bool> Function
for GetFunctionImpl<BY_PATH, IGNORE_CASE>
{
fn name(&self) -> &str {
&*self.display_name
}

fn return_type(&self, _args: &[&DataTypePtr]) -> Result<DataTypePtr> {
Ok(NullableType::arc(VariantType::arc()))
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
Expand Down
Loading

0 comments on commit 6f831c2

Please sign in to comment.