diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 4b0143b3e1ced..3f84fa0f0670e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1705,6 +1705,9 @@ def rank(self, method='average', ascending=True, na_option='keep', ----- DataFrame with ranking of values within each group """ + if na_option not in {'keep', 'top', 'bottom'}: + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + raise ValueError(msg) return self._cython_transform('rank', numeric_only=False, ties_method=method, ascending=ascending, na_option=na_option, pct=pct, axis=axis) diff --git a/pandas/tests/groupby/test_rank.py b/pandas/tests/groupby/test_rank.py index 0628f9c79a154..2740b6475f18d 100644 --- a/pandas/tests/groupby/test_rank.py +++ b/pandas/tests/groupby/test_rank.py @@ -252,14 +252,24 @@ def test_rank_object_raises(ties_method, ascending, na_option, with tm.assert_raises_regex(TypeError, "not callable"): df.groupby('key').rank(method=ties_method, ascending=ascending, - na_option='bad', pct=pct) + na_option=na_option, pct=pct) - with tm.assert_raises_regex(TypeError, "not callable"): - df.groupby('key').rank(method=ties_method, - ascending=ascending, - na_option=True, pct=pct) - with tm.assert_raises_regex(TypeError, "not callable"): +@pytest.mark.parametrize("na_option", [True, "bad", 1]) +@pytest.mark.parametrize("ties_method", [ + 'average', 'min', 'max', 'first', 'dense']) +@pytest.mark.parametrize("ascending", [True, False]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize("vals", [ + ['bar', 'bar', 'foo', 'bar', 'baz'], + ['bar', np.nan, 'foo', np.nan, 'baz'], + [1, np.nan, 2, np.nan, 3] +]) +def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals): + df = DataFrame({'key': ['foo'] * 5, 'val': vals}) + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + + with tm.assert_raises_regex(TypeError, msg): df.groupby('key').rank(method=ties_method, ascending=ascending, - na_option=na_option, pct=pct) + na_option='bad', pct=pct)