From 45935b78922f09ab5be60aef1a1efaf204bc7f4d Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 26 Jun 2023 07:36:54 -0400 Subject: [PATCH] feat(datafusion): add support for scalar pyarrow UDFs --- ibis/backends/datafusion/compiler.py | 18 ++++++++++++++++++ ibis/backends/tests/test_udf.py | 10 ++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index a3bcb6aed115..f0fca32ad38c 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -12,6 +12,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowType @@ -467,6 +468,23 @@ def elementwise_udf(op): return udf(*args) +@translate.register(ops.ScalarUDF) +def scalar_udf(op): + if (input_type := op.__input_type__) != InputType.PYARROW: + raise NotImplementedError( + f"DataFusion only supports pyarrow UDFs: got a {input_type.name.lower()} UDF" + ) + udf = df.udf( + op.__func__, + input_types=[PyArrowType.from_ibis(arg.output_dtype) for arg in op.args], + return_type=PyArrowType.from_ibis(op.output_dtype), + volatility="volatile", + ) + args = map(translate, op.args) + + return udf(*args) + + @translate.register(ops.StringConcat) def string_concat(op): return df.functions.concat(*map(translate, op.arg)) diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index 5cfca0c1f07c..7619d5ff8181 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -13,7 +13,6 @@ "bigquery", "clickhouse", "dask", - "datafusion", "druid", "impala", "mssql", @@ -29,6 +28,7 @@ @no_python_udfs +@mark.notyet(["datafusion"], raises=NotImplementedError) def test_udf(batting): @udf.scalar.python def num_vowels(s: str, include_y: bool = False) -> int: @@ -49,6 +49,7 @@ def num_vowels(s: str, include_y: bool = False) -> int: @mark.notyet( ["postgres"], raises=TypeError, reason="postgres only supports map" ) +@mark.notyet(["datafusion"], raises=NotImplementedError) @mark.xfail( sys.version_info[:2] < (3, 9), reason="annotations not supported with Python 3.8" ) @@ -73,6 +74,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]: @mark.notyet( ["postgres"], raises=TypeError, reason="postgres only supports map" ) +@mark.notyet(["datafusion"], raises=NotImplementedError) @mark.xfail( sys.version_info[:2] < (3, 9), reason="annotations not supported with Python 3.8" ) @@ -141,9 +143,9 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type add_one_pandas, marks=[ mark.notyet( - ["duckdb"], + ["duckdb", "datafusion"], raises=NotImplementedError, - reason="duckdb doesn't support pandas UDFs", + reason="backend doesn't support pandas UDFs", ), ], ), @@ -153,7 +155,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type mark.notyet( ["snowflake"], raises=NotImplementedError, - reason="snowflake doesn't support pyarrow UDFs", + reason="backend doesn't support pyarrow UDFs", ) ], ),