Skip to content

Commit

Permalink
refactor: DfUdfAdapter to bridge ScalaUdf (#3814)
Browse files Browse the repository at this point in the history
* refactor: DfUdfAdapter to bridge ScalaUdf

Signed-off-by: tison <wander4096@gmail.com>

* tidy impl

Signed-off-by: tison <wander4096@gmail.com>

* for more

Signed-off-by: tison <wander4096@gmail.com>

* for more

Signed-off-by: tison <wander4096@gmail.com>

* for more

Signed-off-by: tison <wander4096@gmail.com>

---------

Signed-off-by: tison <wander4096@gmail.com>
  • Loading branch information
tisonkun authored Apr 28, 2024
1 parent ed8b136 commit e154dc5
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 190 deletions.
14 changes: 5 additions & 9 deletions src/common/macro/src/range_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,12 @@ fn build_struct(
}

pub fn scalar_udf() -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
datafusion_expr::create_udf(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(Self::calc) as _),
Self::input_type(),
Arc::new(Self::return_type()),
Volatility::Immutable,
Arc::new(Self::calc) as _,
)
}

Expand Down
1 change: 1 addition & 0 deletions src/common/query/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub use self::udf::ScalarUdf;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
use crate::logical_plan::accumulator::*;
use crate::signature::{Signature, Volatility};

/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:
Expand Down
113 changes: 53 additions & 60 deletions src/common/query/src/logical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,74 +91,67 @@ impl AggregateFunction {
}
}

impl From<AggregateFunction> for DfAggregateUdf {
fn from(udaf: AggregateFunction) -> Self {
struct DfUdafAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type_func: datafusion_expr::ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
creator: AggregateFunctionCreatorRef,
}
struct DfUdafAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type_func: datafusion_expr::ReturnTypeFunction,
accumulator: AccumulatorFactoryFunction,
creator: AggregateFunctionCreatorRef,
}

impl Debug for DfUdafAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("DfUdafAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}
impl Debug for DfUdafAdapter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("DfUdafAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}

impl AggregateUDFImpl for DfUdafAdapter {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}

fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
(self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
name: &str,
_value_type: ArrowDataType,
_ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.creator
.state_types()
.map(|x| {
(0..x.len())
.zip(x)
.map(|(i, t)| {
Field::new(format!("{}_{}", name, i), t.as_arrow_type(), true)
})
.collect::<Vec<_>>()
})
.map_err(|e| e.into())
}
}
impl AggregateUDFImpl for DfUdafAdapter {
fn as_any(&self) -> &dyn Any {
self
}

DfUdafAdapter {
fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}

fn return_type(&self, arg_types: &[ArrowDataType]) -> Result<ArrowDataType> {
(self.return_type_func)(arg_types).map(|x| x.as_ref().clone())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(acc_args)
}

fn state_fields(&self, name: &str, _: ArrowDataType, _: Vec<Field>) -> Result<Vec<Field>> {
let state_types = self.creator.state_types()?;
let fields = state_types
.into_iter()
.enumerate()
.map(|(i, t)| {
let name = format!("{name}_{i}");
Field::new(name, t.as_arrow_type(), true)
})
.collect::<Vec<_>>();
Ok(fields)
}
}

impl From<AggregateFunction> for DfAggregateUdf {
fn from(udaf: AggregateFunction) -> Self {
DfAggregateUdf::new_from_impl(DfUdafAdapter {
name: udaf.name,
signature: udaf.signature.into(),
return_type_func: to_df_return_type(udaf.return_type),
accumulator: to_df_accumulator_func(udaf.accumulator, udaf.creator.clone()),
creator: udaf.creator,
}
.into()
})
}
}

Expand Down
60 changes: 49 additions & 11 deletions src/common/query/src/logical_plan/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

//! Udf module contains foundational types that are used to represent UDFs.
//! It's modified from datafusion.
use std::any::Any;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use datafusion_expr::{
ColumnarValue as DfColumnarValue,
ScalarFunctionImplementation as DfScalarFunctionImplementation, ScalarUDF as DfScalarUDF,
ScalarUDFImpl,
};
use datatypes::arrow::datatypes::DataType;

use crate::error::Result;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
Expand Down Expand Up @@ -68,25 +71,60 @@ impl ScalarUdf {
}
}

#[derive(Clone)]
struct DfUdfAdapter {
name: String,
signature: datafusion_expr::Signature,
return_type: datafusion_expr::ReturnTypeFunction,
fun: DfScalarFunctionImplementation,
}

impl Debug for DfUdfAdapter {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("DfUdfAdapter")
.field("name", &self.name)
.field("signature", &self.signature)
.finish()
}
}

impl ScalarUDFImpl for DfUdfAdapter {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
(self.return_type)(arg_types).map(|ty| ty.as_ref().clone())
}

fn invoke(&self, args: &[DfColumnarValue]) -> datafusion_common::Result<DfColumnarValue> {
(self.fun)(args)
}
}

impl From<ScalarUdf> for DfScalarUDF {
fn from(udf: ScalarUdf) -> Self {
// TODO(LFC): remove deprecated
#[allow(deprecated)]
DfScalarUDF::new(
&udf.name,
&udf.signature.into(),
&to_df_return_type(udf.return_type),
&to_df_scalar_func(udf.fun),
)
DfScalarUDF::new_from_impl(DfUdfAdapter {
name: udf.name,
signature: udf.signature.into(),
return_type: to_df_return_type(udf.return_type),
fun: to_df_scalar_func(udf.fun),
})
}
}

fn to_df_scalar_func(fun: ScalarFunctionImplementation) -> DfScalarFunctionImplementation {
Arc::new(move |args: &[DfColumnarValue]| {
let args: Result<Vec<_>> = args.iter().map(TryFrom::try_from).collect();

let result = (fun)(&args?);

let result = fun(&args?);
result.map(From::from).map_err(|e| e.into())
})
}
2 changes: 1 addition & 1 deletion src/promql/src/functions/aggr_over_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::compute;
Expand Down
2 changes: 1 addition & 1 deletion src/promql/src/functions/changes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
Expand Down
2 changes: 1 addition & 1 deletion src/promql/src/functions/deriv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use common_macro::range_fn;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;
Expand Down
55 changes: 15 additions & 40 deletions src/promql/src/functions/extrapolate_rate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ use std::sync::Arc;
use datafusion::arrow::array::{Float64Array, TimestampMillisecondArray};
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::logical_expr::{ScalarUDF, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datafusion_expr::create_udf;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;

Expand All @@ -62,19 +63,23 @@ impl<const IS_COUNTER: bool, const IS_RATE: bool> ExtrapolatedRate<IS_COUNTER, I
Self { range_length }
}

fn input_type() -> Vec<DataType> {
vec![
fn scalar_udf_with_name(name: &str, range_length: i64) -> ScalarUDF {
let input_types = vec![
// timestamp range vector
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
// value range vector
RangeArray::convert_data_type(DataType::Float64),
// timestamp vector
DataType::Timestamp(TimeUnit::Millisecond, None),
]
}
];

fn return_type() -> DataType {
DataType::Float64
create_udf(
name,
input_types,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _,
)
}

fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
Expand Down Expand Up @@ -204,17 +209,7 @@ impl ExtrapolatedRate<false, false> {
}

pub fn scalar_udf(range_length: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _),
)
Self::scalar_udf_with_name(Self::name(), range_length)
}
}

Expand All @@ -225,17 +220,7 @@ impl ExtrapolatedRate<true, true> {
}

pub fn scalar_udf(range_length: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _),
)
Self::scalar_udf_with_name(Self::name(), range_length)
}
}

Expand All @@ -246,17 +231,7 @@ impl ExtrapolatedRate<true, false> {
}

pub fn scalar_udf(range_length: i64) -> ScalarUDF {
// TODO(LFC): Use the new Datafusion UDF impl.
#[allow(deprecated)]
ScalarUDF::new(
Self::name(),
&Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
&(Arc::new(|_: &_| Ok(Arc::new(Self::return_type()))) as _),
&(Arc::new(move |input: &_| Self::new(range_length).calc(input)) as _),
)
Self::scalar_udf_with_name(Self::name(), range_length)
}
}

Expand Down
Loading

0 comments on commit e154dc5

Please sign in to comment.