From b74856114a4030f318baa8c5228ed7e881785594 Mon Sep 17 00:00:00 2001 From: James Edwards Date: Fri, 29 Mar 2024 12:39:25 -0400 Subject: [PATCH 1/6] Add null_on_oob option to expr.list.get, defaults to False --- .../src/chunked_array/list/namespace.rs | 10 +- .../src/chunked_array/list/to_struct.rs | 2 +- .../polars-plan/src/dsl/function_expr/list.rs | 43 +++++---- .../polars-plan/src/dsl/function_expr/mod.rs | 8 ++ crates/polars-plan/src/dsl/list.rs | 8 +- crates/polars-sql/src/functions.rs | 2 +- py-polars/polars/expr/list.py | 8 +- py-polars/src/expr/list.rs | 4 +- py-polars/tests/unit/datatypes/test_list.py | 2 +- .../tests/unit/namespaces/list/test_list.py | 92 ++++++++++++++++--- 10 files changed, 137 insertions(+), 42 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 38ca7732c40c..3f86728f3f61 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -342,8 +342,16 @@ pub trait ListNameSpaceImpl: AsList { /// So index `0` would return the first item of every sublist /// and index `-1` would return the last item of every sublist /// if an index is out of bounds, it will return a `None`. - fn lst_get(&self, idx: i64) -> PolarsResult { + fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult { let ca = self.as_list(); + if !null_on_oob && ca + .iter() + .any(|sublist| { + sublist.and_then(|s| idx.negative_to_usize(s.len()).map(|idx| idx as IdxSize)).is_none() + }) { + polars_bail!(ComputeError: "get index is out of bounds"); + } + let chunks = ca .downcast_iter() .map(|arr| sublist_get(arr, idx)) diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index c43cfda13024..4b74a76692ed 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -72,7 +72,7 @@ pub trait ToStruct: AsList { (0..n_fields) .into_par_iter() .map(|i| { - ca.lst_get(i as i64).map(|mut s| { + ca.lst_get(i as i64, true).map(|mut s| { s.rename(&name_generator(i)); s }) diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index b9dafcf9e305..57ea3d18b763 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -21,7 +21,7 @@ pub enum ListFunction { }, Slice, Shift, - Get, + Get(bool), #[cfg(feature = "list_gather")] Gather(bool), #[cfg(feature = "list_gather")] @@ -71,7 +71,7 @@ impl ListFunction { Sample { .. } => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), Shift => mapper.with_same_dtype(), - Get => mapper.map_to_list_and_array_inner_dtype(), + Get(_) => mapper.map_to_list_and_array_inner_dtype(), #[cfg(feature = "list_gather")] Gather(_) => mapper.with_same_dtype(), #[cfg(feature = "list_gather")] @@ -136,7 +136,7 @@ impl Display for ListFunction { }, Slice => "slice", Shift => "shift", - Get => "get", + Get(_) => "get", #[cfg(feature = "list_gather")] Gather(_) => "gather", #[cfg(feature = "list_gather")] @@ -203,9 +203,9 @@ impl From for SpecialEq> { }, Slice => wrap!(slice), Shift => map_as_slice!(shift), - Get => wrap!(get), + Get(null_on_oob) => wrap!(get, null_on_oob), #[cfg(feature = "list_gather")] - Gather(null_ob_oob) => map_as_slice!(gather, null_ob_oob), + Gather(null_on_oob) => map_as_slice!(gather, null_on_oob), #[cfg(feature = "list_gather")] GatherEvery => map_as_slice!(gather_every), #[cfg(feature = "list_count")] @@ -414,7 +414,7 @@ pub(super) fn concat(s: &mut [Series]) -> PolarsResult> { first_ca.lst_concat(other).map(|ca| Some(ca.into_series())) } -pub(super) fn get(s: &mut [Series]) -> PolarsResult> { +pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult> { let ca = s[0].list()?; let index = s[1].cast(&DataType::Int64)?; let index = index.i64().unwrap(); @@ -423,7 +423,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { 1 => { let index = index.get(0); if let Some(index) = index { - ca.lst_get(index).map(Some) + ca.lst_get(index, null_on_oob).map(Some) } else { Ok(Some(Series::full_null( ca.name(), @@ -441,18 +441,25 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { .into_iter() .enumerate() .map(|(i, opt_idx)| { - opt_idx.and_then(|idx| { - let (start, end) = - unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) }; - let offset = if idx >= 0 { start + idx } else { end + idx }; - if offset >= end || offset < start || start == end { - None - } else { - Some(offset as IdxSize) + match opt_idx { + Some(idx) => { + let (start, end) = + unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) }; + let offset = if idx >= 0 { start + idx } else { end + idx }; + if offset >= end || offset < start || start == end { + if null_on_oob { + Ok(None) + } else { + polars_bail!(ComputeError: "get index is out of bounds"); + } + } else { + Ok(Some(offset as IdxSize)) + } } - }) + None => Ok(None) + } }) - .collect::(); + .collect::>()?; let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); unsafe { s.take_unchecked(&take_by) } .cast(&ca.inner_dtype()) @@ -475,7 +482,7 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult if idx.len() == 1 && null_on_oob { // fast path let idx = idx.get(0)?.try_extract::()?; - let out = ca.lst_get(idx)?; + let out = ca.lst_get(idx, null_on_oob)?; // make sure we return a list out.reshape(&[-1, 1]) } else { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 397959e9980f..5ff04aee0098 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -718,6 +718,14 @@ macro_rules! wrap { ($e:expr) => { SpecialEq::new(Arc::new($e)) }; + + ($e:expr, $($args:expr),*) => {{ + let f = move |s: &mut [Series]| { + $e(s, $($args),*) + }; + + SpecialEq::new(Arc::new(f)) + }}; } // Fn(&[Series], args) diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 603ec2553590..7e299833e131 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -151,9 +151,9 @@ impl ListNameSpace { } /// Get items in every sublist by index. - pub fn get(self, index: Expr) -> Expr { + pub fn get(self, index: Expr, null_on_oob: bool) -> Expr { self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::Get), + FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)), &[index], false, false, @@ -187,12 +187,12 @@ impl ListNameSpace { /// Get first item of every sublist. pub fn first(self) -> Expr { - self.get(lit(0i64)) + self.get(lit(0i64), false) } /// Get last item of every sublist. pub fn last(self) -> Expr { - self.get(lit(-1i64)) + self.get(lit(-1i64), false) } /// Join all string items in a sublist and place a separator between them. diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 2149912be665..7ef6d2fe221c 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> { // Array functions // ---- ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), - ArrayGet => self.visit_binary(|e, i| e.list().get(i)), + ArrayGet => self.visit_binary(|e, i| e.list().get(i, false)), ArrayLength => self.visit_unary(|e| e.list().len()), ArrayMax => self.visit_unary(|e| e.list().max()), ArrayMean => self.visit_unary(|e| e.list().mean()), diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 474e586a1c62..6518fd0e4bee 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -505,7 +505,11 @@ def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> E other_list.insert(0, wrap_expr(self._pyexpr)) return F.concat_list(other_list) - def get(self, index: int | Expr | str) -> Expr: + def get(self, + index: int | Expr | str, + *, + null_on_oob: bool = False, + ) -> Expr: """ Get the value by index in the sublists. @@ -534,7 +538,7 @@ def get(self, index: int | Expr | str) -> Expr: └───────────┴──────┘ """ index = parse_as_expression(index) - return wrap_expr(self._pyexpr.list_get(index)) + return wrap_expr(self._pyexpr.list_get(index, null_on_oob)) def gather( self, diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index fde544a6ce41..25744b38d741 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -44,8 +44,8 @@ impl PyExpr { self.inner.clone().list().eval(expr.inner, parallel).into() } - fn list_get(&self, index: PyExpr) -> Self { - self.inner.clone().list().get(index.inner).into() + fn list_get(&self, index: PyExpr, null_on_oob: bool) -> Self { + self.inner.clone().list().get(index.inner, null_on_oob).into() } fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index ba580ea8e6ba..fe0028d5067c 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -781,7 +781,7 @@ def test_list_gather_null_struct_14927() -> None: {"index": [1], "col_0": [None], "field_0": [None]}, schema={**df.schema, "field_0": pl.Float64}, ) - expr = pl.col("col_0").list.get(0).struct.field("field_0") + expr = pl.col("col_0").list.get(0, null_on_oob=True).struct.field("field_0") out = df.filter(pl.col("index") > 0).with_columns(expr) assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py index 570716e14fe5..cea176ad7871 100644 --- a/py-polars/tests/unit/namespaces/list/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -21,7 +21,7 @@ def test_list_arr_get() -> None: assert_series_equal(out, expected) out = pl.select(pl.lit(a).list.first()).to_series() assert_series_equal(out, expected) - + out = a.list.get(-1) expected = pl.Series("a", [3, 5, 9]) assert_series_equal(out, expected) @@ -30,25 +30,93 @@ def test_list_arr_get() -> None: out = pl.select(pl.lit(a).list.last()).to_series() assert_series_equal(out, expected) + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + a.list.get(3) + + b = pl.Series("b", [[1, 2, 3], [4], []]) + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + b.list.first() + + # Null index. + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None))) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + a.list.get(-3) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + pl.DataFrame( + {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} + ).with_columns( + [pl.col("a").list.get(i).alias(f"get_{i}") for i in range(4)] + ).to_dict(as_series=False) + + # get by indexes where some are out of bounds + df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False) + + # exact on oob boundary + df = pl.DataFrame( + { + "index": [3, 3, 3], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) == { + "lists": [None, None, 4] + } + df.select(pl.col("lists").list.get(pl.col("index"))).to_dict( + as_series=False + ) + + +def test_list_arr_get_null_on_oob() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + out = a.list.get(0, null_on_oob=True) + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list[0] + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list.first() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.first()).to_series() + assert_series_equal(out, expected) + + out = a.list.get(-1, null_on_oob=True) + expected = pl.Series("a", [3, 5, 9]) + assert_series_equal(out, expected) + out = a.list.last() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.last()).to_series() + assert_series_equal(out, expected) + # Out of bounds index. - out = a.list.get(3) + out = a.list.get(3, null_on_oob=True) expected = pl.Series("a", [None, None, 9]) assert_series_equal(out, expected) # Null index. - out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None))) + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=True)) expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() assert_frame_equal(out_df, expected_df) a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(-3) + out = a.list.get(-3, null_on_oob=True) expected = pl.Series("a", [1, None, 7]) assert_series_equal(out, expected) assert pl.DataFrame( {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} ).with_columns( - [pl.col("a").list.get(i).alias(f"get_{i}") for i in range(4)] + [pl.col("a").list.get(i, null_on_oob=True).alias(f"get_{i}") for i in range(4)] ).to_dict(as_series=False) == { "a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]], "get_0": [1, 2, 3, 4, 7, None], @@ -60,7 +128,7 @@ def test_list_arr_get() -> None: # get by indexes where some are out of bounds df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) - assert df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False) == { + assert df.select([pl.col("cars").list.get("indexes", null_on_oob=True)]).to_dict(as_series=False) == { "cars": [2, 3, None, None] } # exact on oob boundary @@ -71,10 +139,10 @@ def test_list_arr_get() -> None: } ) - assert df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) == { + assert df.select(pl.col("lists").list.get(3, null_on_oob=True)).to_dict(as_series=False) == { "lists": [None, None, 4] } - assert df.select(pl.col("lists").list.get(pl.col("index"))).to_dict( + assert df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=True)).to_dict( as_series=False ) == {"lists": [None, None, 4]} @@ -88,7 +156,7 @@ def test_list_categorical_get() -> None: } ) expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical) - assert_series_equal(df["actions"].list.get(0), expected, categorical_as_str=True) + assert_series_equal(df["actions"].list.get(0, null_on_oob=True), expected, categorical_as_str=True) def test_contains() -> None: @@ -156,8 +224,8 @@ def test_list_arr_empty() -> None: out = df.select( [ - pl.col("cars").list.first().alias("cars_first"), - pl.when(pl.col("cars").list.first() == 2) + pl.col("cars").list.get(0, null_on_oob=True).alias("cars_first"), + pl.when(pl.col("cars").list.get(0, null_on_oob=True) == 2) .then(1) .when(pl.col("cars").list.contains(2)) .then(2) @@ -597,7 +665,7 @@ def test_select_from_list_to_struct_11143() -> None: def test_list_arr_get_8810() -> None: assert pl.DataFrame(pl.Series("a", [None], pl.List(pl.Int64))).select( - pl.col("a").list.get(0) + pl.col("a").list.get(0, null_on_oob=True) ).to_dict(as_series=False) == {"a": [None]} From 3b2b94dbfe27eaa64fbfcb3d5500b9c7c32fd2f3 Mon Sep 17 00:00:00 2001 From: James Edwards Date: Fri, 29 Mar 2024 13:18:04 -0400 Subject: [PATCH 2/6] Adding null_on_oob to expt.list.get (some changes got left behind first time) --- .../src/chunked_array/list/namespace.rs | 12 ++++---- .../polars-plan/src/dsl/function_expr/list.rs | 28 +++++++++---------- py-polars/src/expr/list.rs | 6 +++- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 3f86728f3f61..8fc6e07892e5 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -344,11 +344,13 @@ pub trait ListNameSpaceImpl: AsList { /// if an index is out of bounds, it will return a `None`. fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult { let ca = self.as_list(); - if !null_on_oob && ca - .iter() - .any(|sublist| { - sublist.and_then(|s| idx.negative_to_usize(s.len()).map(|idx| idx as IdxSize)).is_none() - }) { + if !null_on_oob + && ca.iter().any(|sublist| { + sublist + .and_then(|s| idx.negative_to_usize(s.len()).map(|idx| idx as IdxSize)) + .is_none() + }) + { polars_bail!(ComputeError: "get index is out of bounds"); } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 57ea3d18b763..fe35e4f409bf 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -440,24 +440,22 @@ pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult { - let (start, end) = - unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) }; - let offset = if idx >= 0 { start + idx } else { end + idx }; - if offset >= end || offset < start || start == end { - if null_on_oob { - Ok(None) - } else { - polars_bail!(ComputeError: "get index is out of bounds"); - } + .map(|(i, opt_idx)| match opt_idx { + Some(idx) => { + let (start, end) = + unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) }; + let offset = if idx >= 0 { start + idx } else { end + idx }; + if offset >= end || offset < start || start == end { + if null_on_oob { + Ok(None) } else { - Ok(Some(offset as IdxSize)) + polars_bail!(ComputeError: "get index is out of bounds"); } + } else { + Ok(Some(offset as IdxSize)) } - None => Ok(None) - } + }, + None => Ok(None), }) .collect::>()?; let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 25744b38d741..b00476c7bb3a 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -45,7 +45,11 @@ impl PyExpr { } fn list_get(&self, index: PyExpr, null_on_oob: bool) -> Self { - self.inner.clone().list().get(index.inner, null_on_oob).into() + self.inner + .clone() + .list() + .get(index.inner, null_on_oob) + .into() } fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { From 81ca0f838b0ff1af8e3577f4bd8c493c05f82763 Mon Sep 17 00:00:00 2001 From: James Edwards Date: Fri, 29 Mar 2024 14:25:22 -0400 Subject: [PATCH 3/6] Fixed docstrings and converted first() and last() to null_on_oob=True for convenience --- crates/polars-plan/src/dsl/list.rs | 4 +- py-polars/polars/expr/list.py | 21 +++++---- py-polars/polars/series/list.py | 13 +++++- .../tests/unit/namespaces/list/test_list.py | 46 +++++++++---------- 4 files changed, 47 insertions(+), 37 deletions(-) diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 7e299833e131..0f6c15c755e7 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -187,12 +187,12 @@ impl ListNameSpace { /// Get first item of every sublist. pub fn first(self) -> Expr { - self.get(lit(0i64), false) + self.get(lit(0i64), true) } /// Get last item of every sublist. pub fn last(self) -> Expr { - self.get(lit(-1i64), false) + self.get(lit(-1i64), true) } /// Join all string items in a sublist and place a separator between them. diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 6518fd0e4bee..c9c98d75ae55 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -505,11 +505,12 @@ def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> E other_list.insert(0, wrap_expr(self._pyexpr)) return F.concat_list(other_list) - def get(self, - index: int | Expr | str, - *, - null_on_oob: bool = False, - ) -> Expr: + def get( + self, + index: int | Expr | str, + *, + null_on_oob: bool = False, + ) -> Expr: """ Get the value by index in the sublists. @@ -521,11 +522,15 @@ def get(self, ---------- index Index to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error Examples -------- >>> df = pl.DataFrame({"a": [[3, 2, 1], [], [1, 2]]}) - >>> df.with_columns(get=pl.col("a").list.get(0)) + >>> df.with_columns(get=pl.col("a").list.get(0, null_on_oob=True)) shape: (3, 2) ┌───────────┬──────┐ │ a ┆ get │ @@ -645,7 +650,7 @@ def first(self) -> Expr: │ [1, 2] ┆ 1 │ └───────────┴───────┘ """ - return self.get(0) + return self.get(0, null_on_oob=True) def last(self) -> Expr: """ @@ -666,7 +671,7 @@ def last(self) -> Expr: │ [1, 2] ┆ 2 │ └───────────┴──────┘ """ - return self.get(-1) + return self.get(-1, null_on_oob=True) def contains( self, item: float | str | bool | int | date | datetime | time | IntoExprColumn diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 610bb4595f88..8db895d003ab 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -347,7 +347,12 @@ def concat(self, other: list[Series] | Series | list[Any]) -> Series: ] """ - def get(self, index: int | Series | list[int]) -> Series: + def get( + self, + index: int | Series | list[int], + *, + null_on_oob: bool = False, + ) -> Series: """ Get the value by index in the sublists. @@ -359,11 +364,15 @@ def get(self, index: int | Series | list[int]) -> Series: ---------- index Index to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error Examples -------- >>> s = pl.Series("a", [[3, 2, 1], [], [1, 2]]) - >>> s.list.get(0) + >>> s.list.get(0, null_on_oob=True) shape: (3,) Series: 'a' [i64] [ diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py index cea176ad7871..fc0e5d469e7c 100644 --- a/py-polars/tests/unit/namespaces/list/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -21,7 +21,7 @@ def test_list_arr_get() -> None: assert_series_equal(out, expected) out = pl.select(pl.lit(a).list.first()).to_series() assert_series_equal(out, expected) - + out = a.list.get(-1) expected = pl.Series("a", [3, 5, 9]) assert_series_equal(out, expected) @@ -32,10 +32,6 @@ def test_list_arr_get() -> None: with pytest.raises(pl.ComputeError, match="get index is out of bounds"): a.list.get(3) - - b = pl.Series("b", [[1, 2, 3], [4], []]) - with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - b.list.first() # Null index. out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None))) @@ -43,7 +39,7 @@ def test_list_arr_get() -> None: assert_frame_equal(out_df, expected_df) a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): a.list.get(-3) @@ -59,7 +55,7 @@ def test_list_arr_get() -> None: with pytest.raises(pl.ComputeError, match="get index is out of bounds"): df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False) - + # exact on oob boundary df = pl.DataFrame( { @@ -69,17 +65,15 @@ def test_list_arr_get() -> None: ) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) == { - "lists": [None, None, 4] - } - df.select(pl.col("lists").list.get(pl.col("index"))).to_dict( - as_series=False - ) - + df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(pl.col("index"))).to_dict(as_series=False) + def test_list_arr_get_null_on_oob() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(0, null_on_oob=True) + out = a.list.first() expected = pl.Series("a", [1, 4, 6]) assert_series_equal(out, expected) out = a.list[0] @@ -128,9 +122,9 @@ def test_list_arr_get_null_on_oob() -> None: # get by indexes where some are out of bounds df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) - assert df.select([pl.col("cars").list.get("indexes", null_on_oob=True)]).to_dict(as_series=False) == { - "cars": [2, 3, None, None] - } + assert df.select([pl.col("cars").list.get("indexes", null_on_oob=True)]).to_dict( + as_series=False + ) == {"cars": [2, 3, None, None]} # exact on oob boundary df = pl.DataFrame( { @@ -139,12 +133,12 @@ def test_list_arr_get_null_on_oob() -> None: } ) - assert df.select(pl.col("lists").list.get(3, null_on_oob=True)).to_dict(as_series=False) == { - "lists": [None, None, 4] - } - assert df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=True)).to_dict( + assert df.select(pl.col("lists").list.get(3, null_on_oob=True)).to_dict( as_series=False ) == {"lists": [None, None, 4]} + assert df.select( + pl.col("lists").list.get(pl.col("index"), null_on_oob=True) + ).to_dict(as_series=False) == {"lists": [None, None, 4]} def test_list_categorical_get() -> None: @@ -156,7 +150,9 @@ def test_list_categorical_get() -> None: } ) expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical) - assert_series_equal(df["actions"].list.get(0, null_on_oob=True), expected, categorical_as_str=True) + assert_series_equal( + df["actions"].list.get(0, null_on_oob=True), expected, categorical_as_str=True + ) def test_contains() -> None: @@ -224,8 +220,8 @@ def test_list_arr_empty() -> None: out = df.select( [ - pl.col("cars").list.get(0, null_on_oob=True).alias("cars_first"), - pl.when(pl.col("cars").list.get(0, null_on_oob=True) == 2) + pl.col("cars").list.first().alias("cars_first"), + pl.when(pl.col("cars").list.first() == 2) .then(1) .when(pl.col("cars").list.contains(2)) .then(2) From 8c0c162d877c7176c29684b084718f6707878d35 Mon Sep 17 00:00:00 2001 From: James Edwards Date: Sat, 30 Mar 2024 11:53:11 -0400 Subject: [PATCH 4/6] Use offset length iterator to make oob check more efficient --- crates/polars-arrow/src/legacy/kernels/list.rs | 7 +++++++ crates/polars-ops/src/chunked_array/list/namespace.rs | 10 ++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index e67d1638e99d..46c339323b1b 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -75,6 +75,13 @@ pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { unsafe { take_unchecked(&**values, &take_by) } } +/// Check if an index is out of bounds for at least one sublist. +pub fn index_is_oob(arr: &ListArray, index: i64) -> bool { + arr.offsets() + .lengths() + .any(|len| index.negative_to_usize(len).is_none()) +} + /// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]` pub fn array_to_unit_list(array: ArrayRef) -> ListArray { let len = array.len(); diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 8fc6e07892e5..f7e63ce0956e 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -1,7 +1,7 @@ use std::fmt::Write; use arrow::array::ValueSize; -use arrow::legacy::kernels::list::sublist_get; +use arrow::legacy::kernels::list::{index_is_oob, sublist_get}; use polars_core::chunked_array::builder::get_list_builder; #[cfg(feature = "list_gather")] use polars_core::export::num::ToPrimitive; @@ -344,13 +344,7 @@ pub trait ListNameSpaceImpl: AsList { /// if an index is out of bounds, it will return a `None`. fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult { let ca = self.as_list(); - if !null_on_oob - && ca.iter().any(|sublist| { - sublist - .and_then(|s| idx.negative_to_usize(s.len()).map(|idx| idx as IdxSize)) - .is_none() - }) - { + if !null_on_oob && ca.downcast_iter().any(|arr| index_is_oob(arr, idx)) { polars_bail!(ComputeError: "get index is out of bounds"); } From f1eb90853dbc15792ffb674df7053305c995ca76 Mon Sep 17 00:00:00 2001 From: James Edwards Date: Sun, 31 Mar 2024 12:10:26 -0400 Subject: [PATCH 5/6] Changed default behavior to True to avoid breaking change. --- crates/polars-sql/src/functions.rs | 2 +- py-polars/polars/expr/list.py | 8 +++--- .../tests/unit/namespaces/list/test_list.py | 27 +++++++++++-------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 7ef6d2fe221c..6cfd5263c416 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> { // Array functions // ---- ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), - ArrayGet => self.visit_binary(|e, i| e.list().get(i, false)), + ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)), ArrayLength => self.visit_unary(|e| e.list().len()), ArrayMax => self.visit_unary(|e| e.list().max()), ArrayMean => self.visit_unary(|e| e.list().mean()), diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index c9c98d75ae55..3c827794ffdb 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -509,7 +509,7 @@ def get( self, index: int | Expr | str, *, - null_on_oob: bool = False, + null_on_oob: bool = True, ) -> Expr: """ Get the value by index in the sublists. @@ -530,7 +530,7 @@ def get( Examples -------- >>> df = pl.DataFrame({"a": [[3, 2, 1], [], [1, 2]]}) - >>> df.with_columns(get=pl.col("a").list.get(0, null_on_oob=True)) + >>> df.with_columns(get=pl.col("a").list.get(0)) shape: (3, 2) ┌───────────┬──────┐ │ a ┆ get │ @@ -650,7 +650,7 @@ def first(self) -> Expr: │ [1, 2] ┆ 1 │ └───────────┴───────┘ """ - return self.get(0, null_on_oob=True) + return self.get(0) def last(self) -> Expr: """ @@ -671,7 +671,7 @@ def last(self) -> Expr: │ [1, 2] ┆ 2 │ └───────────┴──────┘ """ - return self.get(-1, null_on_oob=True) + return self.get(-1) def contains( self, item: float | str | bool | int | date | datetime | time | IntoExprColumn diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py index fc0e5d469e7c..40dc3561598c 100644 --- a/py-polars/tests/unit/namespaces/list/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -11,7 +11,7 @@ def test_list_arr_get() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(0) + out = a.list.get(0, null_on_oob=False) expected = pl.Series("a", [1, 4, 6]) assert_series_equal(out, expected) out = a.list[0] @@ -22,7 +22,7 @@ def test_list_arr_get() -> None: out = pl.select(pl.lit(a).list.first()).to_series() assert_series_equal(out, expected) - out = a.list.get(-1) + out = a.list.get(-1, null_on_oob=False) expected = pl.Series("a", [3, 5, 9]) assert_series_equal(out, expected) out = a.list.last() @@ -31,30 +31,35 @@ def test_list_arr_get() -> None: assert_series_equal(out, expected) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - a.list.get(3) + a.list.get(3, null_on_oob=False) # Null index. - out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None))) + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=False)) expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() assert_frame_equal(out_df, expected_df) a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - a.list.get(-3) + a.list.get(-3, null_on_oob=False) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): pl.DataFrame( {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} ).with_columns( - [pl.col("a").list.get(i).alias(f"get_{i}") for i in range(4)] - ).to_dict(as_series=False) + [ + pl.col("a").list.get(i, null_on_oob=False).alias(f"get_{i}") + for i in range(4) + ] + ) # get by indexes where some are out of bounds df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False) + df.select([pl.col("cars").list.get("indexes", null_on_oob=False)]).to_dict( + as_series=False + ) # exact on oob boundary df = pl.DataFrame( @@ -65,15 +70,15 @@ def test_list_arr_get() -> None: ) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) + df.select(pl.col("lists").list.get(3, null_on_oob=False)) with pytest.raises(pl.ComputeError, match="get index is out of bounds"): - df.select(pl.col("lists").list.get(pl.col("index"))).to_dict(as_series=False) + df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=False)) def test_list_arr_get_null_on_oob() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.first() + out = a.list.get(0, null_on_oob=True) expected = pl.Series("a", [1, 4, 6]) assert_series_equal(out, expected) out = a.list[0] From 80041742318f93a535635e59c7725eb7a8b85b02 Mon Sep 17 00:00:00 2001 From: James Edwards Date: Sun, 31 Mar 2024 12:16:25 -0400 Subject: [PATCH 6/6] Changing default to null_on_oob=True (missed one) --- py-polars/polars/series/list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 8db895d003ab..88936360d4a6 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -351,7 +351,7 @@ def get( self, index: int | Series | list[int], *, - null_on_oob: bool = False, + null_on_oob: bool = True, ) -> Series: """ Get the value by index in the sublists.