Skip to content

Commit

Permalink
refactor(sqlalchemy): deduplicate and update translation rules
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed May 25, 2022
1 parent ad0447b commit 55f12c8
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 336 deletions.
108 changes: 70 additions & 38 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict

import sqlalchemy as sa
import sqlalchemy.sql as sql

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -166,15 +165,27 @@ def _exists_subquery(t, expr):


def _cast(t, expr):
op = expr.op()
arg, target_type = op.args
arg, typ = expr.op().args

sa_arg = t.translate(arg)
sa_type = t.get_sqla_type(target_type)
sa_type = t.get_sqla_type(typ)

if isinstance(arg, ir.CategoryValue) and target_type == dt.int32:
if isinstance(arg, ir.CategoryValue) and typ == dt.int32:
return sa_arg
else:
return sa.cast(sa_arg, sa_type)

# specialize going from an integer type to a timestamp
if isinstance(arg.type(), dt.Integer) and isinstance(sa_type, sa.DateTime):
return sa.func.timezone('UTC', sa.func.to_timestamp(sa_arg))

if arg.type().equals(dt.binary) and typ.equals(dt.string):
return sa.func.encode(sa_arg, 'escape')

if typ.equals(dt.binary):
# decode yields a column of memoryview which is annoying to deal with
# in pandas. CAST(expr AS BYTEA) is correct and returns byte strings.
return sa.cast(sa_arg, sa.LargeBinary())

return sa.cast(sa_arg, sa_type)


def _contains(func):
Expand Down Expand Up @@ -218,7 +229,7 @@ def _alias(t, expr):
return t.translate(op.arg)


def _literal(t, expr):
def _literal(_, expr):
dtype = expr.type()
value = expr.op().value

Expand Down Expand Up @@ -427,7 +438,7 @@ def _lead(t, expr):
def _ntile(t, expr):
op = expr.op()
args = op.args
arg, buckets = map(t.translate, args)
_, buckets = map(t.translate, args)
return sa.func.ntile(buckets)


Expand Down Expand Up @@ -460,10 +471,41 @@ def _zero_if_null(t, expr):
)


def _substring(t, expr):
op = expr.op()

args = t.translate(op.arg), t.translate(op.start) + 1

if (length := op.length) is not None:
args += (t.translate(length),)

return sa.func.substr(*args)


def _gen_string_find(func):
def string_find(t, expr):
op = expr.op()

if op.start is not None:
raise NotImplementedError("`start` not yet implemented")

if op.end is not None:
raise NotImplementedError("`end` not yet implemented")

return func(t.translate(op.arg), t.translate(op.substr)) - 1

return string_find


def _nth_value(t, expr):
op = expr.op()
return sa.func.nth_value(t.translate(op.arg), t.translate(op.nth) + 1)


sqlalchemy_operation_registry: Dict[Any, Any] = {
ops.Alias: _alias,
ops.And: fixed_arity(sql.and_, 2),
ops.Or: fixed_arity(sql.or_, 2),
ops.And: fixed_arity(operator.and_, 2),
ops.Or: fixed_arity(operator.or_, 2),
ops.Xor: fixed_arity(lambda x, y: (x | y) & ~(x & y), 2),
ops.Not: unary(sa.not_),
ops.Abs: unary(sa.func.abs),
Expand Down Expand Up @@ -515,6 +557,7 @@ def _zero_if_null(t, expr):
ops.Lowercase: unary(sa.func.lower),
ops.Uppercase: unary(sa.func.upper),
ops.StringAscii: unary(sa.func.ascii),
ops.StringFind: _gen_string_find(sa.func.strpos),
ops.StringLength: unary(sa.func.length),
ops.StringJoin: _string_join,
ops.StringReplace: fixed_arity(sa.func.replace, 3),
Expand All @@ -523,6 +566,7 @@ def _zero_if_null(t, expr):
ops.StartsWith: _startswith,
ops.EndsWith: _endswith,
ops.StringConcat: varargs(sa.func.concat),
ops.Substring: _substring,
# math
ops.Ln: unary(sa.func.ln),
ops.Exp: unary(sa.func.exp),
Expand Down Expand Up @@ -553,36 +597,28 @@ def _zero_if_null(t, expr):
ops.Degrees: unary(sa.func.degrees),
ops.Radians: unary(sa.func.radians),
ops.ZeroIfNull: _zero_if_null,
}

# TODO: unit tests for each of these
_binary_ops = {
ops.RandomScalar: fixed_arity(sa.func.random, 0),
# Binary arithmetic
ops.Add: operator.add,
ops.Subtract: operator.sub,
ops.Multiply: operator.mul,
ops.Add: fixed_arity(operator.add, 2),
ops.Subtract: fixed_arity(operator.sub, 2),
ops.Multiply: fixed_arity(operator.mul, 2),
# XXX `ops.Divide` is overwritten in `translator.py` with a custom
# function `_true_divide`, but for some reason both are required
ops.Divide: operator.truediv,
ops.Modulus: operator.mod,
ops.Divide: fixed_arity(operator.truediv, 2),
ops.Modulus: fixed_arity(operator.mod, 2),
# Comparisons
ops.Equals: operator.eq,
ops.NotEquals: operator.ne,
ops.Less: operator.lt,
ops.LessEqual: operator.le,
ops.Greater: operator.gt,
ops.GreaterEqual: operator.ge,
ops.IdenticalTo: lambda x, y: x.op('IS NOT DISTINCT FROM')(y),
# Boolean comparisons
# TODO
ops.Equals: fixed_arity(operator.eq, 2),
ops.NotEquals: fixed_arity(operator.ne, 2),
ops.Less: fixed_arity(operator.lt, 2),
ops.LessEqual: fixed_arity(operator.le, 2),
ops.Greater: fixed_arity(operator.gt, 2),
ops.GreaterEqual: fixed_arity(operator.ge, 2),
ops.IdenticalTo: fixed_arity(
sa.sql.expression.ColumnElement.is_not_distinct_from, 2
),
}


def _nth_value(t, expr):
op = expr.op()
return sa.func.nth_value(t.translate(op.arg), t.translate(op.nth) + 1)


sqlalchemy_window_functions_registry = {
ops.Lag: _lag,
ops.Lead: _lead,
Expand Down Expand Up @@ -666,7 +702,3 @@ def _nth_value(t, expr):
}
else:
_geospatial_functions = {}


for _k, _v in _binary_ops.items():
sqlalchemy_operation_registry[_k] = fixed_arity(_v, 2)
101 changes: 14 additions & 87 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,26 @@
import operator

import pandas as pd
import sqlalchemy as sa

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.alchemy import (
fixed_arity,
infix_op,
sqlalchemy_operation_registry,
sqlalchemy_window_functions_registry,
unary,
)
from ibis.backends.base.sql.alchemy.registry import _gen_string_find

operation_registry = sqlalchemy_operation_registry.copy()

# NOTE: window functions are available from MySQL 8 and MariaDB 10.2
operation_registry.update(sqlalchemy_window_functions_registry)


def _substr(t, expr):
f = sa.func.substr

arg, start, length = expr.op().args

sa_arg = t.translate(arg)
sa_start = t.translate(start)

if length is None:
return f(sa_arg, sa_start + 1)
else:
sa_length = t.translate(length)
return f(sa_arg, sa_start + 1, sa_length)


def _string_find(t, expr):
arg, substr, start, _ = expr.op().args

if start is not None:
raise NotImplementedError

sa_arg = t.translate(arg)
sa_substr = t.translate(substr)

return sa.func.locate(sa_arg, sa_substr) - 1


def _capitalize(t, expr):
(arg,) = expr.op().args
sa_arg = t.translate(arg)
Expand Down Expand Up @@ -89,43 +63,13 @@ def _truncate(t, expr):
return sa.func.date_format(sa_arg, fmt)


def _cast(t, expr):
arg, typ = expr.op().args

sa_arg = t.translate(arg)
sa_type = t.get_sqla_type(typ)

# specialize going from an integer type to a timestamp
if isinstance(arg.type(), dt.Integer) and isinstance(sa_type, sa.DateTime):
return sa.func.timezone('UTC', sa.func.to_timestamp(sa_arg))

if arg.type().equals(dt.binary) and typ.equals(dt.string):
return sa.func.encode(sa_arg, 'escape')

if typ.equals(dt.binary):
# decode yields a column of memoryview which is annoying to deal with
# in pandas. CAST(expr AS BYTEA) is correct and returns byte strings.
return sa.cast(sa_arg, sa.LargeBinary())

return sa.cast(sa_arg, sa_type)


def _log(t, expr):
arg, base = expr.op().args
sa_arg = t.translate(arg)
sa_base = t.translate(base)
return sa.func.log(sa_base, sa_arg)


def _identical_to(t, expr):
left, right = args = expr.op().args
if left.equals(right):
return True
else:
left, right = map(t.translate, args)
return left.op('<=>')(right)


def _round(t, expr):
arg, digits = expr.op().args
sa_arg = t.translate(arg)
Expand All @@ -138,16 +82,6 @@ def _round(t, expr):
return sa.func.round(sa_arg, sa_digits)


def _floor_divide(t, expr):
left, right = map(t.translate, expr.op().args)
return sa.func.floor(left / right)


def _string_join(t, expr):
sep, elements = expr.op().args
return sa.func.concat_ws(t.translate(sep), *map(t.translate, elements))


def _interval_from_integer(t, expr):
arg, unit = expr.op().args
if unit in {'ms', 'ns'}:
Expand Down Expand Up @@ -176,7 +110,7 @@ def _timestamp_diff(t, expr):
return sa.func.timestampdiff(sa.text('SECOND'), sa_right, sa_left)


def _literal(t, expr):
def _literal(_, expr):
if isinstance(expr, ir.IntervalScalar):
if expr.type().unit in {'ms', 'ns'}:
raise com.UnsupportedOperationError(
Expand All @@ -195,10 +129,6 @@ def _literal(t, expr):
return sa.literal(value)


def _random(t, expr):
return sa.func.random()


def _group_concat(t, expr):
op = expr.op()
arg, sep, where = op.args
Expand All @@ -214,7 +144,7 @@ def _day_of_week_index(t, expr):
(arg,) = expr.op().args
left = sa.func.dayofweek(t.translate(arg)) - 2
right = 7
return ((left % right) + right) % right
return (left % right + right) % right


def _day_of_week_name(t, expr):
Expand All @@ -227,22 +157,20 @@ def _day_of_week_name(t, expr):
ops.Literal: _literal,
ops.IfNull: fixed_arity(sa.func.ifnull, 2),
# strings
ops.Substring: _substr,
ops.StringFind: _string_find,
ops.StringFind: _gen_string_find(sa.func.locate),
ops.Capitalize: _capitalize,
ops.RegexSearch: infix_op('REGEXP'),
ops.RegexSearch: fixed_arity(lambda x, y: x.op('REGEXP')(y), 2),
# math
ops.Log: _log,
ops.Log2: unary(sa.func.log2),
ops.Log10: unary(sa.func.log10),
ops.Round: _round,
ops.RandomScalar: _random,
# dates and times
ops.DateAdd: infix_op('+'),
ops.DateSub: infix_op('-'),
ops.DateAdd: fixed_arity(operator.add, 2),
ops.DateSub: fixed_arity(operator.sub, 2),
ops.DateDiff: fixed_arity(sa.func.datediff, 2),
ops.TimestampAdd: infix_op('+'),
ops.TimestampSub: infix_op('-'),
ops.TimestampAdd: fixed_arity(operator.add, 2),
ops.TimestampSub: fixed_arity(operator.sub, 2),
ops.TimestampDiff: _timestamp_diff,
ops.DateTruncate: _truncate,
ops.TimestampTruncate: _truncate,
Expand All @@ -251,16 +179,15 @@ def _day_of_week_name(t, expr):
ops.ExtractYear: _extract('year'),
ops.ExtractMonth: _extract('month'),
ops.ExtractDay: _extract('day'),
ops.ExtractDayOfYear: unary('dayofyear'),
ops.ExtractDayOfYear: unary(sa.func.dayofyear),
ops.ExtractQuarter: _extract('quarter'),
ops.ExtractEpochSeconds: unary('UNIX_TIMESTAMP'),
ops.ExtractWeekOfYear: fixed_arity('weekofyear', 1),
ops.ExtractEpochSeconds: unary(sa.func.UNIX_TIMESTAMP),
ops.ExtractWeekOfYear: unary(sa.func.weekofyear),
ops.ExtractHour: _extract('hour'),
ops.ExtractMinute: _extract('minute'),
ops.ExtractSecond: _extract('second'),
ops.ExtractMillisecond: _extract('millisecond'),
# reductions
ops.IdenticalTo: _identical_to,
ops.TimestampNow: fixed_arity(sa.func.now, 0),
# others
ops.GroupConcat: _group_concat,
Expand Down
Loading

0 comments on commit 55f12c8

Please sign in to comment.