Skip to content

Commit

Permalink
feat(sqlite): support most date/timestamp interval arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 9, 2024
1 parent fe29210 commit 75f594d
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 57 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ jobs:
- name: show installed deps
run: poetry run pip list

- name: show version of python-linked sqlite
if: matrix.backend.name == 'sqlite'
run: poetry run python -c 'import sqlite3; print(sqlite3.sqlite_version)'

- name: "run parallel tests: ${{ matrix.backend.name }}"
if: ${{ !matrix.backend.serial }}
run: just ci-check -m ${{ matrix.backend.name }} --numprocesses auto --dist=loadgroup
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def is_older_than(module_name, given_version):
# For now, many of our tests don't do this, and we're working to change this situation
# by improving all tests file by file. All files that have already been improved are
# added to this list to prevent regression.
FIlES_WITH_STRICT_EXCEPTION_CHECK = [
FILES_WITH_STRICT_EXCEPTION_CHECK = [
"ibis/backends/tests/test_api.py",
"ibis/backends/tests/test_array.py",
"ibis/backends/tests/test_aggregation.py",
Expand Down Expand Up @@ -337,7 +337,7 @@ def _filter_none_from_raises(kwargs):
for marker in item.iter_markers(name="notimpl"):
if backend in marker.args[0]:
if (
item.location[0] in FIlES_WITH_STRICT_EXCEPTION_CHECK
item.location[0] in FILES_WITH_STRICT_EXCEPTION_CHECK
and "raises" not in marker.kwargs.keys()
):
raise ValueError("notimpl requires a raises")
Expand All @@ -351,7 +351,7 @@ def _filter_none_from_raises(kwargs):
for marker in item.iter_markers(name="notyet"):
if backend in marker.args[0]:
if (
item.location[0] in FIlES_WITH_STRICT_EXCEPTION_CHECK
item.location[0] in FILES_WITH_STRICT_EXCEPTION_CHECK
and "raises" not in marker.kwargs.keys()
):
raise ValueError("notyet requires a raises")
Expand Down
69 changes: 58 additions & 11 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
import sqlite3

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -19,6 +20,8 @@ class SQLiteCompiler(SQLGlotCompiler):

dialect = SQLite
type_mapper = SQLiteType
supports_time_shift_modifiers = sqlite3.sqlite_version_info >= (3, 46, 0)
supports_subsec = sqlite3.sqlite_version_info >= (3, 42, 0)

# We could set `supports_order_by=True` for SQLite >= 3.44.0 (2023-11-01).
agg = AggGen(supports_filter=True)
Expand Down Expand Up @@ -53,10 +56,7 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.IntervalFromInteger,
ops.TimestampBucket,
ops.TimestampAdd,
ops.TimestampSub,
ops.TimestampDiff,
ops.StringToDate,
ops.StringToTimestamp,
Expand Down Expand Up @@ -333,18 +333,65 @@ def visit_TimestampTruncate(self, op, *, arg, unit):
return self._temporal_truncate(self.f.anon.datetime, arg, unit)

def visit_DateArithmetic(self, op, *, left, right):
unit = op.right.dtype.unit
sign = "+" if isinstance(op, ops.DateAdd) else "-"
if unit not in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY):
right = right.this

if (unit := op.right.dtype.unit) in (
IntervalUnit.QUARTER,
IntervalUnit.MICROSECOND,
IntervalUnit.NANOSECOND,
):
raise com.UnsupportedOperationError(
"SQLite does not allow binary op {sign!r} with INTERVAL offset {unit}"
f"SQLite does not support `{unit}` units in temporal arithmetic"
)
if isinstance(op.right, ops.Literal):
return self.f.date(left, f"{sign}{op.right.value} {unit.plural}")
elif unit == IntervalUnit.WEEK:
unit = IntervalUnit.DAY
right *= 7
elif unit == IntervalUnit.MILLISECOND:
# sqlite doesn't allow milliseconds, so divide milliseconds by 1e3 to
# get seconds, and change the unit to seconds
unit = IntervalUnit.SECOND
right /= 1e3

# compute whether we're adding or subtracting an interval
sign = "+" if isinstance(op, (ops.DateAdd, ops.TimestampAdd)) else "-"

modifiers = []

# floor the result if the unit is a year, month, or day to match other
# backend behavior
if unit in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY):
if not self.supports_time_shift_modifiers:
raise com.UnsupportedOperationError(
"SQLite does not support time shift modifiers until version 3.46; "
f"found version {sqlite3.sqlite_version}"
)
modifiers.append("floor")

if isinstance(op, (ops.TimestampAdd, ops.TimestampSub)):
# if the left operand is a timestamp, return as much precision as
# possible
if not self.supports_subsec:
raise com.UnsupportedOperationError(
"SQLite does not support subsecond resolution until version 3.42; "
f"found version {sqlite3.sqlite_version}"
)
func = self.f.datetime
modifiers.append("subsec")
else:
return self.f.date(left, self.f.concat(sign, right, f" {unit.plural}"))
func = self.f.date

return func(
left,
self.f.concat(
sign, self.cast(right, dt.string), " ", unit.singular.lower()
),
*modifiers,
dialect=self.dialect,
)

visit_DateAdd = visit_DateSub = visit_DateArithmetic
visit_TimestampAdd = visit_TimestampSub = visit_DateAdd = visit_DateSub = (
visit_DateArithmetic
)

def visit_DateDiff(self, op, *, left, right):
return self.f.julianday(left) - self.f.julianday(right)
Expand Down
Loading

0 comments on commit 75f594d

Please sign in to comment.