From 37b6b7f30286e19b5b3f3402f2f0d16c93c776d7 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 19 Dec 2023 06:36:23 -0500 Subject: [PATCH] feat(datafusion): implement `ops.RegexSplit` using pyarrow UDF --- ibis/backends/datafusion/compiler/values.py | 1 + ibis/backends/datafusion/udfs.py | 12 +++++++++++- ibis/backends/tests/test_string.py | 1 - 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index 1241b04266f1..33cfe78f353c 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -109,6 +109,7 @@ def translate_val(op, **_): ops.ArrayContains: "array_contains", ops.ArrayLength: "array_length", ops.ArrayRemove: "array_remove_all", + ops.RegexSplit: "regex_split", } for _op, _name in _simple_ops.items(): diff --git a/ibis/backends/datafusion/udfs.py b/ibis/backends/datafusion/udfs.py index 7bf52a953623..e9937031633f 100644 --- a/ibis/backends/datafusion/udfs.py +++ b/ibis/backends/datafusion/udfs.py @@ -7,7 +7,8 @@ import pyarrow.compute as pc import pyarrow_hotfix # noqa: F401 -import ibis.expr.datatypes as dt # noqa: TCH001 +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt def _extract_epoch_seconds(array) -> dt.int32: @@ -113,3 +114,12 @@ def extract_minute_timestamp(array: dt.Timestamp(scale=9)) -> dt.int32: def extract_hour_time(array: dt.time) -> dt.int32: return pc.cast(pc.hour(array), pa.int32()) + + +def regex_split(s: str, pattern: str) -> list[str]: + patterns = pattern.to_pylist() + if len(patterns) != 1: + raise com.IbisError( + "Only a single scalar pattern is supported for DataFusion re_split" + ) + return pc.split_pattern_regex(s, patterns[0]) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 5e65aa7f9ee5..5e2cdc394917 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -1103,7 +1103,6 @@ def test_non_match_regex_search_is_false(con): @pytest.mark.notimpl( [ "dask", - "datafusion", "impala", "mysql", "sqlite",