Skip to content

Commit

Permalink
feat(udf): support inputs without type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 12, 2023
1 parent 3491562 commit 99e531d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 7 deletions.

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions docs/how-to/extending/builtin.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,44 @@ And let's take a look at the SQL
```{python}
ibis.to_sql(expr, dialect="snowflake")
```

## Input types

Sometimes the input types of builtin functions are difficult to spell.

Consider a function that computes the length of any array: the elements in the
array can be floats, integers, strings and even other arrays. Spelling that
type is difficult.

Fortunately the `udf.scalar.builtin` decorator doesn't require you to specify
input types in these cases:

```{python}
@udf.scalar.builtin(name="array_size")
def cardinality(arr) -> int:
...
```

::: {.callout-caution}
## The return type annotation **is always required**.
:::

We can pass arrays with different element types to our `cardinality` function:

```{python}
con.execute(cardinality([1, 2, 3]))
```

```{python}
con.execute(cardinality(["a", "b"]))
```

When you bypass input types the errors you get back are backend dependent:

```{python}
#| error: true
con.execute(cardinality("foo"))
```

Here, Snowflake is informing us that the `ARRAY_SIZE` function does not accept
strings as input.
14 changes: 13 additions & 1 deletion ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,24 @@ def array_jaccard_index(a: dt.Array[dt.int64], b: dt.Array[dt.int64]) -> float:
...


@udf.scalar.builtin(name="arrayJaccardIndex")
def array_jaccard_index_no_input_types(a, b) -> float:
...


@udf.scalar.builtin
def arrayJaccardIndex(a: dt.Array[dt.int64], b: dt.Array[dt.int64]) -> float:
...


@pytest.mark.parametrize("func", [array_jaccard_index, arrayJaccardIndex])
@pytest.mark.parametrize(
"func",
[
array_jaccard_index,
arrayJaccardIndex,
array_jaccard_index_no_input_types,
],
)
def test_builtin_udf(con, func):
expr = func([1, 2], [2, 3])
result = con.execute(expr)
Expand Down
9 changes: 5 additions & 4 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,11 @@ def _make_node(
func_name = name or fn.__name__

for arg_name, param in inspect.signature(fn).parameters.items():
if (raw_dtype := annotations.get(arg_name)) is None:
raise exc.MissingParameterAnnotationError(fn, arg_name)

arg = rlz.ValueOf(dt.dtype(raw_dtype))
if (raw_dtype := annotations.get(arg_name)) is not None:
dtype = dt.dtype(raw_dtype)
else:
dtype = raw_dtype
arg = rlz.ValueOf(dtype)
fields[arg_name] = Argument(pattern=arg, default=param.default)

fields["dtype"] = dt.dtype(return_annotation)
Expand Down

0 comments on commit 99e531d

Please sign in to comment.