Skip to content

Commit

Permalink
ENH: Improve error message in corr/cov for Rolling/Expanding/EWM when…
Browse files Browse the repository at this point in the history
… other isn't a DataFrame/Series (#41741)
  • Loading branch information
mroeschke authored May 31, 2021
1 parent 1b6cb0f commit e8dbdb0
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 45 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ Other enhancements
- Add keyword ``sort`` to :func:`pivot_table` to allow non-sorting of the result (:issue:`39143`)
- Add keyword ``dropna`` to :meth:`DataFrame.value_counts` to allow counting rows that include ``NA`` values (:issue:`41325`)
- :meth:`Series.replace` will now cast results to ``PeriodDtype`` where possible instead of ``object`` dtype (:issue:`41526`)
- Improved error message in ``corr` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`)

.. ---------------------------------------------------------------------------
Expand Down
44 changes: 10 additions & 34 deletions pandas/core/window/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Common utility functions for rolling operations"""
from collections import defaultdict
from typing import cast
import warnings

import numpy as np

Expand All @@ -15,17 +14,7 @@

def flex_binary_moment(arg1, arg2, f, pairwise=False):

if not (
isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame))
and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame))
):
raise TypeError(
"arguments to moment function must be of type np.ndarray/Series/DataFrame"
)

if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
arg2, (np.ndarray, ABCSeries)
):
if isinstance(arg1, ABCSeries) and isinstance(arg2, ABCSeries):
X, Y = prep_binary(arg1, arg2)
return f(X, Y)

Expand All @@ -43,31 +32,25 @@ def dataframe_from_int_dict(data, frame_template):
if pairwise is False:
if arg1 is arg2:
# special case in order to handle duplicate column names
for i, col in enumerate(arg1.columns):
for i in range(len(arg1.columns)):
results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i])
return dataframe_from_int_dict(results, arg1)
else:
if not arg1.columns.is_unique:
raise ValueError("'arg1' columns are not unique")
if not arg2.columns.is_unique:
raise ValueError("'arg2' columns are not unique")
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
X, Y = arg1.align(arg2, join="outer")
X = X + 0 * Y
Y = Y + 0 * X

with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
res_columns = arg1.columns.union(arg2.columns)
X, Y = arg1.align(arg2, join="outer")
X, Y = prep_binary(X, Y)
res_columns = arg1.columns.union(arg2.columns)
for col in res_columns:
if col in X and col in Y:
results[col] = f(X[col], Y[col])
return DataFrame(results, index=X.index, columns=res_columns)
elif pairwise is True:
results = defaultdict(dict)
for i, k1 in enumerate(arg1.columns):
for j, k2 in enumerate(arg2.columns):
for i in range(len(arg1.columns)):
for j in range(len(arg2.columns)):
if j < i and arg2 is arg1:
# Symmetric case
results[i][j] = results[j][i]
Expand All @@ -85,10 +68,10 @@ def dataframe_from_int_dict(data, frame_template):
result = concat(
[
concat(
[results[i][j] for j, c in enumerate(arg2.columns)],
[results[i][j] for j in range(len(arg2.columns))],
ignore_index=True,
)
for i, c in enumerate(arg1.columns)
for i in range(len(arg1.columns))
],
ignore_index=True,
axis=1,
Expand Down Expand Up @@ -135,13 +118,10 @@ def dataframe_from_int_dict(data, frame_template):
)

return result

else:
raise ValueError("'pairwise' is not True/False")
else:
results = {
i: f(*prep_binary(arg1.iloc[:, i], arg2))
for i, col in enumerate(arg1.columns)
for i in range(len(arg1.columns))
}
return dataframe_from_int_dict(results, arg1)

Expand All @@ -165,11 +145,7 @@ def zsqrt(x):


def prep_binary(arg1, arg2):
if not isinstance(arg2, type(arg1)):
raise Exception("Input arrays must be of the same type!")

# mask out values, this also makes a common index...
X = arg1 + 0 * arg2
Y = arg2 + 0 * arg1

return X, Y
2 changes: 2 additions & 0 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def _apply_pairwise(
other = target
# only default unset
pairwise = True if pairwise is None else pairwise
elif not isinstance(other, (ABCDataFrame, ABCSeries)):
raise ValueError("other must be a DataFrame or Series")

return flex_binary_moment(target, other, func, pairwise=bool(pairwise))

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/window/moments/test_moments_consistency_ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def test_different_input_array_raise_exception(name):
A = Series(np.random.randn(50), index=np.arange(50))
A[:10] = np.NaN

msg = "Input arrays must be of the same type!"
msg = "other must be a DataFrame or Series"
# exception raised is Exception
with pytest.raises(Exception, match=msg):
with pytest.raises(ValueError, match=msg):
getattr(A.ewm(com=20, min_periods=5), name)(np.random.randn(50))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Series,
)
import pandas._testing as tm
from pandas.core.window.common import flex_binary_moment


def _rolling_consistency_cases():
Expand Down Expand Up @@ -133,14 +132,6 @@ def test_rolling_corr_with_zero_variance(window):
assert s.rolling(window=window).corr(other=other).isna().all()


def test_flex_binary_moment():
# GH3155
# don't blow the stack
msg = "arguments to moment function must be of type np.ndarray/Series/DataFrame"
with pytest.raises(TypeError, match=msg):
flex_binary_moment(5, 6, None)


def test_corr_sanity():
# GH 3155
df = DataFrame(
Expand Down

0 comments on commit e8dbdb0

Please sign in to comment.