Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Feb 16, 2023
1 parent d8e80ab commit 8ea4e6c
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions py-polars/tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,51 @@ def test_groupby() -> None:
assert result.rows() == [("a", 2), ("b", 3)]
assert result.columns == ["b", "a"]

# make sure all the methods below run
assert sorted(df.groupby("b").first().rows()) == [("a", 1, None), ("b", 3, None)]
assert sorted(df.groupby("b").last().rows()) == [("a", 2, 1), ("b", 5, None)]
assert sorted(df.groupby("b").max().rows()) == [("a", 2, 1), ("b", 5, 1)]
assert sorted(df.groupby("b").min().rows()) == [("a", 1, 1), ("b", 3, 1)]
assert sorted(df.groupby("b").count().rows()) == [("a", 2), ("b", 3)]
assert sorted(df.groupby("b").mean().rows()) == [("a", 1.5, 1.0), ("b", 4.0, 1.0)]
assert sorted(df.groupby("b").n_unique().rows()) == [("a", 2, 2), ("b", 3, 2)]
assert sorted(df.groupby("b").median().rows()) == [("a", 1.5, 1.0), ("b", 4.0, 1.0)]
assert sorted(df.groupby("b").all().rows()) == [
("a", [1, 2], [None, 1]),
("b", [3, 4, 5], [None, 1, None]),
]
# assert sorted(df.groupby("b").quantile(0.5).rows()) == ...

@pytest.fixture()
def df() -> pl.DataFrame:
return pl.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": ["a", "a", "b", "b", "b"],
"c": [None, 1, None, 1, None],
}
)


@pytest.mark.parametrize(
("method", "expected"),
[
("all", [("a", [1, 2], [None, 1]), ("b", [3, 4, 5], [None, 1, None])]),
("count", [("a", 2), ("b", 3)]),
("first", [("a", 1, None), ("b", 3, None)]),
("last", [("a", 2, 1), ("b", 5, None)]),
("max", [("a", 2, 1), ("b", 5, 1)]),
("mean", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),
("median", [("a", 1.5, 1.0), ("b", 4.0, 1.0)]),
("min", [("a", 1, 1), ("b", 3, 1)]),
("n_unique", [("a", 2, 2), ("b", 3, 2)]),
],
)
def test_groupby_shorthands(
df: pl.DataFrame, method: str, expected: list[tuple[Any]]
) -> None:
gb = df.groupby("b", maintain_order=True)
result = getattr(gb, method)()
assert result.rows() == expected

gb_lazy = df.lazy().groupby("b", maintain_order=True)
result = getattr(gb_lazy, method)().collect()
assert result.rows() == expected


def test_groupby_shorthand_quantile(df: pl.DataFrame) -> None:
result = df.groupby("b", maintain_order=True).quantile(0.5)
expected = [("a", 2.0, 1.0), ("b", 4.0, 1.0)]
assert result.rows() == expected

result = df.lazy().groupby("b", maintain_order=True).quantile(0.5).collect()
assert result.rows() == expected


def test_groupby_args() -> None:
Expand Down

0 comments on commit 8ea4e6c

Please sign in to comment.