Skip to content

Commit

Permalink
fix(sqlalchemy): make strip family of functions behave like Python
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Feb 9, 2023
1 parent 9a5737d commit dd0a04c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
27 changes: 27 additions & 0 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import contextlib
import functools
import operator
import string

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -116,6 +120,26 @@ def _json_get_item(t, op):
return sa.func.json_extract(arg, path)


class _mysql_trim(GenericFunction):
inherit_cache = True

def __init__(self, input, side: str) -> None:
super().__init__(input)
self.type = sa.VARCHAR()
self.side = side


@compiles(_mysql_trim, "mysql")
def compiles_mysql_trim(element, compiler, **kw):
arg = compiler.function_argspec(element, **kw)
side = element.side.upper()
# has to be called once for every whitespace character because mysql
# interprets `char` literally, not as a set of characters like Python
return functools.reduce(
lambda arg, char: f"TRIM({side} '{char}' FROM {arg})", string.whitespace, arg
)


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -176,6 +200,9 @@ def _json_get_item(t, op):
),
ops.DayOfWeekName: fixed_arity(lambda arg: sa.func.dayname(arg), 1),
ops.JSONGetItem: _json_get_item,
ops.Strip: unary(lambda arg: _mysql_trim(arg, "both")),
ops.LStrip: unary(lambda arg: _mysql_trim(arg, "leading")),
ops.RStrip: unary(lambda arg: _mysql_trim(arg, "trailing")),
}
)

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,5 +620,8 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
ops.ArrayStringJoin: fixed_arity(
lambda sep, arr: sa.func.array_to_string(arr, sep), 2
),
ops.Strip: unary(lambda arg: sa.func.trim(arg, string.whitespace)),
ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)),
ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)),
}
)

0 comments on commit dd0a04c

Please sign in to comment.