diff --git a/ibis/expr/operations/udf.py b/ibis/expr/operations/udf.py index efb6a9295151..1097a51e55f6 100644 --- a/ibis/expr/operations/udf.py +++ b/ibis/expr/operations/udf.py @@ -78,33 +78,47 @@ def _make_node( input_type: InputType, name: str | None = None, schema: str | None = None, + signature: tuple[tuple, Any] | None = None, **kwargs, ) -> type[S]: """Construct a scalar user-defined function that is built-in to the backend.""" - annotations = typing.get_type_hints(fn) - if (return_annotation := annotations.pop("return", None)) is None: - raise exc.MissingReturnAnnotationError(fn) + if signature is None: + annotations = typing.get_type_hints(fn) + if (return_annotation := annotations.pop("return", None)) is None: + raise exc.MissingReturnAnnotationError(fn) + fields = { + arg_name: Argument( + pattern=rlz.ValueOf(annotations.get(arg_name)), + default=param.default, + typehint=annotations.get(arg_name, Any), + ) + for arg_name, param in inspect.signature(fn).parameters.items() + } + else: + arg_types, return_annotation = signature + arg_names = list(inspect.signature(fn).parameters) + fields = { + arg_name: Argument(pattern=rlz.ValueOf(typ), typehint=typ) + for arg_name, typ in zip(arg_names, arg_types) + } func_name = name if name is not None else fn.__name__ - fields = { - arg_name: Argument( - pattern=rlz.ValueOf(annotations.get(arg_name)), default=param.default - ) - for arg_name, param in inspect.signature(fn).parameters.items() - } | { - "dtype": dt.dtype(return_annotation), - "__input_type__": input_type, - # must wrap `fn` in a `property` otherwise `fn` is assumed to be a - # method - "__func__": property(fget=lambda _, fn=fn: fn), - "__config__": FrozenDict(kwargs), - "__udf_namespace__": schema, - "__module__": fn.__module__, - "__func_name__": func_name, - "__full_name__": ".".join(filter(None, (schema, func_name))), - } + fields.update( + { + "dtype": dt.dtype(return_annotation), + "__input_type__": input_type, + # must wrap `fn` in a `property` otherwise `fn` is assumed to be a + # method + "__func__": property(fget=lambda _, fn=fn: fn), + "__config__": FrozenDict(kwargs), + "__udf_namespace__": schema, + "__module__": fn.__module__, + "__func_name__": func_name, + "__full_name__": ".".join(filter(None, (schema, func_name))), + } + ) return type(fn.__name__, (cls._base,), fields) @@ -144,23 +158,30 @@ def builtin( *, name: str | None = None, schema: str | None = None, + signature: tuple[tuple[Any, ...], Any] | None = None, **kwargs: Any, ) -> Callable[[Callable], Callable[..., ir.Value]]: ... @util.experimental @classmethod - def builtin(cls, fn=None, *, name=None, schema=None, **kwargs): + def builtin(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs): """Construct a scalar user-defined function that is built-in to the backend. Parameters ---------- fn - The The function to wrap. + The function to wrap. name The name of the UDF in the backend if different from the function name. schema The schema in which the builtin function resides. + signature + An optional signature to use for the UDF. If present, should be a + tuple containing a tuple of argument types and a return type. For + example, a function taking an int and a float and returning a + string would be `((int, float), str)`. If not present, the argument + types will be derived from the wrapped function. kwargs Additional backend-specific configuration arguments for the UDF. @@ -181,6 +202,7 @@ def builtin(cls, fn=None, *, name=None, schema=None, **kwargs): fn, name=name, schema=schema, + signature=signature, **kwargs, ) @@ -196,13 +218,14 @@ def python( *, name: str | None = None, schema: str | None = None, + signature: tuple[tuple[Any, ...], Any] | None = None, **kwargs: Any, ) -> Callable[[Callable], Callable[..., ir.Value]]: ... @util.experimental @classmethod - def python(cls, fn=None, *, name=None, schema=None, **kwargs): + def python(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs): """Construct a **non-vectorized** scalar user-defined function that accepts Python scalar values as inputs. ::: {.callout-warning collapse="true"} @@ -221,11 +244,17 @@ def python(cls, fn=None, *, name=None, schema=None, **kwargs): Parameters ---------- fn - The The function to wrap. + The function to wrap. name The name of the UDF in the backend if different from the function name. schema The schema in which to create the UDF. + signature + An optional signature to use for the UDF. If present, should be a + tuple containing a tuple of argument types and a return type. For + example, a function taking an int and a float and returning a + string would be `((int, float), str)`. If not present, the argument + types will be derived from the wrapped function. kwargs Additional backend-specific configuration arguments for the UDF. @@ -251,6 +280,7 @@ def python(cls, fn=None, *, name=None, schema=None, **kwargs): fn, name=name, schema=schema, + signature=signature, **kwargs, ) @@ -266,23 +296,30 @@ def pandas( *, name: str | None = None, schema: str | None = None, + signature: tuple[tuple[Any, ...], Any] | None = None, **kwargs: Any, ) -> Callable[[Callable], Callable[..., ir.Value]]: ... @util.experimental @classmethod - def pandas(cls, fn=None, *, name=None, schema=None, **kwargs): + def pandas(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs): """Construct a **vectorized** scalar user-defined function that accepts pandas Series' as inputs. Parameters ---------- fn - The The function to wrap. + The function to wrap. name The name of the UDF in the backend if different from the function name. schema The schema in which to create the UDF. + signature + An optional signature to use for the UDF. If present, should be a + tuple containing a tuple of argument types and a return type. For + example, a function taking an int and a float and returning a + string would be `((int, float), str)`. If not present, the argument + types will be derived from the wrapped function. kwargs Additional backend-specific configuration arguments for the UDF. @@ -310,6 +347,7 @@ def pandas(cls, fn=None, *, name=None, schema=None, **kwargs): fn, name=name, schema=schema, + signature=signature, **kwargs, ) @@ -325,23 +363,30 @@ def pyarrow( *, name: str | None = None, schema: str | None = None, + signature: tuple[tuple[Any, ...], Any] | None = None, **kwargs: Any, ) -> Callable[[Callable], Callable[..., ir.Value]]: ... @util.experimental @classmethod - def pyarrow(cls, fn=None, *, name=None, schema=None, **kwargs): + def pyarrow(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs): """Construct a **vectorized** scalar user-defined function that accepts PyArrow Arrays as input. Parameters ---------- fn - The The function to wrap. + The function to wrap. name The name of the UDF in the backend if different from the function name. schema The schema in which to create the UDF. + signature + An optional signature to use for the UDF. If present, should be a + tuple containing a tuple of argument types and a return type. For + example, a function taking an int and a float and returning a + string would be `((int, float), str)`. If not present, the argument + types will be derived from the wrapped function. kwargs Additional backend-specific configuration arguments for the UDF. @@ -368,6 +413,7 @@ def pyarrow(cls, fn=None, *, name=None, schema=None, **kwargs): fn, name=name, schema=schema, + signature=signature, **kwargs, ) @@ -389,23 +435,30 @@ def builtin( *, name: str | None = None, schema: str | None = None, + signature: tuple[tuple[Any, ...], Any] | None = None, **kwargs: Any, ) -> Callable[[Callable], Callable[..., ir.Value]]: ... @util.experimental @classmethod - def builtin(cls, fn=None, *, name=None, schema=None, **kwargs): + def builtin(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs): """Construct an aggregate user-defined function that is built-in to the backend. Parameters ---------- fn - The The function to wrap. + The function to wrap. name The name of the UDF in the backend if different from the function name. schema The schema in which the builtin function resides. + signature + An optional signature to use for the UDF. If present, should be a + tuple containing a tuple of argument types and a return type. For + example, a function taking an int and a float and returning a + string would be `((int, float), str)`. If not present, the argument + types will be derived from the wrapped function. kwargs Additional backend-specific configuration arguments for the UDF. @@ -427,5 +480,6 @@ def builtin(cls, fn=None, *, name=None, schema=None, **kwargs): fn, name=name, schema=schema, + signature=signature, **kwargs, ) diff --git a/ibis/tests/expr/test_udf.py b/ibis/tests/expr/test_udf.py index 5d63528f1087..bf82bad00164 100644 --- a/ibis/tests/expr/test_udf.py +++ b/ibis/tests/expr/test_udf.py @@ -83,6 +83,50 @@ def test_vectorized_udf_operations(table, klass, output_type): ) +@pytest.mark.parametrize( + "dec", + [ + pytest.param(ibis.udf.scalar.builtin, id="scalar-builtin"), + pytest.param(ibis.udf.scalar.pandas, id="scalar-pandas"), + pytest.param(ibis.udf.scalar.pyarrow, id="scalar-pyarrow"), + pytest.param(ibis.udf.scalar.python, id="scalar-python"), + pytest.param(ibis.udf.agg.builtin, id="agg-builtin"), + ], +) +def test_udf_from_annotations(dec, table): + @dec + def myfunc(x: int, y: str) -> float: + ... + + assert myfunc(table.a, table.b).type().is_floating() + + with pytest.raises(ValidationError): + # Wrong arg types + myfunc(table.b, table.a) + + +@pytest.mark.parametrize( + "dec", + [ + pytest.param(ibis.udf.scalar.builtin, id="scalar-builtin"), + pytest.param(ibis.udf.scalar.pandas, id="scalar-pandas"), + pytest.param(ibis.udf.scalar.pyarrow, id="scalar-pyarrow"), + pytest.param(ibis.udf.scalar.python, id="scalar-python"), + pytest.param(ibis.udf.agg.builtin, id="agg-builtin"), + ], +) +def test_udf_from_sig(dec, table): + @dec(signature=((int, str), float)) + def myfunc(x, y): + ... + + assert myfunc(table.a, table.b).type().is_floating() + + with pytest.raises(ValidationError): + # Wrong arg types + myfunc(table.b, table.a) + + @pytest.mark.parametrize( "dec", [