Skip to content

Commit

Permalink
FIX-#6968: Align API with pandas (#6969)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev authored Feb 26, 2024
1 parent e367490 commit a1d5dd4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
8 changes: 4 additions & 4 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,17 +682,17 @@ def dtypes(self):
)
)

def first(self, numeric_only=False, min_count=-1):
def first(self, numeric_only=False, min_count=-1, skipna=True):
return self._wrap_aggregation(
type(self._query_compiler).groupby_first,
agg_kwargs=dict(min_count=min_count),
agg_kwargs=dict(min_count=min_count, skipna=skipna),
numeric_only=numeric_only,
)

def last(self, numeric_only=False, min_count=-1):
def last(self, numeric_only=False, min_count=-1, skipna=True):
return self._wrap_aggregation(
type(self._query_compiler).groupby_last,
agg_kwargs=dict(min_count=min_count),
agg_kwargs=dict(min_count=min_count, skipna=skipna),
numeric_only=numeric_only,
)

Expand Down
10 changes: 8 additions & 2 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,19 @@ def argmin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01,
result = -1
return result

def argsort(self, axis=0, kind="quicksort", order=None): # noqa: PR01, RT01, D200
def argsort(
self, axis=0, kind="quicksort", order=None, stable=None
): # noqa: PR01, RT01, D200
"""
Return the integer indices that would sort the Series values.
"""
return self.__constructor__(
query_compiler=self._query_compiler.argsort(
axis=axis, kind=kind, order=order
# 'stable' parameter has no effect in Pandas and is only accepted
# for compatibility with NumPy, so we're not passing it forward on purpose
axis=axis,
kind=kind,
order=order,
)
)

Expand Down
16 changes: 16 additions & 0 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3440,3 +3440,19 @@ def func(df):
func, include_groups=include_groups
),
)


@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(how, skipna):
md_df, pd_df = create_test_dfs(
{
"a": [2, 1, 1, 2, 3, 3] * 20,
"b": [np.nan, 3.0, np.nan, 4.0, np.nan, np.nan] * 20,
"c": [np.nan, 3.0, np.nan, 4.0, np.nan, np.nan] * 20,
}
)

pd_res = getattr(pd_df.groupby("a"), how)(skipna=skipna)
md_res = getattr(md_df.groupby("a"), how)(skipna=skipna)
df_equals(md_res, pd_res)

0 comments on commit a1d5dd4

Please sign in to comment.