Skip to content

Commit

Permalink
feat(api): support selectors in window function order_by and `group…
Browse files Browse the repository at this point in the history
…_by` (#9649)
  • Loading branch information
cpcloud authored Jul 22, 2024
1 parent 4d135b3 commit 0ad47de
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 25 deletions.
31 changes: 31 additions & 0 deletions ibis/common/selectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING

from ibis.common.grounds import Concrete

if TYPE_CHECKING:
from collections.abc import Sequence

import ibis.expr.types as ir


class Selector(Concrete):
"""A column selector."""

@abc.abstractmethod
def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""Expand `table` into value expressions that match the selector.
Parameters
----------
table
An ibis table expression
Returns
-------
Sequence[Value]
A sequence of value expressions that match the selector
"""
5 changes: 3 additions & 2 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ibis.common.deferred import Deferred, Resolver, deferrable
from ibis.common.exceptions import IbisInputError
from ibis.common.grounds import Concrete
from ibis.common.selectors import Selector # noqa: TCH001
from ibis.common.typing import VarTuple # noqa: TCH001

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,8 +146,8 @@ class WindowBuilder(Builder):
how: Literal["rows", "range"] = "rows"
start: Optional[RangeWindowBoundary] = None
end: Optional[RangeWindowBoundary] = None
groupings: VarTuple[Union[str, Resolver, ops.Value]] = ()
orderings: VarTuple[Union[str, Resolver, ops.SortKey]] = ()
groupings: VarTuple[Union[str, Resolver, Selector, ops.Value]] = ()
orderings: VarTuple[Union[str, Resolver, Selector, ops.SortKey]] = ()

@attribute
def _table(self):
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
import ibis.expr.schema as sch
from ibis import util
from ibis.common.deferred import Deferred, Resolver
from ibis.common.selectors import Selector
from ibis.expr.rewrites import DerefMap
from ibis.expr.types.core import Expr, _FixedTextJupyterMixin
from ibis.expr.types.generic import Value, literal
from ibis.expr.types.pretty import to_rich
from ibis.expr.types.temporal import TimestampColumn
from ibis.selectors import Selector
from ibis.util import deprecated

if TYPE_CHECKING:
Expand Down
24 changes: 2 additions & 22 deletions ibis/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

from __future__ import annotations

import abc
import functools
import inspect
import operator
Expand All @@ -67,27 +66,8 @@
from ibis.common.collections import frozendict # noqa: TCH001
from ibis.common.deferred import Deferred, Resolver
from ibis.common.exceptions import IbisError
from ibis.common.grounds import Concrete, Singleton


class Selector(Concrete):
"""A column selector."""

@abc.abstractmethod
def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""Expand `table` into value expressions that match the selector.
Parameters
----------
table
An ibis table expression
Returns
-------
Sequence[Value]
A sequence of value expressions that match the selector
"""
from ibis.common.grounds import Singleton
from ibis.common.selectors import Selector


class Predicate(Selector):
Expand Down
23 changes: 23 additions & 0 deletions ibis/tests/expr/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,26 @@ def test_order_by_with_selectors(penguins):

with pytest.raises(exc.IbisError):
penguins.order_by(~s.all())


def test_window_function_group_by(penguins):
expr = penguins.species.count().over(group_by=s.c("island"))
assert expr.equals(penguins.species.count().over(group_by=penguins.island))


def test_window_function_order_by(penguins):
expr = penguins.island.count().over(order_by=s.c("species"))
assert expr.equals(penguins.island.count().over(order_by=penguins.species))


def test_window_function_group_by_order_by(penguins):
expr = penguins.species.count().over(
group_by=s.c("island"),
order_by=s.c("year") | (~s.c("island", "species") & s.of_type("str")),
)
assert expr.equals(
penguins.species.count().over(
group_by=penguins.island,
order_by=[penguins.sex, penguins.year],
)
)

0 comments on commit 0ad47de

Please sign in to comment.