Skip to content

Commit

Permalink
BUG/PERF: use lexsort_indexer in MultiIndex.argsort (#48495)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Oct 7, 2022
1 parent ee352b1 commit c0e6baf
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Missing

MultiIndex
^^^^^^^^^^
- Bug in :meth:`MultiIndex.argsort` raising ``TypeError`` when index contains :attr:`NA` (:issue:`48495`)
- Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48606`)
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
Expand Down
11 changes: 6 additions & 5 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,7 +1952,7 @@ def _lexsort_depth(self) -> int:
return self.sortorder
return _lexsort_depth(self.codes, self.nlevels)

def _sort_levels_monotonic(self) -> MultiIndex:
def _sort_levels_monotonic(self, raise_if_incomparable: bool = False) -> MultiIndex:
"""
This is an *internal* function.
Expand Down Expand Up @@ -1999,7 +1999,8 @@ def _sort_levels_monotonic(self) -> MultiIndex:
# indexer to reorder the levels
indexer = lev.argsort()
except TypeError:
pass
if raise_if_incomparable:
raise
else:
lev = lev.take(indexer)

Expand Down Expand Up @@ -2245,9 +2246,9 @@ def append(self, other):

def argsort(self, *args, **kwargs) -> npt.NDArray[np.intp]:
if len(args) == 0 and len(kwargs) == 0:
# np.lexsort is significantly faster than self._values.argsort()
values = [self._get_level_values(i) for i in reversed(range(self.nlevels))]
return np.lexsort(values)
# lexsort is significantly faster than self._values.argsort()
target = self._sort_levels_monotonic(raise_if_incomparable=True)
return lexsort_indexer(target._get_codes_for_sorting())
return self._values.argsort(*args, **kwargs)

@Appender(_index_shared_docs["repeat"] % _index_doc_kwargs)
Expand Down
24 changes: 24 additions & 0 deletions pandas/tests/indexes/multi/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Index,
MultiIndex,
RangeIndex,
Timestamp,
)
import pandas._testing as tm
from pandas.core.indexes.frozen import FrozenList
Expand Down Expand Up @@ -280,3 +281,26 @@ def test_remove_unused_levels_with_nan():
result = idx.levels
expected = FrozenList([["a", np.nan], [4]])
assert str(result) == str(expected)


def test_sort_values_nan():
# GH48495, GH48626
midx = MultiIndex(levels=[["A", "B", "C"], ["D"]], codes=[[1, 0, 2], [-1, -1, 0]])
result = midx.sort_values()
expected = MultiIndex(
levels=[["A", "B", "C"], ["D"]], codes=[[0, 1, 2], [-1, -1, 0]]
)
tm.assert_index_equal(result, expected)


def test_sort_values_incomparable():
# GH48495
mi = MultiIndex.from_arrays(
[
[1, Timestamp("2000-01-01")],
[3, 4],
]
)
match = "'<' not supported between instances of 'Timestamp' and 'int'"
with pytest.raises(TypeError, match=match):
mi.sort_values()
32 changes: 32 additions & 0 deletions pandas/tests/indexing/multiindex/test_sorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import pytest

from pandas import (
NA,
DataFrame,
MultiIndex,
Series,
array,
)
import pandas._testing as tm

Expand Down Expand Up @@ -86,6 +88,36 @@ def test_sort_values_key(self):

tm.assert_frame_equal(result, expected)

def test_argsort_with_na(self):
# GH48495
arrays = [
array([2, NA, 1], dtype="Int64"),
array([1, 2, 3], dtype="Int64"),
]
index = MultiIndex.from_arrays(arrays)
result = index.argsort()
expected = np.array([2, 0, 1], dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)

def test_sort_values_with_na(self):
# GH48495
arrays = [
array([2, NA, 1], dtype="Int64"),
array([1, 2, 3], dtype="Int64"),
]
index = MultiIndex.from_arrays(arrays)
index = index.sort_values()
result = DataFrame(range(3), index=index)

arrays = [
array([1, 2, NA], dtype="Int64"),
array([3, 1, 2], dtype="Int64"),
]
index = MultiIndex.from_arrays(arrays)
expected = DataFrame(range(3), index=index)

tm.assert_frame_equal(result, expected)

def test_frame_getitem_not_sorted(self, multiindex_dataframe_random_data):
frame = multiindex_dataframe_random_data
df = frame.T
Expand Down

0 comments on commit c0e6baf

Please sign in to comment.