From 7fb1ef01d757359bf92d92f105f37336a9fccb30 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 20 Feb 2023 17:23:22 +0100 Subject: [PATCH] perf(rust, python): optimize arr.min/arr.max (#7050) --- polars/polars-arrow/src/lib.rs | 1 + polars/polars-arrow/src/slice.rs | 17 ++ .../src/chunked_array/list/min_max.rs | 223 ++++++++++++++++++ .../polars-ops/src/chunked_array/list/mod.rs | 1 + .../src/chunked_array/list/namespace.rs | 15 +- py-polars/tests/unit/datatypes/test_list.py | 20 ++ 6 files changed, 266 insertions(+), 11 deletions(-) create mode 100644 polars/polars-arrow/src/slice.rs create mode 100644 polars/polars-ops/src/chunked_array/list/min_max.rs diff --git a/polars/polars-arrow/src/lib.rs b/polars/polars-arrow/src/lib.rs index 5ef306a9addc..9f87df3667aa 100644 --- a/polars/polars-arrow/src/lib.rs +++ b/polars/polars-arrow/src/lib.rs @@ -11,5 +11,6 @@ pub mod index; pub mod is_valid; pub mod kernels; pub mod prelude; +pub mod slice; pub mod trusted_len; pub mod utils; diff --git a/polars/polars-arrow/src/slice.rs b/polars/polars-arrow/src/slice.rs new file mode 100644 index 000000000000..6dc039ead714 --- /dev/null +++ b/polars/polars-arrow/src/slice.rs @@ -0,0 +1,17 @@ +use crate::data_types::IsFloat; +use crate::kernels::rolling::{compare_fn_nan_max, compare_fn_nan_min}; + +pub trait ExtremaNanAware { + fn min_value_nan_aware(&self) -> Option<&T>; + fn max_value_nan_aware(&self) -> Option<&T>; +} + +impl ExtremaNanAware for [T] { + fn min_value_nan_aware(&self) -> Option<&T> { + self.iter().min_by(|a, b| compare_fn_nan_max(*a, *b)) + } + + fn max_value_nan_aware(&self) -> Option<&T> { + self.iter().max_by(|a, b| compare_fn_nan_min(*a, *b)) + } +} diff --git a/polars/polars-ops/src/chunked_array/list/min_max.rs b/polars/polars-ops/src/chunked_array/list/min_max.rs new file mode 100644 index 000000000000..4292ceb9f1c1 --- /dev/null +++ b/polars/polars-ops/src/chunked_array/list/min_max.rs @@ -0,0 +1,223 @@ +use arrow::array::{Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::types::NativeType; +use polars_arrow::array::PolarsArray; +use polars_arrow::data_types::{ArrayRef, IsFloat}; +use polars_arrow::slice::ExtremaNanAware; +use polars_arrow::utils::CustomIterTools; +use polars_core::prelude::*; +use polars_core::with_match_physical_numeric_polars_type; + +use crate::chunked_array::list::namespace::has_inner_nulls; + +fn min_between_offsets(values: &[T], offset: &[i64]) -> PrimitiveArray +where + T: NativeType + PartialOrd + IsFloat, +{ + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + + let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; + slice.min_value_nan_aware().copied() + }) + .collect_trusted() +} + +fn dispatch_min(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + PartialOrd + IsFloat, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + let mut out = min_between_offsets(values, offsets); + + if let Some(validity) = validity { + if out.has_validity() { + out.apply_validity(|other_validity| validity & &other_validity) + } else { + out = out.with_validity(Some(validity.clone())); + } + } + Box::new(out) +} + +fn min_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_min::(values, offsets, arr.validity()), + Int16 => dispatch_min::(values, offsets, arr.validity()), + Int32 => dispatch_min::(values, offsets, arr.validity()), + Int64 => dispatch_min::(values, offsets, arr.validity()), + UInt8 => dispatch_min::(values, offsets, arr.validity()), + UInt16 => dispatch_min::(values, offsets, arr.validity()), + UInt32 => dispatch_min::(values, offsets, arr.validity()), + UInt64 => dispatch_min::(values, offsets, arr.validity()), + Float32 => dispatch_min::(values, offsets, arr.validity()), + Float64 => dispatch_min::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name(), chunks)).unwrap() +} + +pub(super) fn list_min_function(ca: &ListChunked) -> Series { + fn inner(ca: &ListChunked) -> Series { + match ca.inner_dtype() { + DataType::Boolean => { + let out: IdxCa = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().bool().unwrap().min())) + .collect_trusted(); + out.into_series() + } + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let out: ChunkedArray<$T> = ca + .amortized_iter() + .map(|opt_s| + { + let s = opt_s?; + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + ca.min() + } + ) + .collect_trusted(); + out.into_series() + }) + } + _ => ca + .apply_amortized(|s| s.as_ref().min_as_series()) + .explode() + .unwrap() + .into_series(), + } + } + + if has_inner_nulls(ca) { + return inner(ca); + }; + + match ca.inner_dtype() { + dt if dt.is_numeric() => min_list_numerical(ca, &dt), + _ => inner(ca), + } +} + +fn max_between_offsets(values: &[T], offset: &[i64]) -> PrimitiveArray +where + T: NativeType + PartialOrd + IsFloat, +{ + let mut running_offset = offset[0]; + + (offset[1..]) + .iter() + .map(|end| { + let current_offset = running_offset; + running_offset = *end; + + let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; + slice.max_value_nan_aware().copied() + }) + .collect_trusted() +} + +fn dispatch_max(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef +where + T: NativeType + PartialOrd + IsFloat, +{ + let values = arr.as_any().downcast_ref::>().unwrap(); + let values = values.values().as_slice(); + let mut out = max_between_offsets(values, offsets); + + if let Some(validity) = validity { + if out.has_validity() { + out.apply_validity(|other_validity| validity & &other_validity) + } else { + out = out.with_validity(Some(validity.clone())); + } + } + Box::new(out) +} + +fn max_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { + use DataType::*; + let chunks = ca + .downcast_iter() + .map(|arr| { + let offsets = arr.offsets().as_slice(); + let values = arr.values().as_ref(); + + match inner_type { + Int8 => dispatch_max::(values, offsets, arr.validity()), + Int16 => dispatch_max::(values, offsets, arr.validity()), + Int32 => dispatch_max::(values, offsets, arr.validity()), + Int64 => dispatch_max::(values, offsets, arr.validity()), + UInt8 => dispatch_max::(values, offsets, arr.validity()), + UInt16 => dispatch_max::(values, offsets, arr.validity()), + UInt32 => dispatch_max::(values, offsets, arr.validity()), + UInt64 => dispatch_max::(values, offsets, arr.validity()), + Float32 => dispatch_max::(values, offsets, arr.validity()), + Float64 => dispatch_max::(values, offsets, arr.validity()), + _ => unimplemented!(), + } + }) + .collect::>(); + + Series::try_from((ca.name(), chunks)).unwrap() +} + +pub(super) fn list_max_function(ca: &ListChunked) -> Series { + fn inner(ca: &ListChunked) -> Series { + match ca.inner_dtype() { + DataType::Boolean => { + let out: IdxCa = ca + .amortized_iter() + .map(|s| s.and_then(|s| s.as_ref().bool().unwrap().max())) + .collect_trusted(); + out.into_series() + } + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(dt, |$T| { + let out: ChunkedArray<$T> = ca + .amortized_iter() + .map(|opt_s| + { + let s = opt_s?; + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + ca.max() + } + ) + .collect_trusted(); + out.into_series() + }) + } + _ => ca + .apply_amortized(|s| s.as_ref().max_as_series()) + .explode() + .unwrap() + .into_series(), + } + } + + if has_inner_nulls(ca) { + return inner(ca); + }; + + match ca.inner_dtype() { + dt if dt.is_numeric() => max_list_numerical(ca, &dt), + _ => inner(ca), + } +} diff --git a/polars/polars-ops/src/chunked_array/list/mod.rs b/polars/polars-ops/src/chunked_array/list/mod.rs index 3c58fe4ee88b..57dbdd8262e5 100644 --- a/polars/polars-ops/src/chunked_array/list/mod.rs +++ b/polars/polars-ops/src/chunked_array/list/mod.rs @@ -3,6 +3,7 @@ use polars_core::prelude::*; mod count; #[cfg(feature = "hash")] pub(crate) mod hash; +mod min_max; mod namespace; mod sum_mean; #[cfg(feature = "list_to_struct")] diff --git a/polars/polars-ops/src/chunked_array/list/namespace.rs b/polars/polars-ops/src/chunked_array/list/namespace.rs index 52bd08f91c80..b742bcf7cb5b 100644 --- a/polars/polars-ops/src/chunked_array/list/namespace.rs +++ b/polars/polars-ops/src/chunked_array/list/namespace.rs @@ -13,10 +13,11 @@ use polars_core::series::ops::NullBehavior; use polars_core::utils::{try_get_supertype, CustomIterTools}; use super::*; +use crate::chunked_array::list::min_max::{list_max_function, list_min_function}; use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical}; use crate::series::ArgAgg; -fn has_inner_nulls(ca: &ListChunked) -> bool { +pub(super) fn has_inner_nulls(ca: &ListChunked) -> bool { for arr in ca.downcast_iter() { if arr.values().null_count() > 0 { return true; @@ -117,19 +118,11 @@ pub trait ListNameSpaceImpl: AsList { } fn lst_max(&self) -> Series { - let ca = self.as_list(); - ca.apply_amortized(|s| s.as_ref().max_as_series()) - .explode() - .unwrap() - .into_series() + list_max_function(self.as_list()) } fn lst_min(&self) -> Series { - let ca = self.as_list(); - ca.apply_amortized(|s| s.as_ref().min_as_series()) - .explode() - .unwrap() - .into_series() + list_min_function(self.as_list()) } fn lst_sum(&self) -> Series { diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index bdbe14ed929d..999747ee01b3 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -415,3 +415,23 @@ def test_list_mean() -> None: assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], None]}).select( pl.col("a").arr.mean() ).to_dict(False) == {"a": [1.0, 2.0, 2.5, None]} + + +def test_list_min_max() -> None: + for dt in pl.datatypes.NUMERIC_DTYPES: + df = pl.DataFrame( + {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, + schema={"a": pl.List(dt)}, + ) + assert df.select(pl.col("a").arr.min())["a"].series_equal( + df.select(pl.col("a").arr.first())["a"] + ) + assert df.select(pl.col("a").arr.max())["a"].series_equal( + df.select(pl.col("a").arr.last())["a"] + ) + + df = pl.DataFrame( + {"a": [[1], [1, 5, -1, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]}, + ) + assert df.select(pl.col("a").arr.min()).to_dict(False) == {"a": [1, -1, 1, 1, None]} + assert df.select(pl.col("a").arr.max()).to_dict(False) == {"a": [1, 5, 4, 5, None]}