diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index 9508522a953e..2acfeb5ade0e 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -496,6 +496,7 @@ def _sort_key(t, expr): ops.FloorDivide: _floor_divide, # other ops.SortKey: _sort_key, + ops.Date: unary(lambda arg: sa.cast(arg, sa.DATE)), } diff --git a/ibis/backends/mysql/registry.py b/ibis/backends/mysql/registry.py index a120201d9984..94b55e8a7f58 100644 --- a/ibis/backends/mysql/registry.py +++ b/ibis/backends/mysql/registry.py @@ -240,7 +240,6 @@ def _day_of_week_name(t, expr): ops.Round: _round, ops.RandomScalar: _random, # dates and times - ops.Date: unary(sa.func.date), ops.DateAdd: infix_op('+'), ops.DateSub: infix_op('-'), ops.DateDiff: fixed_arity(sa.func.datediff, 2), diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index d5a76701e10d..e888942b20ab 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -119,17 +119,12 @@ def _timestamp_truncate(t, expr): def _interval_from_integer(t, expr): - arg, unit = expr.op().args - sa_arg = t.translate(arg) + op = expr.op() + sa_arg = t.translate(op.arg) interval = sa.text(f"INTERVAL '1 {expr.type().resolution}'") return sa_arg * interval -def _timestamp_add(t, expr): - sa_args = list(map(t.translate, expr.op().args)) - return sa_args[0] + sa_args[1] - - def _is_nan(t, expr): (arg,) = expr.op().args sa_arg = t.translate(arg) @@ -678,7 +673,6 @@ def _day_of_week_name(t, expr): ops.Round: _round, ops.Modulus: _mod, # dates and times - ops.Date: unary(lambda x: sa.cast(x, sa.Date)), ops.DateTruncate: _timestamp_truncate, ops.TimestampTruncate: _timestamp_truncate, ops.IntervalFromInteger: _interval_from_integer, diff --git a/ibis/backends/sqlite/registry.py b/ibis/backends/sqlite/registry.py index 2de752d25973..81981d9cd41c 100644 --- a/ibis/backends/sqlite/registry.py +++ b/ibis/backends/sqlite/registry.py @@ -1,5 +1,6 @@ import sqlalchemy as sa import toolz +from multipledispatch import Dispatcher import ibis import ibis.common.exceptions as com @@ -7,6 +8,7 @@ import ibis.expr.operations as ops import ibis.expr.types as ir from ibis.backends.base.sql.alchemy import ( + AlchemyExprTranslator, fixed_arity, sqlalchemy_operation_registry, sqlalchemy_window_functions_registry, @@ -19,32 +21,55 @@ operation_registry.update(sqlalchemy_window_functions_registry) -def _cast(t, expr): - # It's not all fun and games with SQLite +sqlite_cast = Dispatcher("sqlite_cast") + + +@sqlite_cast.register(AlchemyExprTranslator, ir.IntegerValue, dt.Timestamp) +def _unixepoch(t, arg, _): + return sa.func.datetime(t.translate(arg), "unixepoch") + + +@sqlite_cast.register(AlchemyExprTranslator, ir.StringValue, dt.Timestamp) +def _string_to_timestamp(t, arg, _): + return sa.func.strftime('%Y-%m-%d %H:%M:%f', t.translate(arg)) + + +@sqlite_cast.register(AlchemyExprTranslator, ir.IntegerValue, dt.Date) +def _integer_to_date(t, arg, _): + return sa.func.date(sa.func.datetime(t.translate(arg), "unixepoch")) + + +@sqlite_cast.register( + AlchemyExprTranslator, + (ir.StringValue, ir.TimestampValue), + dt.Date, +) +def _string_or_timestamp_to_date(t, arg, _): + return sa.func.date(t.translate(arg)) + +@sqlite_cast.register( + AlchemyExprTranslator, + ir.ValueExpr, + (dt.Date, dt.Timestamp), +) +def _value_to_temporal(t, arg, _): + raise com.UnsupportedOperationError(type(arg)) + + +@sqlite_cast.register(AlchemyExprTranslator, ir.CategoryValue, dt.Int32) +def _category_to_int(t, arg, _): + return t.translate(arg) + + +@sqlite_cast.register(AlchemyExprTranslator, ir.ValueExpr, dt.DataType) +def _default_cast_impl(t, arg, target_type): + return sa.cast(t.translate(arg), t.get_sqla_type(target_type)) + + +def _cast(t, expr): op = expr.op() - arg, target_type = op.args - sa_arg = t.translate(arg) - sa_type = t.get_sqla_type(target_type) - - if isinstance(target_type, dt.Timestamp): - if isinstance(arg, ir.IntegerValue): - return sa.func.datetime(sa_arg, 'unixepoch') - elif isinstance(arg, ir.StringValue): - return sa.func.strftime('%Y-%m-%d %H:%M:%f', sa_arg) - raise com.UnsupportedOperationError(type(arg)) - - if isinstance(target_type, dt.Date): - if isinstance(arg, ir.IntegerValue): - return sa.func.date(sa.func.datetime(sa_arg, 'unixepoch')) - elif isinstance(arg, ir.StringValue): - return sa.func.date(sa_arg) - raise com.UnsupportedOperationError(type(arg)) - - if isinstance(arg, ir.CategoryValue) and target_type == 'int32': - return sa_arg - else: - return sa.cast(sa_arg, sa_type) + return sqlite_cast(t, op.arg, op.to) def _substr(t, expr): @@ -237,6 +262,7 @@ def _rpad(t, expr): ops.Greatest: varargs(sa.func.max), ops.IfNull: fixed_arity(sa.func.ifnull, 2), ops.DateTruncate: _truncate(sa.func.date), + ops.Date: unary(sa.func.date), ops.TimestampTruncate: _truncate(sa.func.datetime), ops.Strftime: _strftime, ops.ExtractYear: _strftime_int('%Y'), diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 52949f4fe33a..1fcb79c2c5af 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -27,10 +27,8 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value): ('start_string', 'end_string'), [('2009-03-01', '2010-07-03'), ('2014-12-01', '2017-01-05')], ) -@pytest.mark.notimpl(["datafusion", "pyspark", "sqlite"]) -def test_date_scalar_parameter( - backend, alltypes, df, start_string, end_string -): +@pytest.mark.notimpl(["datafusion", "pyspark"]) +def test_date_scalar_parameter(backend, alltypes, start_string, end_string): start, end = ibis.param(dt.date), ibis.param(dt.date) col = alltypes.timestamp_col.date() diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 97fe118821e3..6b8d6797c66c 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -13,9 +13,20 @@ @pytest.mark.parametrize('attr', ['year', 'month', 'day']) -@pytest.mark.notimpl(["datafusion", "sqlite"]) -def test_date_extract(backend, alltypes, df, attr): - expr = getattr(alltypes.timestamp_col.date(), attr)() +@pytest.mark.parametrize( + "expr_fn", + [ + param(lambda c: c.date(), id="date"), + param( + lambda c: c.cast("date"), + id="cast", + marks=pytest.mark.notimpl(["impala"]), + ), + ], +) +@pytest.mark.notimpl(["datafusion"]) +def test_date_extract(backend, alltypes, df, attr, expr_fn): + expr = getattr(expr_fn(alltypes.timestamp_col), attr)() expected = getattr(df.timestamp_col.dt, attr).astype('int32') result = expr.execute() @@ -172,16 +183,17 @@ def test_timestamp_truncate(backend, alltypes, df, unit): "mysql", "postgres", "pyspark", + "sqlite", ] ), ), ], ) -@pytest.mark.notimpl(["datafusion", "sqlite"]) +@pytest.mark.notimpl(["datafusion"]) def test_date_truncate(backend, alltypes, df, unit): expr = alltypes.timestamp_col.date().truncate(unit) - dtype = f'datetime64[{unit}]' + dtype = f"datetime64[{unit}]" expected = pd.Series(df.timestamp_col.values.astype(dtype)) result = expr.execute()