Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: dask group by with kwargs #1676

Merged
merged 6 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,26 +438,20 @@ def max(self) -> Self:
)

def std(self, ddof: int) -> Self:
expr = self._from_call(
return self._from_call(
lambda _input, ddof: _input.std(ddof=ddof),
"std",
ddof=ddof,
returns_scalar=True,
)
if ddof != 1:
expr._depth += 1
return expr

def var(self, ddof: int) -> Self:
expr = self._from_call(
return self._from_call(
lambda _input, ddof: _input.var(ddof=ddof),
"var",
ddof=ddof,
returns_scalar=True,
)
if ddof != 1:
expr._depth += 1
return expr

def skew(self: Self) -> Self:
return self._from_call(
Expand Down
59 changes: 50 additions & 9 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,38 @@ def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
)


def var(
ddof: int = 1,
) -> Callable[
[pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to double check the return type

]:
from functools import partial

import dask.dataframe as dd

return partial(dd.groupby.DataFrameGroupBy.var, ddof=ddof)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved


def std(
ddof: int = 1,
) -> Callable[
[pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy
]:
from functools import partial

import dask.dataframe as dd

return partial(dd.groupby.DataFrameGroupBy.std, ddof=ddof)


POLARS_TO_DASK_AGGREGATIONS = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": "std",
"var": "var",
"std": std,
"var": var,
"len": "size",
"n_unique": n_unique,
"count": "count",
Expand Down Expand Up @@ -137,8 +161,12 @@ def agg_dask(
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
simple_aggregations[output_name] = (keys[0], function_name)
simple_aggregations.update(
{
output_name: (keys[0], function_name)
for output_name in expr._output_names
}
)
continue

# e.g. agg(nw.mean('a')) # noqa: ERA001
Expand All @@ -149,13 +177,26 @@ def agg_dask(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
kwargs = (
{"ddof": expr._kwargs.get("ddof", 1)}
if function_name in {"std", "var"}
else {}
)

agg_function = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
# deal with n_unique case in a "lazy" mode to not depend on dask globally
function_name = function_name() if callable(function_name) else function_name

for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
agg_function = (
agg_function(**kwargs) if callable(agg_function) else agg_function
)

simple_aggregations.update(
{
output_name: (root_name, agg_function)
for root_name, output_name in zip(
expr._root_names, expr._output_names
)
}
)
result_simple = grouped.agg(**simple_aggregations)
return from_dataframe(result_simple.reset_index())

Expand Down
5 changes: 0 additions & 5 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,7 @@ def test_group_by_depth_1_std_var(
constructor: Constructor,
attr: str,
ddof: int,
request: pytest.FixtureRequest,
) -> None:
if "dask" in str(constructor):
# Complex aggregation for dask
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]}
_pow = 0.5 if attr == "std" else 1
expected = {
Expand Down
Loading