diff --git a/ibis/backends/base/sql/__init__.py b/ibis/backends/base/sql/__init__.py index 48a5f2b70a6f..9a4485fc30e6 100644 --- a/ibis/backends/base/sql/__init__.py +++ b/ibis/backends/base/sql/__init__.py @@ -4,7 +4,7 @@ import contextlib import os from functools import lru_cache -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import toolz @@ -255,10 +255,13 @@ def _register_udfs(self, expr: ir.Expr) -> None: if self.supports_python_udfs: raise NotImplementedError(self.name) + def _gen_udf_name(self, name: str, schema: Optional[str]) -> str: + return ".".join(filter(None, (schema, name))) + def _gen_udf_rule(self, op: ops.ScalarUDF): @self.add_operation(type(op)) def _(t, op): - func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__))) + func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__) return f"{func}({', '.join(map(t.translate, op.args))})" def _gen_udaf_rule(self, op: ops.AggUDF): @@ -266,7 +269,7 @@ def _gen_udaf_rule(self, op: ops.AggUDF): @self.add_operation(type(op)) def _(t, op): - func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__))) + func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__) args = ", ".join( t.translate( ops.IfElse(where, arg, NA) diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 402a7fdf9797..960c7a7f4556 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -9,7 +9,7 @@ import re import warnings from functools import partial -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional from urllib.parse import parse_qs, urlparse import google.auth.credentials @@ -785,6 +785,12 @@ def to_pyarrow_batches( ) return pa.RecordBatchReader.from_batches(schema.to_pyarrow(), batch_iter) + def _gen_udf_name(self, name: str, schema: Optional[str]) -> str: + func = ".".join(filter(None, (schema, name))) + if "." in func: + return ".".join(f"`{part}`" for part in func.split(".")) + return func + def get_schema(self, name, schema: str | None = None, database: str | None = None): table_ref = bq.TableReference( bq.DatasetReference( diff --git a/ibis/backends/bigquery/tests/unit/udf/snapshots/test_builtin/test_bqutil_fn_from_hex/out.sql b/ibis/backends/bigquery/tests/unit/udf/snapshots/test_builtin/test_bqutil_fn_from_hex/out.sql new file mode 100644 index 000000000000..2cabc41aa447 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/udf/snapshots/test_builtin/test_bqutil_fn_from_hex/out.sql @@ -0,0 +1,2 @@ +SELECT + `bqutil`.`fn`.from_hex('face') AS `from_hex_'face'` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/udf/snapshots/test_builtin/test_farm_fingerprint/out.sql b/ibis/backends/bigquery/tests/unit/udf/snapshots/test_builtin/test_farm_fingerprint/out.sql new file mode 100644 index 000000000000..9128e636a8a6 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/udf/snapshots/test_builtin/test_farm_fingerprint/out.sql @@ -0,0 +1,2 @@ +SELECT + farm_fingerprint(b'Hello, World!') AS `farm_fingerprint_b'Hello_ World_'` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/udf/test_builtin.py b/ibis/backends/bigquery/tests/unit/udf/test_builtin.py new file mode 100644 index 000000000000..5f5042f96574 --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/udf/test_builtin.py @@ -0,0 +1,30 @@ + +import ibis + +to_sql = ibis.bigquery.compile + + +@ibis.udf.scalar.builtin +def farm_fingerprint(value: bytes) -> int: + ... + + +@ibis.udf.scalar.builtin(schema="bqutil.fn") +def from_hex(value: str) -> int: + """Community function to convert from hex string to integer. + + See: + https://github.com/GoogleCloudPlatform/bigquery-utils/tree/master/udfs/community#from_hexvalue-string + """ + + +def test_bqutil_fn_from_hex(snapshot): + # Project ID should be enclosed in backticks. + expr = from_hex("face") + snapshot.assert_match(to_sql(expr), "out.sql") + + +def test_farm_fingerprint(snapshot): + # No backticks needed if there is no schema defined. + expr = farm_fingerprint(b"Hello, World!") + snapshot.assert_match(to_sql(expr), "out.sql")