Skip to content

Commit

Permalink
feat(timestamps): add support for timestamp/date +/- intervals for ad…
Browse files Browse the repository at this point in the history
…ditional backends (#9799)
  • Loading branch information
cpcloud authored Aug 8, 2024
1 parent de6d988 commit 79cef68
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 96 deletions.
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,7 @@ def visit_DayOfWeekName(self, op, *, arg):
)

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(
this=sge.convert(arg), unit=sge.Var(this=unit.singular.upper())
)
return sge.Interval(this=arg, unit=self.v[unit.singular.upper()])

### String Instruments
def visit_Strip(self, op, *, arg):
Expand Down
23 changes: 17 additions & 6 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.BitXor,
ops.Covariance,
ops.CountDistinctStar,
ops.DateAdd,
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.IsInf,
ops.IsNan,
ops.LPad,
Expand All @@ -115,9 +113,7 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimestampAdd,
ops.TimestampDiff,
ops.TimestampSub,
ops.Unnest,
)

Expand Down Expand Up @@ -404,6 +400,8 @@ def visit_Cast(self, op, *, arg, to):
return arg
elif from_.is_integer() and to.is_timestamp():
return self.f.dateadd(self.v.s, arg, "1970-01-01 00:00:00")
elif from_.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.singular])
return super().visit_Cast(op, arg=arg, to=to)

def visit_Sum(self, op, *, arg, where):
Expand Down Expand Up @@ -500,5 +498,18 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke

return result

def visit_TimestampAdd(self, op, *, left, right):
return self.f.dateadd(
right.unit, self.cast(right.this, dt.int64), left, dialect=self.dialect
)

def visit_TimestampSub(self, op, *, left, right):
return self.f.dateadd(
right.unit, -self.cast(right.this, dt.int64), left, dialect=self.dialect
)

visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub


compiler = MSSQLCompiler()
25 changes: 7 additions & 18 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,16 @@ def visit_Cast(self, op, *, arg, to):
# for TEXT (except when casting of course!)
return arg
elif from_.is_integer() and to.is_interval():
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit=to.unit), arg=arg, unit=to.unit
)
return sge.Interval(this=arg, unit=self.v[to.unit.singular.upper()])
elif from_.is_integer() and to.is_timestamp():
return self.f.from_unixtime(arg)
return super().visit_Cast(op, arg=arg, to=to)

def visit_TimestampDiff(self, op, *, left, right):
return self.f.timestampdiff(
sge.Var(this="SECOND"), right, left, dialect=self.dialect
)
return self.f.timestampdiff(self.v.SECOND, right, left, dialect=self.dialect)

def visit_DateDiff(self, op, *, left, right):
return self.f.timestampdiff(
sge.Var(this="DAY"), right, left, dialect=self.dialect
)
return self.f.timestampdiff(self.v.DAY, right, left, dialect=self.dialect)

def visit_ApproxCountDistinct(self, op, *, arg, where):
if where is not None:
Expand Down Expand Up @@ -317,16 +311,16 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit):

def visit_DateTimeDelta(self, op, *, left, right, part):
return self.f.timestampdiff(
sge.Var(this=part.this), right, left, dialect=self.dialect
self.v[part.this], right, left, dialect=self.dialect
)

visit_TimeDelta = visit_DateDelta = visit_DateTimeDelta

def visit_ExtractMillisecond(self, op, *, arg):
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000)
return self.f.floor(self.f.extract(self.v.microsecond, arg) / 1_000)

def visit_ExtractMicrosecond(self, op, *, arg):
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg))
return self.f.floor(self.f.extract(self.v.microsecond, arg))

def visit_Strip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="BOTH")
Expand All @@ -337,14 +331,9 @@ def visit_LStrip(self, op, *, arg):
def visit_RStrip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="TRAILING")

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=sge.Var(this=op.resolution.upper()))

def visit_TimestampAdd(self, op, *, left, right):
if op.right.dtype.unit.short == "ms":
right = sge.Interval(
this=right.this * 1_000, unit=sge.Var(this="MICROSECOND")
)
right = sge.Interval(this=right.this * 1_000, unit=self.v.MICROSECOND)
return self.f.date_add(left, right, dialect=self.dialect)

def visit_UnwrapJSONString(self, op, *, arg):
Expand Down
45 changes: 27 additions & 18 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class OracleCompiler(SQLGlotCompiler):
ops.TimestampDelta,
ops.TimestampFromYMDHMS,
ops.TimeFromHMS,
ops.IntervalFromInteger,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.DateDiff,
Expand Down Expand Up @@ -138,30 +137,37 @@ def visit_Literal(self, op, *, value, dtype):
elif dtype.is_uuid():
return sge.convert(str(value))
elif dtype.is_interval():
if dtype.unit.short in ("Y", "M"):
return self.f.numtoyminterval(value, dtype.unit.name)
elif dtype.unit.short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(value, dtype.unit.name)
else:
raise com.UnsupportedOperationError(
f"Intervals with precision {dtype.unit.name} not supported in Oracle."
)
return self._value_to_interval(value, dtype.unit)

return super().visit_Literal(op, value=value, dtype=dtype)

def _value_to_interval(self, arg, unit):
short = unit.short

if short in ("Y", "M"):
return self.f.numtoyminterval(arg, unit.singular)
elif short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(arg, unit.singular)
elif short == "ms":
return self.f.numtodsinterval(arg / 1e3, "second")
elif short in "us":
return self.f.numtodsinterval(arg / 1e6, "second")
elif short in "ns":
return self.f.numtodsinterval(arg / 1e9, "second")
else:
raise com.UnsupportedArgumentError(
f"Interval {unit.name} not supported by Oracle"
)

def visit_Cast(self, op, *, arg, to):
if to.is_interval():
from_ = op.arg.dtype
if from_.is_numeric() and to.is_interval():
# CASTing to an INTERVAL in Oracle requires specifying digits of
# precision that are a pain. There are two helper functions that
# should be used instead.
if to.unit.short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(arg, to.unit.name)
elif to.unit.short in ("Y", "M"):
return self.f.numtoyminterval(arg, to.unit.name)
else:
raise com.UnsupportedArgumentError(
f"Interval {to.unit.name} not supported by Oracle"
)
return self._value_to_interval(arg, to.unit)
elif from_.is_string() and to.is_date():
return self.f.to_date(arg, "FXYYYY-MM-DD")
return self.cast(arg, to)

def visit_Limit(self, op, *, parent, n, offset):
Expand Down Expand Up @@ -457,5 +463,8 @@ def visit_GroupConcat(self, op, *, arg, where, sep, order_by):

return out

def visit_IntervalFromInteger(self, op, *, arg, unit):
return self._value_to_interval(arg, unit)


compiler = OracleCompiler()
3 changes: 0 additions & 3 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ def visit_TimestampSub(self, op, *, left, right):
visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=self.v[unit.name])

def visit_IntegerRange(self, op, *, start, stop, step):
return self.if_(
step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array()
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@

try:
from oracledb.exceptions import DatabaseError as OracleDatabaseError
from oracledb.exceptions import InterfaceError as OracleInterfaceError
except ImportError:
OracleDatabaseError = None
OracleDatabaseError = OracleInterfaceError = None

try:
from pyodbc import DataError as PyODBCDataError
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col):
["2009-01-20", datetime.date(2009, 1, 20), datetime.datetime(2009, 1, 20)],
ids=["string", "date", "datetime"],
)
@pytest.mark.notimpl(["druid", "oracle"])
@pytest.mark.notimpl(["druid"])
def test_scalar_param_date(backend, alltypes, value):
param = ibis.param("date")
ds_col = alltypes.date_string_col
Expand Down
Loading

0 comments on commit 79cef68

Please sign in to comment.