From 5e1b200e30182baab8ad5e362a304ab6f374ea6c Mon Sep 17 00:00:00 2001 From: chielP Date: Fri, 26 Jan 2024 15:04:32 +0100 Subject: [PATCH 1/3] list get --- py-polars/tests/unit/namespaces/test_list.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 1e2a661d15d9..f22983078aa5 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -69,6 +69,13 @@ def test_list_arr_get() -> None: ) == {"lists": [None, None, 4]} +def test_list_categorical_get() -> None: + df = pl.DataFrame({ + 'actions': pl.Series([['a','b'], ['c'],[None], None], dtype=pl.List(pl.Categorical)), + }) + expected = pl.Series("actions",['a','c',None,None],dtype=pl.Categorical) + assert_series_equal(df['actions'].list.get(0),expected,categorical_as_str=True) + def test_contains() -> None: a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]]) out = a.list.contains(2) From 33f7ffecca24d9062dc511261d53612964d363c0 Mon Sep 17 00:00:00 2001 From: chielP Date: Fri, 26 Jan 2024 15:05:25 +0100 Subject: [PATCH 2/3] list get --- crates/polars-ops/src/chunked_array/list/namespace.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index fd6694d3f382..00bca56ec465 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -326,9 +326,12 @@ pub trait ListNameSpaceImpl: AsList { .downcast_iter() .map(|arr| sublist_get(arr, idx)) .collect::>(); - Series::try_from((ca.name(), chunks)) - .unwrap() - .cast(&ca.inner_dtype()) + // Safety: every element in list has dtype equal to its inner type + unsafe { + Series::try_from((ca.name(), chunks)) + .unwrap() + .cast_unchecked(&ca.inner_dtype()) + } } #[cfg(feature = "list_gather")] From 8326337dfa3bcf5343dfabf769ad4d2689e4c899 Mon Sep 17 00:00:00 2001 From: chielP Date: Fri, 26 Jan 2024 15:06:32 +0100 Subject: [PATCH 3/3] fmt --- py-polars/tests/unit/namespaces/test_list.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index f22983078aa5..04733cbedb4c 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -70,11 +70,16 @@ def test_list_arr_get() -> None: def test_list_categorical_get() -> None: - df = pl.DataFrame({ - 'actions': pl.Series([['a','b'], ['c'],[None], None], dtype=pl.List(pl.Categorical)), - }) - expected = pl.Series("actions",['a','c',None,None],dtype=pl.Categorical) - assert_series_equal(df['actions'].list.get(0),expected,categorical_as_str=True) + df = pl.DataFrame( + { + "actions": pl.Series( + [["a", "b"], ["c"], [None], None], dtype=pl.List(pl.Categorical) + ), + } + ) + expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical) + assert_series_equal(df["actions"].list.get(0), expected, categorical_as_str=True) + def test_contains() -> None: a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]])