Skip to content

Commit

Permalink
feat(api): support deferred arguments to udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Sep 27, 2023
1 parent 26ffc68 commit a49d259
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
11 changes: 7 additions & 4 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ibis import util
from ibis.common.annotations import Argument
from ibis.common.collections import FrozenDict
from ibis.expr.deferred import deferrable

if TYPE_CHECKING:
import ibis.expr.types as ir
Expand Down Expand Up @@ -50,11 +51,13 @@ def _wrap(
**kwargs: Any,
) -> Callable:
"""Wrap a function `fn` with `wrapper`, allowing zero arguments when used as part of a decorator."""
if fn is None:
return lambda fn: functools.update_wrapper(
wrapper(input_type, fn, *args, **kwargs), fn

def wrap(fn):
return functools.update_wrapper(
deferrable(wrapper(input_type, fn, *args, **kwargs)), fn
)
return functools.update_wrapper(wrapper(input_type, fn, *args, **kwargs), fn)

return wrap(fn) if fn is not None else wrap


S = TypeVar("S", bound=ops.Value)
Expand Down
23 changes: 23 additions & 0 deletions ibis/tests/expr/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import _
from ibis.common.annotations import ValidationError
from ibis.expr.deferred import Deferred


@pytest.fixture
Expand Down Expand Up @@ -79,3 +81,24 @@ def test_vectorized_udf_operations(table, klass, output_type):
input_type=[dt.int8(), dt.string(), dt.boolean()],
return_type=table,
)


@pytest.mark.parametrize(
"dec",
[
pytest.param(ibis.udf.scalar.builtin, id="scalar-builtin"),
pytest.param(ibis.udf.scalar.pandas, id="scalar-pandas"),
pytest.param(ibis.udf.scalar.pyarrow, id="scalar-pyarrow"),
pytest.param(ibis.udf.scalar.python, id="scalar-python"),
pytest.param(ibis.udf.agg.builtin, id="agg-builtin"),
],
)
def test_udf_deferred(dec, table):
@dec
def myfunc(x: int) -> int:
...

expr = myfunc(_.a)
assert isinstance(expr, Deferred)
assert repr(expr) == "myfunc(_.a)"
assert expr.resolve(table).equals(myfunc(table.a))

0 comments on commit a49d259

Please sign in to comment.