From db241738052b137f714716dcf7d619cd4c1b2f78 Mon Sep 17 00:00:00 2001 From: Krzysztof Date: Wed, 18 Jan 2023 10:15:43 +0100 Subject: [PATCH] feat(bigquery): add SQL UDF support --- ibis/backends/bigquery/compiler.py | 2 +- .../tests/system/udf/test_udf_execute.py | 32 +++++- .../snapshots/test_usage/test_udf_sql/out.sql | 5 + .../bigquery/tests/unit/udf/test_usage.py | 17 +++ ibis/backends/bigquery/udf/__init__.py | 100 ++++++++++++++++-- ibis/backends/bigquery/udf/core.py | 2 +- 6 files changed, 145 insertions(+), 13 deletions(-) create mode 100644 ibis/backends/bigquery/tests/unit/udf/snapshots/test_usage/test_udf_sql/out.sql diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index b846e17056fb..b2d09b607563 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -22,7 +22,7 @@ def __init__(self, expr, context): def compile(self): """Generate UDF string from definition.""" - return self.expr.op().js + return self.expr.op().sql class BigQueryUnion(sql_compiler.Union): diff --git a/ibis/backends/bigquery/tests/system/udf/test_udf_execute.py b/ibis/backends/bigquery/tests/system/udf/test_udf_execute.py index b78f05cb6992..4126580003f7 100644 --- a/ibis/backends/bigquery/tests/system/udf/test_udf_execute.py +++ b/ibis/backends/bigquery/tests/system/udf/test_udf_execute.py @@ -3,6 +3,7 @@ import pandas as pd import pandas.testing as tm import pytest +from pytest import param import ibis import ibis.expr.datatypes as dt @@ -60,7 +61,7 @@ def __init__(self, width, height): return Rectangle(a, b) - result = my_struct_thing.js + result = my_struct_thing.sql snapshot.assert_match(result, "out.sql") expr = my_struct_thing(alltypes.double_col, alltypes.double_col) @@ -108,7 +109,7 @@ def my_str_len(s): add = expr.op() # generated javascript is identical - assert add.left.op().js == add.right.op().js + assert add.left.op().sql == add.right.op().sql assert client.execute(expr) == 8.0 @@ -142,3 +143,30 @@ def my_array_len(x): assert client.execute(my_str_len("aaa")) == 3 assert client.execute(my_array_len(["aaa", "bb"])) == 2 + + +@pytest.mark.parametrize( + ("argument_type",), + [ + param( + dt.string, + id="string", + ), + param( + "ANY TYPE", + id="string", + ), + ], +) +def test_udf_sql(client, argument_type): + format_t = udf.sql( + "format_t", + params={'input': argument_type}, + output_type=dt.string, + sql_expression="FORMAT('%T', input)", + ) + + s = ibis.literal("abcd") + expr = format_t(s) + + client.execute(expr) diff --git a/ibis/backends/bigquery/tests/unit/udf/snapshots/test_usage/test_udf_sql/out.sql b/ibis/backends/bigquery/tests/unit/udf/snapshots/test_usage/test_udf_sql/out.sql new file mode 100644 index 000000000000..09c0f57d2338 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/udf/snapshots/test_usage/test_udf_sql/out.sql @@ -0,0 +1,5 @@ +CREATE TEMPORARY FUNCTION format_t_0(input STRING) +RETURNS FLOAT64 +AS (FORMAT('%T', input)); + +SELECT format_t_0('abcd') AS `tmp` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/udf/test_usage.py b/ibis/backends/bigquery/tests/unit/udf/test_usage.py index 77b20b01d9b5..c70dae40095e 100644 --- a/ibis/backends/bigquery/tests/unit/udf/test_usage.py +++ b/ibis/backends/bigquery/tests/unit/udf/test_usage.py @@ -49,6 +49,23 @@ def my_len(s): snapshot.assert_match(sql, "out.sql") +def test_udf_sql(snapshot): + _udf_name_cache.clear() + + format_t = udf.sql( + "format_t", + params={'input': dt.string}, + output_type=dt.double, + sql_expression="FORMAT('%T', input)", + ) + + s = ibis.literal("abcd") + expr = format_t(s) + + sql = ibis.bigquery.compile(expr) + snapshot.assert_match(sql, "out.sql") + + @pytest.mark.parametrize( ("argument_type", "return_type"), [ diff --git a/ibis/backends/bigquery/udf/__init__.py b/ibis/backends/bigquery/udf/__init__.py index 383ecf93acb3..a6959eb562a6 100644 --- a/ibis/backends/bigquery/udf/__init__.py +++ b/ibis/backends/bigquery/udf/__init__.py @@ -3,7 +3,7 @@ import collections import inspect import itertools -from typing import Callable, Iterable, Mapping +from typing import Callable, Iterable, Literal, Mapping import ibis.expr.datatypes as dt import ibis.expr.rules as rlz @@ -84,7 +84,7 @@ def python( >>> @udf.python(input_type=[dt.double], output_type=dt.double) ... def add_one(x): ... return x + 1 - >>> print(add_one.js) + >>> print(add_one.sql) CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64) RETURNS FLOAT64 LANGUAGE js AS """ @@ -106,7 +106,7 @@ def python( ... for value in gen(start, stop): ... result.append(value) ... return result - >>> print(my_range.js) + >>> print(my_range.sql) CREATE TEMPORARY FUNCTION my_range_0(start FLOAT64, stop FLOAT64) RETURNS ARRAY LANGUAGE js AS """ @@ -147,7 +147,7 @@ def python( ... return 2 * (self.width + self.height) ... ... return Rectangle(width, height) - >>> print(my_rectangle.js) + >>> print(my_rectangle.sql) CREATE TEMPORARY FUNCTION my_rectangle_0(width FLOAT64, height FLOAT64) RETURNS STRUCT LANGUAGE js AS """ @@ -246,7 +246,7 @@ def js( ... output_type=dt.double, ... body="return x + 1" ... ) - >>> print(add_one.js) + >>> print(add_one.sql) CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64) RETURNS FLOAT64 LANGUAGE js AS """ @@ -272,7 +272,7 @@ def js( udf_node_fields["output_dtype"] = output_type udf_node_fields["output_shape"] = rlz.shape_like("args") - udf_node_fields["__slots__"] = ("js",) + udf_node_fields["__slots__"] = ("sql",) udf_node = _create_udf_node(name, udf_node_fields) @@ -299,7 +299,7 @@ def compiles_udf_node(t, op): False: 'NOT DETERMINISTIC\n', None: '', }.get(determinism) - js = f'''\ + sql_code = f'''\ CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature}) RETURNS {return_type} {determinism_formatted}LANGUAGE js AS """ @@ -308,7 +308,7 @@ def compiles_udf_node(t, op): def wrapped(*args, **kwargs): node = udf_node(*args, **kwargs) - object.__setattr__(node, "js", js) + object.__setattr__(node, "sql", sql_code) return node.to_expr() wrapped.__signature__ = inspect.Signature( @@ -320,8 +320,90 @@ def wrapped(*args, **kwargs): ] ) wrapped.__name__ = name - wrapped.js = js + wrapped.sql = sql_code return wrapped + @staticmethod + def sql( + name: str, + params: Mapping[str, dt.DataType | Literal["ANY TYPE"]], + output_type: dt.DataType, + sql_expression: str, + ) -> Callable: + """Define a SQL UDF for BigQuery. + + Parameters + ---------- + name: + The name of the function. + params + Mapping of names and types of parameters + output_type + Return type of the UDF. + sql_expression + The SQL expression that defines the function. + + Returns + ------- + Callable + The wrapped user-defined function. + + Examples + -------- + >>> from ibis.backends.bigquery import udf + >>> import ibis.expr.datatypes as dt + >>> add_one = udf.sql( + ... name="add_one", + ... params={'x': dt.double}, + ... output_type=dt.double, + ... sql_expression="x + 1" + ... ) + >>> print(add_one.sql) + CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64) + RETURNS FLOAT64 + AS (x + 1) + """ + validate_output_type(output_type) + udf_node_fields = { + name: rlz.any if type_ == "ANY TYPE" else rlz.value(type_) + for name, type_ in params.items() + } + + return_type = ibis_type_to_bigquery_type(dt.dtype(output_type)) + + udf_node_fields["output_dtype"] = output_type + udf_node_fields["output_shape"] = rlz.shape_like("args") + udf_node_fields["__slots__"] = ("sql",) + + udf_node = _create_udf_node(name, udf_node_fields) + + from ibis.backends.bigquery.compiler import compiles + + @compiles(udf_node) + def compiles_udf_node(t, op): + args_formatted = ", ".join(map(t.translate, op.args)) + return f"{udf_node.__name__}({args_formatted})" + + bigquery_signature = ", ".join( + "{name} {type}".format( + name=name, + type="ANY TYPE" + if type_ == "ANY TYPE" + else ibis_type_to_bigquery_type(dt.dtype(type_)), + ) + for name, type_ in params.items() + ) + sql_code = f'''\ +CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature}) +RETURNS {return_type} +AS ({sql_expression});''' + + def wrapper(*args, **kwargs): + node = udf_node(*args, **kwargs) + object.__setattr__(node, "sql", sql_code) + return node.to_expr() + + return wrapper + udf = _BigQueryUDF() diff --git a/ibis/backends/bigquery/udf/core.py b/ibis/backends/bigquery/udf/core.py index 93cd788e26c6..f6aba56f6e9e 100644 --- a/ibis/backends/bigquery/udf/core.py +++ b/ibis/backends/bigquery/udf/core.py @@ -602,4 +602,4 @@ def range(n): nnn = len(values) return [sum(values) - a + b * y**-x, z, foo.width, nnn] - print(my_func.js) # noqa: T201 + print(my_func.sql) # noqa: T201