From fc5d2fd2cafa84e569fe4f14fbc7fb7b014b4d5d Mon Sep 17 00:00:00 2001 From: Chloe He Date: Fri, 31 May 2024 14:50:12 -0700 Subject: [PATCH] refactor WindowedTable class and add group_by method --- ibis/expr/types/relations.py | 1 - ibis/expr/types/temporal_windows.py | 55 ++++++++++++++++-------- ibis/tests/expr/test_temporal_windows.py | 30 ++++++++++++- 3 files changed, 65 insertions(+), 21 deletions(-) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index c25600778c534..b1d39a404192e 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -218,7 +218,6 @@ def bind(self, *args, **kwargs): args = () else: args = util.promote_list(args[0]) - # bind positional arguments values = [] for arg in args: diff --git a/ibis/expr/types/temporal_windows.py b/ibis/expr/types/temporal_windows.py index 4507394b7796f..170b3b29aa990 100644 --- a/ibis/expr/types/temporal_windows.py +++ b/ibis/expr/types/temporal_windows.py @@ -1,12 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from public import public import ibis.common.exceptions as com import ibis.expr.operations as ops import ibis.expr.types as ir +from ibis.common.collections import FrozenOrderedDict # noqa: TCH001 +from ibis.common.grounds import Concrete +from ibis.expr.operations.relations import Unaliased # noqa: TCH001 from ibis.expr.types.relations import unwrap_aliases if TYPE_CHECKING: @@ -14,27 +17,31 @@ @public -class WindowedTable: +class WindowedTable(Concrete): """An intermediate table expression to hold windowing information.""" - def __init__(self, parent: ir.Table, time_col: ops.Column): + parent: ir.Table + time_col: ops.Column + window_type: Literal["tumble", "hop"] | None = None + window_size: ir.IntervalScalar | None = None + window_slide: ir.IntervalScalar | None = None + window_offset: ir.IntervalScalar | None = None + groups: FrozenOrderedDict[str, Unaliased[ops.Column]] | None = None + metrics: FrozenOrderedDict[str, Unaliased[ops.Column]] | None = None + + def __init__(self, time_col: ops.Column, **kwargs): if time_col is None: raise com.IbisInputError( "Window aggregations require `time_col` as an argument" ) - self.parent = parent - self.time_col = time_col + super().__init__(time_col=time_col, **kwargs) def tumble( self, size: ir.IntervalScalar, offset: ir.IntervalScalar | None = None, ) -> WindowedTable: - self.window_type = "tumble" - self.window_slide = None - self.window_size = size - self.window_offset = offset - return self + return self.copy(window_type="tumble", window_size=size, window_offset=offset) def hop( self, @@ -42,24 +49,28 @@ def hop( slide: ir.IntervalScalar, offset: ir.IntervalScalar | None = None, ) -> WindowedTable: - self.window_type = "hop" - self.window_size = size - self.window_slide = slide - self.window_offset = offset - return self + return self.copy( + window_type="hop", + window_size=size, + window_slide=slide, + window_offset=offset, + ) def aggregate( self, metrics: Sequence[ir.Scalar] | None = (), - by: Sequence[ir.Value] | None = (), + by: str | ir.Value | Sequence[str] | Sequence[ir.Value] | None = (), **kwargs: ir.Value, ) -> ir.Table: - groups = self.parent.bind(by) + by = self.parent.bind(by) metrics = self.parent.bind(metrics, **kwargs) - groups = unwrap_aliases(groups) + by = unwrap_aliases(by) metrics = unwrap_aliases(metrics) + groups = dict(self.groups) if self.groups is not None else {} + groups.update(by) + return ops.WindowAggregate( self.parent, self.window_type, @@ -72,3 +83,11 @@ def aggregate( ).to_expr() agg = aggregate + + def group_by( + self, *by: str | ir.Value | Sequence[str] | Sequence[ir.Value] + ) -> WindowedTable: + by = tuple(v for v in by if v is not None) + groups = self.parent.bind(*by) + groups = unwrap_aliases(groups) + return self.copy(groups=groups) diff --git a/ibis/tests/expr/test_temporal_windows.py b/ibis/tests/expr/test_temporal_windows.py index 8f6d6030e768d..0b1615d4f5af8 100644 --- a/ibis/tests/expr/test_temporal_windows.py +++ b/ibis/tests/expr/test_temporal_windows.py @@ -19,9 +19,10 @@ ], ids=["tumble", "hop"], ) -def test_window_by_agg_schema(table, method): +@pytest.mark.parametrize("by", ["g", _.g, ["g"]]) +def test_window_by_agg_schema(table, method, by): expr = method(table.window_by(time_col=table.i)) - expr = expr.agg(by=["g"], a_sum=_.a.sum()) + expr = expr.agg(by=by, a_sum=_.a.sum()) expected_schema = ibis.schema( { "window_start": "timestamp", @@ -36,3 +37,28 @@ def test_window_by_agg_schema(table, method): def test_window_by_with_non_timestamp_column(table): with pytest.raises(com.IbisInputError): table.window_by(time_col=table.a) + + +@pytest.mark.parametrize( + "method", + [ + methodcaller("tumble", size=ibis.interval(minutes=15)), + methodcaller( + "hop", size=ibis.interval(minutes=15), slide=ibis.interval(minutes=1) + ), + ], + ids=["tumble", "hop"], +) +@pytest.mark.parametrize("by", ["g", _.g, ["g"]]) +def test_window_by_group_by_agg(table, method, by): + expr = method(table.window_by(time_col=table.i)) + expr = expr.group_by(by).agg(a_sum=_.a.sum()) + expected_schema = ibis.schema( + { + "window_start": "timestamp", + "window_end": "timestamp", + "g": "string", + "a_sum": "int64", + } + ) + assert expr.schema() == expected_schema