From 92f29e6b820f207b16a306e5a39f91d396fe818a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 26 Jun 2023 07:29:06 -0400 Subject: [PATCH] feat(sqlite): add scalar python udf support to sqlite --- ibis/backends/sqlite/__init__.py | 42 ++++++++++++++++++++++++++++++++ ibis/backends/sqlite/udf.py | 18 ++++++++------ ibis/backends/tests/test_udf.py | 12 ++++++--- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index eba423b3e162..707ba9721ffb 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -14,6 +14,7 @@ from __future__ import annotations +import inspect import sqlite3 from typing import TYPE_CHECKING, Iterator @@ -21,6 +22,7 @@ import toolz from sqlalchemy.dialects.sqlite import TIMESTAMP +import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.schema as sch from ibis import util @@ -32,11 +34,15 @@ if TYPE_CHECKING: from pathlib import Path + import ibis.expr.operations as ops + import ibis.expr.types as ir + class Backend(BaseAlchemyBackend): name = 'sqlite' compiler = SQLiteCompiler supports_create_or_replace = False + supports_python_udfs = True def __getstate__(self) -> dict: r = super().__getstate__() @@ -194,6 +200,42 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: """Return an ibis Schema from a SQLite SQL string.""" return sch.Schema.from_tuples(self._metadata(query)) + def _register_udfs(self, expr: ir.Expr) -> None: + import ibis.expr.operations as ops + + with self.begin() as con: + for udf_node in expr.op().find(ops.ScalarUDF): + compile_func = getattr( + self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + ) + + registration_func = compile_func(udf_node) + registration_func(con) + + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: + func = udf_node.__func__ + name = func.__name__ + + for argname, arg in zip(udf_node.argnames, udf_node.args): + dtype = arg.output_dtype + if not ( + dtype.is_string() + or dtype.is_binary() + or dtype.is_numeric() + or dtype.is_boolean() + ): + raise com.IbisTypeError( + "SQLite only supports strings, bytes, booleans and numbers as UDF input and output, " + f"got argument `{argname}` with unsupported type {dtype}" + ) + + def register_udf(con): + return con.connection.create_function( + name, len(inspect.signature(func).parameters), udf.ignore_nulls(func) + ) + + return register_udf + def _get_temp_view_definition( self, name: str, definition: sa.sql.compiler.Compiled ) -> str: diff --git a/ibis/backends/sqlite/udf.py b/ibis/backends/sqlite/udf.py index 8b6f22236e73..830cc253d601 100644 --- a/ibis/backends/sqlite/udf.py +++ b/ibis/backends/sqlite/udf.py @@ -17,6 +17,16 @@ _SQLITE_UDAF_REGISTRY = set() +def ignore_nulls(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + if any(arg is None for arg in args): + return None + return f(*args, **kwargs) + + return wrapper + + def udf(f): """Create a SQLite scalar UDF from `f`. @@ -31,13 +41,7 @@ def udf(f): A callable object that returns ``None`` if any of its inputs are ``None``. """ - - @functools.wraps(f) - def wrapper(*args): - if any(arg is None for arg in args): - return None - return f(*args) - + wrapper = ignore_nulls(f) _SQLITE_UDF_REGISTRY.add(wrapper) return wrapper diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index dac04bb33386..cc0c58387704 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -1,6 +1,7 @@ from __future__ import annotations import pandas.testing as tm +import sqlalchemy as sa from pytest import mark, param from ibis import _, udf @@ -18,7 +19,6 @@ "pandas", "polars", "pyspark", - "sqlite", "trino", ] ) @@ -49,6 +49,11 @@ def num_vowels(s: str, include_y: bool = False) -> int: ["postgres"], raises=TypeError, reason="postgres only supports map" ) @mark.notyet(["datafusion"], raises=NotImplementedError) +@mark.notyet( + ["sqlite"], + raises=sa.exc.OperationalError, + reason="sqlite doesn't support map types", +) def test_map_udf(batting): @udf.scalar.python def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]: @@ -73,6 +78,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]: ["postgres"], raises=TypeError, reason="postgres only supports map" ) @mark.notyet(["datafusion"], raises=NotImplementedError) +@mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types") def test_map_merge_udf(batting): @udf.scalar.python def vowels_map(s: str) -> dict[str, int]: @@ -140,7 +146,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type add_one_pandas, marks=[ mark.notyet( - ["duckdb", "datafusion"], + ["duckdb", "datafusion", "sqlite"], raises=NotImplementedError, reason="backend doesn't support pandas UDFs", ), @@ -150,7 +156,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type add_one_pyarrow, marks=[ mark.notyet( - ["snowflake"], + ["snowflake", "sqlite"], raises=NotImplementedError, reason="backend doesn't support pyarrow UDFs", )