Skip to content

Commit

Permalink
feat(pyspark): implement RegexSplit
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Dec 19, 2023
1 parent c955b6a commit cfe0329
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
11 changes: 11 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,3 +2124,14 @@ def compile_timestamp_range(t, op, **kwargs):
step = F.expr(f"INTERVAL {step_value} {unit}")

return _build_sequence(start, stop, step, _zero_value(op.step.dtype))


@compiles(ops.RegexSplit)
def compile_regex_split(t, op, **kwargs):
src_column = t.translate(op.arg, **kwargs)
if not isinstance(op.pattern, ops.Literal):
raise com.UnsupportedOperationError(
"`pattern` argument of re_split must be a literal"
)
pattern = t.translate(op.pattern, raw=True, **kwargs)
return F.split(src_column, pattern)
74 changes: 73 additions & 1 deletion ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,4 +1122,76 @@ def test_re_split(con):
lit = ibis.literal(",a,,,,c")
expr = lit.re_split(",+")
result = con.execute(expr)
assert result == ["", "a", "c"]
assert list(result) == ["", "a", "c"]


@pytest.mark.notimpl(
[
"dask",
"impala",
"mysql",
"sqlite",
"mssql",
"druid",
"oracle",
"flink",
"exasol",
"pandas",
"bigquery",
],
raises=com.OperationNotDefinedError,
)
def test_re_split_column(alltypes):
expr = alltypes.limit(5).string_col.re_split(r"\d+")
result = expr.execute()
assert all(not any(element) for element in result)


@pytest.mark.notimpl(
[
"dask",
"impala",
"mysql",
"sqlite",
"mssql",
"druid",
"oracle",
"flink",
"exasol",
"pandas",
"bigquery",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["clickhouse"],
raises=ClickhouseDatabaseError,
reason="clickhouse only supports pattern constants",
)
@pytest.mark.notyet(
["polars"],
raises=BaseException, # yikes, panic exception
reason="pyarrow doesn't support splitting on a pattern per row",
)
@pytest.mark.notyet(
["datafusion"],
raises=Exception,
reason="pyarrow doesn't support splitting on a pattern per row",
)
@pytest.mark.notyet(
["pyspark"],
raises=com.UnsupportedOperationError,
reason="pyspark only supports pattern constants",
)
def test_re_split_column_multiple_patterns(alltypes):
expr = (
alltypes.filter(lambda t: t.string_col.isin(("1", "2")))
.select(
splits=lambda t: t.string_col.re_split(
ibis.ifelse(t.string_col == "1", "0|1", r"\d+")
)
)
.splits
)
result = expr.execute()
assert all(not any(element) for element in result)

0 comments on commit cfe0329

Please sign in to comment.