From 6f1accd04ff6957b24648ee947dc103979dca350 Mon Sep 17 00:00:00 2001 From: ganevgv Date: Sat, 23 Nov 2019 23:04:40 +0000 Subject: [PATCH] TST: add test for ffill/bfill for non unique multilevel (#29763) --- pandas/tests/groupby/test_transform.py | 35 ++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index 3d9a349d94e10..c46180c1d11cd 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -911,6 +911,41 @@ def test_pct_change(test_series, freq, periods, fill_method, limit): tm.assert_frame_equal(result, expected.to_frame("vals")) +@pytest.mark.parametrize( + "func, expected_status", + [ + ("ffill", ["shrt", "shrt", "lng", np.nan, "shrt", "ntrl", "ntrl"]), + ("bfill", ["shrt", "lng", "lng", "shrt", "shrt", "ntrl", np.nan]), + ], +) +def test_ffill_bfill_non_unique_multilevel(func, expected_status): + # GH 19437 + date = pd.to_datetime( + [ + "2018-01-01", + "2018-01-01", + "2018-01-01", + "2018-01-01", + "2018-01-02", + "2018-01-01", + "2018-01-02", + ] + ) + symbol = ["MSFT", "MSFT", "MSFT", "AAPL", "AAPL", "TSLA", "TSLA"] + status = ["shrt", np.nan, "lng", np.nan, "shrt", "ntrl", np.nan] + + df = DataFrame({"date": date, "symbol": symbol, "status": status}) + df = df.set_index(["date", "symbol"]) + result = getattr(df.groupby("symbol")["status"], func)() + + index = MultiIndex.from_tuples( + tuples=list(zip(*[date, symbol])), names=["date", "symbol"] + ) + expected = Series(expected_status, index=index, name="status") + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("func", [np.any, np.all]) def test_any_all_np_func(func): # GH 20653