Skip to content

Commit

Permalink
feat(api): add ibis.range function for generating sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 21, 2023
1 parent 5d1fadf commit f5a0a5a
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 5 deletions.
11 changes: 11 additions & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,16 @@ def _group_concat(translator, op):
return f"STRING_AGG({arg}, {sep})"


def _integer_range(translator, op):
start = translator.translate(op.start)
stop = translator.translate(op.stop)
step = translator.translate(op.step)
n = f"FLOOR(({stop} - {start}) / NULLIF({step}, 0))"
gen_array = f"GENERATE_ARRAY({start}, {stop}, {step})"
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
return f"IF({n} > 0, ARRAY({inner}), [])"


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand Down Expand Up @@ -939,6 +949,7 @@ def _group_concat(translator, op):
ops.TimeDelta: _time_delta,
ops.DateDelta: _date_delta,
ops.TimestampDelta: _timestamp_delta,
ops.IntegerRange: _integer_range,
}

_invalid_operations = {
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,7 @@ def formatter(op, *, left, right, **_):
ops.ExtractFragment: "fragment",
ops.ArrayPosition: "indexOf",
ops.ArrayFlatten: "arrayFlatten",
ops.IntegerRange: "range",
}


Expand Down
1 change: 1 addition & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def _to_json_collection(t, op):
ops.ToJSONMap: _to_json_collection,
ops.ToJSONArray: _to_json_collection,
ops.ArrayFlatten: unary(sa.func.flatten),
ops.IntegerRange: fixed_arity(sa.func.range, 3),
}
)

Expand Down
17 changes: 17 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,3 +1198,20 @@ def execute_agg_udf(op, **kw):
if (where := op.where) is not None:
first = first.filter(translate(where, **kw))
return getattr(first, op.__func_name__)(*rest)


@translate.register(ops.IntegerRange)
def execute_integer_range(op, **kw):
if not isinstance(op.step, ops.Literal):
raise NotImplementedError("Dynamic step not supported by Polars")
step = op.step.value

dtype = dtype_to_polars(op.dtype)
empty = pl.int_ranges(0, 0, dtype=dtype)

if step == 0:
return empty

start = translate(op.start, **kw)
stop = translate(op.stop, **kw)
return pl.int_ranges(start, stop, step, dtype=dtype)
22 changes: 22 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,27 @@ def _array_filter(t, op):
)


def _integer_range(t, op):
start = t.translate(op.start)
stop = t.translate(op.stop)
step = t.translate(op.step)
satype = t.get_sqla_type(op.dtype)
# `sequence` doesn't allow arguments that would produce an empty range, so
# check that first
n = sa.func.floor((stop - start) / sa.func.nullif(step, 0))
seq = sa.func.generate_series(start, stop, step, type_=satype)
return sa.case(
# TODO(cpcloud): revisit using array_remove when my brain is working
(
n > 0,
sa.func.array_remove(
sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype
),
),
else_=sa.cast(pg.array([]), satype),
)


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -802,5 +823,6 @@ def _array_filter(t, op):
ops.ArrayPosition: fixed_arity(_array_position, 2),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.IntegerRange: _integer_range,
}
)
17 changes: 17 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,3 +2080,20 @@ def compile_levenshtein(t, op, **kwargs):
@compiles(ops.ArrayFlatten)
def compile_flatten(t, op, **kwargs):
return F.flatten(t.translate(op.arg, **kwargs))


@compiles(ops.IntegerRange)
def compile_integer_range(t, op, **kwargs):
start = t.translate(op.start, **kwargs)
stop = t.translate(op.stop, **kwargs)
step = t.translate(op.step, **kwargs)

denom = F.when(step == 0, F.lit(None)).otherwise(step)
n = F.floor((stop - start) / denom)
seq = F.sequence(start, stop, step)
seq = F.slice(
seq,
1,
F.size(seq) - F.when(F.element_at(seq, F.size(seq)) == stop, 1).otherwise(0),
)
return F.when(n > 0, seq).otherwise(F.array())
8 changes: 8 additions & 0 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,14 @@ def _timestamp_bucket(t, op):
lambda part, left, right: sa.func.timestampdiff(part, right, left), 3
),
ops.TimestampBucket: _timestamp_bucket,
ops.IntegerRange: fixed_arity(
lambda start, stop, step: sa.func.iff(
step != 0,
sa.func.array_generate_range(start, stop, step),
sa.func.array_construct(),
),
3,
),
}
)

Expand Down
138 changes: 138 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
except ImportError:
PySparkAnalysisException = None


try:
from polars.exceptions import PolarsInvalidOperationError
except ImportError:
PolarsInvalidOperationError = None

pytestmark = [
pytest.mark.never(
["sqlite", "mysql", "mssql"],
Expand Down Expand Up @@ -910,3 +916,135 @@ def test_array_flatten(backend, flatten_data, column, expected):
expr = t[column].flatten()
result = backend.connection.execute(expr)
backend.assert_series_equal(result, expected, check_names=False)


polars_overflow = pytest.mark.notyet(
["polars"],
reason="but polars overflows allocation with some inputs",
raises=BaseException,
)


@pytest.mark.notyet(
["datafusion"],
reason="range isn't implemented upstream",
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError)
@pytest.mark.parametrize("n", [param(-2, marks=[polars_overflow]), 0, 2])
def test_range_single_argument(con, n):
expr = ibis.range(n)
result = con.execute(expr)
assert list(result) == list(range(n))


@pytest.mark.notyet(
["datafusion"],
reason="range and unnest aren't implemented upstream",
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize(
"n",
[
param(
-2,
marks=[
pytest.mark.broken(
["snowflake"],
reason="snowflake unnests empty arrays to null",
raises=AssertionError,
)
],
),
param(
0,
marks=[
pytest.mark.broken(
["snowflake"],
reason="snowflake unnests empty arrays to null",
raises=AssertionError,
)
],
),
2,
],
)
@pytest.mark.notimpl(
["polars", "flink", "pandas", "dask"], raises=com.OperationNotDefinedError
)
def test_range_single_argument_unnest(con, n):
expr = ibis.range(n).unnest()
result = con.execute(expr)
tm.assert_series_equal(
result,
pd.Series(list(range(n)), dtype=result.dtype, name=expr.get_name()),
check_index=False,
)


@pytest.mark.parametrize(
"step",
[
param(
-2,
marks=[
pytest.mark.notyet(
["polars"],
reason="panic upstream",
raises=PolarsInvalidOperationError,
)
],
),
param(
-1,
marks=[
pytest.mark.notyet(
["polars"],
reason="panic upstream",
raises=PolarsInvalidOperationError,
)
],
),
1,
2,
],
)
@pytest.mark.parametrize(
("start", "stop"),
[
param(-7, -7),
param(-7, 0),
param(-7, 7),
param(0, -7, marks=[polars_overflow]),
param(0, 0),
param(0, 7),
param(7, -7, marks=[polars_overflow]),
param(7, 0, marks=[polars_overflow]),
param(7, 7),
],
)
@pytest.mark.notyet(
["datafusion"],
reason="range and unnest aren't implemented upstream",
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError)
def test_range_start_stop_step(con, start, stop, step):
expr = ibis.range(start, stop, step)
result = con.execute(expr)
assert list(result) == list(range(start, stop, step))


@pytest.mark.parametrize("stop", [-7, 0, 7])
@pytest.mark.parametrize("start", [-7, 0, 7])
@pytest.mark.notyet(
["clickhouse"], raises=ClickhouseDatabaseError, reason="not supported upstream"
)
@pytest.mark.notyet(
["datafusion"], raises=com.OperationNotDefinedError, reason="not supported upstream"
)
@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError)
def test_range_start_stop_step_zero(con, start, stop):
expr = ibis.range(start, stop, 0)
result = con.execute(expr)
assert list(result) == []
19 changes: 19 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,24 @@ def _interval_from_integer(t, op):
return sa.type_coerce(sa.func.parse_duration(arg), INTERVAL)


def _integer_range(t, op):
start = t.translate(op.start)
stop = t.translate(op.stop)
step = t.translate(op.step)
satype = t.get_sqla_type(op.dtype)
# `sequence` doesn't allow arguments that would produce an empty range, so
# check that first
n = sa.func.floor((stop - start) / sa.func.nullif(step, 0))
return if_(
n > 0,
# TODO(cpcloud): revisit using array_remove when my brain is working
sa.func.array_remove(
sa.func.sequence(start, stop, step, type_=satype), stop, type_=satype
),
sa.literal_column("ARRAY[]"),
)


operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -547,6 +565,7 @@ def _interval_from_integer(t, op):
ops.IntervalAdd: fixed_arity(operator.add, 2),
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
ops.IntervalFromInteger: _interval_from_integer,
ops.IntegerRange: _integer_range,
}
)

Expand Down
Loading

0 comments on commit f5a0a5a

Please sign in to comment.