From bc6caa8792b4d2e39ef4a48063147faa5013dccf Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 18 Jan 2024 20:34:33 +0800 Subject: [PATCH] fix: `gather_every` should work on agg context --- py-polars/src/expr/general.rs | 2 +- py-polars/tests/unit/dataframe/test_df.py | 17 +++++++++++++++++ py-polars/tests/unit/namespaces/test_list.py | 7 +++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 88f7915a1fbc..1e166d3541cd 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -426,7 +426,7 @@ impl PyExpr { fn gather_every(&self, n: usize, offset: usize) -> Self { self.inner .clone() - .map( + .apply( move |s: Series| { polars_ensure!(n > 0, InvalidOperation: "gather_every(n): n can't be zero"); Ok(Some(s.gather_every(n, offset))) diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 2b6d965734da..b559026a2ea2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -392,6 +392,23 @@ def test_gather_every() -> None: assert_frame_equal(expected_df, df.gather_every(2, offset=1)) +def test_gather_every_agg() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 2, 2, 2], + "a": ["a", "b", "c", "d", "e", "f"], + } + ) + out = df.group_by(pl.col("g")).agg(pl.col("a").gather_every(2)).sort("g") + expected = pl.DataFrame( + { + "g": [1, 2], + "a": [["a", "c"], ["d", "f"]], + } + ) + assert_frame_equal(out, expected) + + def test_take_misc(fruits_cars: pl.DataFrame) -> None: df = fruits_cars diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 4abff393be14..cefea46ef69d 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -798,3 +798,10 @@ def test_list_get_logical_type() -> None: dtype=pl.Date, ) assert_series_equal(out, expected) + + +def test_list_eval_gater_every_13410() -> None: + df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]}) + out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2))) + expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]}) + assert_frame_equal(out, expected)