Skip to content

Commit

Permalink
feat(sqlite): add scalar python udf support to sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Jul 6, 2023
1 parent 578f875 commit 92f29e6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
42 changes: 42 additions & 0 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import annotations

import inspect
import sqlite3
from typing import TYPE_CHECKING, Iterator

import sqlalchemy as sa
import toolz
from sqlalchemy.dialects.sqlite import TIMESTAMP

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis import util
Expand All @@ -32,11 +34,15 @@
if TYPE_CHECKING:
from pathlib import Path

import ibis.expr.operations as ops
import ibis.expr.types as ir


class Backend(BaseAlchemyBackend):
name = 'sqlite'
compiler = SQLiteCompiler
supports_create_or_replace = False
supports_python_udfs = True

def __getstate__(self) -> dict:
r = super().__getstate__()
Expand Down Expand Up @@ -194,6 +200,42 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a SQLite SQL string."""
return sch.Schema.from_tuples(self._metadata(query))

def _register_udfs(self, expr: ir.Expr) -> None:
import ibis.expr.operations as ops

with self.begin() as con:
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)

registration_func = compile_func(udf_node)
registration_func(con)

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None:
func = udf_node.__func__
name = func.__name__

for argname, arg in zip(udf_node.argnames, udf_node.args):
dtype = arg.output_dtype
if not (
dtype.is_string()
or dtype.is_binary()
or dtype.is_numeric()
or dtype.is_boolean()
):
raise com.IbisTypeError(
"SQLite only supports strings, bytes, booleans and numbers as UDF input and output, "
f"got argument `{argname}` with unsupported type {dtype}"
)

def register_udf(con):
return con.connection.create_function(
name, len(inspect.signature(func).parameters), udf.ignore_nulls(func)
)

return register_udf

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
Expand Down
18 changes: 11 additions & 7 deletions ibis/backends/sqlite/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
_SQLITE_UDAF_REGISTRY = set()


def ignore_nulls(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if any(arg is None for arg in args):
return None
return f(*args, **kwargs)

return wrapper


def udf(f):
"""Create a SQLite scalar UDF from `f`.
Expand All @@ -31,13 +41,7 @@ def udf(f):
A callable object that returns ``None`` if any of its inputs are
``None``.
"""

@functools.wraps(f)
def wrapper(*args):
if any(arg is None for arg in args):
return None
return f(*args)

wrapper = ignore_nulls(f)
_SQLITE_UDF_REGISTRY.add(wrapper)
return wrapper

Expand Down
12 changes: 9 additions & 3 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pandas.testing as tm
import sqlalchemy as sa
from pytest import mark, param

from ibis import _, udf
Expand All @@ -18,7 +19,6 @@
"pandas",
"polars",
"pyspark",
"sqlite",
"trino",
]
)
Expand Down Expand Up @@ -49,6 +49,11 @@ def num_vowels(s: str, include_y: bool = False) -> int:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(
["sqlite"],
raises=sa.exc.OperationalError,
reason="sqlite doesn't support map types",
)
def test_map_udf(batting):
@udf.scalar.python
def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
Expand All @@ -73,6 +78,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.notyet(["sqlite"], raises=TypeError, reason="sqlite doesn't support map types")
def test_map_merge_udf(batting):
@udf.scalar.python
def vowels_map(s: str) -> dict[str, int]:
Expand Down Expand Up @@ -140,7 +146,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pandas,
marks=[
mark.notyet(
["duckdb", "datafusion"],
["duckdb", "datafusion", "sqlite"],
raises=NotImplementedError,
reason="backend doesn't support pandas UDFs",
),
Expand All @@ -150,7 +156,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pyarrow,
marks=[
mark.notyet(
["snowflake"],
["snowflake", "sqlite"],
raises=NotImplementedError,
reason="backend doesn't support pyarrow UDFs",
)
Expand Down

0 comments on commit 92f29e6

Please sign in to comment.