diff --git a/ibis/backends/dask/tests/execution/test_window.py b/ibis/backends/dask/tests/execution/test_window.py index 127af1a87bf3..7c6c294615fb 100644 --- a/ibis/backends/dask/tests/execution/test_window.py +++ b/ibis/backends/dask/tests/execution/test_window.py @@ -200,8 +200,7 @@ def test_batting_avg_change_in_games_per_year(players, players_df): @pytest.mark.xfail( - raises=AssertionError, - reason="Dask doesn't support the `rank` method on SeriesGroupBy", + raises=AttributeError, reason="'Series' object has no attribute 'rank'" ) def test_batting_most_hits(players, players_df): expr = players.mutate( diff --git a/ibis/common/graph.py b/ibis/common/graph.py index 156bdde8b109..31b62a032d79 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -145,9 +145,8 @@ def map(self, fn: Callable, filter: Optional[Any] = None) -> dict[Node, Any]: the results as the second and the results of the children as keyword arguments. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. Returns ------- @@ -171,9 +170,8 @@ def find(self, type: type | tuple[type], filter: Optional[Any] = None) -> set[No type Type or tuple of types to find. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. Returns ------- @@ -197,9 +195,8 @@ def match( Pattern to match. `ibis.common.pattern()` function is used to coerce the input value into a pattern. See the pattern module for more details. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. context Optional context to use for the pattern matching. @@ -288,9 +285,8 @@ def from_bfs(cls, root: Node, filter: Optional[Any] = None) -> Self: root Root node of the graph. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. Returns ------- @@ -338,9 +334,8 @@ def from_dfs(cls, root: Node, filter: Optional[Any] = None) -> Self: root Root node of the graph. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. Returns ------- @@ -441,9 +436,8 @@ def bfs(node: Node, filter: Optional[Any] = None) -> Graph: node Root node of the graph. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. Returns ------- @@ -460,9 +454,8 @@ def dfs(node: Node, filter: Optional[Any] = None) -> Graph: node Root node of the graph. filter - Pattern-like object to filter out nodes from the traversal. Essentially - the traversal will only visit nodes that match the given pattern and - stop otherwise. + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. Returns ------- diff --git a/ibis/common/tests/test_graph_benchmarks.py b/ibis/common/tests/test_graph_benchmarks.py index f0d11ff5c9ed..cf9f9df52b7b 100644 --- a/ibis/common/tests/test_graph_benchmarks.py +++ b/ibis/common/tests/test_graph_benchmarks.py @@ -2,6 +2,7 @@ from typing import Optional +import pytest from typing_extensions import Self # noqa: TCH002 from ibis.common.collections import frozendict @@ -32,11 +33,11 @@ def generate_node(depth): ) -def test_generate_node(): - for depth in [0, 1, 2, 10, 100]: - n = generate_node(depth) - assert isinstance(n, MyNode) - assert len(Graph.from_bfs(n).nodes()) == depth + 1 +@pytest.mark.parametrize("depth", [0, 1, 10]) +def test_generate_node(depth): + n = generate_node(depth) + assert isinstance(n, MyNode) + assert len(Graph.from_bfs(n).nodes()) == depth + 1 def test_bfs(benchmark): diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 788392e9e8b2..b8c1c28f6d65 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -13,7 +13,7 @@ import ibis.expr.types as ir from ibis import util from ibis.common.annotations import ValidationError -from ibis.common.deferred import deferred, var +from ibis.common.deferred import _, deferred, var from ibis.common.exceptions import IbisTypeError, IntegrityError from ibis.common.patterns import pattern from ibis.util import Namespace @@ -338,48 +338,24 @@ def propagate_down_window(func: ops.Value, frame: ops.WindowFrame): return type(func)(*clean_args) -# TODO(kszucs): rewrite to receive and return an ops.Node -def windowize_function(expr, frame): - assert isinstance(expr, ir.Expr), type(expr) - assert isinstance(frame, ops.WindowFrame) - - def _windowize(op, frame): - if isinstance(op, ops.WindowFunction): - walked_child = _walk(op.func, frame) - walked = walked_child.to_expr().over(op.frame).op() - elif isinstance(op, ops.Value): - walked = _walk(op, frame) - else: - walked = op - - if isinstance(walked, (ops.Analytic, ops.Reduction)): - return op.to_expr().over(frame).op() - elif isinstance(walked, ops.WindowFunction): - if frame is not None: - frame = walked.frame.copy( - group_by=frame.group_by + walked.frame.group_by, - order_by=frame.order_by + walked.frame.order_by, - ) - return walked.to_expr().over(frame).op() - else: - return walked - else: - return walked - - def _walk(op, frame): - # TODO(kszucs): rewrite to use the substitute utility - windowed_args = [] - for arg in op.args: - if isinstance(arg, ops.Value): - arg = _windowize(arg, frame) - elif isinstance(arg, tuple): - arg = tuple(_windowize(x, frame) for x in arg) +def windowize_function(expr, default_frame): + func = var("func") + frame = var("frame") - windowed_args.append(arg) + wrap_analytic = (p.Analytic | p.Reduction) >> c.WindowFunction(_, default_frame) + merge_frames = p.WindowFunction(func, frame) >> c.WindowFunction( + func, + frame.copy( + order_by=frame.order_by + default_frame.order_by, + group_by=frame.group_by + default_frame.group_by, + ), + ) - return type(op)(*windowed_args) + node = expr.op() + node = node.replace(merge_frames, filter=p.Value) + node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction) - return _windowize(expr.op(), frame).to_expr() + return node.to_expr() def contains_first_or_last_agg(exprs): @@ -458,8 +434,7 @@ def __init__(self, parent, proj_exprs): default_frame = ops.RowsWindowFrame(table=parent) self.clean_exprs = [ - windowize_function(expr, frame=default_frame) - for expr in self.resolved_exprs + windowize_function(expr, default_frame) for expr in self.resolved_exprs ] def get_result(self): diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 745886776f72..931452f42648 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -29,6 +29,7 @@ from ibis.selectors import Selector from ibis.expr.types.relations import bind_expr import ibis.common.exceptions as com +from public import public _function_types = tuple( filter( @@ -60,13 +61,11 @@ def _get_group_by_key(table, value): yield value -# TODO(kszucs): make a builder class for this +@public class GroupedTable: """An intermediate table expression to hold grouping information.""" - def __init__( - self, table, by, having=None, order_by=None, window=None, **expressions - ): + def __init__(self, table, by, having=None, order_by=None, **expressions): self.table = table self.by = list( itertools.chain( @@ -86,7 +85,6 @@ def __init__( self._order_by = order_by or [] self._having = having or [] - self._window = window def __getitem__(self, args): # Shortcut for projection with window functions @@ -133,7 +131,6 @@ def having(self, expr: ir.BooleanScalar) -> GroupedTable: self.by, having=self._having + util.promote_list(expr), order_by=self._order_by, - window=self._window, ) def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: @@ -158,7 +155,6 @@ def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: self.by, having=self._having, order_by=self._order_by + util.promote_list(expr), - window=self._window, ) def mutate( @@ -250,33 +246,24 @@ def _selectables(self, *exprs, **kwexprs): [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ table = self.table - default_frame = self._get_window() + default_frame = ops.RowsWindowFrame( + table=self.table, + group_by=bind_expr(self.table, self.by), + order_by=bind_expr(self.table, self._order_by), + ) return [ - an.windowize_function(e2, frame=default_frame) + an.windowize_function(e2, default_frame) for expr in exprs for e1 in util.promote_list(expr) for e2 in util.promote_list(table._ensure_expr(e1)) ] + [ - an.windowize_function(e, frame=default_frame).name(k) + an.windowize_function(e, default_frame).name(k) for k, expr in kwexprs.items() for e in util.promote_list(table._ensure_expr(expr)) ] projection = select - def _get_window(self): - if self._window is None: - return ops.RowsWindowFrame( - table=self.table, - group_by=self.by, - order_by=bind_expr(self.table, self._order_by), - ) - else: - return self._window.copy( - groupy_by=bind_expr(self.table, self._window.group_by + self.by), - order_by=bind_expr(self.table, self._window.order_by + self._order_by), - ) - def over( self, window=None, @@ -347,6 +334,7 @@ def wrapper(self, *args, **kwargs): return wrapper +@public class GroupedArray: def __init__(self, arr, parent): self.arr = arr @@ -361,6 +349,7 @@ def __init__(self, arr, parent): group_concat = _group_agg_dispatch("group_concat") +@public class GroupedNumbers(GroupedArray): mean = _group_agg_dispatch("mean") sum = _group_agg_dispatch("sum") diff --git a/ibis/tests/sql/test_select_sql.py b/ibis/tests/sql/test_select_sql.py index 1fd3c3a97087..517714c5226a 100644 --- a/ibis/tests/sql/test_select_sql.py +++ b/ibis/tests/sql/test_select_sql.py @@ -186,6 +186,7 @@ def test_bug_duplicated_where(airlines, snapshot): expr = t.group_by("dest").mutate( dest_avg=t.arrdelay.mean(), dev=t.arrdelay - t.arrdelay.mean() ) + tmp1 = expr[expr.dev.notnull()] tmp2 = tmp1.order_by(ibis.desc("dev")) expr = tmp2.limit(10)