Skip to content

Commit

Permalink
feat(duckdb): implement regexp replace and extract
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 14, 2022
1 parent 0adc67f commit 18d16a7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 115 deletions.
37 changes: 31 additions & 6 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import operator

import numpy as np
import sqlalchemy as sa
Expand All @@ -25,13 +26,9 @@ def _round(t, expr):
if digits is None:
return sa.func.round(sa_arg)

result = sa.func.round(sa_arg, t.translate(digits))
return result
return sa.func.round(sa_arg, t.translate(digits))


def _mod(t, expr):
left, right = map(t.translate, expr.op().args)
return left % right


def _log(t, expr):
Expand Down Expand Up @@ -111,6 +108,32 @@ def _struct_field(t, expr):
)


def _regex_extract(t, expr):
string, pattern, index = map(t.translate, expr.op().args)
result = sa.case(
[
(
sa.func.regexp_matches(string, pattern),
sa.func.regexp_extract(
string,
pattern,
# DuckDB requires the index to be a constant so we compile
# the value and inline it using sa.text
sa.text(
str(
(index + 1).compile(
compile_kwargs=dict(literal_binds=True)
)
)
),
),
)
],
else_="",
)
return result


operation_registry.update(
{
ops.ArrayColumn: _array_column,
Expand All @@ -120,7 +143,7 @@ def _struct_field(t, expr):
ops.Log2: unary(sa.func.log2),
ops.Log: _log,
# TODO: map operations, but DuckDB's maps are multimaps
ops.Modulus: _mod,
ops.Modulus: fixed_arity(operator.mod, 2),
ops.Round: _round,
ops.StructField: _struct_field,
ops.TableColumn: _table_column,
Expand All @@ -129,5 +152,7 @@ def _struct_field(t, expr):
ops.Translate: fixed_arity('replace', 3),
ops.TimestampNow: fixed_arity('now', 0),
ops.ArrayIndex: fixed_arity('list_element', 2),
ops.RegexExtract: _regex_extract,
ops.RegexReplace: fixed_arity("regexp_replace", 3),
}
)
115 changes: 6 additions & 109 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,126 +61,23 @@ def test_string_col_is_unicode(backend, alltypes, df):
]
),
),
param(
lambda t: t.string_col.re_search(r'[[:digit:]]+'),
lambda t: t.string_col.str.contains(r'\d+'),
id='re_search',
marks=pytest.mark.notimpl(["datafusion", "pyspark"]),
),
param(
lambda t: t.string_col.re_extract(r'([[:digit:]]+)', 0),
lambda t: t.string_col.str.extract(r'(\d+)', expand=False),
id='re_extract',
marks=pytest.mark.notimpl(["duckdb", "mysql", "pyspark"]),
),
param(
lambda t: t.string_col.re_replace(r'[[:digit:]]+', 'a'),
lambda t: t.string_col.str.replace(r'\d+', 'a', regex=True),
id='re_replace',
marks=pytest.mark.notimpl(
['datafusion', "duckdb", "mysql", "pyspark"]
),
),
param(
lambda t: t.string_col.re_search(r'\\d+'),
lambda t: t.string_col.str.contains(r'\d+'),
id='re_search_spark_raw',
marks=(
# TODO: check if this test should pass with pyspark
# and if not, remove it
pytest.mark.notimpl(["pyspark"], reason="regression"),
pytest.mark.never(
[
"dask",
"datafusion",
"duckdb",
"mysql",
"pandas",
"postgres",
"sqlite",
],
reason="not spark",
),
),
),
param(
lambda t: t.string_col.re_extract(r'(\\d+)', 0),
lambda t: t.string_col.str.extract(r'(\d+)', expand=False),
id='re_extract_spark',
marks=(
# TODO: check if this test should pass with pyspark
# and if not, remove it
pytest.mark.notimpl(["pyspark"], reason="regression"),
pytest.mark.never(
[
"dask",
"datafusion",
"duckdb",
"mysql",
"pandas",
"postgres",
"sqlite",
],
reason="not spark",
),
),
),
param(
lambda t: t.string_col.re_replace(r'\\d+', 'a'),
lambda t: t.string_col.str.replace(r'\d+', 'a', regex=True),
id='re_replace_spark_raw',
marks=(
# TODO: check if this test should pass with pyspark
# and if not, remove it
pytest.mark.notimpl(["pyspark"], reason="regression"),
pytest.mark.never(
[
"dask",
"datafusion",
"duckdb",
"mysql",
"pandas",
"postgres",
"sqlite",
],
reason="not spark",
),
),
),
param(
lambda t: t.string_col.re_search(r'\d+'),
lambda t: t.string_col.str.contains(r'\d+'),
id='re_search_spark',
marks=(
pytest.mark.notimpl(['impala']),
pytest.mark.never(["datafusion"], reason="not spark"),
),
id='re_search',
marks=pytest.mark.notimpl(["datafusion"]),
),
param(
lambda t: t.string_col.re_extract(r'(\d+)', 0),
lambda t: t.string_col.str.extract(r'(\d+)', expand=False),
id='re_extract_spark_raw',
marks=(
pytest.mark.notimpl(
[
'impala',
"duckdb",
]
),
pytest.mark.never(["mysql"], reason="not spark"),
),
id='re_extract',
marks=pytest.mark.notimpl(["mysql"]),
),
param(
lambda t: t.string_col.re_replace(r'\d+', 'a'),
lambda t: t.string_col.str.replace(r'\d+', 'a', regex=True),
id='re_replace_spark',
marks=(
pytest.mark.notimpl(['impala']),
pytest.mark.never(
["datafusion", "duckdb", "mysql"],
reason="not spark",
),
),
id='re_replace',
marks=pytest.mark.notimpl(['datafusion', "mysql"]),
),
param(
lambda t: t.string_col.repeat(2),
Expand Down

0 comments on commit 18d16a7

Please sign in to comment.