Skip to content

Commit

Permalink
use FactoryCreatorWithTypes for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyass committed Apr 11, 2022
1 parent b251533 commit 77294eb
Show file tree
Hide file tree
Showing 17 changed files with 155 additions and 210 deletions.
4 changes: 2 additions & 2 deletions common/functions/src/scalars/conditionals/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl ConditionalFunction {
factory.register_typed("if", IfFunction::desc());
factory.register_typed("isNull", IsNullFunction::desc());
factory.register_typed("isNotNull", IsNotNullFunction::desc());
factory.register("in", InFunction::<false>::desc());
factory.register("not_in", InFunction::<true>::desc());
factory.register_typed("in", InFunction::<false>::desc());
factory.register_typed("not_in", InFunction::<true>::desc());
}
}
68 changes: 27 additions & 41 deletions common/functions/src/scalars/conditionals/in_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,36 @@ use ordered_float::OrderedFloat;

use crate::scalars::cast_column_field;
use crate::scalars::Function;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;
use crate::scalars::TypedFunctionDescription;

#[derive(Clone)]
pub struct InFunction<const NEGATED: bool>;
pub struct InFunction<const NEGATED: bool> {
is_null: bool,
}

impl<const NEGATED: bool> InFunction<NEGATED> {
pub fn try_create(_display_name: &str) -> Result<Box<dyn Function>> {
Ok(Box::new(InFunction::<NEGATED> {}))
pub fn try_create(_display_name: &str, args: &[&DataTypePtr]) -> Result<Box<dyn Function>> {
for dt in args {
let type_id = remove_nullable(dt).data_type_id();
if type_id.is_date_or_date_time()
|| type_id.is_interval()
|| type_id.is_array()
|| type_id.is_struct()
{
return Err(ErrorCode::UnexpectedError(format!(
"{} type is not supported for IN now",
type_id
)));
}
}

let is_null = args[0].data_type_id() == TypeID::Null;
Ok(Box::new(InFunction::<NEGATED> { is_null }))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create)).features(
pub fn desc() -> TypedFunctionDescription {
TypedFunctionDescription::creator(Box::new(Self::try_create)).features(
FunctionFeatures::default()
.bool_function()
.disable_passthrough_null()
Expand Down Expand Up @@ -93,46 +110,15 @@ impl<const NEGATED: bool> Function for InFunction<NEGATED> {
"InFunction"
}

fn return_type(&self, args: &[&DataTypePtr]) -> Result<DataTypePtr> {
for dt in args {
let type_id = remove_nullable(dt).data_type_id();
if type_id.is_date_or_date_time()
|| type_id.is_interval()
|| type_id.is_array()
|| type_id.is_struct()
{
return Err(ErrorCode::UnexpectedError(format!(
"{} type is not supported for IN now",
type_id
)));
}
}
let input_dt = remove_nullable(args[0]).data_type_id();
if input_dt == TypeID::Null {
fn return_type(&self, _args: &[&DataTypePtr]) -> Result<DataTypePtr> {
if self.is_null {
return Ok(NullType::arc());
}
Ok(BooleanType::arc())
}

fn eval(&self, columns: &ColumnsWithField, input_rows: usize) -> Result<ColumnRef> {
for col in columns {
let dt = col.column().data_type();
let type_id = remove_nullable(&dt).data_type_id();
if type_id.is_date_or_date_time()
|| type_id.is_interval()
|| type_id.is_array()
|| type_id.is_struct()
{
return Err(ErrorCode::UnexpectedError(format!(
"{} type is not supported for IN now",
type_id
)));
}
}

let input_col = &columns[0];
let input_dt = remove_nullable(input_col.data_type()).data_type_id();
if input_dt == TypeID::Null {
if self.is_null {
let col = NullType::arc().create_constant_column(&DataValue::Null, input_rows)?;
return Ok(col);
}
Expand All @@ -141,7 +127,7 @@ impl<const NEGATED: bool> Function for InFunction<NEGATED> {
let least_super_dt = aggregate_types(&types)?;
let least_super_type_id = remove_nullable(&least_super_dt).data_type_id();

let input_col = cast_column_field(input_col, &least_super_dt)?;
let input_col = cast_column_field(&columns[0], &least_super_dt)?;

match least_super_type_id {
TypeID::Boolean => {
Expand Down
27 changes: 8 additions & 19 deletions common/functions/src/scalars/function_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,6 @@ pub struct FunctionAdapter {
}

impl FunctionAdapter {
pub fn create(inner: Box<dyn Function>, passthrough_null: bool) -> Box<dyn Function> {
Box::new(Self {
inner: Some(inner),
passthrough_null,
})
}

pub fn create_some(
inner: Option<Box<dyn Function>>,
passthrough_null: bool,
) -> Box<dyn Function> {
Box::new(Self {
inner,
passthrough_null,
})
}

pub fn try_create_by_typed(
desc: &TypedFunctionDescription,
name: &str,
Expand All @@ -72,7 +55,10 @@ impl FunctionAdapter {
let inner = if passthrough_null {
// one is null, result is null
if args.iter().any(|v| v.data_type_id() == TypeID::Null) {
return Ok(Self::create_some(None, true));
return Ok(Box::new(Self {
inner: None,
passthrough_null: true,
}));
}
let types = args.iter().map(|v| remove_nullable(v)).collect::<Vec<_>>();
let types = types.iter().collect::<Vec<_>>();
Expand All @@ -81,7 +67,10 @@ impl FunctionAdapter {
(desc.typed_function_creator)(name, args)?
};

Ok(Self::create(inner, passthrough_null))
Ok(Box::new(Self {
inner: Some(inner),
passthrough_null,
}))
}
}

Expand Down
81 changes: 15 additions & 66 deletions common/functions/src/scalars/function_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,9 @@ use super::TupleClassFunction;
use crate::scalars::DateFunction;
use crate::scalars::UUIDFunction;

pub type FactoryCreator = Box<dyn Fn(&str) -> Result<Box<dyn Function>> + Send + Sync>;

pub type FactoryCreatorWithTypes =
Box<dyn Fn(&str, &[&DataTypePtr]) -> Result<Box<dyn Function>> + Send + Sync>;

pub struct FunctionDescription {
pub(crate) features: FunctionFeatures,
function_creator: FactoryCreator,
}

impl FunctionDescription {
pub fn creator(creator: FactoryCreator) -> FunctionDescription {
FunctionDescription {
function_creator: creator,
features: FunctionFeatures::default(),
}
}

#[must_use]
pub fn features(mut self, features: FunctionFeatures) -> FunctionDescription {
self.features = features;
self
}
}

pub struct TypedFunctionDescription {
pub(crate) features: FunctionFeatures,
pub typed_function_creator: FactoryCreatorWithTypes,
Expand All @@ -84,7 +62,6 @@ impl TypedFunctionDescription {
}

pub struct FunctionFactory {
case_insensitive_desc: HashMap<String, FunctionDescription>,
case_insensitive_typed_desc: HashMap<String, TypedFunctionDescription>,
}

Expand Down Expand Up @@ -112,7 +89,6 @@ static FUNCTION_FACTORY: Lazy<Arc<FunctionFactory>> = Lazy::new(|| {
impl FunctionFactory {
pub(in crate::scalars::function_factory) fn create() -> FunctionFactory {
FunctionFactory {
case_insensitive_desc: Default::default(),
case_insensitive_typed_desc: Default::default(),
}
}
Expand All @@ -121,11 +97,6 @@ impl FunctionFactory {
FUNCTION_FACTORY.as_ref()
}

pub fn register(&mut self, name: &str, desc: FunctionDescription) {
let case_insensitive_desc = &mut self.case_insensitive_desc;
case_insensitive_desc.insert(name.to_lowercase(), desc);
}

pub fn register_typed(&mut self, name: &str, desc: TypedFunctionDescription) {
let case_insensitive_typed_desc = &mut self.case_insensitive_typed_desc;
case_insensitive_typed_desc.insert(name.to_lowercase(), desc);
Expand All @@ -135,72 +106,50 @@ impl FunctionFactory {
let origin_name = name.as_ref();
let lowercase_name = origin_name.to_lowercase();

match self.case_insensitive_desc.get(&lowercase_name) {
// TODO(Winter): we should write similar function names into error message if function name is not found.
None => match self.case_insensitive_typed_desc.get(&lowercase_name) {
None => Err(ErrorCode::UnknownFunction(format!(
"Unsupported Function: {}",
origin_name
))),
Some(desc) => FunctionAdapter::try_create_by_typed(desc, origin_name, args),
},
Some(desc) => {
let inner = (desc.function_creator)(origin_name)?;
Ok(FunctionAdapter::create(
inner,
desc.features.passthrough_null,
))
}
// TODO(Winter): we should write similar function names into error message if function name is not found.
match self.case_insensitive_typed_desc.get(&lowercase_name) {
Some(desc) => FunctionAdapter::try_create_by_typed(desc, origin_name, args),
None => Err(ErrorCode::UnknownFunction(format!(
"Unsupported Function: {}",
origin_name
))),
}
}

pub fn get_features(&self, name: impl AsRef<str>) -> Result<FunctionFeatures> {
let origin_name = name.as_ref();
let lowercase_name = origin_name.to_lowercase();

match self.case_insensitive_desc.get(&lowercase_name) {
// TODO(Winter): we should write similar function names into error message if function name is not found.
None => match self.case_insensitive_typed_desc.get(&lowercase_name) {
None => Err(ErrorCode::UnknownFunction(format!(
"Unsupported Function: {}",
origin_name
))),
Some(desc) => Ok(desc.features.clone()),
},
// TODO(Winter): we should write similar function names into error message if function name is not found.
match self.case_insensitive_typed_desc.get(&lowercase_name) {
Some(desc) => Ok(desc.features.clone()),
None => Err(ErrorCode::UnknownFunction(format!(
"Unsupported Function: {}",
origin_name
))),
}
}

pub fn check(&self, name: impl AsRef<str>) -> bool {
let origin_name = name.as_ref();
let lowercase_name = origin_name.to_lowercase();

if self.case_insensitive_desc.contains_key(&lowercase_name) {
return true;
}
self.case_insensitive_typed_desc
.contains_key(&lowercase_name)
}

pub fn registered_names(&self) -> Vec<String> {
self.case_insensitive_desc
self.case_insensitive_typed_desc
.keys()
.chain(self.case_insensitive_typed_desc.keys())
.cloned()
.collect::<Vec<_>>()
}

pub fn registered_features(&self) -> Vec<FunctionFeatures> {
self.case_insensitive_desc
self.case_insensitive_typed_desc
.values()
.into_iter()
.map(|v| &v.features)
.chain(
self.case_insensitive_typed_desc
.values()
.into_iter()
.map(|v| &v.features),
)
.cloned()
.collect::<Vec<_>>()
}
Expand Down
47 changes: 22 additions & 25 deletions common/functions/src/scalars/tuples/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,40 @@
use std::fmt;
use std::sync::Arc;

use common_datavalues::DataTypePtr;
use common_datavalues::StructColumn;
use common_datavalues::StructType;
use common_exception::Result;

use crate::scalars::Function;
use crate::scalars::FunctionDescription;
use crate::scalars::FunctionFeatures;
use crate::scalars::TypedFunctionDescription;

#[derive(Clone)]
pub struct TupleFunction {
_display_name: String,
result_type: DataTypePtr,
}

impl TupleFunction {
pub fn try_create_func(_display_name: &str) -> Result<Box<dyn Function>> {
pub fn try_create_func(
_display_name: &str,
args: &[&common_datavalues::DataTypePtr],
) -> Result<Box<dyn Function>> {
let names = (0..args.len())
.map(|i| format!("item_{}", i))
.collect::<Vec<_>>();
let types = args.iter().map(|x| (*x).clone()).collect::<Vec<_>>();
let result_type = Arc::new(StructType::create(names, types));

Ok(Box::new(TupleFunction {
_display_name: "tuple".to_string(),
result_type,
}))
}

pub fn desc() -> FunctionDescription {
FunctionDescription::creator(Box::new(Self::try_create_func)).features(
pub fn desc() -> TypedFunctionDescription {
TypedFunctionDescription::creator(Box::new(Self::try_create_func)).features(
FunctionFeatures::default()
.deterministic()
.disable_passthrough_null()
Expand All @@ -52,36 +64,21 @@ impl Function for TupleFunction {

fn return_type(
&self,
args: &[&common_datavalues::DataTypePtr],
_args: &[&common_datavalues::DataTypePtr],
) -> Result<common_datavalues::DataTypePtr> {
let names = (0..args.len())
.map(|i| format!("item_{}", i))
.collect::<Vec<_>>();
let types = args.iter().map(|x| (*x).clone()).collect::<Vec<_>>();
let t = Arc::new(StructType::create(names, types));
Ok(t)
Ok(self.result_type.clone())
}

fn eval(
&self,
columns: &common_datavalues::ColumnsWithField,
_input_rows: usize,
) -> Result<common_datavalues::ColumnRef> {
let mut cols = vec![];
let mut types = vec![];

let names = (0..columns.len())
.map(|i| format!("item_{}", i))
let cols = columns
.iter()
.map(|v| v.column().clone())
.collect::<Vec<_>>();

for c in columns {
cols.push(c.column().clone());
types.push(c.data_type().clone());
}

let t = Arc::new(StructType::create(names, types));

let arr: StructColumn = StructColumn::from_data(cols, t);
let arr: StructColumn = StructColumn::from_data(cols, self.result_type.clone());
Ok(Arc::new(arr))
}
}
Expand Down
Loading

0 comments on commit 77294eb

Please sign in to comment.