Skip to content

Commit

Permalink
BUG: EWM(online) not raising NotImplementedError for unsupported (pan…
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored and noatamir committed Nov 9, 2022
1 parent f88b3ff commit da99346
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ Plotting

Groupby/resample/rolling
^^^^^^^^^^^^^^^^^^^^^^^^
- Bug in :class:`.ExponentialMovingWindow` with ``online`` not raising a ``NotImplementedError`` for unsupported operations (:issue:`48834`)
- Bug in :meth:`DataFrameGroupBy.sample` raises ``ValueError`` when the object is empty (:issue:`48459`)
-

Expand Down
10 changes: 5 additions & 5 deletions pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,10 @@ def reset(self) -> None:
self._mean.reset()

def aggregate(self, func, *args, **kwargs):
return NotImplementedError
raise NotImplementedError("aggregate is not implemented.")

def std(self, bias: bool = False, *args, **kwargs):
return NotImplementedError
raise NotImplementedError("std is not implemented.")

def corr(
self,
Expand All @@ -979,7 +979,7 @@ def corr(
numeric_only: bool = False,
**kwargs,
):
return NotImplementedError
raise NotImplementedError("corr is not implemented.")

def cov(
self,
Expand All @@ -989,10 +989,10 @@ def cov(
numeric_only: bool = False,
**kwargs,
):
return NotImplementedError
raise NotImplementedError("cov is not implemented.")

def var(self, bias: bool = False, *args, **kwargs):
return NotImplementedError
raise NotImplementedError("var is not implemented.")

def mean(self, *args, update=None, update_times=None, **kwargs):
"""
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/window/test_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,13 @@ def test_update_times_mean(
tm.assert_equal(result, expected.tail(3))

online_ewm.reset()

@pytest.mark.parametrize("method", ["aggregate", "std", "corr", "cov", "var"])
def test_ewm_notimplementederror_raises(self, method):
ser = Series(range(10))
kwargs = {}
if method == "aggregate":
kwargs["func"] = lambda x: x

with pytest.raises(NotImplementedError, match=".* is not implemented."):
getattr(ser.ewm(1).online(), method)(**kwargs)

0 comments on commit da99346

Please sign in to comment.