Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLN: Refactor some sorting code in Index set operations #24533

Merged
merged 1 commit into from
Jan 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2302,27 +2302,15 @@ def union(self, other):
allow_fill=False)
result = _concat._concat_compat((lvals, other_diff))

try:
lvals[0] < other_diff[0]
except TypeError as e:
warnings.warn("%s, sort order is undefined for "
"incomparable objects" % e, RuntimeWarning,
stacklevel=3)
else:
types = frozenset((self.inferred_type,
other.inferred_type))
if not types & _unsortable_types:
result.sort()

else:
result = lvals

try:
result = np.sort(result)
except TypeError as e:
warnings.warn("%s, sort order is undefined for "
"incomparable objects" % e, RuntimeWarning,
stacklevel=3)
try:
result = sorting.safe_sort(result)
except TypeError as e:
warnings.warn("%s, sort order is undefined for "
"incomparable objects" % e, RuntimeWarning,
stacklevel=3)
jreback marked this conversation as resolved.
Show resolved Hide resolved

# for subclasses
return self._wrap_setop_result(other, result)
Expand Down
30 changes: 6 additions & 24 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,7 @@ def test_union_name_preservation(self, first_list, second_list, first_name,

def test_union_dt_as_obj(self):
# TODO: Replace with fixturesult
with tm.assert_produces_warning(RuntimeWarning):
firstCat = self.strIndex.union(self.dateIndex)
firstCat = self.strIndex.union(self.dateIndex)
secondCat = self.strIndex.union(self.strIndex)

if self.dateIndex.dtype == np.object_:
Expand Down Expand Up @@ -1615,7 +1614,7 @@ def test_drop_tuple(self, values, to_drop):
@pytest.mark.parametrize("method,expected", [
('intersection', np.array([(1, 'A'), (2, 'A'), (1, 'B'), (2, 'B')],
dtype=[('num', int), ('let', 'a1')])),
('union', np.array([(1, 'A'), (2, 'A'), (1, 'B'), (2, 'B'), (1, 'C'),
('union', np.array([(1, 'A'), (1, 'B'), (1, 'C'), (2, 'A'), (2, 'B'),
(2, 'C')], dtype=[('num', int), ('let', 'a1')]))
])
def test_tuple_union_bug(self, method, expected):
Expand Down Expand Up @@ -2242,10 +2241,7 @@ def test_copy_name(self):
s1 = Series(2, index=first)
s2 = Series(3, index=second[:-1])

warning_type = RuntimeWarning if PY3 else None
with tm.assert_produces_warning(warning_type):
# Python 3: Unorderable types
s3 = s1 * s2
s3 = s1 * s2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of places where the above warning was given. But now that the safe_sort function is used the warning is no longer given because safe_sort tries to sort some "unorderable" types:

def sort_mixed(values):
# order ints before strings, safe in py3
str_pos = np.array([isinstance(x, string_types) for x in values],
dtype=bool)
nums = np.sort(values[~str_pos])
strs = np.sort(values[str_pos])
return np.concatenate([nums, np.asarray(strs, dtype=object)])
sorter = None
if PY3 and lib.infer_dtype(values) == 'mixed-integer':
# unorderable in py3 if mixed str/int
ordered = sort_mixed(values)
else:
try:
sorter = values.argsort()
ordered = values.take(sorter)
except TypeError:
# try this anyway
ordered = sort_mixed(values)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, this is the intent of the function


assert s3.index.name == 'mario'

Expand Down Expand Up @@ -2274,16 +2270,9 @@ def test_union_base(self):
first = index[3:]
second = index[:5]

if PY3:
# unorderable types
warn_type = RuntimeWarning
else:
warn_type = None

with tm.assert_produces_warning(warn_type):
result = first.union(second)
result = first.union(second)

expected = Index(['b', 2, 'c', 0, 'a', 1])
expected = Index([0, 1, 2, 'a', 'b', 'c'])
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize("klass", [
Expand All @@ -2294,14 +2283,7 @@ def test_union_different_type_base(self, klass):
first = index[3:]
second = index[:5]

if PY3:
# unorderable types
warn_type = RuntimeWarning
else:
warn_type = None

with tm.assert_produces_warning(warn_type):
result = first.union(klass(second.values))
result = first.union(klass(second.values))

assert tm.equalContents(result, index)

Expand Down
26 changes: 5 additions & 21 deletions pandas/tests/series/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,12 @@ def test_operators_bitwise(self):
s_0123 & [0.1, 4, 3.14, 2]

# s_0123 will be all false now because of reindexing like s_tft
if compat.PY3:
# unable to sort incompatible object via .union.
exp = Series([False] * 7, index=['b', 'c', 'a', 0, 1, 2, 3])
with tm.assert_produces_warning(RuntimeWarning):
assert_series_equal(s_tft & s_0123, exp)
else:
exp = Series([False] * 7, index=[0, 1, 2, 3, 'a', 'b', 'c'])
assert_series_equal(s_tft & s_0123, exp)
exp = Series([False] * 7, index=[0, 1, 2, 3, 'a', 'b', 'c'])
assert_series_equal(s_tft & s_0123, exp)

# s_tft will be all false now because of reindexing like s_0123
if compat.PY3:
# unable to sort incompatible object via .union.
exp = Series([False] * 7, index=[0, 1, 2, 3, 'b', 'c', 'a'])
with tm.assert_produces_warning(RuntimeWarning):
assert_series_equal(s_0123 & s_tft, exp)
else:
exp = Series([False] * 7, index=[0, 1, 2, 3, 'a', 'b', 'c'])
assert_series_equal(s_0123 & s_tft, exp)
exp = Series([False] * 7, index=[0, 1, 2, 3, 'a', 'b', 'c'])
assert_series_equal(s_0123 & s_tft, exp)

assert_series_equal(s_0123 & False, Series([False] * 4))
assert_series_equal(s_0123 ^ False, Series([False, True, True, True]))
Expand Down Expand Up @@ -280,11 +268,7 @@ def test_logical_ops_label_based(self):
assert_series_equal(result, a[a])

for e in [Series(['z'])]:
if compat.PY3:
with tm.assert_produces_warning(RuntimeWarning):
result = a[a | e]
else:
result = a[a | e]
result = a[a | e]
assert_series_equal(result, a[a])

# vs scalars
Expand Down