Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Feb 20, 2024
1 parent b35367c commit 39470e5
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions py-polars/tests/unit/datatypes/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_hash() -> None:


def test_group_by() -> None:
# Test num_groups_proxy
df = (
pl.Series(
"x",
Expand All @@ -113,25 +114,27 @@ def test_group_by() -> None:
)
.to_frame()
.with_row_index()
.with_columns(a=pl.lit("a"))
)

expect = pl.Series("index", [[0, 1], [2, 3], [4], [5]], dtype=pl.List(pl.UInt32))

for maintain_order in (True, False):
for drop_nulls in (True, False):
out = df
if drop_nulls:
out = out.drop_nulls()

out = (
out.group_by("x", maintain_order=maintain_order)
.agg("index")
.sort(pl.col("index").list.get(0))
.select("index")
.to_series()
)

if drop_nulls:
assert_series_equal(expect.head(3), out)
else:
assert_series_equal(expect, out)
for group_keys in (("x",), ("x", "a")):
for maintain_order in (True, False):
for drop_nulls in (True, False):
out = df
if drop_nulls:
out = out.drop_nulls()

out = (
out.group_by(group_keys, maintain_order=maintain_order)
.agg("index")
.sort(pl.col("index").list.get(0))
.select("index")
.to_series()
)

if drop_nulls:
assert_series_equal(expect.head(3), out)
else:
assert_series_equal(expect, out)

0 comments on commit 39470e5

Please sign in to comment.