Skip to content

Commit

Permalink
CLN/TST: address TODOs/FIXMES #2 (#44174)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Oct 25, 2021
1 parent d8440f1 commit 7c00e0c
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 104 deletions.
5 changes: 4 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,8 +1517,11 @@ def _cython_agg_general(
if numeric_only:
if is_ser and not is_numeric_dtype(self._selected_obj.dtype):
# GH#41291 match Series behavior
kwd_name = "numeric_only"
if how in ["any", "all"]:
kwd_name = "bool_only"
raise NotImplementedError(
f"{type(self).__name__}.{how} does not implement numeric_only."
f"{type(self).__name__}.{how} does not implement {kwd_name}."
)
elif not is_ser:
data = data.get_numeric_data(copy=False)
Expand Down
4 changes: 1 addition & 3 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,7 @@ def _intersection(self, other: Index, sort=False):
new_index = new_index[::-1]

if sort is None:
# TODO: can revert to just `if sort is None` after GH#43666
if new_index.step < 0:
new_index = new_index[::-1]
new_index = new_index.sort_values()

return new_index

Expand Down
7 changes: 5 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4169,7 +4169,7 @@ def map(self, arg, na_action=None) -> Series:
3 I am a rabbit
dtype: object
"""
new_values = super()._map_values(arg, na_action=na_action)
new_values = self._map_values(arg, na_action=na_action)
return self._constructor(new_values, index=self.index).__finalize__(
self, method="map"
)
Expand Down Expand Up @@ -4396,8 +4396,11 @@ def _reduce(
else:
# dispatch to numpy arrays
if numeric_only:
kwd_name = "numeric_only"
if name in ["any", "all"]:
kwd_name = "bool_only"
raise NotImplementedError(
f"Series.{name} does not implement numeric_only."
f"Series.{name} does not implement {kwd_name}."
)
with np.errstate(all="ignore"):
return op(delegate, skipna=skipna, **kwds)
Expand Down
74 changes: 13 additions & 61 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import decimal
import math
import operator

import numpy as np
Expand Down Expand Up @@ -70,54 +69,7 @@ def data_for_grouping():
return DecimalArray([b, b, na, na, a, a, b, c])


class BaseDecimal:
@classmethod
def assert_series_equal(cls, left, right, *args, **kwargs):
def convert(x):
# need to convert array([Decimal(NaN)], dtype='object') to np.NaN
# because Series[object].isnan doesn't recognize decimal(NaN) as
# NA.
try:
return math.isnan(x)
except TypeError:
return False

if left.dtype == "object":
left_na = left.apply(convert)
else:
left_na = left.isna()
if right.dtype == "object":
right_na = right.apply(convert)
else:
right_na = right.isna()

tm.assert_series_equal(left_na, right_na)
return tm.assert_series_equal(left[~left_na], right[~right_na], *args, **kwargs)

@classmethod
def assert_frame_equal(cls, left, right, *args, **kwargs):
# TODO(EA): select_dtypes
tm.assert_index_equal(
left.columns,
right.columns,
exact=kwargs.get("check_column_type", "equiv"),
check_names=kwargs.get("check_names", True),
check_exact=kwargs.get("check_exact", False),
check_categorical=kwargs.get("check_categorical", True),
obj=f"{kwargs.get('obj', 'DataFrame')}.columns",
)

decimals = (left.dtypes == "decimal").index

for col in decimals:
cls.assert_series_equal(left[col], right[col], *args, **kwargs)

left = left.drop(columns=decimals)
right = right.drop(columns=decimals)
tm.assert_frame_equal(left, right, *args, **kwargs)


class TestDtype(BaseDecimal, base.BaseDtypeTests):
class TestDtype(base.BaseDtypeTests):
def test_hashable(self, dtype):
pass

Expand All @@ -129,27 +81,27 @@ def test_infer_dtype(self, data, data_missing, skipna):
assert infer_dtype(data_missing, skipna=skipna) == "unknown-array"


class TestInterface(BaseDecimal, base.BaseInterfaceTests):
class TestInterface(base.BaseInterfaceTests):
pass


class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
class TestConstructors(base.BaseConstructorsTests):
pass


class TestReshaping(BaseDecimal, base.BaseReshapingTests):
class TestReshaping(base.BaseReshapingTests):
pass


class TestGetitem(BaseDecimal, base.BaseGetitemTests):
class TestGetitem(base.BaseGetitemTests):
def test_take_na_value_other_decimal(self):
arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
self.assert_extension_array_equal(result, expected)


class TestMissing(BaseDecimal, base.BaseMissingTests):
class TestMissing(base.BaseMissingTests):
pass


Expand All @@ -175,7 +127,7 @@ class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
pass


class TestMethods(BaseDecimal, base.BaseMethodsTests):
class TestMethods(base.BaseMethodsTests):
@pytest.mark.parametrize("dropna", [True, False])
def test_value_counts(self, all_data, dropna, request):
all_data = all_data[:10]
Expand All @@ -200,20 +152,20 @@ def test_value_counts_with_normalize(self, data):
return super().test_value_counts_with_normalize(data)


class TestCasting(BaseDecimal, base.BaseCastingTests):
class TestCasting(base.BaseCastingTests):
pass


class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
class TestGroupby(base.BaseGroupbyTests):
def test_groupby_agg_extension(self, data_for_grouping):
super().test_groupby_agg_extension(data_for_grouping)


class TestSetitem(BaseDecimal, base.BaseSetitemTests):
class TestSetitem(base.BaseSetitemTests):
pass


class TestPrinting(BaseDecimal, base.BasePrintingTests):
class TestPrinting(base.BasePrintingTests):
def test_series_repr(self, data):
# Overriding this base test to explicitly test that
# the custom _formatter is used
Expand Down Expand Up @@ -282,7 +234,7 @@ def test_astype_dispatches(frame):
assert result.dtype.context.prec == ctx.prec


class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):
class TestArithmeticOps(base.BaseArithmeticOpsTests):
def check_opname(self, s, op_name, other, exc=None):
super().check_opname(s, op_name, other, exc=None)

Expand Down Expand Up @@ -313,7 +265,7 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
super()._check_divmod_op(s, op, other, exc=None)


class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests):
class TestComparisonOps(base.BaseComparisonOpsTests):
def test_compare_scalar(self, data, all_compare_operators):
op_name = all_compare_operators
s = pd.Series(data)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/frame/methods/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_shift_axis1_multiple_blocks_with_int_fill(self):

@pytest.mark.filterwarnings("ignore:tshift is deprecated:FutureWarning")
def test_tshift(self, datetime_frame):
# TODO: remove this test when tshift deprecation is enforced
# TODO(2.0): remove this test when tshift deprecation is enforced

# PeriodIndex
ps = tm.makePeriodFrame()
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/frame/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,7 +2508,7 @@ def check_views():
# TODO: we can call check_views if we stop consolidating
# in setitem_with_indexer

# FIXME: until GH#35417, iloc.setitem into EA values does not preserve
# FIXME(GH#35417): until GH#35417, iloc.setitem into EA values does not preserve
# view, so we have to check in the other direction
# df.iloc[0, 2] = 0
# if not copy:
Expand All @@ -2522,7 +2522,7 @@ def check_views():
else:
assert a[0] == a.dtype.type(1)
assert b[0] == b.dtype.type(3)
# FIXME: enable after GH#35417
# FIXME(GH#35417): enable after GH#35417
# assert c[0] == 1
assert df.iloc[0, 2] == 1
else:
Expand Down
8 changes: 3 additions & 5 deletions pandas/tests/reductions/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,13 +929,11 @@ def test_all_any_params(self):
with tm.assert_produces_warning(FutureWarning):
s.all(bool_only=True, level=0)

# bool_only is not implemented alone.
# TODO GH38810 change this error message to:
# "Series.any does not implement bool_only"
msg = "Series.any does not implement numeric_only"
# GH#38810 bool_only is not implemented alone.
msg = "Series.any does not implement bool_only"
with pytest.raises(NotImplementedError, match=msg):
s.any(bool_only=True)
msg = "Series.all does not implement numeric_only."
msg = "Series.all does not implement bool_only."
with pytest.raises(NotImplementedError, match=msg):
s.all(bool_only=True)

Expand Down
56 changes: 29 additions & 27 deletions pandas/tests/series/methods/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ def test_rename(self, datetime_series):
renamed2 = ts.rename(rename_dict)
tm.assert_series_equal(renamed, renamed2)

def test_rename_partial_dict(self):
# partial dict
s = Series(np.arange(4), index=["a", "b", "c", "d"], dtype="int64")
renamed = s.rename({"b": "foo", "d": "bar"})
ser = Series(np.arange(4), index=["a", "b", "c", "d"], dtype="int64")
renamed = ser.rename({"b": "foo", "d": "bar"})
tm.assert_index_equal(renamed.index, Index(["a", "foo", "c", "bar"]))

def test_rename_retain_index_name(self):
# index with name
renamer = Series(
np.arange(4), index=Index(["a", "b", "c", "d"], name="name"), dtype="int64"
Expand All @@ -35,38 +37,38 @@ def test_rename(self, datetime_series):
assert renamed.index.name == renamer.index.name

def test_rename_by_series(self):
s = Series(range(5), name="foo")
ser = Series(range(5), name="foo")
renamer = Series({1: 10, 2: 20})
result = s.rename(renamer)
result = ser.rename(renamer)
expected = Series(range(5), index=[0, 10, 20, 3, 4], name="foo")
tm.assert_series_equal(result, expected)

def test_rename_set_name(self):
s = Series(range(4), index=list("abcd"))
ser = Series(range(4), index=list("abcd"))
for name in ["foo", 123, 123.0, datetime(2001, 11, 11), ("foo",)]:
result = s.rename(name)
result = ser.rename(name)
assert result.name == name
tm.assert_numpy_array_equal(result.index.values, s.index.values)
assert s.name is None
tm.assert_numpy_array_equal(result.index.values, ser.index.values)
assert ser.name is None

def test_rename_set_name_inplace(self):
s = Series(range(3), index=list("abc"))
ser = Series(range(3), index=list("abc"))
for name in ["foo", 123, 123.0, datetime(2001, 11, 11), ("foo",)]:
s.rename(name, inplace=True)
assert s.name == name
ser.rename(name, inplace=True)
assert ser.name == name

exp = np.array(["a", "b", "c"], dtype=np.object_)
tm.assert_numpy_array_equal(s.index.values, exp)
tm.assert_numpy_array_equal(ser.index.values, exp)

def test_rename_axis_supported(self):
# Supporting axis for compatibility, detailed in GH-18589
s = Series(range(5))
s.rename({}, axis=0)
s.rename({}, axis="index")
# FIXME: dont leave commenred-out
ser = Series(range(5))
ser.rename({}, axis=0)
ser.rename({}, axis="index")
# FIXME: dont leave commented-out
# TODO: clean up shared index validation
# with pytest.raises(ValueError, match="No axis named 5"):
# s.rename({}, axis=5)
# ser.rename({}, axis=5)

def test_rename_inplace(self, datetime_series):
renamer = lambda x: x.strftime("%Y%m%d")
Expand All @@ -81,24 +83,24 @@ class MyIndexer:
pass

ix = MyIndexer()
s = Series([1, 2, 3]).rename(ix)
assert s.name is ix
ser = Series([1, 2, 3]).rename(ix)
assert ser.name is ix

def test_rename_with_custom_indexer_inplace(self):
# GH 27814
class MyIndexer:
pass

ix = MyIndexer()
s = Series([1, 2, 3])
s.rename(ix, inplace=True)
assert s.name is ix
ser = Series([1, 2, 3])
ser.rename(ix, inplace=True)
assert ser.name is ix

def test_rename_callable(self):
# GH 17407
s = Series(range(1, 6), index=Index(range(2, 7), name="IntIndex"))
result = s.rename(str)
expected = s.rename(lambda i: str(i))
ser = Series(range(1, 6), index=Index(range(2, 7), name="IntIndex"))
result = ser.rename(str)
expected = ser.rename(lambda i: str(i))
tm.assert_series_equal(result, expected)

assert result.name == expected.name
Expand All @@ -111,8 +113,8 @@ def test_rename_series_with_multiindex(self):
]

index = MultiIndex.from_arrays(arrays, names=["first", "second"])
s = Series(np.ones(5), index=index)
result = s.rename(index={"one": "yes"}, level="second", errors="raise")
ser = Series(np.ones(5), index=index)
result = ser.rename(index={"one": "yes"}, level="second", errors="raise")

arrays_expected = [
["bar", "baz", "baz", "foo", "qux"],
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/series/methods/test_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_shift_dst(self):

@pytest.mark.filterwarnings("ignore:tshift is deprecated:FutureWarning")
def test_tshift(self, datetime_series):
# TODO: remove this test when tshift deprecation is enforced
# TODO(2.0): remove this test when tshift deprecation is enforced

# PeriodIndex
ps = tm.makePeriodSeries()
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/strings/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test_api_per_method(
inferred_dtype, values = any_allowed_skipna_inferred_dtype
method_name, args, kwargs = any_string_method

# TODO: get rid of these xfails
reason = None
if box is Index and values.size == 0:
if method_name in ["partition", "rpartition"] and kwargs.get("expand", True):
Expand Down

0 comments on commit 7c00e0c

Please sign in to comment.