Skip to content

Commit

Permalink
feat(udf): add support for builtin aggregate UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 14, 2023
1 parent f29a8e7 commit 8ee12bf
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 98 deletions.

Large diffs are not rendered by default.

110 changes: 104 additions & 6 deletions docs/how-to/extending/builtin.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@ freeze: auto

# Reference built-in functions


## Scalar functions

Functions that aren't exposed in ibis directly can be accessed using the
`@ibis.udf.scalar.builtin` decorator.

::: {.callout-tip}
## [Ibis APIs](../../reference/index.qmd) may already exist for your function.
### [Ibis APIs](../../reference/index.qmd) may already exist for your function.

Builtin scalar UDFs are designed to be an escape hatch when Ibis doesn't have
a defined API for a built-in database function.

See [the reference documentation](../../reference/index.qmd) for existing APIs.
:::

## DuckDB
### DuckDB

Ibis doesn't directly expose many of the DuckDB [text similarity
functions](https://duckdb.org/docs/sql/functions/char.html#text-similarity-functions).
Expand All @@ -34,9 +37,9 @@ The [`...`](https://docs.python.org/3/library/constants.html#Ellipsis) is
a visual indicator that the function definition is unknown to Ibis.

::: {.callout-note collapse="true"}
## Ibis does not do anything with the function body.
### Ibis does not do anything with the function body.

Ibis will not inspect the function body or otherwise inspect it. Any code you
Ibis will not execute the function body or otherwise inspect it. Any code you
write in the function body **will be ignored**.
:::

Expand Down Expand Up @@ -91,7 +94,7 @@ pandas_ish.count()

There are a good number of packages that look similar to `pandas`!

## Snowflake
### Snowflake

Similarly we can expose Snowflake's
[`jarowinkler_similarity`](https://docs.snowflake.com/en/sql-reference/functions/jarowinkler_similarity)
Expand Down Expand Up @@ -129,7 +132,7 @@ And let's take a look at the SQL
ibis.to_sql(expr, dialect="snowflake")
```

## Input types
### Input types

Sometimes the input types of builtin functions are difficult to spell.

Expand Down Expand Up @@ -169,3 +172,98 @@ con.execute(cardinality("foo"))

Here, Snowflake is informing us that the `ARRAY_SIZE` function does not accept
strings as input.


## Aggregate functions

Aggregate functions that aren't exposed in ibis directly can be accessed using
the `@ibis.udf.agg.builtin` decorator.

::: {.callout-tip}
### [Ibis APIs](../../reference/index.qmd) may already exist for your function.

Builtin aggregate UDFs are designed to be an escape hatch when Ibis doesn't have
a defined API for a built-in database function.

See [the reference documentation](../../reference/index.qmd) for existing APIs.
:::

Let's the use the DuckDB backend to demonstrate how to access an aggregate
function that isn't exposed in ibis:
[`kurtosis`](https://en.wikipedia.org/wiki/Kurtosis).

### DuckDB

First, define the builtin aggregate function:

```{python}
@udf.agg.builtin
def kurtosis(x: float) -> float: # <1>
...
```

1. Both the input and return type annotations indicate the **element** type of
the input, not the shape (column or scalar). Aggregations can only be called
on column expressions.

One of the powerful features of this API is that you can define your UD(A)Fs at
any point during your analysis. You don't need to connect to the database to
define your functions.

Let's compute the kurtosis of the number of votes across all movies:

```{python}
from ibis import _
expr = (
ibis.examples.imdb_title_ratings.fetch()
.rename("snake_case")
.agg(kurt=lambda t: kurtosis(t.num_votes))
)
expr
```

Since this is an aggregate function, it has the same capabilities as other,
builtin aggregates like `sum`: it can be used in a group by as well as in
a window function expression.

Let's compute kurtosis for all the different types of productions (shorts,
movies, TV, etc):

```{python}
basics = (
ibis.examples.imdb_title_basics.fetch()
.rename("snake_case")
.filter(_.is_adult == 0)
)
ratings = ibis.examples.imdb_title_ratings.fetch().rename("snake_case")
basics_ratings = ratings.join(basics, "tconst")
expr = (
basics_ratings.group_by("title_type")
.agg(kurt=lambda t: kurtosis(t.num_votes))
.order_by(_.kurt.desc())
.head()
)
expr
```

Similarly for window functions:

```{python}
expr = (
basics_ratings.mutate(
kurt=lambda t: kurtosis(t.num_votes).over(group_by="title_type")
)
.relocate("kurt", after="tconst")
.filter(
[
_.original_title.lower().contains("godfather"),
_.title_type == "movie",
_.genres.contains("Crime") & _.genres.contains("Drama"),
]
)
)
expr
```
23 changes: 23 additions & 0 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,36 @@ def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
return f"{func}({', '.join(map(t.translate, op.args))})"

def _gen_udaf_rule(self, op: ops.AggUDF):
from ibis import NA

@self.add_operation(type(op))
def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
args = ", ".join(
t.translate(
ops.Where(where, arg, NA)
if (where := op.where) is not None
else arg
)
for name, arg in zip(op.argnames, op.args)
if name != "where"
)
return f"{func}({args})"

def _define_udf_translation_rules(self, expr):
for udf_node in expr.op().find(ops.ScalarUDF):
udf_node_type = type(udf_node)

if udf_node_type not in self.compiler.translator_class._registry:
self._gen_udf_rule(udf_node)

for udf_node in expr.op().find(ops.AggUDF):
udf_node_type = type(udf_node)

if udf_node_type not in self.compiler.translator_class._registry:
self._gen_udaf_rule(udf_node)

def execute(
self,
expr: ir.Expr,
Expand Down
23 changes: 22 additions & 1 deletion ibis/backends/bigquery/tests/system/udf/test_udf_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_udf_sql(con, argument_type):
param(b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", 80, id="eighty"),
],
)
def test_builtin(con, value, expected):
def test_builtin_scalar(con, value, expected):
from ibis import udf

@udf.scalar.builtin
Expand All @@ -197,3 +197,24 @@ def bit_count(x: bytes) -> int:
expr = bit_count(value)
result = con.execute(expr)
assert result == expected


@pytest.mark.parametrize(
("where", "expected"),
[
param({"where": True}, list("abcdef"), id="where-true"),
param({"where": False}, [], id="where-false"),
param({}, list("abcdef"), id="where-nothing"),
],
)
def test_builtin_agg(con, where, expected):
from ibis import udf

@udf.agg.builtin(name="array_concat_agg")
def concat_agg(x, where: bool = True) -> dt.Array[str]:
...

t = ibis.memtable({"a": [list("abc"), list("def")]})
expr = concat_agg(t.a, **where)
result = con.execute(expr)
assert result == expected
11 changes: 4 additions & 7 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,12 @@ def _get_udf_source(self, udf_node: ops.ScalarUDF):
config = udf_node.__config__["kwargs"]
func = udf_node.__func__
func_name = func.__name__
schema = udf_node.__udf_namespace__
name = udf_node.__func_name__
ident = ".".join(filter(None, [schema, name]))
return dict(
name=name,
ident=ident,
name=udf_node.__func_name__,
ident=udf_node.__full_name__,
signature=", ".join(
f"{name} {self._compile_type(arg.dtype)}"
for name, arg in zip(udf_node.argnames, udf_node.args)
f"{argname} {self._compile_type(arg.dtype)}"
for argname, arg in zip(udf_node.argnames, udf_node.args)
),
return_type=self._compile_type(udf_node.dtype),
language=config.get("language", "plpython3u"),
Expand Down
Loading

0 comments on commit 8ee12bf

Please sign in to comment.