From 0ebac438fd067b541a26dc21569fe1dcb2fc980b Mon Sep 17 00:00:00 2001 From: Dawid Makar Date: Wed, 27 Sep 2023 02:42:04 +0200 Subject: [PATCH] Automatic refactoring. Refactoring step id: UUID('c6990dfa-d782-4765-a762-804d74ae288b') --- pandas/tests/groupby/test_allowlist.py | 73 +++++++++++++++++++------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/pandas/tests/groupby/test_allowlist.py b/pandas/tests/groupby/test_allowlist.py index d495441593aed..d16e8ef56a42c 100644 --- a/pandas/tests/groupby/test_allowlist.py +++ b/pandas/tests/groupby/test_allowlist.py @@ -3,38 +3,72 @@ Do not add tests here! """ +from string import ascii_lowercase + +import numpy as np import pytest from pandas import ( DataFrame, + Series, date_range, ) import pandas._testing as tm +AGG_FUNCTIONS = [ + "sum", + "prod", + "min", + "max", + "median", + "mean", + "skew", + "std", + "var", + "sem", +] +AGG_FUNCTIONS_WITH_SKIPNA = ["skew"] + + +@pytest.fixture +def df(): + return DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.randn(8), + "D": np.random.randn(8), + } + ) -@pytest.mark.parametrize( - "op", - [ - "sum", - "prod", - "min", - "max", - "median", - "mean", - "skew", - "std", - "var", - "sem", - ], -) + +@pytest.fixture +def df_letters(): + letters = np.array(list(ascii_lowercase)) + N = 10 + random_letters = letters.take(np.random.randint(0, 26, N)) + df = DataFrame( + { + "floats": N / 10 * Series(np.random.random(N)), + "letters": Series(random_letters), + } + ) + return df + + +@pytest.fixture +def raw_frame(): + return DataFrame([0]) + + +@pytest.mark.parametrize("op", AGG_FUNCTIONS) @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("skipna", [True, False]) @pytest.mark.parametrize("sort", [True, False]) -def test_regression_allowlist_methods(op, axis, skipna, sort): +def test_regression_allowlist_methods(raw_frame, op, axis, skipna, sort): # GH6944 # GH 17537 # explicitly test the allowlist methods - raw_frame = DataFrame([0]) if axis == 0: frame = raw_frame msg = "The 'axis' keyword in DataFrame.groupby is deprecated and will be" @@ -45,8 +79,7 @@ def test_regression_allowlist_methods(op, axis, skipna, sort): with tm.assert_produces_warning(FutureWarning, match=msg): grouped = frame.groupby(level=0, axis=axis, sort=sort) - if op == "skew": - # skew has skipna + if op in AGG_FUNCTIONS_WITH_SKIPNA: result = getattr(grouped, op)(skipna=skipna) expected = frame.groupby(level=0).apply( lambda h: getattr(h, op)(axis=axis, skipna=skipna) @@ -121,4 +154,4 @@ def test_groupby_selection_other_methods(df): tm.assert_frame_equal( g.filter(lambda x: len(x) == 3), g_exp.filter(lambda x: len(x) == 3) - ) + ) \ No newline at end of file