Skip to content

Commit

Permalink
refactor(intervals): conslidate interval conversion under `_make_inte…
Browse files Browse the repository at this point in the history
…rval` base compiler implementation
  • Loading branch information
cpcloud committed Aug 9, 2024
1 parent 84bfeb5 commit fe29210
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 51 deletions.
10 changes: 9 additions & 1 deletion ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,11 @@ def visit_Field(self, op, *, rel, name):
)

def visit_Cast(self, op, *, arg, to):
from_ = op.arg.dtype

if from_.is_integer() and to.is_interval():
return self._make_interval(arg, to.unit)

return self.cast(arg, to)

def visit_ScalarSubquery(self, op, *, rel):
Expand Down Expand Up @@ -941,8 +946,11 @@ def visit_DayOfWeekName(self, op, *, arg):
ifs=list(itertools.starmap(self.if_, enumerate(calendar.day_name))),
)

def _make_interval(self, arg, unit):
return sge.Interval(this=arg, unit=self.v[unit.singular])

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=self.v[unit.singular.upper()])
return self._make_interval(arg, unit)

### String Instruments
def visit_Strip(self, op, *, arg):
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,6 @@ def visit_Cast(self, op, *, arg, to):
f"BigQuery does not allow extracting date part `{from_.unit}` from intervals"
)
return self.f.extract(self.v[to.resolution.upper()], arg)
elif from_.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.singular])
elif from_.is_floating() and to.is_integer():
return self.cast(self.f.trunc(arg), dt.int64)
return super().visit_Cast(op, arg=arg, to=to)
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def visit_NonNullLiteral(self, op, *, value, dtype):

def visit_Cast(self, op, *, arg, to):
from_ = op.arg.dtype
if from_.is_integer() and to.is_interval():
return sge.Interval(this=sge.convert(arg), unit=to.unit.singular.upper())
elif from_.is_temporal() and to.is_integer():
if from_.is_temporal() and to.is_integer():
return 1_000_000 * self.f.unix_timestamp(arg)
return super().visit_Cast(op, arg=arg, to=to)

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,6 @@ def visit_Cast(self, op, *, arg, to):
return arg
elif from_.is_integer() and to.is_timestamp():
return self.f.dateadd(self.v.s, arg, "1970-01-01 00:00:00")
elif from_.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.singular])
return super().visit_Cast(op, arg=arg, to=to)

def visit_Sum(self, op, *, arg, where):
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def visit_Cast(self, op, *, arg, to):
# MariaDB does not support casting to JSON because it's an alias
# for TEXT (except when casting of course!)
return arg
elif from_.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.singular.upper()])
elif from_.is_integer() and to.is_timestamp():
return self.f.from_unixtime(arg)
return super().visit_Cast(op, arg=arg, to=to)
Expand Down
9 changes: 2 additions & 7 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def visit_ArraySlice(self, op, *, arg, start, stop):
slice_expr = sge.Slice(this=start + 1, expression=stop)
return sge.paren(arg, copy=False)[slice_expr]

def visit_IntervalFromInteger(self, op, *, arg, unit):
def _make_interval(self, arg, unit):
plural = unit.plural
if plural == "minutes":
plural = "mins"
Expand Down Expand Up @@ -666,19 +666,14 @@ def visit_Cast(self, op, *, arg, to):
if (timezone := to.timezone) is not None:
arg = self.f.timezone(timezone, arg)
return arg
elif from_.is_integer() and to.is_interval():
unit = to.unit
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit), arg=arg, unit=unit
)
elif from_.is_string() and to.is_binary():
# Postgres and Python use the words "decode" and "encode" in
# opposite ways, sweet!
return self.f.decode(arg, "escape")
elif from_.is_binary() and to.is_string():
return self.f.encode(arg, "escape")

return self.cast(arg, op.to)
return super().visit_Cast(op, arg=arg, to=to)

visit_TryCast = visit_Cast

Expand Down
10 changes: 2 additions & 8 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sqlglot.expressions as sge

import ibis.common.exceptions as com
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compilers import PostgresCompiler
Expand Down Expand Up @@ -87,13 +86,8 @@ def visit_TimestampTruncate(self, op, *, arg, unit):

visit_TimeTruncate = visit_DateTruncate = visit_TimestampTruncate

def visit_IntervalFromInteger(self, op, *, arg, unit):
if op.arg.shape == ds.scalar:
return sge.Interval(this=arg, unit=self.v[unit.name])
elif op.arg.shape == ds.columnar:
return arg * sge.Interval(this=sge.convert(1), unit=self.v[unit.name])
else:
raise ValueError("Invalid shape for converting to interval")
def _make_interval(self, arg, unit):
return arg * sge.Interval(this=sge.convert(1), unit=self.v[unit.name])

def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_binary():
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,7 @@ def visit_Cast(self, op, *, arg, to):
return self.if_(self.f.is_object(arg), arg, NULL)
elif to.is_array():
return self.if_(self.f.is_array(arg), arg, NULL)
elif op.arg.dtype.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.name])
return self.cast(arg, to)
return super().visit_Cast(op, arg=arg, to=to)

def visit_ToJSONMap(self, op, *, arg):
return self.if_(self.f.is_object(arg), arg, NULL)
Expand Down
34 changes: 17 additions & 17 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,7 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
elif dtype.is_time():
return self.cast(value.isoformat(), dtype)
elif dtype.is_interval():
return sge.Interval(
this=sge.convert(str(value)), unit=self.v[dtype.resolution.upper()]
)
return self._make_interval(sge.convert(str(value)), dtype.unit)
elif dtype.is_binary():
return self.f.from_hex(value.hex())
else:
Expand Down Expand Up @@ -442,15 +440,23 @@ def visit_TemporalDelta(self, op, *, part, left, right):

visit_TimeDelta = visit_DateDelta = visit_TimestampDelta = visit_TemporalDelta

def visit_IntervalFromInteger(self, op, *, arg, unit):
unit = op.unit.short
if unit in ("Y", "Q", "M", "W"):
def _make_interval(self, arg, unit):
short = unit.short
if short in ("Y", "Q", "M", "W"):
raise com.UnsupportedOperationError(f"Interval unit {unit!r} not supported")
return self.f.parse_duration(
self.f.concat(
self.cast(arg, dt.String(nullable=op.arg.dtype.nullable)), unit.lower()
elif short in ("D", "h", "m", "s", "ms", "us"):
if isinstance(arg, sge.Literal):
# force strings in interval literals because trino requires it
arg.args["is_string"] = True
return super()._make_interval(arg, unit)
else:
return self.f.parse_duration(
self.f.concat(self.cast(arg, dt.string), short.lower())
)
else:
raise com.UnsupportedOperationError(
f"Interval unit {unit.name!r} not supported"
)
)

def visit_Range(self, op, *, start, stop, step):
def zero_value(dtype):
Expand Down Expand Up @@ -492,13 +498,7 @@ def visit_ArrayIndex(self, op, *, arg, index):

def visit_Cast(self, op, *, arg, to):
from_ = op.arg.dtype
if from_.is_integer() and to.is_interval():
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit=to.unit),
arg=arg,
unit=to.unit,
)
elif from_.is_integer() and to.is_timestamp():
if from_.is_integer() and to.is_timestamp():
return self.f.from_unixtime(arg, to.timezone or "UTC")
return super().visit_Cast(op, arg=arg, to=to)

Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import ibis.expr.datatypes as dt
from ibis.backends.tests.errors import (
ClickHouseDatabaseError,
ExaQueryError,
GoogleBadRequest,
ImpalaHiveServer2Error,
MySQLOperationalError,
Expand Down Expand Up @@ -1100,11 +1099,6 @@ def test_first_last(backend):
["mssql"], raises=PyODBCProgrammingError, reason="not support by the backend"
)
@pytest.mark.notyet(["flink"], raises=Py4JJavaError, reason="bug in Flink")
@pytest.mark.notyet(
["exasol"],
raises=ExaQueryError,
reason="database can't handle UTC timestamps in DataFrames",
)
@pytest.mark.notyet(
["risingwave"],
raises=PsycoPg2InternalError,
Expand Down

0 comments on commit fe29210

Please sign in to comment.