From dd0a04c31d38523f67ff78e09aa418578d239daa Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 9 Feb 2023 07:19:21 -0500 Subject: [PATCH] fix(sqlalchemy): make `strip` family of functions behave like Python --- ibis/backends/mysql/registry.py | 27 +++++++++++++++++++++++++++ ibis/backends/postgres/registry.py | 3 +++ 2 files changed, 30 insertions(+) diff --git a/ibis/backends/mysql/registry.py b/ibis/backends/mysql/registry.py index 60c6cdb8cd03..acbebbbf9f01 100644 --- a/ibis/backends/mysql/registry.py +++ b/ibis/backends/mysql/registry.py @@ -1,9 +1,13 @@ from __future__ import annotations import contextlib +import functools import operator +import string import sqlalchemy as sa +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import GenericFunction import ibis import ibis.common.exceptions as com @@ -116,6 +120,26 @@ def _json_get_item(t, op): return sa.func.json_extract(arg, path) +class _mysql_trim(GenericFunction): + inherit_cache = True + + def __init__(self, input, side: str) -> None: + super().__init__(input) + self.type = sa.VARCHAR() + self.side = side + + +@compiles(_mysql_trim, "mysql") +def compiles_mysql_trim(element, compiler, **kw): + arg = compiler.function_argspec(element, **kw) + side = element.side.upper() + # has to be called once for every whitespace character because mysql + # interprets `char` literally, not as a set of characters like Python + return functools.reduce( + lambda arg, char: f"TRIM({side} '{char}' FROM {arg})", string.whitespace, arg + ) + + operation_registry.update( { ops.Literal: _literal, @@ -176,6 +200,9 @@ def _json_get_item(t, op): ), ops.DayOfWeekName: fixed_arity(lambda arg: sa.func.dayname(arg), 1), ops.JSONGetItem: _json_get_item, + ops.Strip: unary(lambda arg: _mysql_trim(arg, "both")), + ops.LStrip: unary(lambda arg: _mysql_trim(arg, "leading")), + ops.RStrip: unary(lambda arg: _mysql_trim(arg, "trailing")), } ) diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 75d739c0fe75..d29d8229b92a 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -620,5 +620,8 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: ops.ArrayStringJoin: fixed_arity( lambda sep, arr: sa.func.array_to_string(arr, sep), 2 ), + ops.Strip: unary(lambda arg: sa.func.trim(arg, string.whitespace)), + ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)), + ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)), } )