Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): More ergonomic over args #6986

Merged
merged 1 commit into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3596,8 +3596,8 @@ def groupby(
Parameters
----------
by
Column or columns to group by. Accepts expression input. Strings are parsed
as column names.
Column(s) to group by. Accepts expression input. Strings are parsed as
column names.
*more_by
Additional columns to group by, specified as positional arguments.
maintain_order
Expand Down
132 changes: 83 additions & 49 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,68 +2766,102 @@ def last(self) -> Self:
"""
return self._from_pyexpr(self._pyexpr.last())

def over(self, expr: str | Expr | list[Expr | str]) -> Self:
def over(self, expr: IntoExpr | Iterable[IntoExpr], *more_exprs: IntoExpr) -> Self:
"""
Apply window function over a subgroup.
Compute expressions over the given groups.

This is similar to a groupby + aggregation + self join.
Or similar to `window functions in Postgres
<https://www.postgresql.org/docs/current/tutorial-window.html>`_.
This expression is similar to performing a groupby aggregation and joining the
result back into the original dataframe.

The outcome is similar to how `window functions
<https://www.postgresql.org/docs/current/tutorial-window.html>`_
work in PostgreSQL.

Parameters
----------
expr
Column(s) to group by.
Column(s) to group by. Accepts expression input. Strings are parsed as
column names.
*more_exprs
Additional columns to group by, specified as positional arguments.

Examples
--------
Pass the name of a column to compute the expression over that column.

>>> df = pl.DataFrame(
... {
... "groups": ["g1", "g1", "g2"],
... "values": [1, 2, 3],
... }
... )
>>> df.with_columns(pl.col("values").max().over("groups").alias("max_by_group"))
shape: (3, 3)
┌────────┬────────┬──────────────┐
│ groups ┆ values ┆ max_by_group │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞════════╪════════╪══════════════╡
│ g1 ┆ 1 ┆ 2 │
│ g1 ┆ 2 ┆ 2 │
│ g2 ┆ 3 ┆ 3 │
└────────┴────────┴──────────────┘
>>> df = pl.DataFrame(
... {
... "groups": [1, 1, 2, 2, 1, 2, 3, 3, 1],
... "values": [1, 2, 3, 4, 5, 6, 7, 8, 8],
... "a": ["a", "a", "b", "b", "b"],
... "b": [1, 2, 3, 5, 3],
... "c": [5, 4, 3, 2, 1],
... }
... )
>>> df.lazy().select(
... pl.col("groups").sum().over("groups"),
... ).collect()
shape: (9, 1)
┌────────┐
│ groups │
│ --- │
│ i64 │
╞════════╡
│ 4 │
│ 4 │
│ 6 │
│ 6 │
│ ... │
│ 6 │
│ 6 │
│ 6 │
│ 4 │
└────────┘

"""
pyexprs = selection_to_pyexpr_list(expr)

return self._from_pyexpr(self._pyexpr.over(pyexprs))
>>> df.with_columns(pl.col("c").max().over("a").suffix("_max"))
shape: (5, 4)
┌─────┬─────┬─────┬───────┐
│ a ┆ b ┆ c ┆ c_max │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╪═══════╡
│ a ┆ 1 ┆ 5 ┆ 5 │
│ a ┆ 2 ┆ 4 ┆ 5 │
│ b ┆ 3 ┆ 3 ┆ 3 │
│ b ┆ 5 ┆ 2 ┆ 3 │
│ b ┆ 3 ┆ 1 ┆ 3 │
└─────┴─────┴─────┴───────┘

Expression input is supported.

>>> df.with_columns(pl.col("c").max().over(pl.col("b") // 2).suffix("_max"))
shape: (5, 4)
┌─────┬─────┬─────┬───────┐
│ a ┆ b ┆ c ┆ c_max │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╪═══════╡
│ a ┆ 1 ┆ 5 ┆ 5 │
│ a ┆ 2 ┆ 4 ┆ 4 │
│ b ┆ 3 ┆ 3 ┆ 4 │
│ b ┆ 5 ┆ 2 ┆ 2 │
│ b ┆ 3 ┆ 1 ┆ 4 │
└─────┴─────┴─────┴───────┘

Group by multiple columns by passing a list of column names or expressions.

>>> df.with_columns(pl.col("c").min().over(["a", "b"]).suffix("_min"))
shape: (5, 4)
┌─────┬─────┬─────┬───────┐
│ a ┆ b ┆ c ┆ c_min │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╪═══════╡
│ a ┆ 1 ┆ 5 ┆ 5 │
│ a ┆ 2 ┆ 4 ┆ 4 │
│ b ┆ 3 ┆ 3 ┆ 1 │
│ b ┆ 5 ┆ 2 ┆ 2 │
│ b ┆ 3 ┆ 1 ┆ 1 │
└─────┴─────┴─────┴───────┘

Or use positional arguments to group by multiple columns in the same way.

>>> df.with_columns(pl.col("c").min().over("a", pl.col("b") % 2).suffix("_min"))
shape: (5, 4)
┌─────┬─────┬─────┬───────┐
│ a ┆ b ┆ c ┆ c_min │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╪═══════╡
│ a ┆ 1 ┆ 5 ┆ 5 │
│ a ┆ 2 ┆ 4 ┆ 4 │
│ b ┆ 3 ┆ 3 ┆ 1 │
│ b ┆ 5 ┆ 2 ┆ 1 │
│ b ┆ 3 ┆ 1 ┆ 1 │
└─────┴─────┴─────┴───────┘

"""
exprs = selection_to_pyexpr_list(expr)
exprs.extend(selection_to_pyexpr_list(more_exprs))
return self._from_pyexpr(self._pyexpr.over(exprs))

def is_unique(self) -> Self:
"""
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,8 +1694,8 @@ def groupby(
Parameters
----------
by
Column or columns to group by. Accepts expression input. Strings are parsed
as column names.
Column(s) to group by. Accepts expression input. Strings are parsed as
column names.
*more_by
Additional columns to group by, specified as positional arguments.
maintain_order
Expand Down
24 changes: 24 additions & 0 deletions py-polars/tests/unit/operations/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@
from polars.testing import assert_frame_equal, assert_series_equal


def test_over_args() -> None:
df = pl.DataFrame(
{
"a": ["a", "a", "b"],
"b": [1, 2, 3],
"c": [3, 2, 1],
}
)

# Single input
expected = pl.Series("c", [3, 3, 1]).to_frame()
result = df.select(pl.col("c").max().over("a"))
assert_frame_equal(result, expected)

# Multiple input as list
expected = pl.Series("c", [3, 2, 1]).to_frame()
result = df.select(pl.col("c").max().over(["a", "b"]))
assert_frame_equal(result, expected)

# Multiple input as positional args
result = df.select(pl.col("c").max().over("a", "b"))
assert_frame_equal(result, expected)


@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64, pl.Int32])
def test_std(dtype: type[pl.DataType]) -> None:
if dtype == pl.Int32:
Expand Down