From 43996c080b8e2c9c53e838b90743878be9764ae7 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 11 Mar 2022 07:23:26 -0500 Subject: [PATCH] feat(sqlalchemy): implement ilike --- ibis/backends/base/sql/alchemy/registry.py | 12 ++++++---- ibis/backends/tests/test_string.py | 28 ++++++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index ebcd6a64cb0f..27419d144eb7 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -1,3 +1,4 @@ +import functools import operator from typing import Any, Dict @@ -290,10 +291,10 @@ def unary(sa_func): return fixed_arity(sa_func, 1) -def _string_like(t, expr): - arg, pattern, escape = expr.op().args - result = t.translate(arg).like(t.translate(pattern), escape=escape) - return result +def _string_like(method_name, t, expr): + op = expr.op() + method = getattr(t.translate(op.arg), method_name) + return method(t.translate(op.pattern), escape=op.escape) def _startswith(t, expr): @@ -482,7 +483,8 @@ def _sort_key(t, expr): ops.StringAscii: unary(sa.func.ascii), ops.StringLength: unary(sa.func.length), ops.StringReplace: fixed_arity(sa.func.replace, 3), - ops.StringSQLLike: _string_like, + ops.StringSQLLike: functools.partial(_string_like, "like"), + ops.StringSQLILike: functools.partial(_string_like, "ilike"), ops.StartsWith: _startswith, ops.EndsWith: _endswith, ops.StringConcat: varargs(sa.func.concat), diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index 5bcbdd8b3670..f24823708b7c 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -52,32 +52,46 @@ def test_string_col_is_unicode(backend, alltypes, df): [ "clickhouse", "datafusion", - "duckdb", "impala", - "mysql", - "postgres", "pyspark", - "sqlite", ] ), ), + param( + lambda t: t.string_col.re_search(r'[[:digit:]]+'), + lambda t: t.string_col.str.contains(r'\d+'), + id='re_search_posix', + marks=pytest.mark.notimpl(["datafusion", "pyspark"]), + ), + param( + lambda t: t.string_col.re_extract(r'([[:digit:]]+)', 0), + lambda t: t.string_col.str.extract(r'(\d+)', expand=False), + id='re_extract_posix', + marks=pytest.mark.notimpl(["mysql", "pyspark"]), + ), + param( + lambda t: t.string_col.re_replace(r'[[:digit:]]+', 'a'), + lambda t: t.string_col.str.replace(r'\d+', 'a', regex=True), + id='re_replace_posix', + marks=pytest.mark.notimpl(['datafusion', "mysql", "pyspark"]), + ), param( lambda t: t.string_col.re_search(r'\d+'), lambda t: t.string_col.str.contains(r'\d+'), id='re_search', - marks=pytest.mark.notimpl(["datafusion"]), + marks=pytest.mark.notimpl(["impala", "datafusion"]), ), param( lambda t: t.string_col.re_extract(r'(\d+)', 0), lambda t: t.string_col.str.extract(r'(\d+)', expand=False), id='re_extract', - marks=pytest.mark.notimpl(["mysql"]), + marks=pytest.mark.notimpl(["impala", "mysql"]), ), param( lambda t: t.string_col.re_replace(r'\d+', 'a'), lambda t: t.string_col.str.replace(r'\d+', 'a', regex=True), id='re_replace', - marks=pytest.mark.notimpl(['datafusion', "mysql"]), + marks=pytest.mark.notimpl(["impala", "datafusion", "mysql"]), ), param( lambda t: t.string_col.repeat(2),