Skip to content

Commit

Permalink
perf(rust, python): optimize arr.min/arr.max (pola-rs#7050)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and josemasar committed Feb 21, 2023
1 parent d39c9c3 commit 7fb1ef0
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 11 deletions.
1 change: 1 addition & 0 deletions polars/polars-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ pub mod index;
pub mod is_valid;
pub mod kernels;
pub mod prelude;
pub mod slice;
pub mod trusted_len;
pub mod utils;
17 changes: 17 additions & 0 deletions polars/polars-arrow/src/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use crate::data_types::IsFloat;
use crate::kernels::rolling::{compare_fn_nan_max, compare_fn_nan_min};

pub trait ExtremaNanAware<T> {
fn min_value_nan_aware(&self) -> Option<&T>;
fn max_value_nan_aware(&self) -> Option<&T>;
}

impl<T: PartialOrd + IsFloat> ExtremaNanAware<T> for [T] {
fn min_value_nan_aware(&self) -> Option<&T> {
self.iter().min_by(|a, b| compare_fn_nan_max(*a, *b))
}

fn max_value_nan_aware(&self) -> Option<&T> {
self.iter().max_by(|a, b| compare_fn_nan_min(*a, *b))
}
}
223 changes: 223 additions & 0 deletions polars/polars-ops/src/chunked_array/list/min_max.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::Bitmap;
use arrow::types::NativeType;
use polars_arrow::array::PolarsArray;
use polars_arrow::data_types::{ArrayRef, IsFloat};
use polars_arrow::slice::ExtremaNanAware;
use polars_arrow::utils::CustomIterTools;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;

use crate::chunked_array::list::namespace::has_inner_nulls;

fn min_between_offsets<T>(values: &[T], offset: &[i64]) -> PrimitiveArray<T>
where
T: NativeType + PartialOrd + IsFloat,
{
let mut running_offset = offset[0];

(offset[1..])
.iter()
.map(|end| {
let current_offset = running_offset;
running_offset = *end;

let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) };
slice.min_value_nan_aware().copied()
})
.collect_trusted()
}

fn dispatch_min<T>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
where
T: NativeType + PartialOrd + IsFloat,
{
let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values = values.values().as_slice();
let mut out = min_between_offsets(values, offsets);

if let Some(validity) = validity {
if out.has_validity() {
out.apply_validity(|other_validity| validity & &other_validity)
} else {
out = out.with_validity(Some(validity.clone()));
}
}
Box::new(out)
}

fn min_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
use DataType::*;
let chunks = ca
.downcast_iter()
.map(|arr| {
let offsets = arr.offsets().as_slice();
let values = arr.values().as_ref();

match inner_type {
Int8 => dispatch_min::<i8>(values, offsets, arr.validity()),
Int16 => dispatch_min::<i16>(values, offsets, arr.validity()),
Int32 => dispatch_min::<i32>(values, offsets, arr.validity()),
Int64 => dispatch_min::<i64>(values, offsets, arr.validity()),
UInt8 => dispatch_min::<u8>(values, offsets, arr.validity()),
UInt16 => dispatch_min::<u16>(values, offsets, arr.validity()),
UInt32 => dispatch_min::<u32>(values, offsets, arr.validity()),
UInt64 => dispatch_min::<u64>(values, offsets, arr.validity()),
Float32 => dispatch_min::<f32>(values, offsets, arr.validity()),
Float64 => dispatch_min::<f64>(values, offsets, arr.validity()),
_ => unimplemented!(),
}
})
.collect::<Vec<_>>();

Series::try_from((ca.name(), chunks)).unwrap()
}

pub(super) fn list_min_function(ca: &ListChunked) -> Series {
fn inner(ca: &ListChunked) -> Series {
match ca.inner_dtype() {
DataType::Boolean => {
let out: IdxCa = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().bool().unwrap().min()))
.collect_trusted();
out.into_series()
}
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(dt, |$T| {
let out: ChunkedArray<$T> = ca
.amortized_iter()
.map(|opt_s|
{
let s = opt_s?;
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
ca.min()
}
)
.collect_trusted();
out.into_series()
})
}
_ => ca
.apply_amortized(|s| s.as_ref().min_as_series())
.explode()
.unwrap()
.into_series(),
}
}

if has_inner_nulls(ca) {
return inner(ca);
};

match ca.inner_dtype() {
dt if dt.is_numeric() => min_list_numerical(ca, &dt),
_ => inner(ca),
}
}

fn max_between_offsets<T>(values: &[T], offset: &[i64]) -> PrimitiveArray<T>
where
T: NativeType + PartialOrd + IsFloat,
{
let mut running_offset = offset[0];

(offset[1..])
.iter()
.map(|end| {
let current_offset = running_offset;
running_offset = *end;

let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) };
slice.max_value_nan_aware().copied()
})
.collect_trusted()
}

fn dispatch_max<T>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
where
T: NativeType + PartialOrd + IsFloat,
{
let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values = values.values().as_slice();
let mut out = max_between_offsets(values, offsets);

if let Some(validity) = validity {
if out.has_validity() {
out.apply_validity(|other_validity| validity & &other_validity)
} else {
out = out.with_validity(Some(validity.clone()));
}
}
Box::new(out)
}

fn max_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
use DataType::*;
let chunks = ca
.downcast_iter()
.map(|arr| {
let offsets = arr.offsets().as_slice();
let values = arr.values().as_ref();

match inner_type {
Int8 => dispatch_max::<i8>(values, offsets, arr.validity()),
Int16 => dispatch_max::<i16>(values, offsets, arr.validity()),
Int32 => dispatch_max::<i32>(values, offsets, arr.validity()),
Int64 => dispatch_max::<i64>(values, offsets, arr.validity()),
UInt8 => dispatch_max::<u8>(values, offsets, arr.validity()),
UInt16 => dispatch_max::<u16>(values, offsets, arr.validity()),
UInt32 => dispatch_max::<u32>(values, offsets, arr.validity()),
UInt64 => dispatch_max::<u64>(values, offsets, arr.validity()),
Float32 => dispatch_max::<f32>(values, offsets, arr.validity()),
Float64 => dispatch_max::<f64>(values, offsets, arr.validity()),
_ => unimplemented!(),
}
})
.collect::<Vec<_>>();

Series::try_from((ca.name(), chunks)).unwrap()
}

pub(super) fn list_max_function(ca: &ListChunked) -> Series {
fn inner(ca: &ListChunked) -> Series {
match ca.inner_dtype() {
DataType::Boolean => {
let out: IdxCa = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().bool().unwrap().max()))
.collect_trusted();
out.into_series()
}
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(dt, |$T| {
let out: ChunkedArray<$T> = ca
.amortized_iter()
.map(|opt_s|
{
let s = opt_s?;
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
ca.max()
}
)
.collect_trusted();
out.into_series()
})
}
_ => ca
.apply_amortized(|s| s.as_ref().max_as_series())
.explode()
.unwrap()
.into_series(),
}
}

if has_inner_nulls(ca) {
return inner(ca);
};

match ca.inner_dtype() {
dt if dt.is_numeric() => max_list_numerical(ca, &dt),
_ => inner(ca),
}
}
1 change: 1 addition & 0 deletions polars/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use polars_core::prelude::*;
mod count;
#[cfg(feature = "hash")]
pub(crate) mod hash;
mod min_max;
mod namespace;
mod sum_mean;
#[cfg(feature = "list_to_struct")]
Expand Down
15 changes: 4 additions & 11 deletions polars/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use polars_core::series::ops::NullBehavior;
use polars_core::utils::{try_get_supertype, CustomIterTools};

use super::*;
use crate::chunked_array::list::min_max::{list_max_function, list_min_function};
use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical};
use crate::series::ArgAgg;

fn has_inner_nulls(ca: &ListChunked) -> bool {
pub(super) fn has_inner_nulls(ca: &ListChunked) -> bool {
for arr in ca.downcast_iter() {
if arr.values().null_count() > 0 {
return true;
Expand Down Expand Up @@ -117,19 +118,11 @@ pub trait ListNameSpaceImpl: AsList {
}

fn lst_max(&self) -> Series {
let ca = self.as_list();
ca.apply_amortized(|s| s.as_ref().max_as_series())
.explode()
.unwrap()
.into_series()
list_max_function(self.as_list())
}

fn lst_min(&self) -> Series {
let ca = self.as_list();
ca.apply_amortized(|s| s.as_ref().min_as_series())
.explode()
.unwrap()
.into_series()
list_min_function(self.as_list())
}

fn lst_sum(&self) -> Series {
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,23 @@ def test_list_mean() -> None:
assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], None]}).select(
pl.col("a").arr.mean()
).to_dict(False) == {"a": [1.0, 2.0, 2.5, None]}


def test_list_min_max() -> None:
for dt in pl.datatypes.NUMERIC_DTYPES:
df = pl.DataFrame(
{"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]},
schema={"a": pl.List(dt)},
)
assert df.select(pl.col("a").arr.min())["a"].series_equal(
df.select(pl.col("a").arr.first())["a"]
)
assert df.select(pl.col("a").arr.max())["a"].series_equal(
df.select(pl.col("a").arr.last())["a"]
)

df = pl.DataFrame(
{"a": [[1], [1, 5, -1, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]},
)
assert df.select(pl.col("a").arr.min()).to_dict(False) == {"a": [1, -1, 1, 1, None]}
assert df.select(pl.col("a").arr.max()).to_dict(False) == {"a": [1, 5, 4, 5, None]}

0 comments on commit 7fb1ef0

Please sign in to comment.