diff --git a/ibis/expr/operations/udf.py b/ibis/expr/operations/udf.py index 461ffe318ac6..dcca4f603243 100644 --- a/ibis/expr/operations/udf.py +++ b/ibis/expr/operations/udf.py @@ -16,6 +16,7 @@ from ibis import util from ibis.common.annotations import Argument from ibis.common.collections import FrozenDict +from ibis.expr.deferred import deferrable if TYPE_CHECKING: import ibis.expr.types as ir @@ -50,11 +51,13 @@ def _wrap( **kwargs: Any, ) -> Callable: """Wrap a function `fn` with `wrapper`, allowing zero arguments when used as part of a decorator.""" - if fn is None: - return lambda fn: functools.update_wrapper( - wrapper(input_type, fn, *args, **kwargs), fn + + def wrap(fn): + return functools.update_wrapper( + deferrable(wrapper(input_type, fn, *args, **kwargs)), fn ) - return functools.update_wrapper(wrapper(input_type, fn, *args, **kwargs), fn) + + return wrap(fn) if fn is not None else wrap S = TypeVar("S", bound=ops.Value) diff --git a/ibis/tests/expr/test_udf.py b/ibis/tests/expr/test_udf.py index e59be3f9337c..d81dd995eb01 100644 --- a/ibis/tests/expr/test_udf.py +++ b/ibis/tests/expr/test_udf.py @@ -6,7 +6,9 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir +from ibis import _ from ibis.common.annotations import ValidationError +from ibis.expr.deferred import Deferred @pytest.fixture @@ -79,3 +81,24 @@ def test_vectorized_udf_operations(table, klass, output_type): input_type=[dt.int8(), dt.string(), dt.boolean()], return_type=table, ) + + +@pytest.mark.parametrize( + "dec", + [ + pytest.param(ibis.udf.scalar.builtin, id="scalar-builtin"), + pytest.param(ibis.udf.scalar.pandas, id="scalar-pandas"), + pytest.param(ibis.udf.scalar.pyarrow, id="scalar-pyarrow"), + pytest.param(ibis.udf.scalar.python, id="scalar-python"), + pytest.param(ibis.udf.agg.builtin, id="agg-builtin"), + ], +) +def test_udf_deferred(dec, table): + @dec + def myfunc(x: int) -> int: + ... + + expr = myfunc(_.a) + assert isinstance(expr, Deferred) + assert repr(expr) == "myfunc(_.a)" + assert expr.resolve(table).equals(myfunc(table.a))