Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): allow nonstrict cast of categorical/enum to enum #14910

Merged
merged 5 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 16 additions & 32 deletions crates/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ impl CategoricalChunked {
}
}

// Convert to fixed enum. In case a value is not in the categories return Error
pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> PolarsResult<Self> {
// Convert to fixed enum. Values not in categories are mapped to None.
pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
// Fast paths
match self.get_rev_map().as_ref() {
RevMapping::Local(_, cur_hash) if hash == *cur_hash => {
return unsafe {
Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
CategoricalChunked::from_cats_and_rev_map_unchecked(
self.physical().clone(),
self.get_rev_map().clone(),
true,
self.get_ordering(),
))
)
};
},
_ => (),
Expand All @@ -159,34 +159,18 @@ impl CategoricalChunked {
let new_phys: UInt32Chunked = self
.physical()
.into_iter()
.map(|opt_v: Option<u32>| {
let Some(v) = opt_v else {
return Ok(None);
};

let Some(idx) = idx_map.get(&v) else {
polars_bail!(
not_in_enum,
value = old_rev_map.get(v),
categories = &categories
);
};
.map(|opt_v: Option<u32>| opt_v.and_then(|v| idx_map.get(&v).copied()))
.collect();

Ok(Some(*idx))
})
.collect::<PolarsResult<_>>()?;

Ok(
// SAFETY: we created the physical from the enum categories
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
new_phys,
Arc::new(RevMapping::Local(categories.clone(), hash)),
true,
self.get_ordering(),
)
},
)
// SAFETY: we created the physical from the enum categories
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
new_phys,
Arc::new(RevMapping::Local(categories.clone(), hash)),
true,
self.get_ordering(),
)
}
}

pub(crate) fn get_flags(&self) -> Settings {
Expand Down Expand Up @@ -373,7 +357,7 @@ impl LogicalType for CategoricalChunked {
polars_bail!(ComputeError: "can not cast to enum with global mapping")
};
Ok(self
.to_enum(categories, *hash)?
.to_enum(categories, *hash)
.set_ordering(*ordering, true)
.into_series()
.with_name(self.name()))
Expand Down
28 changes: 26 additions & 2 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ def test_casting_to_an_enum_from_categorical() -> None:
assert_series_equal(s2, expected)


def test_casting_to_an_enum_from_categorical_nonstrict() -> None:
dtype = pl.Enum(["a", "b"])
s = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical)
s2 = s.cast(dtype, strict=False)
assert s2.dtype == dtype
assert s2.null_count() == 2 # "c" mapped to null
expected = pl.Series([None, "a", "b", None], dtype=dtype)
assert_series_equal(s2, expected)


def test_casting_to_an_enum_from_enum_nonstrict() -> None:
dtype = pl.Enum(["a", "b"])
s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"]))
s2 = s.cast(dtype, strict=False)
assert s2.dtype == dtype
assert s2.null_count() == 2 # "c" mapped to null
expected = pl.Series([None, "a", "b", None], dtype=dtype)
assert_series_equal(s2, expected)


def test_casting_to_an_enum_from_integer() -> None:
dtype = pl.Enum(["a", "b", "c"])
expected = pl.Series([None, "b", "a", "c"], dtype=dtype)
Expand All @@ -139,7 +159,9 @@ def test_casting_to_an_enum_oob_from_integer() -> None:
def test_casting_to_an_enum_from_categorical_nonexistent() -> None:
with pytest.raises(
pl.ComputeError,
match=("value 'c' is not present in Enum"),
match=(
r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]"
),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"]))

Expand All @@ -159,7 +181,9 @@ def test_casting_to_an_enum_from_global_categorical() -> None:
def test_casting_to_an_enum_from_global_categorical_nonexistent() -> None:
with pytest.raises(
pl.ComputeError,
match=("value 'c' is not present in Enum"),
match=(
r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]"
),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"]))

Expand Down
Loading