Skip to content

Commit

Permalink
fix(rust): Allow casting integer types to Enum (#13955)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-peters authored Jan 26, 2024
1 parent 78b4254 commit 37ca0af
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
30 changes: 22 additions & 8 deletions crates/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,22 @@ where
match data_type {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(rev_map, ordering) => {
polars_ensure!(
self.dtype() == &DataType::UInt32,
ComputeError: "cannot cast numeric types to 'Categorical'"
);
// SAFETY
// we are guarded by the type system
let ca = unsafe { &*(self as *const ChunkedArray<T> as *const UInt32Chunked) };
if let Some(rev_map) = rev_map {
if let RevMapping::Enum(categories, _) = &**rev_map {
let ca = match self.dtype() {
DataType::UInt32 => {
// SAFETY: we are guarded by the type system
unsafe {
&*(self as *const ChunkedArray<T> as *const UInt32Chunked)
}
.clone()
},
dt if dt.is_integer() => self.cast(&DataType::UInt32)?.u32()?.clone(),
_ => {
polars_bail!(ComputeError: "cannot cast non integer types to 'Categorical'")
},
};

// Check if indices are in bounds
if let Some(m) = ca.max() {
if m >= categories.len() as u32 {
Expand All @@ -120,14 +127,21 @@ where
// SAFETY indices are in bound
unsafe {
return Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
ca.clone(),
ca,
rev_map.clone(),
*ordering,
)
.into_series());
}
}
}
polars_ensure!(
self.dtype() == &DataType::UInt32,
ComputeError: "cannot cast numeric types to 'Categorical'"
);
// SAFETY
// we are guarded by the type system
let ca = unsafe { &*(self as *const ChunkedArray<T> as *const UInt32Chunked) };

CategoricalChunked::from_global_indices(ca.clone(), *ordering)
.map(|ca| ca.into_series())
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,28 @@ def test_enum_categories_series_zero_copy() -> None:
result_dtype = s.dtype

assert result_dtype == dtype


@pytest.mark.parametrize(
"dtype",
[pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Int8, pl.Int16, pl.Int32, pl.Int64],
)
def test_enum_cast_from_other_integer_dtype(dtype: pl.DataType) -> None:
enum_dtype = pl.Enum(["a", "b", "c", "d"])
series = pl.Series([1, 2, 3, 3, 2, 1], dtype=dtype)
series.cast(enum_dtype)


def test_enum_cast_from_other_integer_dtype_oob() -> None:
enum_dtype = pl.Enum(["a", "b", "c", "d"])
series = pl.Series([-1, 2, 3, 3, 2, 1], dtype=pl.Int8)
with pytest.raises(
pl.ComputeError, match="conversion from `i8` to `enum` failed in column"
):
series.cast(enum_dtype)

series = pl.Series([2**34, 2, 3, 3, 2, 1], dtype=pl.UInt64)
with pytest.raises(
pl.ComputeError, match="conversion from `u64` to `enum` failed in column"
):
series.cast(enum_dtype)

0 comments on commit 37ca0af

Please sign in to comment.