Skip to content

Commit

Permalink
feat(api): add timestamp range
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 18, 2023
1 parent f20e34e commit c567fe0
Show file tree
Hide file tree
Showing 16 changed files with 497 additions and 43 deletions.
67 changes: 58 additions & 9 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,14 +776,62 @@ def _group_concat(translator, op):
return f"STRING_AGG({arg}, {sep})"


def _integer_range(translator, op):
start = translator.translate(op.start)
stop = translator.translate(op.stop)
step = translator.translate(op.step)
n = f"FLOOR(({stop} - {start}) / NULLIF({step}, 0))"
gen_array = f"GENERATE_ARRAY({start}, {stop}, {step})"
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
return f"IF({n} > 0, ARRAY({inner}), [])"
def _zero(dtype):
if dtype.is_interval():
return "MAKE_INTERVAL()"
return "0"


def _sign(value, dtype):
if dtype.is_interval():
zero = _zero(dtype)
return f"""\
CASE
WHEN {value} < {zero} THEN -1
WHEN {value} = {zero} THEN 0
WHEN {value} > {zero} THEN 1
ELSE NULL
END"""
return f"SIGN({value})"


def _nullifzero(step, zero, step_dtype):
if step_dtype.is_interval():
return f"IF({step} = {zero}, NULL, {step})"
return f"NULLIF({step}, {zero})"


def _make_range(func):
def _range(translator, op):
start = translator.translate(op.start)
stop = translator.translate(op.stop)
step = translator.translate(op.step)

step_dtype = op.step.dtype
step_sign = _sign(step, step_dtype)
delta_sign = _sign(step, step_dtype)
zero = _zero(step_dtype)
nullifzero = _nullifzero(step, zero, step_dtype)

condition = f"{nullifzero} IS NOT NULL AND {step_sign} = {delta_sign}"
gen_array = f"{func}({start}, {stop}, {step})"
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
return f"IF({condition}, ARRAY({inner}), [])"

return _range


def _timestamp_range(translator, op):
start = op.start
stop = op.stop

if start.dtype.timezone is None or stop.dtype.timezone is None:
raise com.IbisTypeError(
"Timestamps without timezone values are not supported when generating timestamp ranges"
)

rule = _make_range("GENERATE_TIMESTAMP_ARRAY")
return rule(translator, op)


OPERATION_REGISTRY = {
Expand Down Expand Up @@ -949,7 +997,8 @@ def _integer_range(translator, op):
ops.TimeDelta: _time_delta,
ops.DateDelta: _date_delta,
ops.TimestampDelta: _timestamp_delta,
ops.IntegerRange: _integer_range,
ops.IntegerRange: _make_range("GENERATE_ARRAY"),
ops.TimestampRange: _timestamp_range,
}

_invalid_operations = {
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import atexit
import glob
import warnings
from contextlib import closing, suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -169,6 +170,11 @@ def do_connect(
compress=compression,
**kwargs,
)
try:
with closing(self.raw_sql("SET session_timezone = 'UTC'")):
pass
except Exception as e: # noqa: BLE001
warnings.warn(f"Could not set timezone to UTC: {e}", category=UserWarning)
self._temp_views = set()

@property
Expand Down
27 changes: 27 additions & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,30 @@ def _agg_udf(op, *, where, **kw) -> str:
@translate_val.register(ops.TimestampDelta)
def _delta(op, *, part, left, right, **_):
return sg.exp.DateDiff(this=left, expression=right, unit=part)


@translate_val.register(ops.TimestampRange)
def _timestamp_range(op, *, start, stop, step, **_):
unit = op.step.dtype.unit.name.lower()

if not isinstance(op.step, ops.Literal):
raise com.UnsupportedOperationError(
"ClickHouse doesn't support non-literal step values"
)

step_value = op.step.value

offset = sg.to_identifier("offset")

# e.g., offset -> dateAdd(DAY, offset, start)
func = sg.exp.Lambda(
this=F.dateAdd(sg.to_identifier(unit), offset, start), expressions=[offset]
)

if step_value == 0:
return F.array()

result = F.arrayMap(
func, F.range(0, F.timestampDiff(unit, start, stop), step_value)
)
return result
4 changes: 3 additions & 1 deletion ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _literal(t, op):
sqla_type = t.get_sqla_type(dtype)

if dtype.is_interval():
return sa.literal_column(f"INTERVAL '{value} {dtype.resolution}'")
return getattr(sa.func, f"to_{dtype.unit.plural}")(value)
elif dtype.is_array():
values = value.tolist() if isinstance(value, np.ndarray) else value
return sa.cast(sa.func.list_value(*values), sqla_type)
Expand Down Expand Up @@ -550,6 +550,8 @@ def _array_remove(t, op):
ops.GeoWithin: fixed_arity(sa.func.ST_Within, 2),
ops.GeoX: unary(sa.func.ST_X),
ops.GeoY: unary(sa.func.ST_Y),
# other ops
ops.TimestampRange: fixed_arity(sa.func.range, 3),
}
)

Expand Down
18 changes: 17 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,9 @@ def execute_agg_udf(op, **kw):
@translate.register(ops.IntegerRange)
def execute_integer_range(op, **kw):
if not isinstance(op.step, ops.Literal):
raise NotImplementedError("Dynamic step not supported by Polars")
raise com.UnsupportedOperationError(
"Dynamic integer step not supported by Polars"
)
step = op.step.value

dtype = dtype_to_polars(op.dtype)
Expand All @@ -1217,3 +1219,17 @@ def execute_integer_range(op, **kw):
start = translate(op.start, **kw)
stop = translate(op.stop, **kw)
return pl.int_ranges(start, stop, step, dtype=dtype)


@translate.register(ops.TimestampRange)
def execute_timestamp_range(op, **kw):
if not isinstance(op.step, ops.Literal):
raise com.UnsupportedOperationError(
"Dynamic interval step not supported by Polars"
)
step = op.step.value
unit = op.step.dtype.unit.value

start = translate(op.start, **kw)
stop = translate(op.stop, **kw)
return pl.datetime_ranges(start, stop, f"{step}{unit}", closed="left")
32 changes: 25 additions & 7 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,19 +618,36 @@ def _array_filter(t, op):
)


def _integer_range(t, op):
def zero_value(dtype):
if dtype.is_interval():
return sa.func.make_interval()
return 0


def interval_sign(v):
zero = sa.func.make_interval()
return sa.case((v == zero, 0), (v < zero, -1), (v > zero, 1))


def _sign(value, dtype):
if dtype.is_interval():
return interval_sign(value)
return sa.func.sign(value)


def _range(t, op):
start = t.translate(op.start)
stop = t.translate(op.stop)
step = t.translate(op.step)
satype = t.get_sqla_type(op.dtype)
# `sequence` doesn't allow arguments that would produce an empty range, so
# check that first
n = sa.func.floor((stop - start) / sa.func.nullif(step, 0))
seq = sa.func.generate_series(start, stop, step, type_=satype)
zero = zero_value(op.step.dtype)
return sa.case(
# TODO(cpcloud): revisit using array_remove when my brain is working
(
n > 0,
sa.and_(
sa.func.nullif(step, zero).is_not(None),
_sign(step, op.step.dtype) == _sign(stop - start, op.step.dtype),
),
sa.func.array_remove(
sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype
),
Expand Down Expand Up @@ -839,6 +856,7 @@ def _integer_range(t, op):
ops.ArrayPosition: fixed_arity(_array_position, 2),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.IntegerRange: _integer_range,
ops.IntegerRange: _range,
ops.TimestampRange: _range,
}
)
45 changes: 36 additions & 9 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,18 +2082,45 @@ def compile_flatten(t, op, **kwargs):
return F.flatten(t.translate(op.arg, **kwargs))


def _zero_value(dtype):
if dtype.is_interval():
return F.expr(f"INTERVAL 0 {dtype.resolution}")
return F.lit(0)


def _build_sequence(start, stop, step, zero):
seq = F.sequence(start, stop, step)
length = F.size(seq)
last_element = F.element_at(seq, length)
# slice off the last element if we'd be inclusive on the right
seq = F.when(last_element == stop, F.slice(seq, 1, length - 1)).otherwise(seq)
return F.when(
(step != zero) & (F.signum(step) == F.signum(stop - start)), seq
).otherwise(F.array())


@compiles(ops.IntegerRange)
def compile_integer_range(t, op, **kwargs):
start = t.translate(op.start, **kwargs)
stop = t.translate(op.stop, **kwargs)
step = t.translate(op.step, **kwargs)

denom = F.when(step == 0, F.lit(None)).otherwise(step)
n = F.floor((stop - start) / denom)
seq = F.sequence(start, stop, step)
seq = F.slice(
seq,
1,
F.size(seq) - F.when(F.element_at(seq, F.size(seq)) == stop, 1).otherwise(0),
)
return F.when(n > 0, seq).otherwise(F.array())
return _build_sequence(start, stop, step, _zero_value(op.step.dtype))


@compiles(ops.TimestampRange)
def compile_timestamp_range(t, op, **kwargs):
start = t.translate(op.start, **kwargs)
stop = t.translate(op.stop, **kwargs)

if not isinstance(op.step, ops.Literal):
raise com.UnsupportedOperationError(
"`step` argument of timestamp range must be a literal"
)

step_value = op.step.value
unit = op.step.dtype.resolution

step = F.expr(f"INTERVAL {step_value} {unit}")

return _build_sequence(start, stop, step, _zero_value(op.step.dtype))
27 changes: 26 additions & 1 deletion ibis/backends/snowflake/converter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING

from ibis.formats.pandas import PandasData
Expand All @@ -18,7 +19,31 @@ def convert_JSON(s, dtype, pandas_type):
converter = SnowflakePandasData.convert_JSON_element(dtype)
return s.map(converter, na_action="ignore").astype("object")

convert_Struct = convert_Array = convert_Map = convert_JSON
convert_Struct = convert_Map = convert_JSON

@staticmethod
def get_element_converter(dtype):
funcgen = getattr(
SnowflakePandasData,
f"convert_{type(dtype).__name__}_element",
lambda _: lambda x: x,
)
return funcgen(dtype)

def convert_Timestamp_element(dtype):
return lambda values: list(map(datetime.datetime.fromisoformat, values))

def convert_Date_element(dtype):
return lambda values: list(map(datetime.date.fromisoformat, values))

def convert_Time_element(dtype):
return lambda values: list(map(datetime.time.fromisoformat, values))

@staticmethod
def convert_Array(s, dtype, pandas_type):
raw_json_objects = SnowflakePandasData.convert_JSON(s, dtype, pandas_type)
converter = SnowflakePandasData.get_element_converter(dtype.value_type)
return raw_json_objects.map(converter, na_action="ignore")


class SnowflakePyArrowData(PyArrowData):
Expand Down
60 changes: 60 additions & 0 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,65 @@ def _timestamp_bucket(t, op):
)


class _flatten(sa.sql.functions.GenericFunction):
def __init__(self, arg, *, type: sa.types.TypeEngine) -> None:
super().__init__(arg)
self.type = sa.sql.sqltypes.TableValueType(
sa.Column("index", sa.BIGINT()), sa.Column("value", type)
)


@compiles(_flatten, "snowflake")
def compiles_flatten(element, compiler, **kw):
(arg,) = element.clauses.clauses
return f"TABLE(FLATTEN(INPUT => {compiler.process(arg, **kw)}, MODE => 'ARRAY'))"


def _timestamp_range(t, op):
if not isinstance(op.step, ops.Literal):
raise com.UnsupportedOperationError("`step` argument must be a literal")

start = t.translate(op.start)
stop = t.translate(op.stop)

unit = op.step.dtype.unit.name.lower()
step = op.step.value

value_type = op.dtype.value_type

f = _flatten(
sa.func.array_generate_range(0, sa.func.datediff(unit, start, stop), step),
type=t.get_sqla_type(op.start.dtype),
).alias("f")
return sa.func.iff(
step != 0,
sa.select(
sa.func.array_agg(
sa.func.replace(
# conversion to varchar is necessary to control
# the timestamp format
#
# otherwise, since timestamps in arrays become strings
# anyway due to lack of parameterized type support in
# Snowflake the format depends on a session parameter
sa.func.to_varchar(
sa.func.dateadd(unit, f.c.value, start),
'YYYY-MM-DD"T"HH24:MI:SS.FF6'
+ (value_type.timezone is not None) * "TZH:TZM",
),
# timezones are always hour:minute offsets from UTC, not
# named, so replacing "Z" shouldn't be an issue
"Z",
"+00:00",
),
)
)
.select_from(f)
.scalar_subquery(),
sa.func.array_construct(),
)


_TIMESTAMP_UNITS_TO_SCALE = {"s": 0, "ms": 3, "us": 6, "ns": 9}

_SF_POS_INF = sa.func.to_double("Inf")
Expand Down Expand Up @@ -504,6 +563,7 @@ def _timestamp_bucket(t, op):
),
3,
),
ops.TimestampRange: _timestamp_range,
}
)

Expand Down
Loading

0 comments on commit c567fe0

Please sign in to comment.