Skip to content

Commit

Permalink
Improve ScalarUDFImpl docs
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 23, 2025
1 parent 0228bee commit 02ba5ad
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ impl ScalarUDF {
self.inner.return_type_from_exprs(args, schema, arg_types)
}

/// Return the datatype this function returns given the input argument types.
///
/// See [`ScalarUDFImpl::return_type_from_args`] for more details.
pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
self.inner.return_type_from_args(args)
}
Expand Down Expand Up @@ -433,7 +437,6 @@ impl ReturnInfo {
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
///
/// /// This struct for a simple UDF that adds one to an int32
/// #[derive(Debug)]
/// struct AddOne {
Expand Down Expand Up @@ -494,7 +497,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Returns this function's name
fn name(&self) -> &str;

/// Returns the user-defined display name of the UDF given the arguments
/// Returns the user-defined display name of function, given the arguments
///
/// This can be used to customize the output column name generated by this
/// function.
///
/// Defaults to `name(args[0], args[1], ...)`
fn display_name(&self, args: &[Expr]) -> Result<String> {
let names: Vec<String> = args.iter().map(ToString::to_string).collect();
// TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
Expand Down Expand Up @@ -522,7 +530,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// # Notes
///
/// If you provide an implementation for [`Self::return_type_from_args`],
/// DataFusion will not call `return_type` (this function). In this case it
/// DataFusion will not call `return_type` (this function). In such cases
/// is recommended to return [`DataFusionError::Internal`].
///
/// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
Expand All @@ -538,18 +546,24 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
self.return_type(arg_types)
}

/// What [`DataType`] will be returned by this function, given the
/// arguments?
///
/// Note most UDFs should implement [`Self::return_type`] and not this
/// function. The output type for most functions only depends on the types
/// of their inputs (e.g. `sqrt(f32)` is always `f32`).
/// What type will be returned by this function, given the arguments?
///
/// By default, this function calls [`Self::return_type`] with the
/// types of each argument.
///
/// This method can be overridden for functions that return different
/// *types* based on the *values* of their arguments.
/// # Notes
///
/// Most UDFs should implement [`Self::return_type`] and not this
/// function as the output type for most functions only depends on the types
/// of their inputs (e.g. `sqrt(f32)` is always `f32`).
///
/// This function can be used for more advanced cases such as:
///
/// 1. specifying nullability
/// 2. return types based on the **values** of the arguments (rather than
/// their **types**.
///
/// # Output Type based on Values
///
/// For example, the following two function calls get the same argument
/// types (something and a `Utf8` string) but return different types based
Expand All @@ -558,9 +572,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// * `arrow_cast(x, 'Int16')` --> `Int16`
/// * `arrow_cast(x, 'Float32')` --> `Float32`
///
/// # Notes:
/// # Requirements
///
/// This function must consistently return the same type for the same
/// This function **must** consistently return the same type for the same
/// logical input even if the input is simplified (e.g. it must return the same
/// value for `('foo' | 'bar')` as it does for ('foobar').
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
Expand Down

0 comments on commit 02ba5ad

Please sign in to comment.