Skip to content

Commit

Permalink
feat(sqlite): implement date_truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Mar 8, 2022
1 parent d630a77 commit 3ce4f2a
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 42 deletions.
1 change: 1 addition & 0 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def _sort_key(t, expr):
ops.FloorDivide: _floor_divide,
# other
ops.SortKey: _sort_key,
ops.Date: unary(lambda arg: sa.cast(arg, sa.DATE)),
}


Expand Down
1 change: 0 additions & 1 deletion ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def _day_of_week_name(t, expr):
ops.Round: _round,
ops.RandomScalar: _random,
# dates and times
ops.Date: unary(sa.func.date),
ops.DateAdd: infix_op('+'),
ops.DateSub: infix_op('-'),
ops.DateDiff: fixed_arity(sa.func.datediff, 2),
Expand Down
10 changes: 2 additions & 8 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,12 @@ def _timestamp_truncate(t, expr):


def _interval_from_integer(t, expr):
arg, unit = expr.op().args
sa_arg = t.translate(arg)
op = expr.op()
sa_arg = t.translate(op.arg)
interval = sa.text(f"INTERVAL '1 {expr.type().resolution}'")
return sa_arg * interval


def _timestamp_add(t, expr):
sa_args = list(map(t.translate, expr.op().args))
return sa_args[0] + sa_args[1]


def _is_nan(t, expr):
(arg,) = expr.op().args
sa_arg = t.translate(arg)
Expand Down Expand Up @@ -678,7 +673,6 @@ def _day_of_week_name(t, expr):
ops.Round: _round,
ops.Modulus: _mod,
# dates and times
ops.Date: unary(lambda x: sa.cast(x, sa.Date)),
ops.DateTruncate: _timestamp_truncate,
ops.TimestampTruncate: _timestamp_truncate,
ops.IntervalFromInteger: _interval_from_integer,
Expand Down
74 changes: 50 additions & 24 deletions ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sqlalchemy as sa
import toolz
from multipledispatch import Dispatcher

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.alchemy import (
AlchemyExprTranslator,
fixed_arity,
sqlalchemy_operation_registry,
sqlalchemy_window_functions_registry,
Expand All @@ -19,32 +21,55 @@
operation_registry.update(sqlalchemy_window_functions_registry)


def _cast(t, expr):
# It's not all fun and games with SQLite
sqlite_cast = Dispatcher("sqlite_cast")


@sqlite_cast.register(AlchemyExprTranslator, ir.IntegerValue, dt.Timestamp)
def _unixepoch(t, arg, _):
return sa.func.datetime(t.translate(arg), "unixepoch")


@sqlite_cast.register(AlchemyExprTranslator, ir.StringValue, dt.Timestamp)
def _string_to_timestamp(t, arg, _):
return sa.func.strftime('%Y-%m-%d %H:%M:%f', t.translate(arg))


@sqlite_cast.register(AlchemyExprTranslator, ir.IntegerValue, dt.Date)
def _integer_to_date(t, arg, _):
return sa.func.date(sa.func.datetime(t.translate(arg), "unixepoch"))


@sqlite_cast.register(
AlchemyExprTranslator,
(ir.StringValue, ir.TimestampValue),
dt.Date,
)
def _string_or_timestamp_to_date(t, arg, _):
return sa.func.date(t.translate(arg))


@sqlite_cast.register(
AlchemyExprTranslator,
ir.ValueExpr,
(dt.Date, dt.Timestamp),
)
def _value_to_temporal(t, arg, _):
raise com.UnsupportedOperationError(type(arg))


@sqlite_cast.register(AlchemyExprTranslator, ir.CategoryValue, dt.Int32)
def _category_to_int(t, arg, _):
return t.translate(arg)


@sqlite_cast.register(AlchemyExprTranslator, ir.ValueExpr, dt.DataType)
def _default_cast_impl(t, arg, target_type):
return sa.cast(t.translate(arg), t.get_sqla_type(target_type))


def _cast(t, expr):
op = expr.op()
arg, target_type = op.args
sa_arg = t.translate(arg)
sa_type = t.get_sqla_type(target_type)

if isinstance(target_type, dt.Timestamp):
if isinstance(arg, ir.IntegerValue):
return sa.func.datetime(sa_arg, 'unixepoch')
elif isinstance(arg, ir.StringValue):
return sa.func.strftime('%Y-%m-%d %H:%M:%f', sa_arg)
raise com.UnsupportedOperationError(type(arg))

if isinstance(target_type, dt.Date):
if isinstance(arg, ir.IntegerValue):
return sa.func.date(sa.func.datetime(sa_arg, 'unixepoch'))
elif isinstance(arg, ir.StringValue):
return sa.func.date(sa_arg)
raise com.UnsupportedOperationError(type(arg))

if isinstance(arg, ir.CategoryValue) and target_type == 'int32':
return sa_arg
else:
return sa.cast(sa_arg, sa_type)
return sqlite_cast(t, op.arg, op.to)


def _substr(t, expr):
Expand Down Expand Up @@ -237,6 +262,7 @@ def _rpad(t, expr):
ops.Greatest: varargs(sa.func.max),
ops.IfNull: fixed_arity(sa.func.ifnull, 2),
ops.DateTruncate: _truncate(sa.func.date),
ops.Date: unary(sa.func.date),
ops.TimestampTruncate: _truncate(sa.func.datetime),
ops.Strftime: _strftime,
ops.ExtractYear: _strftime_int('%Y'),
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ def test_floating_scalar_parameter(backend, alltypes, df, column, raw_value):
('start_string', 'end_string'),
[('2009-03-01', '2010-07-03'), ('2014-12-01', '2017-01-05')],
)
@pytest.mark.notimpl(["datafusion", "pyspark", "sqlite"])
def test_date_scalar_parameter(
backend, alltypes, df, start_string, end_string
):
@pytest.mark.notimpl(["datafusion", "pyspark"])
def test_date_scalar_parameter(backend, alltypes, start_string, end_string):
start, end = ibis.param(dt.date), ibis.param(dt.date)

col = alltypes.timestamp_col.date()
Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,20 @@


@pytest.mark.parametrize('attr', ['year', 'month', 'day'])
@pytest.mark.notimpl(["datafusion", "sqlite"])
def test_date_extract(backend, alltypes, df, attr):
expr = getattr(alltypes.timestamp_col.date(), attr)()
@pytest.mark.parametrize(
"expr_fn",
[
param(lambda c: c.date(), id="date"),
param(
lambda c: c.cast("date"),
id="cast",
marks=pytest.mark.notimpl(["impala"]),
),
],
)
@pytest.mark.notimpl(["datafusion"])
def test_date_extract(backend, alltypes, df, attr, expr_fn):
expr = getattr(expr_fn(alltypes.timestamp_col), attr)()
expected = getattr(df.timestamp_col.dt, attr).astype('int32')

result = expr.execute()
Expand Down Expand Up @@ -172,16 +183,17 @@ def test_timestamp_truncate(backend, alltypes, df, unit):
"mysql",
"postgres",
"pyspark",
"sqlite",
]
),
),
],
)
@pytest.mark.notimpl(["datafusion", "sqlite"])
@pytest.mark.notimpl(["datafusion"])
def test_date_truncate(backend, alltypes, df, unit):
expr = alltypes.timestamp_col.date().truncate(unit)

dtype = f'datetime64[{unit}]'
dtype = f"datetime64[{unit}]"
expected = pd.Series(df.timestamp_col.values.astype(dtype))

result = expr.execute()
Expand Down

0 comments on commit 3ce4f2a

Please sign in to comment.