From 96b8b93a959c24ea4c4a5f8aa8c375a167f5699b Mon Sep 17 00:00:00 2001 From: chielP Date: Fri, 26 Jan 2024 15:14:50 +0100 Subject: [PATCH] fix: Fix casting from categorical to numeric (#13957) --- .../chunked_array/logical/categorical/mod.rs | 17 +++++++++++++++++ .../tests/unit/datatypes/test_categorical.py | 12 ++++++++++++ 2 files changed, 29 insertions(+) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index ead8ae6e4271..ba282994de65 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -394,6 +394,23 @@ impl LogicalType for CategoricalChunked { // Otherwise we do nothing Ok(self.clone().set_ordering(*ordering, true).into_series()) }, + dt if dt.is_numeric() => { + // Apply the cast to to the categories and then index into the casted series + let categories = + StringChunked::with_chunk("", self.get_rev_map().get_categories().clone()); + let casted_series = categories.cast(dtype)?; + + #[cfg(feature = "bigidx")] + { + let s = self.physical.cast(&DataType::UInt64)?; + Ok(unsafe { casted_series.take_unchecked(s.u64()?) }) + } + #[cfg(not(feature = "bigidx"))] + { + // Safety: Invariant of categorical means indices are in bound + Ok(unsafe { casted_series.take_unchecked(&self.physical) }) + } + }, _ => self.physical.cast(dtype), } } diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index 07f7a2026305..4e02decb8fe9 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -800,3 +800,15 @@ def test_sort_categorical_retain_none( "foo", "ham", ] + + +def test_cast_from_cat_to_numeric() -> None: + cat_series = pl.Series( + "cat_series", + ["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"], + ).cast(pl.Categorical) + maximum = cat_series.cast(pl.Float32).max() + assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator] + + s = pl.Series(["1", "2", "3"], dtype=pl.Categorical) + assert s.cast(pl.UInt8).sum() == 6