From f5a0a5a7dcd14efdacd91d7f0ba8d6578a739be9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 17 Nov 2023 07:09:12 -0500 Subject: [PATCH] feat(api): add `ibis.range` function for generating sequences --- ibis/backends/bigquery/registry.py | 11 ++ ibis/backends/clickhouse/compiler/values.py | 1 + ibis/backends/duckdb/registry.py | 1 + ibis/backends/polars/compiler.py | 17 +++ ibis/backends/postgres/registry.py | 22 ++++ ibis/backends/pyspark/compiler.py | 17 +++ ibis/backends/snowflake/registry.py | 8 ++ ibis/backends/tests/test_array.py | 138 ++++++++++++++++++++ ibis/backends/trino/registry.py | 19 +++ ibis/expr/api.py | 100 +++++++++++++- ibis/expr/operations/arrays.py | 15 +++ ibis/expr/operations/relations.py | 3 +- ibis/expr/types/generic.py | 8 +- 13 files changed, 355 insertions(+), 5 deletions(-) diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index 8f4d702aa360..83851cb474ef 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -776,6 +776,16 @@ def _group_concat(translator, op): return f"STRING_AGG({arg}, {sep})" +def _integer_range(translator, op): + start = translator.translate(op.start) + stop = translator.translate(op.stop) + step = translator.translate(op.step) + n = f"FLOOR(({stop} - {start}) / NULLIF({step}, 0))" + gen_array = f"GENERATE_ARRAY({start}, {stop}, {step})" + inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}" + return f"IF({n} > 0, ARRAY({inner}), [])" + + OPERATION_REGISTRY = { **operation_registry, # Literal @@ -939,6 +949,7 @@ def _group_concat(translator, op): ops.TimeDelta: _time_delta, ops.DateDelta: _date_delta, ops.TimestampDelta: _timestamp_delta, + ops.IntegerRange: _integer_range, } _invalid_operations = { diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index bd2b99e8668f..fac0250edae0 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -817,6 +817,7 @@ def formatter(op, *, left, right, **_): ops.ExtractFragment: "fragment", ops.ArrayPosition: "indexOf", ops.ArrayFlatten: "arrayFlatten", + ops.IntegerRange: "range", } diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index 9e346560aff6..f2b7b3c54276 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -497,6 +497,7 @@ def _to_json_collection(t, op): ops.ToJSONMap: _to_json_collection, ops.ToJSONArray: _to_json_collection, ops.ArrayFlatten: unary(sa.func.flatten), + ops.IntegerRange: fixed_arity(sa.func.range, 3), } ) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index ecaad4c01dfa..b14beda35632 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1198,3 +1198,20 @@ def execute_agg_udf(op, **kw): if (where := op.where) is not None: first = first.filter(translate(where, **kw)) return getattr(first, op.__func_name__)(*rest) + + +@translate.register(ops.IntegerRange) +def execute_integer_range(op, **kw): + if not isinstance(op.step, ops.Literal): + raise NotImplementedError("Dynamic step not supported by Polars") + step = op.step.value + + dtype = dtype_to_polars(op.dtype) + empty = pl.int_ranges(0, 0, dtype=dtype) + + if step == 0: + return empty + + start = translate(op.start, **kw) + stop = translate(op.stop, **kw) + return pl.int_ranges(start, stop, step, dtype=dtype) diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index da90a2ff2485..e9f506799c78 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -604,6 +604,27 @@ def _array_filter(t, op): ) +def _integer_range(t, op): + start = t.translate(op.start) + stop = t.translate(op.stop) + step = t.translate(op.step) + satype = t.get_sqla_type(op.dtype) + # `sequence` doesn't allow arguments that would produce an empty range, so + # check that first + n = sa.func.floor((stop - start) / sa.func.nullif(step, 0)) + seq = sa.func.generate_series(start, stop, step, type_=satype) + return sa.case( + # TODO(cpcloud): revisit using array_remove when my brain is working + ( + n > 0, + sa.func.array_remove( + sa.func.array(sa.select(seq).scalar_subquery()), stop, type_=satype + ), + ), + else_=sa.cast(pg.array([]), satype), + ) + + operation_registry.update( { ops.Literal: _literal, @@ -802,5 +823,6 @@ def _array_filter(t, op): ops.ArrayPosition: fixed_arity(_array_position, 2), ops.ArrayMap: _array_map, ops.ArrayFilter: _array_filter, + ops.IntegerRange: _integer_range, } ) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 7ebf423dca8e..0245c24bda1f 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -2080,3 +2080,20 @@ def compile_levenshtein(t, op, **kwargs): @compiles(ops.ArrayFlatten) def compile_flatten(t, op, **kwargs): return F.flatten(t.translate(op.arg, **kwargs)) + + +@compiles(ops.IntegerRange) +def compile_integer_range(t, op, **kwargs): + start = t.translate(op.start, **kwargs) + stop = t.translate(op.stop, **kwargs) + step = t.translate(op.step, **kwargs) + + denom = F.when(step == 0, F.lit(None)).otherwise(step) + n = F.floor((stop - start) / denom) + seq = F.sequence(start, stop, step) + seq = F.slice( + seq, + 1, + F.size(seq) - F.when(F.element_at(seq, F.size(seq)) == stop, 1).otherwise(0), + ) + return F.when(n > 0, seq).otherwise(F.array()) diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 4ce03b0d8bc4..8db1096a7831 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -491,6 +491,14 @@ def _timestamp_bucket(t, op): lambda part, left, right: sa.func.timestampdiff(part, right, left), 3 ), ops.TimestampBucket: _timestamp_bucket, + ops.IntegerRange: fixed_arity( + lambda start, stop, step: sa.func.iff( + step != 0, + sa.func.array_generate_range(start, stop, step), + sa.func.array_construct(), + ), + 3, + ), } ) diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 44033a029020..73ec7915f7a9 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -39,6 +39,12 @@ except ImportError: PySparkAnalysisException = None + +try: + from polars.exceptions import PolarsInvalidOperationError +except ImportError: + PolarsInvalidOperationError = None + pytestmark = [ pytest.mark.never( ["sqlite", "mysql", "mssql"], @@ -910,3 +916,135 @@ def test_array_flatten(backend, flatten_data, column, expected): expr = t[column].flatten() result = backend.connection.execute(expr) backend.assert_series_equal(result, expected, check_names=False) + + +polars_overflow = pytest.mark.notyet( + ["polars"], + reason="but polars overflows allocation with some inputs", + raises=BaseException, +) + + +@pytest.mark.notyet( + ["datafusion"], + reason="range isn't implemented upstream", + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError) +@pytest.mark.parametrize("n", [param(-2, marks=[polars_overflow]), 0, 2]) +def test_range_single_argument(con, n): + expr = ibis.range(n) + result = con.execute(expr) + assert list(result) == list(range(n)) + + +@pytest.mark.notyet( + ["datafusion"], + reason="range and unnest aren't implemented upstream", + raises=com.OperationNotDefinedError, +) +@pytest.mark.parametrize( + "n", + [ + param( + -2, + marks=[ + pytest.mark.broken( + ["snowflake"], + reason="snowflake unnests empty arrays to null", + raises=AssertionError, + ) + ], + ), + param( + 0, + marks=[ + pytest.mark.broken( + ["snowflake"], + reason="snowflake unnests empty arrays to null", + raises=AssertionError, + ) + ], + ), + 2, + ], +) +@pytest.mark.notimpl( + ["polars", "flink", "pandas", "dask"], raises=com.OperationNotDefinedError +) +def test_range_single_argument_unnest(con, n): + expr = ibis.range(n).unnest() + result = con.execute(expr) + tm.assert_series_equal( + result, + pd.Series(list(range(n)), dtype=result.dtype, name=expr.get_name()), + check_index=False, + ) + + +@pytest.mark.parametrize( + "step", + [ + param( + -2, + marks=[ + pytest.mark.notyet( + ["polars"], + reason="panic upstream", + raises=PolarsInvalidOperationError, + ) + ], + ), + param( + -1, + marks=[ + pytest.mark.notyet( + ["polars"], + reason="panic upstream", + raises=PolarsInvalidOperationError, + ) + ], + ), + 1, + 2, + ], +) +@pytest.mark.parametrize( + ("start", "stop"), + [ + param(-7, -7), + param(-7, 0), + param(-7, 7), + param(0, -7, marks=[polars_overflow]), + param(0, 0), + param(0, 7), + param(7, -7, marks=[polars_overflow]), + param(7, 0, marks=[polars_overflow]), + param(7, 7), + ], +) +@pytest.mark.notyet( + ["datafusion"], + reason="range and unnest aren't implemented upstream", + raises=com.OperationNotDefinedError, +) +@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError) +def test_range_start_stop_step(con, start, stop, step): + expr = ibis.range(start, stop, step) + result = con.execute(expr) + assert list(result) == list(range(start, stop, step)) + + +@pytest.mark.parametrize("stop", [-7, 0, 7]) +@pytest.mark.parametrize("start", [-7, 0, 7]) +@pytest.mark.notyet( + ["clickhouse"], raises=ClickhouseDatabaseError, reason="not supported upstream" +) +@pytest.mark.notyet( + ["datafusion"], raises=com.OperationNotDefinedError, reason="not supported upstream" +) +@pytest.mark.notimpl(["flink", "pandas", "dask"], raises=com.OperationNotDefinedError) +def test_range_start_stop_step_zero(con, start, stop): + expr = ibis.range(start, stop, 0) + result = con.execute(expr) + assert list(result) == [] diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index 347753395910..87c3e776e2ef 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -350,6 +350,24 @@ def _interval_from_integer(t, op): return sa.type_coerce(sa.func.parse_duration(arg), INTERVAL) +def _integer_range(t, op): + start = t.translate(op.start) + stop = t.translate(op.stop) + step = t.translate(op.step) + satype = t.get_sqla_type(op.dtype) + # `sequence` doesn't allow arguments that would produce an empty range, so + # check that first + n = sa.func.floor((stop - start) / sa.func.nullif(step, 0)) + return if_( + n > 0, + # TODO(cpcloud): revisit using array_remove when my brain is working + sa.func.array_remove( + sa.func.sequence(start, stop, step, type_=satype), stop, type_=satype + ), + sa.literal_column("ARRAY[]"), + ) + + operation_registry.update( { # conditional expressions @@ -547,6 +565,7 @@ def _interval_from_integer(t, op): ops.IntervalAdd: fixed_arity(operator.add, 2), ops.IntervalSubtract: fixed_arity(operator.sub, 2), ops.IntervalFromInteger: _interval_from_integer, + ops.IntegerRange: _integer_range, } ) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 69d4461200ee..2f3631061fa1 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -2,6 +2,7 @@ from __future__ import annotations +import builtins import datetime import functools import numbers @@ -143,6 +144,7 @@ "parse_sql", "pi", "random", + "range", "range_window", "read_csv", "read_delta", @@ -471,7 +473,7 @@ def _memtable_from_dataframe( newcols = getattr( schema, "names", - (f"col{i:d}" for i in range(len(cols))), + (f"col{i:d}" for i in builtins.range(len(cols))), ) df = df.rename(columns=dict(zip(cols, newcols))) op = ops.InMemoryTable( @@ -1947,6 +1949,102 @@ def watermark(time_col: str, allowed_delay: ir.IntervalScalar) -> Watermark: return Watermark(time_col=time_col, allowed_delay=allowed_delay) +@functools.singledispatch +def range(start, stop, step) -> ir.ArrayValue: + """Generate a range of values. + + ::: {.callout-note} + `start` is inclucive and `stop` is exclusive, just like Python's builtin + [`range`](range). + + When `step` equals 0, however, this function will return an empty array. + + Python's `range` will raise an exception when `step` is zero. + ::: + + Parameters + ---------- + start + Lower bound of the range, inclusive. + stop + Upper bound of the range, exclusive. + step + Step value. Optional, defaults to 1. + + Returns + ------- + ArrayValue + An array of values + + Examples + -------- + >>> import ibis + >>> ibis.options.interactive = True + + Range using only a stop argument + + >>> ibis.range(5) + [0, 1, 2, 3, 4] + + Simple range using start and stop + + >>> ibis.range(1, 5) + [1, 2, 3, 4] + + Generate an empty range + + >>> ibis.range(0) + [] + + Negative step values are supported + + >>> ibis.range(10, 4, -2) + [10, 8, 6] + + `ibis.range` behaves the same as Python's range ... + + >>> ibis.range(0, 7, -1) + [] + + ... except when the step is zero, in which case `ibis.range` returns an + empty array + + >>> ibis.range(0, 5, 0) + [] + + Because the resulting expression is array, you can unnest the values + + >>> ibis.range(5).unnest().name("numbers") + ┏━━━━━━━━━┓ + ┃ numbers ┃ + ┡━━━━━━━━━┩ + │ int8 │ + ├─────────┤ + │ 0 │ + │ 1 │ + │ 2 │ + │ 3 │ + │ 4 │ + └─────────┘ + """ + raise NotImplementedError() + + +@range.register(int) +@range.register(ir.IntegerValue) +def _int_range( + start: int, + stop: int | ir.IntegerValue | None = None, + step: int | ir.IntegerValue | None = None, +) -> ir.ArrayValue: + if stop is None: + stop = start + start = 0 + if step is None: + step = 1 + return ops.IntegerRange(start=start, stop=stop, step=step).to_expr() + + def _wrap_deprecated(fn, prefix=""): """Deprecate the top-level geo function.""" diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 67d1e736e676..8b00ae5a17b6 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -202,3 +202,18 @@ class ArrayFlatten(Value): @property def dtype(self): return self.arg.dtype.value_type + + +class Range(Value): + shape = rlz.shape_like("args") + + @attribute + def dtype(self) -> dt.DataType: + return dt.Array(dt.highest_precedence((self.start.dtype, self.stop.dtype))) + + +@public +class IntegerRange(Range): + start: Value[dt.Integer] + stop: Value[dt.Integer] + step: Value[dt.Integer] diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 5789537a44f1..7d5f231d8da1 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -9,7 +9,6 @@ from public import public import ibis.common.exceptions as com -import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util @@ -502,7 +501,7 @@ def _projection(self): @public class DummyTable(Relation): # TODO(kszucs): verify that it has at least one element: Length(at_least=1) - values: VarTuple[Value[dt.Any, ds.Scalar]] + values: VarTuple[Value[dt.Any]] @attribute def schema(self): diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 87a0b7e69bb1..38bfea5f8148 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1172,8 +1172,12 @@ def as_table(self) -> ir.Table: "involving multiple base table references " "to a projection" ) - table = roots[0].to_expr() - return table.select(self) + + if roots: + return roots[0].to_expr().select(self) + + # no child table to select from + return ops.DummyTable(values=(self,)).to_expr() def to_pandas(self, **kwargs) -> pd.Series: """Convert a column expression to a pandas Series or scalar object.