Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(timestamps): add support for timestamp/date +/- intervals for additional backends #9799

Merged
merged 10 commits into from
Aug 8, 2024
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,7 @@ def visit_DayOfWeekName(self, op, *, arg):
)

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(
this=sge.convert(arg), unit=sge.Var(this=unit.singular.upper())
)
return sge.Interval(this=arg, unit=self.v[unit.singular.upper()])

### String Instruments
def visit_Strip(self, op, *, arg):
Expand Down
23 changes: 17 additions & 6 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.BitXor,
ops.Covariance,
ops.CountDistinctStar,
ops.DateAdd,
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.IsInf,
ops.IsNan,
ops.LPad,
Expand All @@ -115,9 +113,7 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimestampAdd,
ops.TimestampDiff,
ops.TimestampSub,
ops.Unnest,
)

Expand Down Expand Up @@ -404,6 +400,8 @@ def visit_Cast(self, op, *, arg, to):
return arg
elif from_.is_integer() and to.is_timestamp():
return self.f.dateadd(self.v.s, arg, "1970-01-01 00:00:00")
elif from_.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.singular])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is such a common implementation I'm going to try and encode the logic in the base implementation of visit_Cast

return super().visit_Cast(op, arg=arg, to=to)

def visit_Sum(self, op, *, arg, where):
Expand Down Expand Up @@ -500,5 +498,18 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke

return result

def visit_TimestampAdd(self, op, *, left, right):
return self.f.dateadd(
right.unit, self.cast(right.this, dt.int64), left, dialect=self.dialect
)

def visit_TimestampSub(self, op, *, left, right):
return self.f.dateadd(
right.unit, -self.cast(right.this, dt.int64), left, dialect=self.dialect
)

visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub


compiler = MSSQLCompiler()
25 changes: 7 additions & 18 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,16 @@ def visit_Cast(self, op, *, arg, to):
# for TEXT (except when casting of course!)
return arg
elif from_.is_integer() and to.is_interval():
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit=to.unit), arg=arg, unit=to.unit
)
return sge.Interval(this=arg, unit=self.v[to.unit.singular.upper()])
elif from_.is_integer() and to.is_timestamp():
return self.f.from_unixtime(arg)
return super().visit_Cast(op, arg=arg, to=to)

def visit_TimestampDiff(self, op, *, left, right):
return self.f.timestampdiff(
sge.Var(this="SECOND"), right, left, dialect=self.dialect
)
return self.f.timestampdiff(self.v.SECOND, right, left, dialect=self.dialect)

def visit_DateDiff(self, op, *, left, right):
return self.f.timestampdiff(
sge.Var(this="DAY"), right, left, dialect=self.dialect
)
return self.f.timestampdiff(self.v.DAY, right, left, dialect=self.dialect)

def visit_ApproxCountDistinct(self, op, *, arg, where):
if where is not None:
Expand Down Expand Up @@ -317,16 +311,16 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit):

def visit_DateTimeDelta(self, op, *, left, right, part):
return self.f.timestampdiff(
sge.Var(this=part.this), right, left, dialect=self.dialect
self.v[part.this], right, left, dialect=self.dialect
)

visit_TimeDelta = visit_DateDelta = visit_DateTimeDelta

def visit_ExtractMillisecond(self, op, *, arg):
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000)
return self.f.floor(self.f.extract(self.v.microsecond, arg) / 1_000)

def visit_ExtractMicrosecond(self, op, *, arg):
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg))
return self.f.floor(self.f.extract(self.v.microsecond, arg))

def visit_Strip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="BOTH")
Expand All @@ -337,14 +331,9 @@ def visit_LStrip(self, op, *, arg):
def visit_RStrip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="TRAILING")

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=sge.Var(this=op.resolution.upper()))

def visit_TimestampAdd(self, op, *, left, right):
if op.right.dtype.unit.short == "ms":
right = sge.Interval(
this=right.this * 1_000, unit=sge.Var(this="MICROSECOND")
)
right = sge.Interval(this=right.this * 1_000, unit=self.v.MICROSECOND)
return self.f.date_add(left, right, dialect=self.dialect)

def visit_UnwrapJSONString(self, op, *, arg):
Expand Down
45 changes: 27 additions & 18 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
ops.TimestampDelta,
ops.TimestampFromYMDHMS,
ops.TimeFromHMS,
ops.IntervalFromInteger,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.DateDiff,
Expand Down Expand Up @@ -138,30 +137,37 @@
elif dtype.is_uuid():
return sge.convert(str(value))
elif dtype.is_interval():
if dtype.unit.short in ("Y", "M"):
return self.f.numtoyminterval(value, dtype.unit.name)
elif dtype.unit.short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(value, dtype.unit.name)
else:
raise com.UnsupportedOperationError(
f"Intervals with precision {dtype.unit.name} not supported in Oracle."
)
return self._value_to_interval(value, dtype.unit)

return super().visit_Literal(op, value=value, dtype=dtype)

def _value_to_interval(self, arg, unit):
short = unit.short

if short in ("Y", "M"):
return self.f.numtoyminterval(arg, unit.singular)
elif short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(arg, unit.singular)
elif short == "ms":
return self.f.numtodsinterval(arg / 1e3, "second")
elif short in "us":
return self.f.numtodsinterval(arg / 1e6, "second")
elif short in "ns":
return self.f.numtodsinterval(arg / 1e9, "second")

Check warning on line 156 in ibis/backends/sql/compilers/oracle.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/oracle.py#L156

Added line #L156 was not covered by tests
else:
raise com.UnsupportedArgumentError(
f"Interval {unit.name} not supported by Oracle"
)

def visit_Cast(self, op, *, arg, to):
if to.is_interval():
from_ = op.arg.dtype
if from_.is_numeric() and to.is_interval():
# CASTing to an INTERVAL in Oracle requires specifying digits of
# precision that are a pain. There are two helper functions that
# should be used instead.
if to.unit.short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(arg, to.unit.name)
elif to.unit.short in ("Y", "M"):
return self.f.numtoyminterval(arg, to.unit.name)
else:
raise com.UnsupportedArgumentError(
f"Interval {to.unit.name} not supported by Oracle"
)
return self._value_to_interval(arg, to.unit)
elif from_.is_string() and to.is_date():
return self.f.to_date(arg, "FXYYYY-MM-DD")
return self.cast(arg, to)

def visit_Limit(self, op, *, parent, n, offset):
Expand Down Expand Up @@ -457,5 +463,8 @@

return out

def visit_IntervalFromInteger(self, op, *, arg, unit):
return self._value_to_interval(arg, unit)


compiler = OracleCompiler()
3 changes: 0 additions & 3 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,6 @@ def visit_TimestampSub(self, op, *, left, right):
visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=self.v[unit.name])

def visit_IntegerRange(self, op, *, start, stop, step):
return self.if_(
step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array()
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@

try:
from oracledb.exceptions import DatabaseError as OracleDatabaseError
from oracledb.exceptions import InterfaceError as OracleInterfaceError
except ImportError:
OracleDatabaseError = None
OracleDatabaseError = OracleInterfaceError = None

try:
from pyodbc import DataError as PyODBCDataError
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_scalar_param(backend, alltypes, df, value, dtype, col):
["2009-01-20", datetime.date(2009, 1, 20), datetime.datetime(2009, 1, 20)],
ids=["string", "date", "datetime"],
)
@pytest.mark.notimpl(["druid", "oracle"])
@pytest.mark.notimpl(["druid"])
def test_scalar_param_date(backend, alltypes, value):
param = ibis.param("date")
ds_col = alltypes.date_string_col
Expand Down
Loading
Loading