Skip to content

Commit

Permalink
feat(duckdb): add support for native and pyarrow UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 26, 2023
1 parent f1a775b commit 7e56fc4
Show file tree
Hide file tree
Showing 17 changed files with 291 additions and 141 deletions.
5 changes: 5 additions & 0 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,14 @@ def _register_in_memory_tables(self, expr: ir.Expr):

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
"""Backend-specific hooks to run before an expression is executed."""
self._define_udf_translation_rules(expr)
self._register_udfs(expr)
self._register_in_memory_tables(expr)

def _define_udf_translation_rules(self, expr):
if self.supports_in_memory_tables:
raise NotImplementedError(self.name)

def compile(
self,
expr: ir.Expr,
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,14 @@ def to_pyarrow_batches(

return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), batches)

def _compile_udfs(self, expr: ir.Expr) -> Iterable[str]:
def _register_udfs(self, expr: ir.Expr) -> None:
"""Return an iterator of DDL strings, once for each UDFs contained within `expr`."""
if self.supports_python_udfs:
raise NotImplementedError(self.name)
return []

def _define_udf_translation_rules(self, expr: ir.Expr) -> None:
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def execute(
self,
Expand Down Expand Up @@ -311,7 +314,7 @@ def compile(
The output of compilation. The type of this value depends on the
backend.
"""
util.consume(self._compile_udfs(expr))
self._define_udf_translation_rules(expr)
return self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
Expand Down
43 changes: 20 additions & 23 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,27 +707,6 @@ def insert(
f"The given obj is of type {type(obj).__name__} ."
)

def _compile_udfs(self, expr: ir.Expr) -> Iterable[str]:
for udf_node in expr.op().find(ops.ScalarUDF):
udf_node_type = type(udf_node)

if udf_node_type not in self.compiler.translator_class._registry:

@self.add_operation(udf_node_type)
def _(t, op):
generator = sa.func
if (namespace := op.__udf_namespace__) is not None:
generator = getattr(generator, namespace)
func = getattr(generator, type(op).__name__)
return func(*map(t.translate, op.args))

compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
compiled = compile_func(udf_node)
if compiled is not None:
yield compiled

def _compile_opaque_udf(self, udf_node: ops.ScalarUDF) -> str:
return None

Expand All @@ -749,10 +728,28 @@ def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> str:
f"The {self.name} backend does not support PyArrow-based vectorized scalar UDFs"
)

def _define_udf_translation_rules(self, expr):
for udf_node in expr.op().find(ops.ScalarUDF):
udf_node_type = type(udf_node)

if udf_node_type not in self.compiler.translator_class._registry:

@self.add_operation(udf_node_type)
def _(t, op):
generator = sa.func
if (namespace := op.__udf_namespace__) is not None:
generator = getattr(generator, namespace)
func = getattr(generator, type(op).__name__)
return func(*map(t.translate, op.args))

def _register_udfs(self, expr: ir.Expr) -> None:
with self.begin() as con:
for sql in self._compile_udfs(expr):
con.exec_driver_sql(sql)
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
con.exec_driver_sql(sql)

def _quote(self, name: str) -> str:
"""Quote an identifier."""
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ class Unknown(sa.Text):


class AlchemyType(TypeMapper):
@classmethod
def to_string(cls, dtype: dt.DataType):
dialect_class = sa.dialects.registry.load(cls.dialect)
return str(
sa.types.to_instance(cls.from_ibis(dtype)).compile(dialect=dialect_class())
)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine:
"""Convert an Ibis type to a SQLAlchemy type.
Expand Down
55 changes: 54 additions & 1 deletion ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, MutableMapping
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Iterator,
Mapping,
MutableMapping,
)

import duckdb
import pyarrow as pa
Expand All @@ -25,6 +32,7 @@
from ibis.backends.duckdb.compiler import DuckDBSQLCompiler
from ibis.backends.duckdb.datatypes import DuckDBType, parse
from ibis.expr.operations.relations import PandasDataFrameProxy
from ibis.expr.operations.udf import InputType
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,6 +63,12 @@ def _format_kwargs(kwargs: Mapping[str, Any]):
return sa.text(", ".join(pieces)).bindparams(*bindparams)


_UDF_INPUT_TYPE_MAPPING = {
InputType.PYARROW: duckdb.functional.ARROW,
InputType.PYTHON: duckdb.functional.NATIVE,
}


class Backend(BaseAlchemyBackend):
name = "duckdb"
compiler = DuckDBSQLCompiler
Expand Down Expand Up @@ -914,6 +928,45 @@ def _get_temp_view_definition(
) -> str:
yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"

def _register_udfs(self, expr: ir.Expr) -> None:
import ibis.expr.operations as ops

with self.begin() as con:
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
with contextlib.suppress(duckdb.InvalidInputException):
con.connection.driver_connection.remove_function(
udf_node.__class__.__name__
)

registration_func = compile_func(udf_node)
registration_func(con)

def _compile_udf(self, udf_node: ops.ScalarUDF) -> None:
func = udf_node.__func__
name = func.__name__
input_types = [DuckDBType.to_string(arg.output_dtype) for arg in udf_node.args]
output_type = DuckDBType.to_string(udf_node.output_dtype)

def register_udf(con):
return con.connection.driver_connection.create_function(
name,
func,
input_types,
output_type,
type=_UDF_INPUT_TYPE_MAPPING[udf_node.__input_type__],
)

return register_udf

_compile_python_udf = _compile_udf
_compile_pyarrow_udf = _compile_udf

def _compile_pandas_udf(self, _: ops.ScalarUDF) -> None:
raise NotImplementedError("duckdb doesn't support pandas UDFs")

def _get_compiled_statement(self, view: sa.Table, definition: sa.sql.Selectable):
# TODO: remove this once duckdb supports CTAS prepared statements
return super()._get_compiled_statement(
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:


class DuckDBType(AlchemyType):
dialect = "duckdb"

@classmethod
def to_ibis(cls, typ, nullable=True):
if dtype := _from_duckdb_types.get(type(typ)):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/mssql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def _type_from_result_set_info(col: _FieldDescription) -> dt.DataType:


class MSSQLType(AlchemyType):
dialect = "mssql"

@classmethod
def to_ibis(cls, typ, nullable=True):
if dtype := _from_mssql_types.get(type(typ)):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/mysql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def result_processor(self, *_):


class MySQLType(AlchemyType):
dialect = "mysql"

@classmethod
def from_ibis(cls, dtype):
try:
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/oracle/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


class OracleType(AlchemyType):
dialect = "oracle"

@classmethod
def to_ibis(cls, typ, nullable=True):
if isinstance(typ, oracle.ROWID):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/postgres/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def _get_type(typestr: str) -> dt.DataType:


class PostgresType(AlchemyType):
dialect = "postgresql"

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine:
if dtype.is_floating():
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/snowflake/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def parse(text: str) -> dt.DataType:


class SnowflakeType(AlchemyType):
dialect = "snowflake"

@classmethod
def from_ibis(cls, dtype):
if dtype.is_array():
Expand Down
113 changes: 0 additions & 113 deletions ibis/backends/snowflake/tests/test_udf.py

This file was deleted.

2 changes: 2 additions & 0 deletions ibis/backends/sqlite/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def parse(text: str) -> dt.DataType:


class SqliteType(AlchemyType):
dialect = "sqlite"

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sat.TypeEngine:
if dtype.is_floating():
Expand Down
Loading

0 comments on commit 7e56fc4

Please sign in to comment.