Skip to content

Commit

Permalink
feat(common): add support for start parameter in StringFind
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed May 22, 2023
1 parent e499c7f commit 31ce741
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 15 deletions.
14 changes: 10 additions & 4 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,19 @@ def _substring(t, op):

def _gen_string_find(func):
def string_find(t, op):
if op.start is not None:
raise NotImplementedError("`start` not yet implemented")

if op.end is not None:
raise NotImplementedError("`end` not yet implemented")

return func(t.translate(op.arg), t.translate(op.substr)) - 1
arg = t.translate(op.arg)
sub_string = t.translate(op.substr)

if (op_start := op.start) is not None:
start = t.translate(op_start)
arg = sa.func.substr(arg, start + 1)
pos = func(arg, sub_string)
return sa.case((pos > 0, pos - 1 + start), else_=-1)

return func(arg, sub_string) - 1

return string_find

Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,16 @@ def _substring(op, **kw):

@translate_val.register(ops.StringFind)
def _string_find(op, **kw):
if op.start is not None:
raise com.UnsupportedOperationError(
"String find doesn't support start argument"
)
if op.end is not None:
raise com.UnsupportedOperationError("String find doesn't support end argument")

arg = translate_val(op.arg, **kw)
substr = translate_val(op.substr, **kw)

if op.start is not None:
op_start = translate_val(op.start)
return f"locate({arg}, {substr}, {op_start}) - 1"

return f"locate({arg}, {substr}) - 1"


Expand Down
16 changes: 11 additions & 5 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,13 +530,19 @@ def regex_replace(op):

@translate.register(ops.StringFind)
def string_find(op):
if op.end is not None:
raise NotImplementedError("`end` not yet implemented")

arg = translate(op.arg)
pattern = translate(op.substr)

if op.start is not None:
raise NotImplementedError("`start` not yet implemented")

if op.end is not None:
raise NotImplementedError("`end` not yet implemented")
if (op_start := op.start) is not None:
sub_string = ops.Substring(op.arg, op_start)
arg = translate(sub_string)
pos = df.functions.strpos(arg, pattern)
start = translate(op_start)
return df.functions.coalesce(
df.functions.nullif(pos + start, start), df.lit(0)
) - df.lit(1)

return df.functions.strpos(arg, pattern) - df.lit(1)
14 changes: 12 additions & 2 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.backends.base.sql.alchemy.registry import (
_gen_string_find,
geospatial_functions,
)

Expand Down Expand Up @@ -137,6 +136,17 @@ def _regex_extract(arg, pattern, index):
)


def _string_find(t, op):
arg = t.translate(op.arg)
substr = t.translate(op.substr)

if op_start := op.start:
start = t.translate(op_start)
return sa.func.locate(substr, arg, start) - 1

return sa.func.locate(substr, arg) - 1


class _mysql_trim(GenericFunction):
inherit_cache = True

Expand Down Expand Up @@ -164,7 +174,7 @@ def compiles_mysql_trim(element, compiler, **kw):
# static checks are not happy with using "if" as a property
ops.Where: fixed_arity(getattr(sa.func, 'if'), 3),
# strings
ops.StringFind: _gen_string_find(sa.func.locate),
ops.StringFind: _string_find,
ops.FindInSet: (
lambda t, op: (
sa.func.find_in_set(
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def test_string_col_is_unicode(alltypes, df):
id='find',
marks=pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError),
),
param(
lambda t: t.date_string_col.find('13', 3),
lambda t: t.date_string_col.str.find('13', 3),
id='find_start',
marks=[
pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError),
pytest.mark.notyet(["bigquery"], raises=NotImplementedError),
],
),
param(
lambda t: t.string_col.lpad(10, 'a'),
lambda t: t.string_col.str.pad(10, fillchar='a', side='left'),
Expand Down

0 comments on commit 31ce741

Please sign in to comment.