Skip to content

Commit

Permalink
refactor: move asin() to function crate (#9379)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveLauC authored Feb 28, 2024
1 parent 5f90ead commit 96abac8
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 23 deletions.
10 changes: 2 additions & 8 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ use strum_macros::EnumIter;
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
pub enum BuiltinScalarFunction {
// math functions
/// asin
Asin,
/// atan
Atan,
/// atan2
Expand Down Expand Up @@ -359,7 +357,6 @@ impl BuiltinScalarFunction {
pub fn volatility(&self) -> Volatility {
match self {
// Immutable scalar builtins
BuiltinScalarFunction::Asin => Volatility::Immutable,
BuiltinScalarFunction::Atan => Volatility::Immutable,
BuiltinScalarFunction::Atan2 => Volatility::Immutable,
BuiltinScalarFunction::Acosh => Volatility::Immutable,
Expand Down Expand Up @@ -858,8 +855,7 @@ impl BuiltinScalarFunction {
utf8_to_int_type(&input_expr_types[0], "levenshtein")
}

BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
| BuiltinScalarFunction::Atanh
Expand Down Expand Up @@ -1321,8 +1317,7 @@ impl BuiltinScalarFunction {
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
self.volatility(),
),
BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
| BuiltinScalarFunction::Atanh
Expand Down Expand Up @@ -1413,7 +1408,6 @@ impl BuiltinScalarFunction {
pub fn aliases(&self) -> &'static [&'static str] {
match self {
BuiltinScalarFunction::Acosh => &["acosh"],
BuiltinScalarFunction::Asin => &["asin"],
BuiltinScalarFunction::Asinh => &["asinh"],
BuiltinScalarFunction::Atan => &["atan"],
BuiltinScalarFunction::Atanh => &["atanh"],
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,6 @@ scalar_expr!(Cot, cot, num, "cotangent");
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
scalar_expr!(Asin, asin, num, "inverse sine");
scalar_expr!(Atan, atan, num, "inverse tangent");
scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine");
scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine");
Expand Down Expand Up @@ -1332,7 +1331,6 @@ mod test {
test_unary_scalar_expr!(Sinh, sinh);
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Tanh, tanh);
test_unary_scalar_expr!(Asin, asin);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Asinh, asinh);
test_unary_scalar_expr!(Acosh, acosh);
Expand Down
110 changes: 110 additions & 0 deletions datafusion/functions/src/math/asin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Math function: `asin()`.
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{
utils::generate_signature_error_msg, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
pub struct AsinFunc {
signature: Signature,
}

impl AsinFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::uniform(
1,
vec![Float64, Float32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for AsinFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"asin"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}

let arg_type = &arg_types[0];

match arg_type {
DataType::Float64 => Ok(DataType::Float64),
DataType::Float32 => Ok(DataType::Float32),

// For other types (possible values null/int), use Float 64
_ => Ok(DataType::Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float64Array,
Float64Array,
{ f64::asin }
)),
DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float32Array,
Float32Array,
{ f32::asin }
)),
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
)
}
};
Ok(ColumnarValue::Array(arr))
}
}
7 changes: 7 additions & 0 deletions datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
mod abs;
mod acos;
mod asin;
mod nans;

// create UDFs
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(abs::AbsFunc, ABS, abs);
make_udf_function!(acos::AcosFunc, ACOS, acos);
make_udf_function!(asin::AsinFunc, ASIN, asin);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
Expand All @@ -38,5 +40,10 @@ export_functions!(
acos,
num,
"returns the arc cosine or inverse cosine of a number"
),
(
asin,
num,
"returns the arc sine or inverse sine of a number"
)
);
1 change: 0 additions & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ pub fn create_physical_fun(
) -> Result<ScalarFunctionImplementation> {
Ok(match fun {
// math functions
BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin),
BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan),
BuiltinScalarFunction::Acosh => Arc::new(math_expressions::acosh),
BuiltinScalarFunction::Asinh => Arc::new(math_expressions::asinh),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ enum ScalarFunction {
// The first enum value must be zero for open enums
unknown = 0;
// 1 was Acos
Asin = 2;
// 2 was Asin
Atan = 3;
Ascii = 4;
Ceil = 5;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ use datafusion_expr::{
array_length, array_ndims, array_pop_back, array_pop_front, array_position,
array_positions, array_prepend, array_remove, array_remove_all, array_remove_n,
array_repeat, array_replace, array_replace_all, array_replace_n, array_resize,
array_slice, array_sort, array_union, arrow_typeof, ascii, asin, asinh, atan, atan2,
atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce,
array_slice, array_sort, array_union, arrow_typeof, ascii, asinh, atan, atan2, atanh,
bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce,
concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin,
date_part, date_trunc, degrees, digest, ends_with, exp,
expr::{self, InList, Sort, WindowFunction},
Expand Down Expand Up @@ -449,7 +449,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Cos => Self::Cos,
ScalarFunction::Tan => Self::Tan,
ScalarFunction::Cot => Self::Cot,
ScalarFunction::Asin => Self::Asin,
ScalarFunction::Atan => Self::Atan,
ScalarFunction::Sinh => Self::Sinh,
ScalarFunction::Cosh => Self::Cosh,
Expand Down Expand Up @@ -1359,7 +1358,6 @@ pub fn parse_expr(

match scalar_function {
ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")),
ScalarFunction::Asin => Ok(asin(parse_expr(&args[0], registry)?)),
ScalarFunction::Asinh => Ok(asinh(parse_expr(&args[0], registry)?)),
ScalarFunction::Acosh => Ok(acosh(parse_expr(&args[0], registry)?)),
ScalarFunction::Array => Ok(array(
Expand Down
1 change: 0 additions & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Sinh => Self::Sinh,
BuiltinScalarFunction::Cosh => Self::Cosh,
BuiltinScalarFunction::Tanh => Self::Tanh,
BuiltinScalarFunction::Asin => Self::Asin,
BuiltinScalarFunction::Atan => Self::Atan,
BuiltinScalarFunction::Asinh => Self::Asinh,
BuiltinScalarFunction::Acosh => Self::Acosh,
Expand Down

0 comments on commit 96abac8

Please sign in to comment.