Skip to content

Commit

Permalink
feat(datafusion): add string functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed May 21, 2023
1 parent 1e99e9f commit 66c0afb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
30 changes: 30 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,33 @@ def regex_extract(op):
name="string_array_get",
)
return string_array_get(df.functions.regexp_match(arg, pattern))


@translate.register(ops.StringReplace)
def string_replace(op):
arg = translate(op.arg)
pattern = translate(op.pattern)
replacement = translate(op.replacement)
return df.functions.replace(arg, pattern, replacement)


@translate.register(ops.RegexReplace)
def regex_replace(op):
arg = translate(op.arg)
pattern = translate(op.pattern)
replacement = translate(op.replacement)
return df.functions.regexp_replace(arg, pattern, replacement, df.lit("g"))


@translate.register(ops.StringFind)
def string_find(op):
arg = translate(op.arg)
pattern = translate(op.substr)

if op.start is not None:
raise NotImplementedError("`start` not yet implemented")

if op.end is not None:
raise NotImplementedError("`end` not yet implemented")

return df.functions.strpos(arg, pattern) - df.lit(1)
13 changes: 4 additions & 9 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def test_string_col_is_unicode(alltypes, df):
id='re_replace_posix',
marks=[
pytest.mark.notimpl(
['datafusion', "mysql", "mssql", "druid", "oracle"],
["mysql", "mssql", "druid", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.broken(
Expand All @@ -351,7 +351,7 @@ def test_string_col_is_unicode(alltypes, df):
id='re_replace',
marks=[
pytest.mark.notimpl(
["datafusion", "mysql", "mssql", "druid", "oracle"],
["mysql", "mssql", "druid", "oracle"],
raises=com.OperationNotDefinedError,
),
pytest.mark.broken(
Expand Down Expand Up @@ -413,9 +413,7 @@ def test_string_col_is_unicode(alltypes, df):
lambda t: t.string_col.find('a'),
lambda t: t.string_col.str.find('a'),
id='find',
marks=pytest.mark.notimpl(
["datafusion", "polars"], raises=com.OperationNotDefinedError
),
marks=pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError),
),
param(
lambda t: t.string_col.lpad(10, 'a'),
Expand Down Expand Up @@ -826,9 +824,6 @@ def test_string_col_is_unicode(alltypes, df):
lambda t: t.string_col.replace("1", "42"),
lambda t: t.string_col.str.replace("1", "42"),
id="replace",
marks=pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError
),
),
],
)
Expand All @@ -841,7 +836,7 @@ def test_string(backend, alltypes, df, result_func, expected_func):


@pytest.mark.notimpl(
["datafusion", "mysql", "mssql", "druid", "oracle"],
["mysql", "mssql", "druid", "oracle"],
raises=com.OperationNotDefinedError,
)
def test_re_replace_global(con):
Expand Down

0 comments on commit 66c0afb

Please sign in to comment.