From 0ad47de877021e4aa631a8f519b0b2a118f8f5a8 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 22 Jul 2024 09:18:52 -0700 Subject: [PATCH] feat(api): support selectors in window function `order_by` and `group_by` (#9649) --- ibis/common/selectors.py | 31 +++++++++++++++++++++++++++++++ ibis/expr/builders.py | 5 +++-- ibis/expr/types/relations.py | 2 +- ibis/selectors.py | 24 ++---------------------- ibis/tests/expr/test_selectors.py | 23 +++++++++++++++++++++++ 5 files changed, 60 insertions(+), 25 deletions(-) create mode 100644 ibis/common/selectors.py diff --git a/ibis/common/selectors.py b/ibis/common/selectors.py new file mode 100644 index 000000000000..10e1a96f1d8f --- /dev/null +++ b/ibis/common/selectors.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING + +from ibis.common.grounds import Concrete + +if TYPE_CHECKING: + from collections.abc import Sequence + + import ibis.expr.types as ir + + +class Selector(Concrete): + """A column selector.""" + + @abc.abstractmethod + def expand(self, table: ir.Table) -> Sequence[ir.Value]: + """Expand `table` into value expressions that match the selector. + + Parameters + ---------- + table + An ibis table expression + + Returns + ------- + Sequence[Value] + A sequence of value expressions that match the selector + + """ diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index 59d131123df0..1cb5b194ca66 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -13,6 +13,7 @@ from ibis.common.deferred import Deferred, Resolver, deferrable from ibis.common.exceptions import IbisInputError from ibis.common.grounds import Concrete +from ibis.common.selectors import Selector # noqa: TCH001 from ibis.common.typing import VarTuple # noqa: TCH001 if TYPE_CHECKING: @@ -145,8 +146,8 @@ class WindowBuilder(Builder): how: Literal["rows", "range"] = "rows" start: Optional[RangeWindowBoundary] = None end: Optional[RangeWindowBoundary] = None - groupings: VarTuple[Union[str, Resolver, ops.Value]] = () - orderings: VarTuple[Union[str, Resolver, ops.SortKey]] = () + groupings: VarTuple[Union[str, Resolver, Selector, ops.Value]] = () + orderings: VarTuple[Union[str, Resolver, Selector, ops.SortKey]] = () @attribute def _table(self): diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 7fd384e4bb95..0198f5a5318f 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -18,12 +18,12 @@ import ibis.expr.schema as sch from ibis import util from ibis.common.deferred import Deferred, Resolver +from ibis.common.selectors import Selector from ibis.expr.rewrites import DerefMap from ibis.expr.types.core import Expr, _FixedTextJupyterMixin from ibis.expr.types.generic import Value, literal from ibis.expr.types.pretty import to_rich from ibis.expr.types.temporal import TimestampColumn -from ibis.selectors import Selector from ibis.util import deprecated if TYPE_CHECKING: diff --git a/ibis/selectors.py b/ibis/selectors.py index 6af620c6c0f6..2b89e1ffb35b 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -50,7 +50,6 @@ from __future__ import annotations -import abc import functools import inspect import operator @@ -67,27 +66,8 @@ from ibis.common.collections import frozendict # noqa: TCH001 from ibis.common.deferred import Deferred, Resolver from ibis.common.exceptions import IbisError -from ibis.common.grounds import Concrete, Singleton - - -class Selector(Concrete): - """A column selector.""" - - @abc.abstractmethod - def expand(self, table: ir.Table) -> Sequence[ir.Value]: - """Expand `table` into value expressions that match the selector. - - Parameters - ---------- - table - An ibis table expression - - Returns - ------- - Sequence[Value] - A sequence of value expressions that match the selector - - """ +from ibis.common.grounds import Singleton +from ibis.common.selectors import Selector class Predicate(Selector): diff --git a/ibis/tests/expr/test_selectors.py b/ibis/tests/expr/test_selectors.py index 62c7821e91af..fac48716345a 100644 --- a/ibis/tests/expr/test_selectors.py +++ b/ibis/tests/expr/test_selectors.py @@ -494,3 +494,26 @@ def test_order_by_with_selectors(penguins): with pytest.raises(exc.IbisError): penguins.order_by(~s.all()) + + +def test_window_function_group_by(penguins): + expr = penguins.species.count().over(group_by=s.c("island")) + assert expr.equals(penguins.species.count().over(group_by=penguins.island)) + + +def test_window_function_order_by(penguins): + expr = penguins.island.count().over(order_by=s.c("species")) + assert expr.equals(penguins.island.count().over(order_by=penguins.species)) + + +def test_window_function_group_by_order_by(penguins): + expr = penguins.species.count().over( + group_by=s.c("island"), + order_by=s.c("year") | (~s.c("island", "species") & s.of_type("str")), + ) + assert expr.equals( + penguins.species.count().over( + group_by=penguins.island, + order_by=[penguins.sex, penguins.year], + ) + )