Skip to content

Commit

Permalink
feat(sqlalchemy): support builtin aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 14, 2023
1 parent ebc8eae commit 3b27e23
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
18 changes: 18 additions & 0 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,24 @@ def _(t, op):
func = getattr(generator, op.__func_name__)
return func(*map(t.translate, op.args))

def _gen_udaf_rule(self, op: ops.AggUDF):
from ibis import NA

@self.add_operation(type(op))
def _(t, op):
args = (arg for name, arg in zip(op.argnames, op.args) if name != "where")
generator = sa.func
if (namespace := op.__udf_namespace__) is not None:
generator = getattr(generator, namespace)
func = getattr(generator, op.__func_name__)

if (where := op.where) is None:
return func(*map(t.translate, args))
elif t._has_reduction_filter_syntax:
return func(*map(t.translate, args)).filter(t.translate(where))
else:
return func(*(t.translate(ops.Where(where, arg, NA)) for arg in args))

def _register_udfs(self, expr: ir.Expr) -> None:
with self.begin() as con:
for udf_node in expr.op().find(ops.ScalarUDF):
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _window_function(t, window):
# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
if isinstance(func, t._require_order_by) and not window.frame.order_by:
order_by = t.translate(func.arg) # .args[0])
order_by = t.translate(func.args[0])
else:
order_by = [t.translate(arg) for arg in window.frame.order_by]

Expand Down
35 changes: 34 additions & 1 deletion ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import pandas.testing as tm
import pytest
from pytest import param

import ibis.expr.datatypes as dt
from ibis import udf


Expand Down Expand Up @@ -34,7 +36,7 @@ def compress_bytes(data: bytes, method: str) -> bytes:
param(jarowinkler_similarity, ("snow", "show"), id="jarowinkler_similarity"),
],
)
def test_builtin(con, func, args):
def test_builtin_scalar_udf(con, func, args):
expr = func(*args)

query = f"SELECT {func.__name__}({', '.join(map(repr, args))})"
Expand All @@ -59,3 +61,34 @@ def test_compress(con, func, pyargs, snowargs):
expected = c.exec_driver_sql(query).scalar()

assert con.execute(expr) == expected


@udf.agg.builtin
def minhash(x, y) -> dt.json:
...


@udf.agg.builtin
def approximate_jaccard_index(a) -> float:
...


def test_builtin_agg_udf(con):
ft = con.tables.FUNCTIONAL_ALLTYPES.limit(2)
ft = ft.select(mh=minhash(100, ft.string_col).over(group_by=ft.date_string_col))
expr = ft.agg(aji=approximate_jaccard_index(ft.mh))

result = expr.execute()
query = """
SELECT approximate_jaccard_index("mh") AS "aji"
FROM (
SELECT minhash(100, "string_col") OVER (PARTITION BY "date_string_col") AS "mh"
FROM (
SELECT * FROM "FUNCTIONAL_ALLTYPES" LIMIT 2
)
)
"""
with con.begin() as c:
expected = c.exec_driver_sql(query).cursor.fetch_pandas_all()

tm.assert_frame_equal(result, expected)

0 comments on commit 3b27e23

Please sign in to comment.