From ca106c375a1875a05e1a1c76c71bebd6b48ba356 Mon Sep 17 00:00:00 2001 From: Peter Li Date: Sun, 29 Jul 2018 22:11:31 +0800 Subject: [PATCH] BUG: Better handling of invalid na_option argument for groupby.rank(#22124) --- pandas/core/groupby/groupby.py | 3 ++ pandas/tests/groupby/test_rank.py | 62 ++++++++++++++++++------------- 2 files changed, 39 insertions(+), 26 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 4b0143b3e1ced..3f84fa0f0670e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1705,6 +1705,9 @@ def rank(self, method='average', ascending=True, na_option='keep', ----- DataFrame with ranking of values within each group """ + if na_option not in {'keep', 'top', 'bottom'}: + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + raise ValueError(msg) return self._cython_transform('rank', numeric_only=False, ties_method=method, ascending=ascending, na_option=na_option, pct=pct, axis=axis) diff --git a/pandas/tests/groupby/test_rank.py b/pandas/tests/groupby/test_rank.py index 0628f9c79a154..f0dcf768e3607 100644 --- a/pandas/tests/groupby/test_rank.py +++ b/pandas/tests/groupby/test_rank.py @@ -172,35 +172,35 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp): [3., 3., np.nan, 1., 3., 2., np.nan, np.nan]), ('dense', False, 'keep', True, [3. / 3., 3. / 3., np.nan, 1. / 3., 3. / 3., 2. / 3., np.nan, np.nan]), - ('average', True, 'no_na', False, [2., 2., 7., 5., 2., 4., 7., 7.]), - ('average', True, 'no_na', True, + ('average', True, 'bottom', False, [2., 2., 7., 5., 2., 4., 7., 7.]), + ('average', True, 'bottom', True, [0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875]), - ('average', False, 'no_na', False, [4., 4., 7., 1., 4., 2., 7., 7.]), - ('average', False, 'no_na', True, + ('average', False, 'bottom', False, [4., 4., 7., 1., 4., 2., 7., 7.]), + ('average', False, 'bottom', True, [0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875]), - ('min', True, 'no_na', False, [1., 1., 6., 5., 1., 4., 6., 6.]), - ('min', True, 'no_na', True, + ('min', True, 'bottom', False, [1., 1., 6., 5., 1., 4., 6., 6.]), + ('min', True, 'bottom', True, [0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75]), - ('min', False, 'no_na', False, [3., 3., 6., 1., 3., 2., 6., 6.]), - ('min', False, 'no_na', True, + ('min', False, 'bottom', False, [3., 3., 6., 1., 3., 2., 6., 6.]), + ('min', False, 'bottom', True, [0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75]), - ('max', True, 'no_na', False, [3., 3., 8., 5., 3., 4., 8., 8.]), - ('max', True, 'no_na', True, + ('max', True, 'bottom', False, [3., 3., 8., 5., 3., 4., 8., 8.]), + ('max', True, 'bottom', True, [0.375, 0.375, 1., 0.625, 0.375, 0.5, 1., 1.]), - ('max', False, 'no_na', False, [5., 5., 8., 1., 5., 2., 8., 8.]), - ('max', False, 'no_na', True, + ('max', False, 'bottom', False, [5., 5., 8., 1., 5., 2., 8., 8.]), + ('max', False, 'bottom', True, [0.625, 0.625, 1., 0.125, 0.625, 0.25, 1., 1.]), - ('first', True, 'no_na', False, [1., 2., 6., 5., 3., 4., 7., 8.]), - ('first', True, 'no_na', True, + ('first', True, 'bottom', False, [1., 2., 6., 5., 3., 4., 7., 8.]), + ('first', True, 'bottom', True, [0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.]), - ('first', False, 'no_na', False, [3., 4., 6., 1., 5., 2., 7., 8.]), - ('first', False, 'no_na', True, + ('first', False, 'bottom', False, [3., 4., 6., 1., 5., 2., 7., 8.]), + ('first', False, 'bottom', True, [0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.]), - ('dense', True, 'no_na', False, [1., 1., 4., 3., 1., 2., 4., 4.]), - ('dense', True, 'no_na', True, + ('dense', True, 'bottom', False, [1., 1., 4., 3., 1., 2., 4., 4.]), + ('dense', True, 'bottom', True, [0.25, 0.25, 1., 0.75, 0.25, 0.5, 1., 1.]), - ('dense', False, 'no_na', False, [3., 3., 4., 1., 3., 2., 4., 4.]), - ('dense', False, 'no_na', True, + ('dense', False, 'bottom', False, [3., 3., 4., 1., 3., 2., 4., 4.]), + ('dense', False, 'bottom', True, [0.75, 0.75, 1., 0.25, 0.75, 0.5, 1., 1.]) ]) def test_rank_args_missing(grps, vals, ties_method, ascending, @@ -252,14 +252,24 @@ def test_rank_object_raises(ties_method, ascending, na_option, with tm.assert_raises_regex(TypeError, "not callable"): df.groupby('key').rank(method=ties_method, ascending=ascending, - na_option='bad', pct=pct) + na_option=na_option, pct=pct) - with tm.assert_raises_regex(TypeError, "not callable"): - df.groupby('key').rank(method=ties_method, - ascending=ascending, - na_option=True, pct=pct) - with tm.assert_raises_regex(TypeError, "not callable"): +@pytest.mark.parametrize("na_option", [True, "bad", 1]) +@pytest.mark.parametrize("ties_method", [ + 'average', 'min', 'max', 'first', 'dense']) +@pytest.mark.parametrize("ascending", [True, False]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize("vals", [ + ['bar', 'bar', 'foo', 'bar', 'baz'], + ['bar', np.nan, 'foo', np.nan, 'baz'], + [1, np.nan, 2, np.nan, 3] +]) +def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals): + df = DataFrame({'key': ['foo'] * 5, 'val': vals}) + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + + with tm.assert_raises_regex(ValueError, msg): df.groupby('key').rank(method=ties_method, ascending=ascending, na_option=na_option, pct=pct)