From abb0bddaca2df6cb31ca7affa52f3a937e1e0f9b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 17 Oct 2023 05:54:13 -0400 Subject: [PATCH] fix(bigquery): move sql code to proper argument --- ibis/backends/bigquery/udf/__init__.py | 89 +++++++++++--------------- ibis/backends/tests/test_export.py | 3 +- 2 files changed, 39 insertions(+), 53 deletions(-) diff --git a/ibis/backends/bigquery/udf/__init__.py b/ibis/backends/bigquery/udf/__init__.py index 85f7ef01beb6..bd4a472a1a23 100644 --- a/ibis/backends/bigquery/udf/__init__.py +++ b/ibis/backends/bigquery/udf/__init__.py @@ -20,24 +20,10 @@ _udf_name_cache: dict[str, Iterable[int]] = collections.defaultdict(itertools.count) -def _create_udf_node(name, fields): - """Create a new UDF node type. - - Parameters - ---------- - name : str - Then name of the UDF node - fields : OrderedDict - Mapping of class member name to definition - - Returns - ------- - result : type - A new BigQueryUDFNode subclass - """ +def _make_udf_name(name): definition = next(_udf_name_cache[name]) external_name = f"{name}_{definition:d}" - return type(external_name, (BigQueryUDFNode,), fields) + return external_name class _BigQueryUDF: @@ -274,24 +260,6 @@ def js( if libraries is None: libraries = [] - udf_node_fields = { - name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_) - for name, type_ in params.items() - } - - udf_node_fields["dtype"] = output_type - udf_node_fields["shape"] = rlz.shape_like("args") - udf_node_fields["__slots__"] = ("sql",) - - udf_node = _create_udf_node(name, udf_node_fields) - - from ibis.backends.bigquery.compiler import compiles - - @compiles(udf_node) - def compiles_udf_node(t, op): - args = ", ".join(map(t.translate, op.args)) - return f"{udf_node.__name__}({args})" - bigquery_signature = ", ".join( f"{name} {BigQueryType.from_ibis(dt.dtype(type_))}" for name, type_ in params.items() @@ -305,16 +273,35 @@ def compiles_udf_node(t, op): False: "NOT DETERMINISTIC\n", None: "", }.get(determinism) + + name = _make_udf_name(name) sql_code = f'''\ -CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature}) +CREATE TEMPORARY FUNCTION {name}({bigquery_signature}) RETURNS {return_type} {determinism_formatted}LANGUAGE js AS """ {body} """{libraries_opts};''' + udf_node_fields = { + name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_) + for name, type_ in params.items() + } + + udf_node_fields["dtype"] = output_type + udf_node_fields["shape"] = rlz.shape_like("args") + udf_node_fields["sql"] = sql_code + + udf_node = type(name, (BigQueryUDFNode,), udf_node_fields) + + from ibis.backends.bigquery.compiler import compiles + + @compiles(udf_node) + def compiles_udf_node(t, op): + args = ", ".join(map(t.translate, op.args)) + return f"{udf_node.__name__}({args})" + def wrapped(*args, **kwargs): node = udf_node(*args, **kwargs) - object.__setattr__(node, "sql", sql_code) return node.to_expr() wrapped.__signature__ = inspect.Signature( @@ -376,19 +363,6 @@ def sql( } return_type = BigQueryType.from_ibis(dt.dtype(output_type)) - udf_node_fields["dtype"] = output_type - udf_node_fields["shape"] = rlz.shape_like("args") - udf_node_fields["__slots__"] = ("sql",) - - udf_node = _create_udf_node(name, udf_node_fields) - - from ibis.backends.bigquery.compiler import compiles - - @compiles(udf_node) - def compiles_udf_node(t, op): - args_formatted = ", ".join(map(t.translate, op.args)) - return f"{udf_node.__name__}({args_formatted})" - bigquery_signature = ", ".join( "{name} {type}".format( name=name, @@ -398,14 +372,27 @@ def compiles_udf_node(t, op): ) for name, type_ in params.items() ) + name = _make_udf_name(name) sql_code = f"""\ -CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature}) +CREATE TEMPORARY FUNCTION {name}({bigquery_signature}) RETURNS {return_type} AS ({sql_expression});""" + udf_node_fields["dtype"] = output_type + udf_node_fields["shape"] = rlz.shape_like("args") + udf_node_fields["sql"] = sql_code + + udf_node = type(name, (BigQueryUDFNode,), udf_node_fields) + + from ibis.backends.bigquery.compiler import compiles + + @compiles(udf_node) + def compiles_udf_node(t, op): + args = ", ".join(map(t.translate, op.args)) + return f"{udf_node.__name__}({args})" + def wrapper(*args, **kwargs): node = udf_node(*args, **kwargs) - object.__setattr__(node, "sql", sql_code) return node.to_expr() return wrapper diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 662f8c0e9468..a9e9d0045e63 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -329,7 +329,6 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype): @pytest.mark.notyet( [ - "bigquery", "impala", "mysql", "oracle", @@ -343,7 +342,7 @@ def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype): ) @pytest.mark.notyet(["clickhouse"], raises=Exception) @pytest.mark.notyet(["mssql", "pandas"], raises=PyDeltaTableError) -@pytest.mark.notyet(["dask"], raises=NotImplementedError) +@pytest.mark.notyet(["bigquery", "dask"], raises=NotImplementedError) @pytest.mark.notyet( ["druid"], raises=pa.lib.ArrowTypeError,