Skip to content

Commit

Permalink
feat(datafusion): add support for scalar pyarrow UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 27, 2023
1 parent 3283333 commit 45935b7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
18 changes: 18 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType


Expand Down Expand Up @@ -467,6 +468,23 @@ def elementwise_udf(op):
return udf(*args)


@translate.register(ops.ScalarUDF)
def scalar_udf(op):
if (input_type := op.__input_type__) != InputType.PYARROW:
raise NotImplementedError(
f"DataFusion only supports pyarrow UDFs: got a {input_type.name.lower()} UDF"
)
udf = df.udf(
op.__func__,
input_types=[PyArrowType.from_ibis(arg.output_dtype) for arg in op.args],
return_type=PyArrowType.from_ibis(op.output_dtype),
volatility="volatile",
)
args = map(translate, op.args)

return udf(*args)


@translate.register(ops.StringConcat)
def string_concat(op):
return df.functions.concat(*map(translate, op.arg))
Expand Down
10 changes: 6 additions & 4 deletions ibis/backends/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"bigquery",
"clickhouse",
"dask",
"datafusion",
"druid",
"impala",
"mssql",
Expand All @@ -29,6 +28,7 @@


@no_python_udfs
@mark.notyet(["datafusion"], raises=NotImplementedError)
def test_udf(batting):
@udf.scalar.python
def num_vowels(s: str, include_y: bool = False) -> int:
Expand All @@ -49,6 +49,7 @@ def num_vowels(s: str, include_y: bool = False) -> int:
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.xfail(
sys.version_info[:2] < (3, 9), reason="annotations not supported with Python 3.8"
)
Expand All @@ -73,6 +74,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
@mark.notyet(
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
)
@mark.notyet(["datafusion"], raises=NotImplementedError)
@mark.xfail(
sys.version_info[:2] < (3, 9), reason="annotations not supported with Python 3.8"
)
Expand Down Expand Up @@ -141,9 +143,9 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
add_one_pandas,
marks=[
mark.notyet(
["duckdb"],
["duckdb", "datafusion"],
raises=NotImplementedError,
reason="duckdb doesn't support pandas UDFs",
reason="backend doesn't support pandas UDFs",
),
],
),
Expand All @@ -153,7 +155,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
mark.notyet(
["snowflake"],
raises=NotImplementedError,
reason="snowflake doesn't support pyarrow UDFs",
reason="backend doesn't support pyarrow UDFs",
)
],
),
Expand Down

0 comments on commit 45935b7

Please sign in to comment.