diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 9de6279b44dc..a6aa12794f6e 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -510,3 +510,33 @@ def regex_extract(op): name="string_array_get", ) return string_array_get(df.functions.regexp_match(arg, pattern)) + + +@translate.register(ops.StringReplace) +def string_replace(op): + arg = translate(op.arg) + pattern = translate(op.pattern) + replacement = translate(op.replacement) + return df.functions.replace(arg, pattern, replacement) + + +@translate.register(ops.RegexReplace) +def regex_replace(op): + arg = translate(op.arg) + pattern = translate(op.pattern) + replacement = translate(op.replacement) + return df.functions.regexp_replace(arg, pattern, replacement, df.lit("g")) + + +@translate.register(ops.StringFind) +def string_find(op): + arg = translate(op.arg) + pattern = translate(op.substr) + + if op.start is not None: + raise NotImplementedError("`start` not yet implemented") + + if op.end is not None: + raise NotImplementedError("`end` not yet implemented") + + return df.functions.strpos(arg, pattern) - df.lit(1) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 7858a8d09161..633186276c1c 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -336,7 +336,7 @@ def test_string_col_is_unicode(alltypes, df): id='re_replace_posix', marks=[ pytest.mark.notimpl( - ['datafusion', "mysql", "mssql", "druid", "oracle"], + ["mysql", "mssql", "druid", "oracle"], raises=com.OperationNotDefinedError, ), pytest.mark.broken( @@ -351,7 +351,7 @@ def test_string_col_is_unicode(alltypes, df): id='re_replace', marks=[ pytest.mark.notimpl( - ["datafusion", "mysql", "mssql", "druid", "oracle"], + ["mysql", "mssql", "druid", "oracle"], raises=com.OperationNotDefinedError, ), pytest.mark.broken( @@ -413,9 +413,7 @@ def test_string_col_is_unicode(alltypes, df): lambda t: t.string_col.find('a'), lambda t: t.string_col.str.find('a'), id='find', - marks=pytest.mark.notimpl( - ["datafusion", "polars"], raises=com.OperationNotDefinedError - ), + marks=pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError), ), param( lambda t: t.string_col.lpad(10, 'a'), @@ -826,9 +824,6 @@ def test_string_col_is_unicode(alltypes, df): lambda t: t.string_col.replace("1", "42"), lambda t: t.string_col.str.replace("1", "42"), id="replace", - marks=pytest.mark.notimpl( - ["datafusion"], raises=com.OperationNotDefinedError - ), ), ], ) @@ -841,7 +836,7 @@ def test_string(backend, alltypes, df, result_func, expected_func): @pytest.mark.notimpl( - ["datafusion", "mysql", "mssql", "druid", "oracle"], + ["mysql", "mssql", "druid", "oracle"], raises=com.OperationNotDefinedError, ) def test_re_replace_global(con):