Skip to content

Commit

Permalink
fix: Fix casting from categorical to numeric (#13957)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-peters committed Jan 26, 2024
1 parent 088d822 commit 96b8b93
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
17 changes: 17 additions & 0 deletions crates/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,23 @@ impl LogicalType for CategoricalChunked {
// Otherwise we do nothing
Ok(self.clone().set_ordering(*ordering, true).into_series())
},
dt if dt.is_numeric() => {
// Apply the cast to to the categories and then index into the casted series
let categories =
StringChunked::with_chunk("", self.get_rev_map().get_categories().clone());
let casted_series = categories.cast(dtype)?;

#[cfg(feature = "bigidx")]
{
let s = self.physical.cast(&DataType::UInt64)?;
Ok(unsafe { casted_series.take_unchecked(s.u64()?) })
}
#[cfg(not(feature = "bigidx"))]
{
// Safety: Invariant of categorical means indices are in bound
Ok(unsafe { casted_series.take_unchecked(&self.physical) })
}
},
_ => self.physical.cast(dtype),
}
}
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,15 @@ def test_sort_categorical_retain_none(
"foo",
"ham",
]


def test_cast_from_cat_to_numeric() -> None:
cat_series = pl.Series(
"cat_series",
["0.69845702", "0.69317475", "2.43642724", "-0.95303469", "0.60684237"],
).cast(pl.Categorical)
maximum = cat_series.cast(pl.Float32).max()
assert abs(maximum - 2.43642724) < 1e-6 # type: ignore[operator]

s = pl.Series(["1", "2", "3"], dtype=pl.Categorical)
assert s.cast(pl.UInt8).sum() == 6

0 comments on commit 96b8b93

Please sign in to comment.