Skip to content

Commit

Permalink
feat(api): support specifying signature in udf definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Oct 23, 2023
1 parent 3518b78 commit 764977e
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 30 deletions.
114 changes: 84 additions & 30 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,47 @@ def _make_node(
input_type: InputType,
name: str | None = None,
schema: str | None = None,
signature: tuple[tuple, Any] | None = None,
**kwargs,
) -> type[S]:
"""Construct a scalar user-defined function that is built-in to the backend."""

annotations = typing.get_type_hints(fn)
if (return_annotation := annotations.pop("return", None)) is None:
raise exc.MissingReturnAnnotationError(fn)
if signature is None:
annotations = typing.get_type_hints(fn)
if (return_annotation := annotations.pop("return", None)) is None:
raise exc.MissingReturnAnnotationError(fn)
fields = {
arg_name: Argument(
pattern=rlz.ValueOf(annotations.get(arg_name)),
default=param.default,
typehint=annotations.get(arg_name, Any),
)
for arg_name, param in inspect.signature(fn).parameters.items()
}
else:
arg_types, return_annotation = signature
arg_names = list(inspect.signature(fn).parameters)
fields = {
arg_name: Argument(pattern=rlz.ValueOf(typ), typehint=typ)
for arg_name, typ in zip(arg_names, arg_types)
}

func_name = name if name is not None else fn.__name__

fields = {
arg_name: Argument(
pattern=rlz.ValueOf(annotations.get(arg_name)), default=param.default
)
for arg_name, param in inspect.signature(fn).parameters.items()
} | {
"dtype": dt.dtype(return_annotation),
"__input_type__": input_type,
# must wrap `fn` in a `property` otherwise `fn` is assumed to be a
# method
"__func__": property(fget=lambda _, fn=fn: fn),
"__config__": FrozenDict(kwargs),
"__udf_namespace__": schema,
"__module__": fn.__module__,
"__func_name__": func_name,
"__full_name__": ".".join(filter(None, (schema, func_name))),
}
fields.update(
{
"dtype": dt.dtype(return_annotation),
"__input_type__": input_type,
# must wrap `fn` in a `property` otherwise `fn` is assumed to be a
# method
"__func__": property(fget=lambda _, fn=fn: fn),
"__config__": FrozenDict(kwargs),
"__udf_namespace__": schema,
"__module__": fn.__module__,
"__func_name__": func_name,
"__full_name__": ".".join(filter(None, (schema, func_name))),
}
)

return type(fn.__name__, (cls._base,), fields)

Expand Down Expand Up @@ -144,23 +158,30 @@ def builtin(
*,
name: str | None = None,
schema: str | None = None,
signature: tuple[tuple[Any, ...], Any] | None = None,
**kwargs: Any,
) -> Callable[[Callable], Callable[..., ir.Value]]:
...

@util.experimental
@classmethod
def builtin(cls, fn=None, *, name=None, schema=None, **kwargs):
def builtin(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs):
"""Construct a scalar user-defined function that is built-in to the backend.
Parameters
----------
fn
The The function to wrap.
The function to wrap.
name
The name of the UDF in the backend if different from the function name.
schema
The schema in which the builtin function resides.
signature
An optional signature to use for the UDF. If present, should be a
tuple containing a tuple of argument types and a return type. For
example, a function taking an int and a float and returning a
string would be `((int, float), str)`. If not present, the argument
types will be derived from the wrapped function.
kwargs
Additional backend-specific configuration arguments for the UDF.
Expand All @@ -181,6 +202,7 @@ def builtin(cls, fn=None, *, name=None, schema=None, **kwargs):
fn,
name=name,
schema=schema,
signature=signature,
**kwargs,
)

Expand All @@ -196,13 +218,14 @@ def python(
*,
name: str | None = None,
schema: str | None = None,
signature: tuple[tuple[Any, ...], Any] | None = None,
**kwargs: Any,
) -> Callable[[Callable], Callable[..., ir.Value]]:
...

@util.experimental
@classmethod
def python(cls, fn=None, *, name=None, schema=None, **kwargs):
def python(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs):
"""Construct a **non-vectorized** scalar user-defined function that accepts Python scalar values as inputs.
::: {.callout-warning collapse="true"}
Expand All @@ -221,11 +244,17 @@ def python(cls, fn=None, *, name=None, schema=None, **kwargs):
Parameters
----------
fn
The The function to wrap.
The function to wrap.
name
The name of the UDF in the backend if different from the function name.
schema
The schema in which to create the UDF.
signature
An optional signature to use for the UDF. If present, should be a
tuple containing a tuple of argument types and a return type. For
example, a function taking an int and a float and returning a
string would be `((int, float), str)`. If not present, the argument
types will be derived from the wrapped function.
kwargs
Additional backend-specific configuration arguments for the UDF.
Expand All @@ -251,6 +280,7 @@ def python(cls, fn=None, *, name=None, schema=None, **kwargs):
fn,
name=name,
schema=schema,
signature=signature,
**kwargs,
)

Expand All @@ -266,23 +296,30 @@ def pandas(
*,
name: str | None = None,
schema: str | None = None,
signature: tuple[tuple[Any, ...], Any] | None = None,
**kwargs: Any,
) -> Callable[[Callable], Callable[..., ir.Value]]:
...

@util.experimental
@classmethod
def pandas(cls, fn=None, *, name=None, schema=None, **kwargs):
def pandas(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs):
"""Construct a **vectorized** scalar user-defined function that accepts pandas Series' as inputs.
Parameters
----------
fn
The The function to wrap.
The function to wrap.
name
The name of the UDF in the backend if different from the function name.
schema
The schema in which to create the UDF.
signature
An optional signature to use for the UDF. If present, should be a
tuple containing a tuple of argument types and a return type. For
example, a function taking an int and a float and returning a
string would be `((int, float), str)`. If not present, the argument
types will be derived from the wrapped function.
kwargs
Additional backend-specific configuration arguments for the UDF.
Expand Down Expand Up @@ -310,6 +347,7 @@ def pandas(cls, fn=None, *, name=None, schema=None, **kwargs):
fn,
name=name,
schema=schema,
signature=signature,
**kwargs,
)

Expand All @@ -325,23 +363,30 @@ def pyarrow(
*,
name: str | None = None,
schema: str | None = None,
signature: tuple[tuple[Any, ...], Any] | None = None,
**kwargs: Any,
) -> Callable[[Callable], Callable[..., ir.Value]]:
...

@util.experimental
@classmethod
def pyarrow(cls, fn=None, *, name=None, schema=None, **kwargs):
def pyarrow(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs):
"""Construct a **vectorized** scalar user-defined function that accepts PyArrow Arrays as input.
Parameters
----------
fn
The The function to wrap.
The function to wrap.
name
The name of the UDF in the backend if different from the function name.
schema
The schema in which to create the UDF.
signature
An optional signature to use for the UDF. If present, should be a
tuple containing a tuple of argument types and a return type. For
example, a function taking an int and a float and returning a
string would be `((int, float), str)`. If not present, the argument
types will be derived from the wrapped function.
kwargs
Additional backend-specific configuration arguments for the UDF.
Expand All @@ -368,6 +413,7 @@ def pyarrow(cls, fn=None, *, name=None, schema=None, **kwargs):
fn,
name=name,
schema=schema,
signature=signature,
**kwargs,
)

Expand All @@ -389,23 +435,30 @@ def builtin(
*,
name: str | None = None,
schema: str | None = None,
signature: tuple[tuple[Any, ...], Any] | None = None,
**kwargs: Any,
) -> Callable[[Callable], Callable[..., ir.Value]]:
...

@util.experimental
@classmethod
def builtin(cls, fn=None, *, name=None, schema=None, **kwargs):
def builtin(cls, fn=None, *, name=None, schema=None, signature=None, **kwargs):
"""Construct an aggregate user-defined function that is built-in to the backend.
Parameters
----------
fn
The The function to wrap.
The function to wrap.
name
The name of the UDF in the backend if different from the function name.
schema
The schema in which the builtin function resides.
signature
An optional signature to use for the UDF. If present, should be a
tuple containing a tuple of argument types and a return type. For
example, a function taking an int and a float and returning a
string would be `((int, float), str)`. If not present, the argument
types will be derived from the wrapped function.
kwargs
Additional backend-specific configuration arguments for the UDF.
Expand All @@ -427,5 +480,6 @@ def builtin(cls, fn=None, *, name=None, schema=None, **kwargs):
fn,
name=name,
schema=schema,
signature=signature,
**kwargs,
)
44 changes: 44 additions & 0 deletions ibis/tests/expr/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,50 @@ def test_vectorized_udf_operations(table, klass, output_type):
)


@pytest.mark.parametrize(
"dec",
[
pytest.param(ibis.udf.scalar.builtin, id="scalar-builtin"),
pytest.param(ibis.udf.scalar.pandas, id="scalar-pandas"),
pytest.param(ibis.udf.scalar.pyarrow, id="scalar-pyarrow"),
pytest.param(ibis.udf.scalar.python, id="scalar-python"),
pytest.param(ibis.udf.agg.builtin, id="agg-builtin"),
],
)
def test_udf_from_annotations(dec, table):
@dec
def myfunc(x: int, y: str) -> float:
...

assert myfunc(table.a, table.b).type().is_floating()

with pytest.raises(ValidationError):
# Wrong arg types
myfunc(table.b, table.a)


@pytest.mark.parametrize(
"dec",
[
pytest.param(ibis.udf.scalar.builtin, id="scalar-builtin"),
pytest.param(ibis.udf.scalar.pandas, id="scalar-pandas"),
pytest.param(ibis.udf.scalar.pyarrow, id="scalar-pyarrow"),
pytest.param(ibis.udf.scalar.python, id="scalar-python"),
pytest.param(ibis.udf.agg.builtin, id="agg-builtin"),
],
)
def test_udf_from_sig(dec, table):
@dec(signature=((int, str), float))
def myfunc(x, y):
...

assert myfunc(table.a, table.b).type().is_floating()

with pytest.raises(ValidationError):
# Wrong arg types
myfunc(table.b, table.a)


@pytest.mark.parametrize(
"dec",
[
Expand Down

0 comments on commit 764977e

Please sign in to comment.