From 4b7dffd447067dc466264f1dda928bbe7f4beb34 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 4 Jan 2023 09:33:44 +0100 Subject: [PATCH] fix(rust, python): keep name when sorting categorical in lexial order (#6029) --- .../src/chunked_array/ops/sort/categorical.rs | 5 ++++- py-polars/tests/unit/test_categorical.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/polars/polars-core/src/chunked_array/ops/sort/categorical.rs b/polars/polars-core/src/chunked_array/ops/sort/categorical.rs index bbf56ef8f8166..26d3d04076b66 100644 --- a/polars/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/polars/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -65,11 +65,14 @@ impl CategoricalChunked { ); let cats: NoNull = vals.into_iter().map(|(idx, _v)| idx).collect_trusted(); + let mut cats = cats.into_inner(); + cats.rename(self.name()); + // safety: // we only reordered the indexes so we are still in bounds unsafe { CategoricalChunked::from_cats_and_rev_map_unchecked( - cats.into_inner(), + cats, self.get_rev_map().clone(), ) } diff --git a/py-polars/tests/unit/test_categorical.py b/py-polars/tests/unit/test_categorical.py index 3b952a7074e9c..425a86176d29b 100644 --- a/py-polars/tests/unit/test_categorical.py +++ b/py-polars/tests/unit/test_categorical.py @@ -288,3 +288,20 @@ def test_categorical_in_struct_nulls() -> None: assert s[0] == {"job": None, "counts": 3} assert s[1] == {"job": "doctor", "counts": 2} assert s[2] == {"job": "waiter", "counts": 1} + + +def test_sort_categoricals_6014() -> None: + with pl.StringCache(): + # create basic categorical + df1 = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_column( + pl.col("key").cast(pl.Categorical) + ) + # create lexically-ordered categorical + df2 = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_column( + pl.col("key").cast(pl.Categorical).cat.set_ordering("lexical") + ) + + out = df1.sort("key") + assert out.to_dict(False) == {"key": ["bbb", "aaa", "ccc"]} + out = df2.sort("key") + assert out.to_dict(False) == {"key": ["aaa", "bbb", "ccc"]}