Skip to content

Commit

Permalink
fix(rust): allow get access to list of categoricals (#14015)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-peters authored Jan 26, 2024
1 parent 298e71f commit 3651ecd
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
9 changes: 6 additions & 3 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,12 @@ pub trait ListNameSpaceImpl: AsList {
.downcast_iter()
.map(|arr| sublist_get(arr, idx))
.collect::<Vec<_>>();
Series::try_from((ca.name(), chunks))
.unwrap()
.cast(&ca.inner_dtype())
// Safety: every element in list has dtype equal to its inner type
unsafe {
Series::try_from((ca.name(), chunks))
.unwrap()
.cast_unchecked(&ca.inner_dtype())
}
}

#[cfg(feature = "list_gather")]
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def test_list_arr_get() -> None:
) == {"lists": [None, None, 4]}


def test_list_categorical_get() -> None:
df = pl.DataFrame(
{
"actions": pl.Series(
[["a", "b"], ["c"], [None], None], dtype=pl.List(pl.Categorical)
),
}
)
expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical)
assert_series_equal(df["actions"].list.get(0), expected, categorical_as_str=True)


def test_contains() -> None:
a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]])
out = a.list.contains(2)
Expand Down

0 comments on commit 3651ecd

Please sign in to comment.