diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index 1241b04266f1..33cfe78f353c 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -109,6 +109,7 @@ def translate_val(op, **_): ops.ArrayContains: "array_contains", ops.ArrayLength: "array_length", ops.ArrayRemove: "array_remove_all", + ops.RegexSplit: "regex_split", } for _op, _name in _simple_ops.items(): diff --git a/ibis/backends/datafusion/udfs.py b/ibis/backends/datafusion/udfs.py index 7bf52a953623..e9937031633f 100644 --- a/ibis/backends/datafusion/udfs.py +++ b/ibis/backends/datafusion/udfs.py @@ -7,7 +7,8 @@ import pyarrow.compute as pc import pyarrow_hotfix # noqa: F401 -import ibis.expr.datatypes as dt # noqa: TCH001 +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt def _extract_epoch_seconds(array) -> dt.int32: @@ -113,3 +114,12 @@ def extract_minute_timestamp(array: dt.Timestamp(scale=9)) -> dt.int32: def extract_hour_time(array: dt.time) -> dt.int32: return pc.cast(pc.hour(array), pa.int32()) + + +def regex_split(s: str, pattern: str) -> list[str]: + patterns = pattern.to_pylist() + if len(patterns) != 1: + raise com.IbisError( + "Only a single scalar pattern is supported for DataFusion re_split" + ) + return pc.split_pattern_regex(s, patterns[0]) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 5e65aa7f9ee5..5e2cdc394917 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -1103,7 +1103,6 @@ def test_non_match_regex_search_is_false(con): @pytest.mark.notimpl( [ "dask", - "datafusion", "impala", "mysql", "sqlite",