From caee5c16587acf371aec3cac35eea6caf21954b3 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 22 Jul 2023 08:06:20 -0400 Subject: [PATCH] feat(api): implement negative slice indexing --- .../base/sql/alchemy/query_builder.py | 39 ++- .../base/sql/compiler/query_builder.py | 54 ++-- .../base/sql/compiler/select_builder.py | 26 +- ibis/backends/bigquery/compiler.py | 1 + .../backends/clickhouse/compiler/relations.py | 18 +- ibis/backends/dask/execution/generic.py | 12 + ibis/backends/datafusion/compiler.py | 14 +- ibis/backends/druid/compiler.py | 1 + ibis/backends/mssql/compiler.py | 1 + ibis/backends/mysql/compiler.py | 1 + ibis/backends/oracle/__init__.py | 1 + ibis/backends/pandas/execution/generic.py | 7 +- ibis/backends/polars/compiler.py | 10 +- ibis/backends/pyspark/compiler.py | 20 +- ibis/backends/sqlite/compiler.py | 1 + ibis/backends/tests/test_generic.py | 232 +++++++++++++++++- ibis/backends/trino/compiler.py | 1 + ibis/expr/operations/relations.py | 4 +- ibis/expr/types/relations.py | 76 ++++-- .../test_respect_set_limit/out.sql | 10 +- ibis/tests/expr/test_table.py | 28 +-- .../result.sql | 18 ++ .../test_multiple_limits/out.sql | 6 +- ibis/tests/sql/test_select_sql.py | 15 ++ ibis/util.py | 86 +++++++ 25 files changed, 584 insertions(+), 98 deletions(-) create mode 100644 ibis/tests/sql/snapshots/test_select_sql/test_chain_limit_doesnt_collapse/result.sql diff --git a/ibis/backends/base/sql/alchemy/query_builder.py b/ibis/backends/base/sql/alchemy/query_builder.py index bc8d1e1b11ff..8b8662eb35af 100644 --- a/ibis/backends/base/sql/alchemy/query_builder.py +++ b/ibis/backends/base/sql/alchemy/query_builder.py @@ -340,8 +340,32 @@ def _add_limit(self, fragment): if self.limit is None: return fragment - fragment = fragment.limit(self.limit.n) - if offset := self.limit.offset: + frag = fragment + + n = self.limit.n + + if n is None: + n = self.context.compiler.null_limit + elif not isinstance(n, int): + n = ( + sa.select(self._translate(n)) + .select_from(frag.subquery()) + .scalar_subquery() + ) + + if n is not None: + fragment = fragment.limit(n) + + offset = self.limit.offset + + if not isinstance(offset, int): + offset = ( + sa.select(self._translate(offset)) + .select_from(frag.subquery()) + .scalar_subquery() + ) + + if offset != 0 and n != 0: fragment = fragment.offset(offset) return fragment @@ -393,6 +417,17 @@ class AlchemyCompiler(Compiler): supports_indexed_grouping_keys = True + # Value to use when the user specified `n` from the `limit` API is + # `None`. + # + # For some backends this is: + # * the identifier ALL (sa.literal_column('ALL')) + # * a NULL literal (sa.null()) + # + # and some don't accept an unbounded limit at all: the `LIMIT` + # keyword must simply be left out of the query + null_limit = sa.null() + @classmethod def to_sql(cls, expr, context=None, params=None, exists=False): if context is None: diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py index 3fcb6657ccb3..a6970401a313 100644 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ b/ibis/backends/base/sql/compiler/query_builder.py @@ -11,7 +11,7 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.base.sql.compiler.base import DML, QueryAST, SetOp -from ibis.backends.base.sql.compiler.select_builder import SelectBuilder, _LimitSpec +from ibis.backends.base.sql.compiler.select_builder import SelectBuilder from ibis.backends.base.sql.compiler.translator import ExprTranslator, QueryContext from ibis.backends.base.sql.registry import quote_identifier from ibis.common.grounds import Comparable @@ -422,14 +422,27 @@ def format_order_by(self): return buf.getvalue() def format_limit(self): - if not self.limit: + if self.limit is None: return None buf = StringIO() n = self.limit.n - buf.write(f"LIMIT {n}") - if offset := self.limit.offset: + + if n is None: + n = self.context.compiler.null_limit + elif not isinstance(n, int): + n = f"(SELECT {self._translate(n)} {self.format_table_set()})" + + if n is not None: + buf.write(f"LIMIT {n}") + + offset = self.limit.offset + + if not isinstance(offset, int): + offset = f"(SELECT {self._translate(offset)} {self.format_table_set()})" + + if offset != 0 and n != 0: buf.write(f" OFFSET {offset}") return buf.getvalue() @@ -501,6 +514,7 @@ class Compiler: cheap_in_memory_tables = False support_values_syntax_in_select = True + null_limit = None @classmethod def make_context(cls, params=None): @@ -555,27 +569,17 @@ def to_ast(cls, node, context=None): @classmethod def to_ast_ensure_limit(cls, expr, limit=None, params=None): context = cls.make_context(params=params) - query_ast = cls.to_ast(expr, context) - - # note: limit can still be None at this point, if the global - # default_limit is None - for query in reversed(query_ast.queries): - if ( - isinstance(query, Select) - and not isinstance(expr, ir.Scalar) - and query.table_set is not None - ): - if query.limit is None: - if limit == "default": - query_limit = options.sql.default_limit - else: - query_limit = limit - if query_limit: - query.limit = _LimitSpec(query_limit, offset=0) - elif limit is not None and limit != "default": - query.limit = _LimitSpec(limit, query.limit.offset) - - return query_ast + table = expr.as_table() + + if limit == "default": + query_limit = options.sql.default_limit + else: + query_limit = limit + + if query_limit is not None: + table = table.limit(query_limit) + + return cls.to_ast(table, context) @classmethod def to_sql(cls, node, context=None, params=None): diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py index 15f316d15c58..4966c3778811 100644 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ b/ibis/backends/base/sql/compiler/select_builder.py @@ -10,8 +10,8 @@ class _LimitSpec(NamedTuple): - n: int - offset: int + n: ops.Value | int | None + offset: ops.Value | int = 0 class SelectBuilder: @@ -182,21 +182,15 @@ def _collect_FillNa(self, op, toplevel=False): self._collect(new_op, toplevel=toplevel) def _collect_Limit(self, op, toplevel=False): - if not toplevel: - return - - n = op.n - offset = op.offset or 0 - - if self.limit is None: - self.limit = _LimitSpec(n, offset) - else: - self.limit = _LimitSpec( - min(n, self.limit.n), - offset + self.limit.offset, - ) + if toplevel: + if isinstance(table := op.table, ops.Limit): + self.table_set = table + self.select_set = [table] + else: + self._collect(table, toplevel=toplevel) - self._collect(op.table, toplevel=toplevel) + assert self.limit is None + self.limit = _LimitSpec(op.n, op.offset) def _collect_Union(self, op, toplevel=False): if toplevel: diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index ecd4a7de2efc..2a3632428961 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -157,6 +157,7 @@ class BigQueryCompiler(sql_compiler.Compiler): difference_class = BigQueryDifference support_values_syntax_in_select = False + null_limit = None @staticmethod def _generate_setup_queries(expr, context): diff --git a/ibis/backends/clickhouse/compiler/relations.py b/ibis/backends/clickhouse/compiler/relations.py index 6b04cf7805f0..625d5fd5e7d0 100644 --- a/ibis/backends/clickhouse/compiler/relations.py +++ b/ibis/backends/clickhouse/compiler/relations.py @@ -179,12 +179,20 @@ def _set_op(op: ops.SetOp, *, left, right, **_): @translate_rel.register def _limit(op: ops.Limit, *, table, **kw): - n = op.n - limited = sg.select("*").from_(table).limit(n) + result = sg.select("*").from_(table) - if offset := op.offset: - limited = limited.offset(offset) - return limited + if (limit := op.n) is not None: + if not isinstance(limit, int): + limit = f"(SELECT {translate_val(limit, **kw)} FROM {table})" + result = result.limit(limit) + + if not isinstance(offset := op.offset, int): + offset = f"(SELECT {translate_val(offset, **kw)} FROM {table})" + + if offset != 0: + return result.offset(offset) + else: + return result @translate_rel.register diff --git a/ibis/backends/dask/execution/generic.py b/ibis/backends/dask/execution/generic.py index e98698daf632..f2a71f91137d 100644 --- a/ibis/backends/dask/execution/generic.py +++ b/ibis/backends/dask/execution/generic.py @@ -348,12 +348,24 @@ def execute_cast_series_date(op, data, type, **kwargs): def execute_limit_frame(op, data, nrows, offset, **kwargs): # NOTE: Dask Dataframes do not support iloc row based indexing # Need to add a globally consecutive index in order to select nrows number of rows + if nrows == 0: + return dd.from_pandas( + pd.DataFrame(columns=data.columns).astype(data.dtypes), npartitions=1 + ) unique_col_name = ibis.util.guid() df = add_globally_consecutive_column(data, col_name=unique_col_name) ret = df.loc[offset : (offset + nrows) - 1] return rename_index(ret, None) +@execute_node.register(ops.Limit, dd.DataFrame, type(None), integer_types) +def execute_limit_frame_no_limit(op, data, nrows, offset, **kwargs): + unique_col_name = ibis.util.guid() + df = add_globally_consecutive_column(data, col_name=unique_col_name) + ret = df.loc[offset : (offset + len(df)) - 1] + return rename_index(ret, None) + + @execute_node.register(ops.Not, (dd.core.Scalar, dd.Series)) def execute_not_scalar_or_series(op, data, **kwargs): return ~data diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index cdd77f38d27b..6246f3b75a87 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -143,9 +143,17 @@ def selection(op, **kw): @translate.register(ops.Limit) def limit(op, **kw): - if op.offset: - raise NotImplementedError("DataFusion does not support offset") - return translate(op.table, **kw).limit(op.n) + if (n := op.n) is not None and not isinstance(n, int): + raise NotImplementedError("Dynamic limit not supported") + + if not isinstance(offset := op.offset, int) or (offset != 0 and n != 0): + raise NotImplementedError("Dynamic offset not supported") + + t = translate(op.table, **kw) + + if n is not None: + return t.limit(n) + return t @translate.register(ops.Aggregation) diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index 0f4e1772c0a7..60a48bbf5b45 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -28,3 +28,4 @@ def translate(self, op): class DruidCompiler(AlchemyCompiler): translator_class = DruidExprTranslator + null_limit = sa.literal_column("ALL") diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index 3ead29ef497e..7722702eb95a 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -34,3 +34,4 @@ class MsSqlCompiler(AlchemyCompiler): translator_class = MsSqlExprTranslator supports_indexed_grouping_keys = False + null_limit = None diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 0c33a1324238..c33f49318bad 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -23,3 +23,4 @@ class MySQLExprTranslator(AlchemyExprTranslator): class MySQLCompiler(AlchemyCompiler): translator_class = MySQLExprTranslator support_values_syntax_in_select = False + null_limit = None diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index ae9cfd91629b..30f86e42ab36 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -71,6 +71,7 @@ class OracleCompiler(AlchemyCompiler): translator_class = OracleExprTranslator support_values_syntax_in_select = False supports_indexed_grouping_keys = False + null_limit = None class Backend(BaseAlchemyBackend): diff --git a/ibis/backends/pandas/execution/generic.py b/ibis/backends/pandas/execution/generic.py index c404c7fdb71e..58a47198d6d7 100644 --- a/ibis/backends/pandas/execution/generic.py +++ b/ibis/backends/pandas/execution/generic.py @@ -97,10 +97,15 @@ def execute_interval_literal(op, value, dtype, **kwargs): @execute_node.register(ops.Limit, pd.DataFrame, integer_types, integer_types) -def execute_limit_frame(op, data, nrows, offset, **kwargs): +def execute_limit_frame(op, data, nrows: int, offset: int, **kwargs): return data.iloc[offset : offset + nrows] +@execute_node.register(ops.Limit, pd.DataFrame, type(None), integer_types) +def execute_limit_frame_no_limit(op, data, nrows: None, offset: int, **kwargs): + return data.iloc[offset:] + + @execute_node.register(ops.Cast, SeriesGroupBy, dt.DataType) def execute_cast_series_group_by(op, data, type, **kwargs): result = execute_cast_series_generic(op, data.obj, type, **kwargs) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 3a6a0cac7e9f..fea30b8b09c0 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -233,10 +233,14 @@ def selection(op, **kw): @translate.register(ops.Limit) def limit(op, **kw): + if (n := op.n) is not None and not isinstance(n, int): + raise NotImplementedError("Dynamic limit not supported") + + if not isinstance(offset := op.offset, int): + raise NotImplementedError("Dynamic offset not supported") + lf = translate(op.table, **kw) - if op.offset: - return lf.slice(op.offset, op.n) - return lf.limit(op.n) + return lf.slice(offset, n) @translate.register(ops.Aggregation) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index f97781d6bbde..57917442c6b2 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -272,13 +272,25 @@ def compile_cast(t, op, **kwargs): @compiles(ops.Limit) def compile_limit(t, op, **kwargs): - if op.offset != 0: + if (n := op.n) is not None and not isinstance(n, int): raise com.UnsupportedArgumentError( - "PySpark backend does not support non-zero offset is for " - f"limit operation. Got offset {op.offset}." + "Dynamic LIMIT is not implemented upstream in PySpark" + ) + if not isinstance(offset := op.offset, int): + raise com.UnsupportedArgumentError( + "Dynamic OFFSET is not implemented upstream in PySpark" + ) + if n != 0 and offset != 0: + raise com.UnsupportedArgumentError( + "PySpark backend does not support non-zero offset values for " + f"the limit operation. Got offset {offset:d}." ) df = t.translate(op.table, **kwargs) - return df.limit(op.n) + + if n is not None: + return df.limit(n) + else: + return df @compiles(ops.And) diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index f42abfb4da6b..97a4604524af 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -31,3 +31,4 @@ class SQLiteExprTranslator(AlchemyExprTranslator): class SQLiteCompiler(AlchemyCompiler): translator_class = SQLiteExprTranslator support_values_syntax_in_select = False + null_limit = None diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 46a96735383c..47212f3f1771 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -37,9 +37,21 @@ try: import clickhouse_connect as cc - ClickhouseDriverOperationalError = cc.driver.ProgrammingError + ClickhouseDriverDatabaseError = cc.driver.exceptions.DatabaseError except ImportError: - ClickhouseDriverOperationalError = None + ClickhouseDriverDatabaseError = None + + +try: + from google.api_core.exceptions import BadRequest +except ImportError: + BadRequest = None + + +try: + from impala.error import HiveServer2Error +except ImportError: + HiveServer2Error = None NULL_BACKEND_TYPES = { @@ -1285,3 +1297,219 @@ def test_try_cast_table(con): ) def test_try_cast_func(con, from_val, to_type, func): assert func(con.execute(ibis.literal(from_val).try_cast(to_type))) + + +@pytest.mark.parametrize( + ("slc", "expected_count_fn"), + [ + ################### + ### NONE/ZERO start + # no stop + param(slice(None, 0), lambda _: 0, id="[:0]"), + param(slice(None, None), lambda t: t.count().to_pandas(), id="[:]"), + param(slice(0, 0), lambda _: 0, id="[0:0]"), + param(slice(0, None), lambda t: t.count().to_pandas(), id="[0:]"), + # positive stop + param(slice(None, 2), lambda _: 2, id="[:2]"), + param(slice(0, 2), lambda _: 2, id="[0:2]"), + ################## + ### NEGATIVE start + # zero stop + param(slice(-3, 0), lambda _: 0, id="[-3:0]"), + # negative stop + param(slice(-3, -3), lambda _: 0, id="[-3:-3]"), + param(slice(-3, -4), lambda _: 0, id="[-3:-4]"), + param(slice(-3, -5), lambda _: 0, id="[-3:-5]"), + ################## + ### POSITIVE start + # no stop + param(slice(3, 0), lambda _: 0, id="[3:0]"), + param( + slice(3, None), + lambda t: t.count().to_pandas() - 3, + id="[3:]", + marks=[ + pytest.mark.notyet( + ["bigquery"], + raises=BadRequest, + reason="bigquery doesn't support OFFSET without LIMIT", + ), + pytest.mark.notyet( + ["datafusion"], + raises=NotImplementedError, + reason="no support for offset yet", + ), + pytest.mark.notyet( + ["mssql"], + raises=sa.exc.CompileError, + reason="mssql doesn't support OFFSET without LIMIT", + ), + pytest.mark.never( + ["impala"], + raises=HiveServer2Error, + reason="impala doesn't support OFFSET without ORDER BY", + ), + pytest.mark.notyet( + ["pyspark"], + raises=com.UnsupportedArgumentError, + reason="pyspark doesn't support non-zero offset until version 3.4", + ), + ], + ), + # positive stop + param(slice(3, 2), lambda _: 0, id="[3:2]"), + param( + slice(3, 4), + lambda _: 1, + id="[3:4]", + marks=[ + pytest.mark.notyet( + ["datafusion"], + raises=NotImplementedError, + reason="no support for offset yet", + ), + pytest.mark.notyet( + ["mssql"], + raises=sa.exc.CompileError, + reason="mssql doesn't support OFFSET without LIMIT", + ), + pytest.mark.notyet( + ["impala"], + raises=HiveServer2Error, + reason="impala doesn't support OFFSET without ORDER BY", + ), + pytest.mark.notyet( + ["pyspark"], + raises=com.UnsupportedArgumentError, + reason="pyspark doesn't support non-zero offset until version 3.4", + ), + ], + ), + ], +) +def test_static_table_slice(backend, slc, expected_count_fn): + t = backend.functional_alltypes + + rows = t[slc] + count = rows.count().to_pandas() + + expected_count = expected_count_fn(t) + assert count == expected_count + + +@pytest.mark.parametrize( + ("slc", "expected_count_fn"), + [ + ### NONE/ZERO start + # negative stop + param(slice(None, -2), lambda t: t.count().to_pandas() - 2, id="[:-2]"), + param(slice(0, -2), lambda t: t.count().to_pandas() - 2, id="[0:-2]"), + # no stop + param(slice(-3, None), lambda _: 3, id="[-3:]"), + ################## + ### NEGATIVE start + # negative stop + param(slice(-3, -2), lambda _: 1, id="[-3:-2]"), + # positive stop + param(slice(-4000, 7000), lambda _: 3700, id="[-4000:7000]"), + param(slice(-3, 2), lambda _: 0, id="[-3:2]"), + ################## + ### POSITIVE start + # negative stop + param(slice(3, -2), lambda t: t.count().to_pandas() - 5, id="[3:-2]"), + param(slice(3, -4), lambda t: t.count().to_pandas() - 7, id="[3:-4]"), + ], + ids=str, +) +@pytest.mark.notyet( + ["mysql", "snowflake", "trino"], + raises=sa.exc.ProgrammingError, + reason="backend doesn't support dynamic limit/offset", +) +@pytest.mark.notimpl( + ["mssql"], + raises=sa.exc.CompileError, + reason="mssql doesn't support dynamic limit/offset without an ORDER BY", +) +@pytest.mark.notyet( + ["clickhouse"], + raises=ClickhouseDriverDatabaseError, + reason="clickhouse doesn't support dynamic limit/offset", +) +@pytest.mark.notyet(["druid"], reason="druid doesn't support dynamic limit/offset") +@pytest.mark.notyet(["polars"], reason="polars doesn't support dynamic limit/offset") +@pytest.mark.notyet( + ["bigquery"], + reason="bigquery doesn't support dynamic limit/offset", + raises=BadRequest, +) +@pytest.mark.notyet( + ["datafusion"], + reason="datafusion doesn't support dynamic limit/offset", + raises=NotImplementedError, +) +@pytest.mark.never( + ["impala"], + reason="impala doesn't support dynamic limit/offset", + raises=HiveServer2Error, +) +@pytest.mark.notyet(["pyspark"], reason="pyspark doesn't support dynamic limit/offset") +def test_dynamic_table_slice(backend, slc, expected_count_fn): + t = backend.functional_alltypes + + rows = t[slc] + count = rows.count().to_pandas() + + expected_count = expected_count_fn(t) + assert count == expected_count + + +@pytest.mark.notyet( + ["mysql", "snowflake", "trino"], + raises=sa.exc.ProgrammingError, + reason="backend doesn't support dynamic limit/offset", +) +@pytest.mark.notyet( + ["clickhouse"], + raises=ClickhouseDriverDatabaseError, + reason="clickhouse doesn't support dynamic limit/offset", +) +@pytest.mark.notyet(["druid"], reason="druid doesn't support dynamic limit/offset") +@pytest.mark.notyet(["polars"], reason="polars doesn't support dynamic limit/offset") +@pytest.mark.notyet( + ["bigquery"], + reason="bigquery doesn't support dynamic limit/offset", + raises=BadRequest, +) +@pytest.mark.notyet( + ["datafusion"], + reason="datafusion doesn't support dynamic limit/offset", + raises=NotImplementedError, +) +@pytest.mark.never( + ["impala"], + reason="impala doesn't support dynamic limit/offset", + raises=HiveServer2Error, +) +@pytest.mark.notyet(["pyspark"], reason="pyspark doesn't support dynamic limit/offset") +@pytest.mark.xfail_version( + duckdb=["duckdb<=0.8.1"], + raises=AssertionError, + reason="https://github.com/duckdb/duckdb/issues/8412", +) +def test_dynamic_table_slice_with_computed_offset(backend): + t = backend.functional_alltypes + + col = "id" + df = t[[col]].to_pandas() + + assert len(df) == df[col].nunique() + + n = 10 + + expr = t[[col]].order_by(col)[-n:] + + expected = df.sort_values([col]).iloc[-n:].reset_index(drop=True) + result = expr.to_pandas() + + backend.assert_frame_equal(result, expected) diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 796222cd3eeb..8f52d58e04e5 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -50,3 +50,4 @@ def _rewrite_string_contains(op): class TrinoSQLCompiler(AlchemyCompiler): cheap_in_memory_tables = False translator_class = TrinoSQLExprTranslator + null_limit = sa.literal_column("ALL") diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 52be5f74cb7e..9dcc2ac1ac19 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -377,8 +377,8 @@ class Difference(SetOp): @public class Limit(Relation): table: Relation - n: int - offset: int + n: UnionType[int, Scalar[dt.Integer], None] = None + offset: UnionType[int, Scalar[dt.Integer]] = 0 @property def schema(self): diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 7b4f57e65050..ba250653961e 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -434,6 +434,48 @@ def __getitem__(self, what): │ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │ └─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘ + Some backends support negative slice indexing + + >>> t[-5:] # last 5 rows + ┏━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├───────────┼────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Chinstrap │ Dream │ 55.8 │ 19.8 │ 207 │ … │ + │ Chinstrap │ Dream │ 43.5 │ 18.1 │ 202 │ … │ + │ Chinstrap │ Dream │ 49.6 │ 18.2 │ 193 │ … │ + │ Chinstrap │ Dream │ 50.8 │ 19.0 │ 210 │ … │ + │ Chinstrap │ Dream │ 50.2 │ 18.7 │ 198 │ … │ + └───────────┴────────┴────────────────┴───────────────┴───────────────────┴───┘ + >>> t[-5:-3] # last 5th to 3rd rows + ┏━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├───────────┼────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Chinstrap │ Dream │ 55.8 │ 19.8 │ 207 │ … │ + │ Chinstrap │ Dream │ 43.5 │ 18.1 │ 202 │ … │ + └───────────┴────────┴────────────────┴───────────────┴───────────────────┴───┘ + >>> t[2:-2] # chop off the first two and last two rows + ┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├─────────┼───────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Adelie │ Torgersen │ 40.3 │ 18.0 │ 195 │ … │ + │ Adelie │ Torgersen │ nan │ nan │ NULL │ … │ + │ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │ + │ Adelie │ Torgersen │ 39.3 │ 20.6 │ 190 │ … │ + │ Adelie │ Torgersen │ 38.9 │ 17.8 │ 181 │ … │ + │ Adelie │ Torgersen │ 39.2 │ 19.6 │ 195 │ … │ + │ Adelie │ Torgersen │ 34.1 │ 18.1 │ 193 │ … │ + │ Adelie │ Torgersen │ 42.0 │ 20.2 │ 190 │ … │ + │ Adelie │ Torgersen │ 37.8 │ 17.1 │ 186 │ … │ + │ Adelie │ Torgersen │ 37.8 │ 17.3 │ 180 │ … │ + │ … │ … │ … │ … │ … │ … │ + └─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘ + Select columns >>> t[["island", "bill_length_mm"]].head() @@ -522,19 +564,8 @@ def __getitem__(self, what): return ops.TableColumn(self, what).to_expr() if isinstance(what, slice): - step = what.step - if step is not None and step != 1: - raise ValueError("Slice step can only be 1") - start = what.start or 0 - stop = what.stop - - if stop is None or stop < 0: - raise ValueError("End index must be a positive number") - - if start < 0: - raise ValueError("Start index must be a positive number") - - return self.limit(stop - start, offset=start) + limit, offset = util.slice_to_limit_offset(what, self.count()) + return self.limit(limit, offset=offset) what = bind_expr(self, what) @@ -1089,7 +1120,7 @@ def distinct( return res.select(self.columns) return res - def limit(self, n: int, offset: int = 0) -> Table: + def limit(self, n: int | None, offset: int = 0) -> Table: """Select `n` rows from `self` starting at `offset`. !!! note "The result set is not deterministic without a call to [`order_by`][ibis.expr.types.relations.Table.order_by]." @@ -1097,7 +1128,8 @@ def limit(self, n: int, offset: int = 0) -> Table: Parameters ---------- n - Number of rows to include + Number of rows to include. If `None`, the entire table is selected + starting from `offset`. offset Number of rows to skip first @@ -1131,11 +1163,23 @@ def limit(self, n: int, offset: int = 0) -> Table: │ 1 │ a │ └───────┴────────┘ + You can use `None` with `offset` to slice starting from a particular row + + >>> t.limit(None, offset=1) + ┏━━━━━━━┳━━━━━━━━┓ + ┃ a ┃ b ┃ + ┡━━━━━━━╇━━━━━━━━┩ + │ int64 │ string │ + ├───────┼────────┤ + │ 1 │ a │ + │ 2 │ a │ + └───────┴────────┘ + See Also -------- [`Table.order_by`][ibis.expr.types.relations.Table.order_by] """ - return ops.Limit(self, n, offset=offset).to_expr() + return ops.Limit(self, n, offset).to_expr() def head(self, n: int = 5) -> Table: """Select the first `n` rows of a table. diff --git a/ibis/tests/expr/snapshots/test_interactive/test_respect_set_limit/out.sql b/ibis/tests/expr/snapshots/test_interactive/test_respect_set_limit/out.sql index bac12682efc3..de1d76b1264b 100644 --- a/ibis/tests/expr/snapshots/test_interactive/test_respect_set_limit/out.sql +++ b/ibis/tests/expr/snapshots/test_interactive/test_respect_set_limit/out.sql @@ -1,3 +1,7 @@ -SELECT t0.`id`, t0.`bool_col` -FROM functional_alltypes t0 -LIMIT 10 \ No newline at end of file +SELECT t0.* +FROM ( + SELECT t1.`id`, t1.`bool_col` + FROM functional_alltypes t1 + LIMIT 10 +) t0 +LIMIT 11 \ No newline at end of file diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 057afd09c90a..863f6a366053 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -525,27 +525,25 @@ def test_order_by_nonexistent_column_errors(table, expr_func, key, exc_type): def test_slice(table): - expr = table[:5] + expr1 = table[:5] expr2 = table[:5:1] - assert_equal(expr, table.limit(5)) - assert_equal(expr, expr2) + expr3 = table[5:] + assert_equal(expr1, table.limit(5)) + assert_equal(expr1, expr2) + assert_equal(expr3, table.limit(None, offset=5)) - expr = table[2:7] + expr1 = table[2:7] expr2 = table[2:7:1] - assert_equal(expr, table.limit(5, offset=2)) - assert_equal(expr, expr2) - - with pytest.raises(ValueError): - table[2:15:2] + expr3 = table[2::1] + assert_equal(expr1, table.limit(5, offset=2)) + assert_equal(expr1, expr2) + assert_equal(expr3, table.limit(None, offset=2)) - with pytest.raises(ValueError): - table[5:] - - with pytest.raises(ValueError): - table[:-5] +@pytest.mark.parametrize("step", [-1, 0, 2]) +def test_invalid_slice(table, step): with pytest.raises(ValueError): - table[-10:-5] + table[:5:step] def test_table_count(table): diff --git a/ibis/tests/sql/snapshots/test_select_sql/test_chain_limit_doesnt_collapse/result.sql b/ibis/tests/sql/snapshots/test_select_sql/test_chain_limit_doesnt_collapse/result.sql new file mode 100644 index 000000000000..e25947808580 --- /dev/null +++ b/ibis/tests/sql/snapshots/test_select_sql/test_chain_limit_doesnt_collapse/result.sql @@ -0,0 +1,18 @@ +WITH t0 AS ( + SELECT t2.`city`, count(t2.`city`) AS `Count(city)` + FROM tbl t2 + GROUP BY 1 +) +SELECT t1.* +FROM ( + SELECT t0.* + FROM t0 + ORDER BY t0.`Count(city)` DESC + LIMIT 10 +) t1 +LIMIT 5 OFFSET (SELECT count(1) + -5 FROM ( + SELECT t0.* + FROM t0 + ORDER BY t0.`Count(city)` DESC + LIMIT 10 +) t1) \ No newline at end of file diff --git a/ibis/tests/sql/snapshots/test_select_sql/test_multiple_limits/out.sql b/ibis/tests/sql/snapshots/test_select_sql/test_multiple_limits/out.sql index 88c710a94478..b4f67ae8d56d 100644 --- a/ibis/tests/sql/snapshots/test_select_sql/test_multiple_limits/out.sql +++ b/ibis/tests/sql/snapshots/test_select_sql/test_multiple_limits/out.sql @@ -1,3 +1,7 @@ SELECT t0.* -FROM functional_alltypes t0 +FROM ( + SELECT t1.* + FROM functional_alltypes t1 + LIMIT 20 +) t0 LIMIT 10 \ No newline at end of file diff --git a/ibis/tests/sql/test_select_sql.py b/ibis/tests/sql/test_select_sql.py index ee660bd1eda1..517714c5226a 100644 --- a/ibis/tests/sql/test_select_sql.py +++ b/ibis/tests/sql/test_select_sql.py @@ -842,3 +842,18 @@ def compute(t): u = ibis.union(t1, t2) snapshot.assert_match(to_sql(u), "result.sql") + + +def test_chain_limit_doesnt_collapse(snapshot): + t = ibis.table( + [ + ("foo", "string"), + ("bar", "string"), + ("city", "string"), + ("v1", "double"), + ("v2", "double"), + ], + "tbl", + ) + expr = t.city.topk(10)[-5:] + snapshot.assert_match(to_sql(expr), "result.sql") diff --git a/ibis/util.py b/ibis/util.py index 05f1fe6cc925..5cfd13677bdb 100644 --- a/ibis/util.py +++ b/ibis/util.py @@ -29,6 +29,8 @@ from numbers import Real from pathlib import Path + import ibis.expr.types as ir + T = TypeVar("T", covariant=True) U = TypeVar("U", covariant=True) K = TypeVar("K") @@ -529,3 +531,87 @@ def gen_name(namespace: str) -> str: """Create a case-insensitive uuid4 unique table name.""" uid = base64.b32encode(uuid.uuid4().bytes).decode().rstrip("=").lower() return f"_ibis_{namespace}_{uid}" + + +def slice_to_limit_offset( + what: slice, count: ir.IntegerScalar +) -> tuple[int | ir.IntegerScalar, int | ir.IntegerScalar]: + """Convert a Python [`slice`][slice] to a `limit`, `offset` pair. + + Parameters + ---------- + what + The slice to convert + count + The total number of rows in the table as an expression + + Returns + ------- + tuple[int | ir.IntegerScalar, int | ir.IntegerScalar] + The offset and limit to use in a `Table.limit` call + + Examples + -------- + >>> import ibis + >>> t = ibis.table(dict(a="int", b="string"), name="t") + + First 10 rows + >>> count = t.count() + >>> what = slice(0, 10) + >>> limit, offset = slice_to_limit_offset(what, count) + >>> limit + 10 + >>> offset + 0 + + Last 10 rows + >>> what = slice(-10, None) + >>> limit, offset = slice_to_limit_offset(what, count) + >>> limit + 10 + >>> offset + r0 := UnboundTable: t + a int64 + b string + + Add(CountStar(t), -10): CountStar(r0) + -10 + + From 5th row to 10th row + >>> what = slice(5, 10) + >>> limit, offset = slice_to_limit_offset(what, count) + >>> limit, offset + (5, 5) + """ + if (step := what.step) is not None and step != 1: + raise ValueError("Slice step can only be 1") + + import ibis + + start = what.start + stop = what.stop + + if start is None or start >= 0: + offset = start or 0 + + if stop is None: + limit = None + elif stop == 0: + limit = 0 + elif stop < 0: + limit = count + (stop - offset) + else: # stop > 0 + limit = max(stop - offset, 0) + else: # start < 0 + offset = count + start + + if stop is None: + limit = -start + elif stop == 0: + limit = offset = 0 + elif stop < 0: + limit = max(stop - start, 0) + if limit == 0: + offset = 0 + else: # stop > 0 + limit = ibis.greatest((stop - start) - count, 0) + return limit, offset