Skip to content

Commit

Permalink
fix(duckdb): udfs builtins taking zero args
Browse files Browse the repository at this point in the history
chore: add test_builtin_scalr_noargs to test/expr/test_udf.py
  • Loading branch information
ncclementi authored and jcrist committed Feb 9, 2024
1 parent 9d34f82 commit ab39344
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
13 changes: 13 additions & 0 deletions ibis/backends/duckdb/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ def test_builtin_scalar(con, func):
assert con.execute(expr) == expected


def test_builtin_scalar_noargs(con):
@udf.scalar.builtin
def version() -> str:
...

expr = version()

with con.begin() as c:
expected = c.exec_driver_sql("SELECT version()").scalar()

assert con.execute(expr) == expected


@udf.agg.builtin
def product(x, where: bool = True) -> float:
...
Expand Down
15 changes: 13 additions & 2 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from public import public

import ibis.common.exceptions as exc
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
from ibis import util
from ibis.common.annotations import Argument
from ibis.common.annotations import Argument, attribute
from ibis.common.collections import FrozenDict
from ibis.common.deferred import deferrable

Expand All @@ -35,7 +36,16 @@ class InputType(enum.Enum):

@public
class ScalarUDF(ops.Value):
shape = rlz.shape_like("args")
@attribute
def shape(self):
if not (args := getattr(self, "args")): # noqa: B009
# if a udf builtin takes no args then the shape check will fail
# because there are no arguments to grab the shape of. In that case
# default to a scalar shape
return ds.scalar
else:
args = args if util.is_iterable(args) else [args]
return rlz.highest_precedence_shape(args)


@public
Expand Down Expand Up @@ -95,6 +105,7 @@ def _make_node(
)
for arg_name, param in inspect.signature(fn).parameters.items()
}

else:
arg_types, return_annotation = signature
arg_names = list(inspect.signature(fn).parameters)
Expand Down
10 changes: 10 additions & 0 deletions ibis/tests/expr/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,13 @@ def myfunc(x: int) -> int:
assert isinstance(expr, Deferred)
assert repr(expr) == "myfunc(_.a)"
assert expr.resolve(table).equals(myfunc(table.a))


def test_builtin_scalar_noargs():
@ibis.udf.scalar.builtin
def version() -> str:
...

expr = version()
assert expr.type().is_string()
assert type(expr.op().shape) is ibis.expr.datashape.Scalar

0 comments on commit ab39344

Please sign in to comment.