diff --git a/py-polars/tests/unit/test_groupby.py b/py-polars/tests/unit/test_groupby.py index 30aae59e5f10..5bae6300b46d 100644 --- a/py-polars/tests/unit/test_groupby.py +++ b/py-polars/tests/unit/test_groupby.py @@ -49,20 +49,51 @@ def test_groupby() -> None: assert result.rows() == [("a", 2), ("b", 3)] assert result.columns == ["b", "a"] - # make sure all the methods below run - assert sorted(df.groupby("b").first().rows()) == [("a", 1, None), ("b", 3, None)] - assert sorted(df.groupby("b").last().rows()) == [("a", 2, 1), ("b", 5, None)] - assert sorted(df.groupby("b").max().rows()) == [("a", 2, 1), ("b", 5, 1)] - assert sorted(df.groupby("b").min().rows()) == [("a", 1, 1), ("b", 3, 1)] - assert sorted(df.groupby("b").count().rows()) == [("a", 2), ("b", 3)] - assert sorted(df.groupby("b").mean().rows()) == [("a", 1.5, 1.0), ("b", 4.0, 1.0)] - assert sorted(df.groupby("b").n_unique().rows()) == [("a", 2, 2), ("b", 3, 2)] - assert sorted(df.groupby("b").median().rows()) == [("a", 1.5, 1.0), ("b", 4.0, 1.0)] - assert sorted(df.groupby("b").all().rows()) == [ - ("a", [1, 2], [None, 1]), - ("b", [3, 4, 5], [None, 1, None]), - ] - # assert sorted(df.groupby("b").quantile(0.5).rows()) == ... + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": ["a", "a", "b", "b", "b"], + "c": [None, 1, None, 1, None], + } + ) + + +@pytest.mark.parametrize( + ("method", "expected"), + [ + ("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]), + ("count", [("a", 2), ("b", 3)]), + ("first", [("a", 1, None), ("b", 3, None)]), + ("last", [("a", 2, 1), ("b", 5, None)]), + ("max", [("a", 2, 1), ("b", 5, 1)]), + ("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]), + ("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]), + ("min", [("a", 1, 1), ("b", 3, 1)]), + ("n_unique", [("a", 2, 2), ("b", 3, 2)]), + ], +) +def test_groupby_shorthands( + df: pl.DataFrame, method: str, expected: list[tuple[Any]] +) -> None: + gb = df.groupby("b", maintain_order=True) + result = getattr(gb, method)() + assert result.rows() == expected + + gb_lazy = df.lazy().groupby("b", maintain_order=True) + result = getattr(gb_lazy, method)().collect() + assert result.rows() == expected + + +def test_groupby_shorthand_quantile(df: pl.DataFrame) -> None: + result = df.groupby("b", maintain_order=True).quantile(0.5) + expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)] + assert result.rows() == expected + + result = df.lazy().groupby("b", maintain_order=True).quantile(0.5).collect() + assert result.rows() == expected def test_groupby_args() -> None: