Skip to content

Commit

Permalink
feat(bigquery): add SQL UDF support
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztof-kwitt authored and cpcloud committed Jan 18, 2023
1 parent fb33bf9 commit db24173
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, expr, context):

def compile(self):
"""Generate UDF string from definition."""
return self.expr.op().js
return self.expr.op().sql


class BigQueryUnion(sql_compiler.Union):
Expand Down
32 changes: 30 additions & 2 deletions ibis/backends/bigquery/tests/system/udf/test_udf_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import pandas.testing as tm
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, width, height):

return Rectangle(a, b)

result = my_struct_thing.js
result = my_struct_thing.sql
snapshot.assert_match(result, "out.sql")

expr = my_struct_thing(alltypes.double_col, alltypes.double_col)
Expand Down Expand Up @@ -108,7 +109,7 @@ def my_str_len(s):
add = expr.op()

# generated javascript is identical
assert add.left.op().js == add.right.op().js
assert add.left.op().sql == add.right.op().sql
assert client.execute(expr) == 8.0


Expand Down Expand Up @@ -142,3 +143,30 @@ def my_array_len(x):

assert client.execute(my_str_len("aaa")) == 3
assert client.execute(my_array_len(["aaa", "bb"])) == 2


@pytest.mark.parametrize(
("argument_type",),
[
param(
dt.string,
id="string",
),
param(
"ANY TYPE",
id="string",
),
],
)
def test_udf_sql(client, argument_type):
format_t = udf.sql(
"format_t",
params={'input': argument_type},
output_type=dt.string,
sql_expression="FORMAT('%T', input)",
)

s = ibis.literal("abcd")
expr = format_t(s)

client.execute(expr)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TEMPORARY FUNCTION format_t_0(input STRING)
RETURNS FLOAT64
AS (FORMAT('%T', input));

SELECT format_t_0('abcd') AS `tmp`
17 changes: 17 additions & 0 deletions ibis/backends/bigquery/tests/unit/udf/test_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def my_len(s):
snapshot.assert_match(sql, "out.sql")


def test_udf_sql(snapshot):
_udf_name_cache.clear()

format_t = udf.sql(
"format_t",
params={'input': dt.string},
output_type=dt.double,
sql_expression="FORMAT('%T', input)",
)

s = ibis.literal("abcd")
expr = format_t(s)

sql = ibis.bigquery.compile(expr)
snapshot.assert_match(sql, "out.sql")


@pytest.mark.parametrize(
("argument_type", "return_type"),
[
Expand Down
100 changes: 91 additions & 9 deletions ibis/backends/bigquery/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections
import inspect
import itertools
from typing import Callable, Iterable, Mapping
from typing import Callable, Iterable, Literal, Mapping

import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
Expand Down Expand Up @@ -84,7 +84,7 @@ def python(
>>> @udf.python(input_type=[dt.double], output_type=dt.double)
... def add_one(x):
... return x + 1
>>> print(add_one.js)
>>> print(add_one.sql)
CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64)
RETURNS FLOAT64
LANGUAGE js AS """
Expand All @@ -106,7 +106,7 @@ def python(
... for value in gen(start, stop):
... result.append(value)
... return result
>>> print(my_range.js)
>>> print(my_range.sql)
CREATE TEMPORARY FUNCTION my_range_0(start FLOAT64, stop FLOAT64)
RETURNS ARRAY<FLOAT64>
LANGUAGE js AS """
Expand Down Expand Up @@ -147,7 +147,7 @@ def python(
... return 2 * (self.width + self.height)
...
... return Rectangle(width, height)
>>> print(my_rectangle.js)
>>> print(my_rectangle.sql)
CREATE TEMPORARY FUNCTION my_rectangle_0(width FLOAT64, height FLOAT64)
RETURNS STRUCT<width FLOAT64, height FLOAT64>
LANGUAGE js AS """
Expand Down Expand Up @@ -246,7 +246,7 @@ def js(
... output_type=dt.double,
... body="return x + 1"
... )
>>> print(add_one.js)
>>> print(add_one.sql)
CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64)
RETURNS FLOAT64
LANGUAGE js AS """
Expand All @@ -272,7 +272,7 @@ def js(

udf_node_fields["output_dtype"] = output_type
udf_node_fields["output_shape"] = rlz.shape_like("args")
udf_node_fields["__slots__"] = ("js",)
udf_node_fields["__slots__"] = ("sql",)

udf_node = _create_udf_node(name, udf_node_fields)

Expand All @@ -299,7 +299,7 @@ def compiles_udf_node(t, op):
False: 'NOT DETERMINISTIC\n',
None: '',
}.get(determinism)
js = f'''\
sql_code = f'''\
CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature})
RETURNS {return_type}
{determinism_formatted}LANGUAGE js AS """
Expand All @@ -308,7 +308,7 @@ def compiles_udf_node(t, op):

def wrapped(*args, **kwargs):
node = udf_node(*args, **kwargs)
object.__setattr__(node, "js", js)
object.__setattr__(node, "sql", sql_code)
return node.to_expr()

wrapped.__signature__ = inspect.Signature(
Expand All @@ -320,8 +320,90 @@ def wrapped(*args, **kwargs):
]
)
wrapped.__name__ = name
wrapped.js = js
wrapped.sql = sql_code
return wrapped

@staticmethod
def sql(
name: str,
params: Mapping[str, dt.DataType | Literal["ANY TYPE"]],
output_type: dt.DataType,
sql_expression: str,
) -> Callable:
"""Define a SQL UDF for BigQuery.
Parameters
----------
name:
The name of the function.
params
Mapping of names and types of parameters
output_type
Return type of the UDF.
sql_expression
The SQL expression that defines the function.
Returns
-------
Callable
The wrapped user-defined function.
Examples
--------
>>> from ibis.backends.bigquery import udf
>>> import ibis.expr.datatypes as dt
>>> add_one = udf.sql(
... name="add_one",
... params={'x': dt.double},
... output_type=dt.double,
... sql_expression="x + 1"
... )
>>> print(add_one.sql)
CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64)
RETURNS FLOAT64
AS (x + 1)
"""
validate_output_type(output_type)
udf_node_fields = {
name: rlz.any if type_ == "ANY TYPE" else rlz.value(type_)
for name, type_ in params.items()
}

return_type = ibis_type_to_bigquery_type(dt.dtype(output_type))

udf_node_fields["output_dtype"] = output_type
udf_node_fields["output_shape"] = rlz.shape_like("args")
udf_node_fields["__slots__"] = ("sql",)

udf_node = _create_udf_node(name, udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args_formatted = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args_formatted})"

bigquery_signature = ", ".join(
"{name} {type}".format(
name=name,
type="ANY TYPE"
if type_ == "ANY TYPE"
else ibis_type_to_bigquery_type(dt.dtype(type_)),
)
for name, type_ in params.items()
)
sql_code = f'''\
CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature})
RETURNS {return_type}
AS ({sql_expression});'''

def wrapper(*args, **kwargs):
node = udf_node(*args, **kwargs)
object.__setattr__(node, "sql", sql_code)
return node.to_expr()

return wrapper


udf = _BigQueryUDF()
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/udf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,4 +602,4 @@ def range(n):
nnn = len(values)
return [sum(values) - a + b * y**-x, z, foo.width, nnn]

print(my_func.js) # noqa: T201
print(my_func.sql) # noqa: T201

0 comments on commit db24173

Please sign in to comment.