diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index 5aa20a786ea77..521044368caa7 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -205,6 +205,10 @@ impl BinaryViewArrayGeneric { &self.views } + pub fn into_views(self) -> Vec { + self.views.make_mut() + } + pub fn try_new( data_type: ArrowDataType, views: Buffer, @@ -265,28 +269,8 @@ impl BinaryViewArrayGeneric { /// Assumes that the `i < self.len`. #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> &T { - let v = *self.views.get_unchecked_release(i); - let len = v.length; - - // view layout: - // length: 4 bytes - // prefix: 4 bytes - // buffer_index: 4 bytes - // offset: 4 bytes - - // inlined layout: - // length: 4 bytes - // data: 12 bytes - - let bytes = if len <= 12 { - let ptr = self.views.as_ptr() as *const u8; - std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize) - } else { - let data = self.buffers.get_unchecked_release(v.buffer_idx as usize); - let offset = v.offset as usize; - data.get_unchecked_release(offset..offset + len as usize) - }; - T::from_bytes_unchecked(bytes) + let v = self.views.get_unchecked_release(i); + T::from_bytes_unchecked(v.get_slice_unchecked(&self.buffers)) } /// Returns an iterator of `Option<&T>` over every element of this array. diff --git a/crates/polars-arrow/src/array/binview/mutable.rs b/crates/polars-arrow/src/array/binview/mutable.rs index ec29bd2b60e3a..25482754337aa 100644 --- a/crates/polars-arrow/src/array/binview/mutable.rs +++ b/crates/polars-arrow/src/array/binview/mutable.rs @@ -343,24 +343,30 @@ impl MutableBinaryViewArray { /// Assumes that the `i < self.len`. #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> &T { - let v = *self.views.get_unchecked(i); - let len = v.length; + self.value_from_view_unchecked(self.views.get_unchecked(i)) + } - // view layout: + /// Returns the element indicated by the given view. + /// + /// # Safety + /// Assumes the View belongs to this MutableBinaryViewArray. + pub unsafe fn value_from_view_unchecked<'a>(&'a self, view: &'a View) -> &'a T { + // View layout: // length: 4 bytes // prefix: 4 bytes // buffer_index: 4 bytes // offset: 4 bytes - // inlined layout: + // Inlined layout: // length: 4 bytes // data: 12 bytes + let len = view.length; let bytes = if len <= 12 { - let ptr = self.views.as_ptr() as *const u8; - std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize) + let ptr = view as *const View as *const u8; + std::slice::from_raw_parts(ptr.add(4), len as usize) } else { - let buffer_idx = v.buffer_idx as usize; - let offset = v.offset; + let buffer_idx = view.buffer_idx as usize; + let offset = view.offset; let data = if buffer_idx == self.completed_buffers.len() { self.in_progress_buffer.as_slice() diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 4975930ee7445..ccb771d2417dc 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -57,6 +57,23 @@ impl View { } } } + + /// Constructs a byteslice from this view. + /// + /// # Safety + /// Assumes that this view is valid for the given buffers. + pub unsafe fn get_slice_unchecked<'a>(&'a self, buffers: &'a [Buffer]) -> &'a [u8] { + unsafe { + if self.length <= 12 { + let ptr = self as *const View as *const u8; + std::slice::from_raw_parts(ptr.add(4), self.length as usize) + } else { + let data = buffers.get_unchecked_release(self.buffer_idx as usize); + let offset = self.offset as usize; + data.get_unchecked_release(offset..offset + self.length as usize) + } + } + } } impl IsNull for View { diff --git a/crates/polars-compute/src/filter/mod.rs b/crates/polars-compute/src/filter/mod.rs index 38cce5c103d29..2ac66243fb8e5 100644 --- a/crates/polars-compute/src/filter/mod.rs +++ b/crates/polars-compute/src/filter/mod.rs @@ -9,53 +9,50 @@ mod avx512; use arrow::array::growable::make_growable; use arrow::array::{new_empty_array, Array, BinaryViewArray, BooleanArray, PrimitiveArray}; use arrow::bitmap::utils::SlicesIterator; -use arrow::datatypes::ArrowDataType; +use arrow::bitmap::Bitmap; use arrow::with_match_primitive_type_full; -use polars_error::PolarsResult; -pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult> { +pub fn filter(array: &dyn Array, mask: &BooleanArray) -> Box { assert_eq!(array.len(), mask.len()); // Treat null mask values as false. if let Some(validities) = mask.validity() { - let values = mask.values(); - let new_values = values & validities; - let mask = BooleanArray::new(ArrowDataType::Boolean, new_values, None); - return filter(array, &mask); + let combined_mask = mask.values() & validities; + filter_with_bitmap(array, &combined_mask) + } else { + filter_with_bitmap(array, mask.values()) } +} +pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box { // Fast-path: completely empty or completely full mask. - let false_count = mask.values().unset_bits(); + let false_count = mask.unset_bits(); if false_count == mask.len() { - return Ok(new_empty_array(array.data_type().clone())); + return new_empty_array(array.data_type().clone()); } if false_count == 0 { - return Ok(array.to_boxed()); + return array.to_boxed(); } use arrow::datatypes::PhysicalType::*; match array.data_type().to_physical_type() { Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap(); - let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask.values()); - Ok(Box::new(PrimitiveArray::from_vec(values).with_validity(validity))) + let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask); + Box::new(PrimitiveArray::from_vec(values).with_validity(validity)) }), Boolean => { let array = array.as_any().downcast_ref::().unwrap(); - let (values, validity) = boolean::filter_bitmap_and_validity( - array.values(), - array.validity(), - mask.values(), - ); - Ok(BooleanArray::new(array.data_type().clone(), values, validity).boxed()) + let (values, validity) = + boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask); + BooleanArray::new(array.data_type().clone(), values, validity).boxed() }, BinaryView => { let array = array.as_any().downcast_ref::().unwrap(); let views = array.views(); let validity = array.validity(); - let (views, validity) = - primitive::filter_values_and_validity(views, validity, mask.values()); - Ok(unsafe { + let (views, validity) = primitive::filter_values_and_validity(views, validity, mask); + unsafe { BinaryViewArray::new_unchecked_unknown_md( array.data_type().clone(), views.into(), @@ -64,19 +61,19 @@ pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult { unreachable!() }, _ => { - let iter = SlicesIterator::new(mask.values()); + let iter = SlicesIterator::new(mask); let mut mutable = make_growable(&[array], false, iter.slots()); // SAFETY: // we are in bounds iter.for_each(|(start, len)| unsafe { mutable.extend(0, start, len) }); - Ok(mutable.as_box()) + mutable.as_box() }, } } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 210a09966ef4f..af4892890336f 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -104,6 +104,14 @@ where unsafe { Self::from_chunks(name, vec![Box::new(arr)]) } } + pub fn with_chunk_like(ca: &Self, arr: A) -> Self + where + A: Array, + T: PolarsDataType, + { + Self::from_chunk_iter_like(ca, std::iter::once(arr)) + } + pub fn from_chunk_iter(name: &str, iter: I) -> Self where I: IntoIterator, @@ -165,12 +173,14 @@ where }) .collect(); - ChunkedArray::new_with_dims( - field, - chunks, - length.try_into().expect(LENGTH_LIMIT_MSG), - null_count as IdxSize, - ) + unsafe { + ChunkedArray::new_with_dims( + field, + chunks, + length.try_into().expect(LENGTH_LIMIT_MSG), + null_count as IdxSize, + ) + } } /// Create a new [`ChunkedArray`] from existing chunks. diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index 33191cfafd3c5..0108742ef743f 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -64,7 +64,7 @@ impl DerefMut for Logical { } impl Logical { - pub(crate) fn new_logical(ca: ChunkedArray) -> Logical { + pub fn new_logical(ca: ChunkedArray) -> Logical { Logical(ca, PhantomData, None) } } diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 1654a6e2d42a1..05d682fc6ed98 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use arrow::array::*; use arrow::bitmap::Bitmap; +use polars_compute::filter::filter_with_bitmap; use crate::prelude::*; @@ -148,16 +149,21 @@ impl ChunkedArray { /// If you want to explicitly the `length` and `null_count`, look at /// [`ChunkedArray::new_with_dims`] pub fn new_with_compute_len(field: Arc, chunks: Vec) -> Self { - let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0); - chunked_arr.compute_len(); - chunked_arr + unsafe { + let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0); + chunked_arr.compute_len(); + chunked_arr + } } /// Create a new [`ChunkedArray`] and explicitly set its `length` and `null_count`. /// /// If you want to compute the `length` and `null_count`, look at /// [`ChunkedArray::new_with_compute_len`] - pub fn new_with_dims( + /// + /// # Safety + /// The length and null_count must be correct. + pub unsafe fn new_with_dims( field: Arc, chunks: Vec, length: IdxSize, @@ -424,6 +430,31 @@ impl ChunkedArray { } } + pub fn drop_nulls(&self) -> Self { + if self.null_count() == 0 { + self.clone() + } else { + let chunks = self + .downcast_iter() + .map(|arr| { + if arr.null_count() == 0 { + arr.to_boxed() + } else { + filter_with_bitmap(arr, arr.validity().unwrap()) + } + }) + .collect(); + unsafe { + Self::new_with_dims( + self.field.clone(), + chunks, + (self.len() - self.null_count()) as IdxSize, + 0, + ) + } + } + } + /// Get the buffer of bits representing null values #[inline] #[allow(clippy::type_complexity)] diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 232f7f67d5422..d54fb8c7dea62 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -69,7 +69,9 @@ where self.field.dtype = get_object_type::(); - ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len as IdxSize, null_count) + unsafe { + ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len as IdxSize, null_count) + } } } @@ -141,7 +143,7 @@ where len, }); - ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0) + unsafe { ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0) } } pub fn new_from_vec_and_validity(name: &str, v: Vec, validity: Bitmap) -> Self { @@ -155,7 +157,9 @@ where len, }); - ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, null_count as IdxSize) + unsafe { + ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, null_count as IdxSize) + } } pub fn new_empty(name: &str) -> Self { diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs index 9bd64b029117f..a6cf1ca982881 100644 --- a/crates/polars-core/src/chunked_array/ops/filter.rs +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -31,7 +31,7 @@ where arity::binary_unchecked_same_type( self, filter, - |left, mask| filter_fn(left, mask).unwrap(), + |left, mask| filter_fn(left, mask), true, true, ) @@ -53,7 +53,7 @@ impl ChunkFilter for BooleanChunked { arity::binary_unchecked_same_type( self, filter, - |left, mask| filter_fn(left, mask).unwrap(), + |left, mask| filter_fn(left, mask), true, true, ) @@ -82,7 +82,7 @@ impl ChunkFilter for BinaryChunked { arity::binary_unchecked_same_type( self, filter, - |left, mask| filter_fn(left, mask).unwrap(), + |left, mask| filter_fn(left, mask), true, true, ) @@ -104,7 +104,7 @@ impl ChunkFilter for BinaryOffsetChunked { arity::binary_unchecked_same_type( self, filter, - |left, mask| filter_fn(left, mask).unwrap(), + |left, mask| filter_fn(left, mask), true, true, ) @@ -129,7 +129,7 @@ impl ChunkFilter for ListChunked { arity::binary_unchecked_same_type( self, filter, - |left, mask| filter_fn(left, mask).unwrap(), + |left, mask| filter_fn(left, mask), true, true, ) @@ -155,7 +155,7 @@ impl ChunkFilter for ArrayChunked { arity::binary_unchecked_same_type( self, filter, - |left, mask| filter_fn(left, mask).unwrap(), + |left, mask| filter_fn(left, mask), true, true, ) diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 5650e93505d13..ecf7db6cbfe38 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -657,6 +657,7 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult #[cfg(feature = "dtype-decimal")] Decimal(_, _) => s.clone(), List(inner) if !inner.is_nested() => s.clone(), + Null => s.clone(), _ => { let phys = s.to_physical_repr().into_owned(); polars_ensure!( diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index c43211c141e3d..02e06abc3c420 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -239,6 +239,10 @@ impl SeriesTrait for NullChunked { .into_series() } + fn sort_with(&self, _options: SortOptions) -> PolarsResult { + Ok(self.clone().into_series()) + } + fn is_null(&self) -> BooleanChunked { BooleanChunked::full(self.name(), true, self.len()) } diff --git a/crates/polars-ops/src/chunked_array/top_k.rs b/crates/polars-ops/src/chunked_array/top_k.rs index 99e24b9123623..c5edaaebd64fe 100644 --- a/crates/polars-ops/src/chunked_array/top_k.rs +++ b/crates/polars-ops/src/chunked_array/top_k.rs @@ -1,206 +1,153 @@ -use std::cmp::Ordering; - -use arrow::array::{BooleanArray, MutableBooleanArray}; -use arrow::bitmap::MutableBitmap; -use either::Either; +use arrow::array::{BinaryViewArray, BooleanArray, PrimitiveArray, StaticArray, View}; +use arrow::bitmap::{Bitmap, MutableBitmap}; use polars_core::chunked_array::ops::sort::arg_bottom_k::_arg_bottom_k; +use polars_core::downcast_as_macro_arg_physical; use polars_core::prelude::*; -use polars_core::{downcast_as_macro_arg_physical, POOL}; +use polars_core::series::IsSorted; use polars_utils::total_ord::TotalOrd; -use rayon::prelude::*; -fn arg_partition Ordering + Sync>( - v: &mut [T], - k: usize, - sort_options: SortOptions, - cmp: C, -) -> &[T] { - let (lower, _el, upper) = v.select_nth_unstable_by(k, &cmp); - let to_sort = if sort_options.descending { - lower +fn first_n_valid_mask(num_valid: usize, out_len: usize) -> Option { + if num_valid < out_len { + let mut bm = MutableBitmap::with_capacity(out_len); + bm.extend_constant(num_valid, true); + bm.extend_constant(out_len - num_valid, false); + Some(bm.freeze()) } else { - upper - }; - let cmp = |a: &T, b: &T| { - if sort_options.descending { - cmp(a, b) - } else { - cmp(b, a) - } - }; - match (sort_options.multithreaded, sort_options.maintain_order) { - (true, true) => POOL.install(|| { - to_sort.par_sort_by(cmp); - }), - (true, false) => POOL.install(|| { - to_sort.par_sort_unstable_by(cmp); - }), - (false, true) => to_sort.sort_by(cmp), - (false, false) => to_sort.sort_unstable_by(cmp), - }; - to_sort + None + } } -fn top_k_num_impl(ca: &ChunkedArray, k: usize, sort_options: SortOptions) -> ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: ChunkSort, -{ - if k >= ca.len() { - return ca.sort_with( - sort_options - .with_maintain_order(false) - .with_order_reversed(), - ); +fn top_k_bool_impl( + ca: &ChunkedArray, + k: usize, + descending: bool, +) -> ChunkedArray { + if k >= ca.len() && ca.null_count() == 0 { + return ca.clone(); } - // descending is opposite from sort as top-k returns largest - let k = if sort_options.descending { - std::cmp::min(k, ca.len()) + let null_count = ca.null_count(); + let non_null_count = ca.len() - ca.null_count(); + let true_count = ca.sum().unwrap() as usize; + let false_count = non_null_count - true_count; + let mut out_len = k.min(ca.len()); + let validity = first_n_valid_mask(non_null_count, out_len); + + // Logical sequence of physical bits. + let sequence = if descending { + [ + (false_count, false), + (true_count, true), + (null_count, false), + ] } else { - ca.len().saturating_sub(k + 1) + [ + (true_count, true), + (false_count, false), + (null_count, false), + ] }; - match ca.to_vec_null_aware() { - Either::Left(mut v) => { - let values = arg_partition( - &mut v, - k, - sort_options.with_maintain_order(false), - TotalOrd::tot_cmp, - ); - ChunkedArray::from_slice(ca.name(), values) - }, - Either::Right(mut v) => { - let values = arg_partition( - &mut v, - k, - sort_options.with_maintain_order(false), - TotalOrd::tot_cmp, - ); - let mut out = ChunkedArray::from_iter(values.iter().copied()); - out.rename(ca.name()); - out - }, + let mut bm = MutableBitmap::with_capacity(out_len); + for (n, value) in sequence { + if out_len == 0 { + break; + } + let extra = out_len.min(n); + bm.extend_constant(extra, value); + out_len -= extra; } + + let arr = BooleanArray::from_data_default(bm.into(), validity); + ChunkedArray::with_chunk_like(ca, arr) } -fn top_k_bool_impl( - ca: &ChunkedArray, - k: usize, - sort_options: SortOptions, -) -> ChunkedArray { - if ca.null_count() == 0 { - let true_count = ca.sum().unwrap() as usize; - let mut bitmap = MutableBitmap::with_capacity(k); - if !sort_options.descending { - // true first - bitmap.extend_constant(std::cmp::min(k, true_count), true); - bitmap.extend_constant(k.saturating_sub(true_count), false); - } else { - let false_count = ca.len().saturating_sub(true_count); - bitmap.extend_constant(std::cmp::min(k, false_count), false); - bitmap.extend_constant(k.saturating_sub(false_count), true); - } - let arr = BooleanArray::from_data_default(bitmap.into(), None); - unsafe { - ChunkedArray::from_chunks_and_dtype(ca.name(), vec![Box::new(arr)], DataType::Boolean) - } - } else { - let null_count = ca.null_count(); - let true_count = ca.sum().unwrap() as usize; - let false_count = ca.len() - true_count - null_count; - let mut remaining = k; - - fn extend_constant_check_remaining( - array: &mut MutableBooleanArray, - remaining: &mut usize, - additional: usize, - value: Option, - ) { - array.extend_constant(std::cmp::min(additional, *remaining), value); - *remaining = remaining.saturating_sub(additional); - } +fn top_k_num_impl(ca: &ChunkedArray, k: usize, descending: bool) -> ChunkedArray +where + T: PolarsNumericType, +{ + if k >= ca.len() && ca.null_count() == 0 { + return ca.clone(); + } - let mut array = MutableBooleanArray::with_capacity(k); - if !sort_options.descending { - if sort_options.nulls_last { - // True -> False -> Null - extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true)); - extend_constant_check_remaining( - &mut array, - &mut remaining, - false_count, - Some(false), - ); - extend_constant_check_remaining(&mut array, &mut remaining, null_count, None); - } else { - // Null -> True -> False - extend_constant_check_remaining(&mut array, &mut remaining, null_count, None); - extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true)); - extend_constant_check_remaining( - &mut array, - &mut remaining, - false_count, - Some(false), - ); - } + // Get rid of all the nulls and transform into Vec. + let nnca = ca.drop_nulls().rechunk(); + let chunk = nnca.downcast_into_iter().next().unwrap(); + let (_, buffer, _) = chunk.into_inner(); + let mut vec = buffer.make_mut(); + + // Partition. + if k < vec.len() { + if descending { + vec.select_nth_unstable_by(k, TotalOrd::tot_cmp); } else { - // False -> True -> Null - extend_constant_check_remaining(&mut array, &mut remaining, false_count, Some(false)); - extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true)); - extend_constant_check_remaining(&mut array, &mut remaining, null_count, None); + vec.select_nth_unstable_by(k, |a, b| TotalOrd::tot_cmp(b, a)); } - let mut new_ca: ChunkedArray = BooleanArray::from(array).into(); - new_ca.rename(ca.name()); - new_ca } + + // Reconstruct output (with nulls at the end). + let out_len = k.min(ca.len()); + let non_null_count = ca.len() - ca.null_count(); + vec.resize(out_len, T::Native::default()); + let validity = first_n_valid_mask(non_null_count, out_len); + + let arr = PrimitiveArray::from_vec(vec).with_validity_typed(validity); + ChunkedArray::with_chunk_like(ca, arr) } fn top_k_binary_impl( ca: &ChunkedArray, k: usize, - sort_options: SortOptions, + descending: bool, ) -> ChunkedArray { - if k >= ca.len() { - return ca.sort_with( - sort_options - .with_order_reversed() - // single series main order is meaningless - .with_maintain_order(false), - ); + if k >= ca.len() && ca.null_count() == 0 { + return ca.clone(); } - // descending is opposite from sort as top-k returns largest - let k = if sort_options.descending { - std::cmp::min(k, ca.len()) - } else { - ca.len().saturating_sub(k + 1) - }; - - if ca.null_count() == 0 { - let mut v: Vec<&[u8]> = Vec::with_capacity(ca.len()); - for arr in ca.downcast_iter() { - v.extend(arr.non_null_values_iter()); - } - let values = arg_partition(&mut v, k, sort_options, TotalOrd::tot_cmp); - ChunkedArray::from_slice(ca.name(), values) - } else { - let mut v = Vec::with_capacity(ca.len()); - for arr in ca.downcast_iter() { - v.extend(arr.iter()); + // Get rid of all the nulls and transform into mutable views. + let nnca = ca.drop_nulls().rechunk(); + let chunk = nnca.downcast_into_iter().next().unwrap(); + let buffers = chunk.data_buffers().clone(); + let mut views = chunk.into_views(); + + // Partition. + if k < views.len() { + if descending { + views.select_nth_unstable_by(k, |a, b| unsafe { + let a_sl = a.get_slice_unchecked(&buffers); + let b_sl = b.get_slice_unchecked(&buffers); + a_sl.cmp(b_sl) + }); + } else { + views.select_nth_unstable_by(k, |a, b| unsafe { + let a_sl = a.get_slice_unchecked(&buffers); + let b_sl = b.get_slice_unchecked(&buffers); + b_sl.cmp(a_sl) + }); } - let values = arg_partition(&mut v, k, sort_options, TotalOrd::tot_cmp); - let mut out = ChunkedArray::from_iter(values.iter().copied()); - out.rename(ca.name()); - out } + + // Reconstruct output (with nulls at the end). + let out_len = k.min(ca.len()); + let non_null_count = ca.len() - ca.null_count(); + views.resize(out_len, View::default()); + let validity = first_n_valid_mask(non_null_count, out_len); + + let arr = unsafe { + BinaryViewArray::new_unchecked_unknown_md( + ArrowDataType::BinaryView, + views.into(), + buffers, + validity, + None, + ) + }; + ChunkedArray::with_chunk_like(ca, arr) } -pub fn top_k(s: &[Series], sort_options: SortOptions) -> PolarsResult { +pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { fn extract_target_and_k(s: &[Series]) -> PolarsResult<(usize, &Series)> { let k_s = &s[1]; - polars_ensure!( k_s.len() == 1, ComputeError: "`k` must be a single value for `top_k`." @@ -211,7 +158,6 @@ pub fn top_k(s: &[Series], sort_options: SortOptions) -> PolarsResult { }; let src = &s[0]; - Ok((k as usize, src)) } @@ -221,15 +167,29 @@ pub fn top_k(s: &[Series], sort_options: SortOptions) -> PolarsResult { return Ok(src.clone()); } - match src.is_sorted_flag() { - polars_core::series::IsSorted::Ascending => { - // TopK is the k element in the bottom of ascending sorted array - return Ok(src.slice((src.len() - k) as i64, k).reverse()); - }, - polars_core::series::IsSorted::Descending => { - return Ok(src.slice(0, k)); - }, - _ => {}, + let sorted_flag = src.is_sorted_flag(); + let is_sorted = match src.is_sorted_flag() { + IsSorted::Ascending => true, + IsSorted::Descending => true, + IsSorted::Not => false, + }; + if is_sorted { + let out_len = k.min(src.len()); + let ignored_len = src.len() - out_len; + + let slice_at_start = (sorted_flag == IsSorted::Ascending) ^ descending; + let nulls_at_start = src.get(0).unwrap() == AnyValue::Null; + let offset = if nulls_at_start == slice_at_start { + src.null_count().min(ignored_len) + } else { + 0 + }; + + return if slice_at_start { + Ok(src.slice(offset as i64, out_len)) + } else { + Ok(src.slice(-(offset as i64) - (out_len as i64), out_len)) + }; } let origin_dtype = src.dtype(); @@ -237,19 +197,29 @@ pub fn top_k(s: &[Series], sort_options: SortOptions) -> PolarsResult { let s = src.to_physical_repr(); match s.dtype() { - DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, sort_options).into_series()), + DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, descending).into_series()), DataType::String => { - let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, sort_options); + let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, descending); let ca = unsafe { ca.to_string_unchecked() }; Ok(ca.into_series()) }, - DataType::Binary => { - Ok(top_k_binary_impl(s.binary().unwrap(), k, sort_options).into_series()) + DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_series()), + DataType::Decimal(_, _) => { + let src = src.decimal().unwrap(); + let ca = top_k_num_impl(src, k, descending); + let mut lca = DecimalChunked::new_logical(ca); + lca.2 = Some(DataType::Decimal(src.precision(), Some(src.scale()))); + Ok(lca.into_series()) + }, + DataType::Null => Ok(src.slice(0, k)), + DataType::Struct(_) => { + // Fallback to more generic impl. + top_k_by_impl(k, src, &[src.clone()], vec![descending]) }, _dt => { macro_rules! dispatch { ($ca:expr) => {{ - top_k_num_impl($ca, k, sort_options).into_series() + top_k_num_impl($ca, k, descending).into_series() }}; } unsafe { downcast_as_macro_arg_physical!(&s, dispatch).cast_unchecked(origin_dtype) } @@ -257,7 +227,7 @@ pub fn top_k(s: &[Series], sort_options: SortOptions) -> PolarsResult { } } -pub fn top_k_by(s: &[Series], sort_options: SortMultipleOptions) -> PolarsResult { +pub fn top_k_by(s: &[Series], descending: Vec) -> PolarsResult { /// Return (k, src, by) fn extract_parameters(s: &[Series]) -> PolarsResult<(usize, &Series, &[Series])> { let k_s = &s[1]; @@ -294,22 +264,28 @@ pub fn top_k_by(s: &[Series], sort_options: SortMultipleOptions) -> PolarsResult } } - top_k_by_impl(k, src, by, sort_options) + top_k_by_impl(k, src, by, descending) } fn top_k_by_impl( k: usize, src: &Series, by: &[Series], - sort_options: SortMultipleOptions, + descending: Vec, ) -> PolarsResult { if src.is_empty() { return Ok(src.clone()); } - let multithreaded = sort_options.multithreaded; + let multithreaded = k >= 10000; + let mut sort_options = SortMultipleOptions { + descending: descending.into_iter().map(|x| !x).collect(), + nulls_last: vec![true; by.len()], + multithreaded, + maintain_order: false, + }; - let idx = _arg_bottom_k(k, by, &mut sort_options.with_order_reversed())?; + let idx = _arg_bottom_k(k, by, &mut sort_options)?; let result = unsafe { if multithreaded { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index b46f71a77fa4e..6fdab6763f502 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -190,11 +190,11 @@ pub enum FunctionExpr { AsStruct, #[cfg(feature = "top_k")] TopK { - sort_options: SortOptions, + descending: bool, }, #[cfg(feature = "top_k")] TopKBy { - sort_options: SortMultipleOptions, + descending: Vec, }, #[cfg(feature = "cum_agg")] CumCount { @@ -452,7 +452,7 @@ impl Hash for FunctionExpr { has_max.hash(state); }, #[cfg(feature = "top_k")] - TopK { sort_options } => sort_options.hash(state), + TopK { descending } => descending.hash(state), #[cfg(feature = "cum_agg")] CumCount { reverse } => reverse.hash(state), #[cfg(feature = "cum_agg")] @@ -575,7 +575,7 @@ impl Hash for FunctionExpr { Reinterpret(signed) => signed.hash(state), ExtendConstant => {}, #[cfg(feature = "top_k")] - TopKBy { sort_options } => sort_options.hash(state), + TopKBy { descending } => descending.hash(state), } } } @@ -650,9 +650,7 @@ impl Display for FunctionExpr { #[cfg(feature = "dtype-struct")] AsStruct => "as_struct", #[cfg(feature = "top_k")] - TopK { - sort_options: SortOptions { descending, .. }, - } => { + TopK { descending } => { if *descending { "bottom_k" } else { @@ -989,11 +987,11 @@ impl From for SpecialEq> { map_as_slice!(coerce::as_struct) }, #[cfg(feature = "top_k")] - TopK { sort_options } => { - map_as_slice!(top_k, sort_options) + TopK { descending } => { + map_as_slice!(top_k, descending) }, #[cfg(feature = "top_k")] - TopKBy { sort_options } => map_as_slice!(top_k_by, sort_options.clone()), + TopKBy { descending } => map_as_slice!(top_k_by, descending.clone()), Shift => map_as_slice!(shift_and_fill::shift), #[cfg(feature = "cum_agg")] CumCount { reverse } => map!(cum::cum_count, reverse), diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 6d6f9304e2ece..876c4e72fda88 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -450,8 +450,8 @@ impl Expr { /// /// This has time complexity `O(n + k log(n))`. #[cfg(feature = "top_k")] - pub fn top_k(self, k: Expr, sort_options: SortOptions) -> Self { - self.apply_many_private(FunctionExpr::TopK { sort_options }, &[k], false, false) + pub fn top_k(self, k: Expr) -> Self { + self.apply_many_private(FunctionExpr::TopK { descending: false }, &[k], false, false) } /// Returns the `k` largest rows by given column. @@ -462,26 +462,19 @@ impl Expr { self, k: K, by: E, - sort_options: SortMultipleOptions, + descending: Vec, ) -> Self { let mut args = vec![k.into()]; args.extend(by.as_ref().iter().map(|e| -> Expr { e.clone().into() })); - self.apply_many_private(FunctionExpr::TopKBy { sort_options }, &args, false, false) + self.apply_many_private(FunctionExpr::TopKBy { descending }, &args, false, false) } /// Returns the `k` smallest elements. /// /// This has time complexity `O(n + k log(n))`. #[cfg(feature = "top_k")] - pub fn bottom_k(self, k: Expr, sort_options: SortOptions) -> Self { - self.apply_many_private( - FunctionExpr::TopK { - sort_options: sort_options.with_order_reversed(), - }, - &[k], - false, - false, - ) + pub fn bottom_k(self, k: Expr) -> Self { + self.apply_many_private(FunctionExpr::TopK { descending: true }, &[k], false, false) } /// Returns the `k` smallest rows by given column. @@ -493,18 +486,12 @@ impl Expr { self, k: K, by: E, - sort_options: SortMultipleOptions, + descending: Vec, ) -> Self { let mut args = vec![k.into()]; args.extend(by.as_ref().iter().map(|e| -> Expr { e.clone().into() })); - self.apply_many_private( - FunctionExpr::TopKBy { - sort_options: sort_options.with_order_reversed(), - }, - &args, - false, - false, - ) + let descending = descending.into_iter().map(|x| !x).collect(); + self.apply_many_private(FunctionExpr::TopKBy { descending }, &args, false, false) } /// Reverse column diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 9fc727ab71ca7..fbd80f4e058fe 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4730,6 +4730,11 @@ def top_k( """ Return the `k` largest rows. + Non-null elements are always preferred over null elements, regardless of + the value of `descending`. The output is not guaranteed to be in any + particular order, call :func:`sort` after this function if you wish the + output to be sorted. + Parameters ---------- k @@ -4806,6 +4811,11 @@ def bottom_k( """ Return the `k` smallest rows. + Non-null elements are always preferred over null elements, regardless of + the value of `descending`. The output is not guaranteed to be in any + particular order, call :func:`sort` after this function if you wish the + output to be sorted. + Parameters ---------- k diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 4b74959af9f9d..349deb89daca7 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -1824,9 +1824,13 @@ def top_k(self, k: int | IntoExprColumn = 5) -> Self: r""" Return the `k` largest elements. + Non-null elements are always preferred over null elements. The output + is not guaranteed to be in any particular order, call :func:`sort` + after this function if you wish the output to be sorted. + This has time complexity: - .. math:: O(n + k \log{n}) + .. math:: O(n) Parameters ---------- @@ -1854,11 +1858,11 @@ def top_k(self, k: int | IntoExprColumn = 5) -> Self: │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═══════╪══════════╡ - │ 99 ┆ 1 │ - │ 98 ┆ 2 │ - │ 4 ┆ 3 │ - │ 3 ┆ 4 │ - │ 2 ┆ 98 │ + │ 4 ┆ 1 │ + │ 98 ┆ 98 │ + │ 2 ┆ 2 │ + │ 3 ┆ 3 │ + │ 99 ┆ 4 │ └───────┴──────────┘ """ k = parse_as_expression(k) @@ -1874,9 +1878,14 @@ def top_k_by( r""" Return the elements corresponding to the `k` largest elements of the `by` column(s). + Non-null elements are always preferred over null elements, regardless of + the value of `descending`. The output is not guaranteed to be in any + particular order, call :func:`sort` after this function if you wish the + output to be sorted. + This has time complexity: - .. math:: O(n + k \log{n}) + .. math:: O(n \log{n}) Parameters ---------- @@ -1985,9 +1994,13 @@ def bottom_k(self, k: int | IntoExprColumn = 5) -> Self: r""" Return the `k` smallest elements. + Non-null elements are always preferred over null elements. The output is + not guaranteed to be in any particular order, call :func:`sort` after + this function if you wish the output to be sorted. + This has time complexity: - .. math:: O(n + k \log{n}) + .. math:: O(n) Parameters ---------- @@ -2017,11 +2030,11 @@ def bottom_k(self, k: int | IntoExprColumn = 5) -> Self: │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═══════╪══════════╡ - │ 99 ┆ 1 │ - │ 98 ┆ 2 │ - │ 4 ┆ 3 │ - │ 3 ┆ 4 │ - │ 2 ┆ 98 │ + │ 4 ┆ 1 │ + │ 98 ┆ 98 │ + │ 2 ┆ 2 │ + │ 3 ┆ 3 │ + │ 99 ┆ 4 │ └───────┴──────────┘ """ k = parse_as_expression(k) @@ -2037,9 +2050,14 @@ def bottom_k_by( r""" Return the elements corresponding to the `k` smallest elements of the `by` column(s). + Non-null elements are always preferred over null elements, regardless of + the value of `descending`. The output is not guaranteed to be in any + particular order, call :func:`sort` after this function if you wish the + output to be sorted. + This has time complexity: - .. math:: O(n + k \log{n}) + .. math:: O(n \log{n}) Parameters ---------- diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index a1ea726351ea4..bbd4cbadc4ff9 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1379,6 +1379,11 @@ def top_k( """ Return the `k` largest rows. + Non-null elements are always preferred over null elements, regardless of + the value of `descending`. The output is not guaranteed to be in any + particular order, call :func:`sort` after this function if you wish the + output to be sorted. + Parameters ---------- k @@ -1448,6 +1453,11 @@ def bottom_k( """ Return the `k` smallest rows. + Non-null elements are always preferred over null elements, regardless of + the value of `descending`. The output is not guaranteed to be in any + particular order, call :func:`sort` after this function if you wish the + output to be sorted. + Parameters ---------- k diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index e0bbde0a3f406..9b481d947aec0 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3114,9 +3114,13 @@ def top_k(self, k: int = 5) -> Series: r""" Return the `k` largest elements. + Non-null elements are always preferred over null elements. The output is + not guaranteed to be in any particular order, call :func:`sort` after + this function if you wish the output to be sorted. + This has time complexity: - .. math:: O(n + k \log{n}) + .. math:: O(n) Parameters ---------- @@ -3144,9 +3148,14 @@ def bottom_k(self, k: int = 5) -> Series: r""" Return the `k` smallest elements. + Non-null elements are always preferred over null elements. The output is + not guaranteed to be in any particular order, call :func:`sort` after + this function if you wish the output to be sorted. This has time + complexity: + This has time complexity: - .. math:: O(n + k \log{n}) + .. math:: O(n) Parameters ---------- diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 7f630003a4cec..57a43f1ff6728 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -294,28 +294,18 @@ impl PyExpr { #[cfg(feature = "top_k")] fn top_k(&self, k: Self) -> Self { - self.inner.clone().top_k(k.inner, SortOptions::new()).into() + self.inner.clone().top_k(k.inner).into() } #[cfg(feature = "top_k")] fn top_k_by(&self, by: Vec, k: Self, descending: Vec) -> Self { let by = by.into_iter().map(|e| e.inner).collect::>(); - self.inner - .clone() - .top_k_by( - k.inner, - by, - SortMultipleOptions::new().with_order_descending_multi(descending), - ) - .into() + self.inner.clone().top_k_by(k.inner, by, descending).into() } #[cfg(feature = "top_k")] fn bottom_k(&self, k: Self) -> Self { - self.inner - .clone() - .bottom_k(k.inner, SortOptions::new()) - .into() + self.inner.clone().bottom_k(k.inner).into() } #[cfg(feature = "top_k")] @@ -323,11 +313,7 @@ impl PyExpr { let by = by.into_iter().map(|e| e.inner).collect::>(); self.inner .clone() - .bottom_k_by( - k.inner, - by, - SortMultipleOptions::new().with_order_descending_multi(descending), - ) + .bottom_k_by(k.inner, by, descending) .into() } diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 91f826d1d0a97..116274d8138fd 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -1111,9 +1111,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { has_max: _, } => return Err(PyNotImplementedError::new_err("clip")), FunctionExpr::AsStruct => return Err(PyNotImplementedError::new_err("as struct")), - FunctionExpr::TopK { sort_options: _ } => { - return Err(PyNotImplementedError::new_err("top k")) - }, + FunctionExpr::TopK { .. } => return Err(PyNotImplementedError::new_err("top k")), FunctionExpr::CumCount { reverse } => ("cumcount", reverse).to_object(py), FunctionExpr::CumSum { reverse } => ("cumsum", reverse).to_object(py), FunctionExpr::CumProd { reverse } => ("cumprod", reverse).to_object(py), @@ -1234,7 +1232,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::Business(_) => { return Err(PyNotImplementedError::new_err("business")) }, - FunctionExpr::TopKBy { sort_options: _ } => { + FunctionExpr::TopKBy { .. } => { return Err(PyNotImplementedError::new_err("top_k_by")) }, FunctionExpr::EwmMeanBy { half_life: _ } => { diff --git a/py-polars/tests/unit/operations/test_top_k.py b/py-polars/tests/unit/operations/test_top_k.py index 0fbfac7d56365..4c3578ed252a7 100644 --- a/py-polars/tests/unit/operations/test_top_k.py +++ b/py-polars/tests/unit/operations/test_top_k.py @@ -1,15 +1,18 @@ import pytest +from hypothesis import given +from hypothesis.strategies import booleans import polars as pl from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import series def test_top_k() -> None: # expression s = pl.Series("a", [3, 8, 1, 5, 2]) - assert_series_equal(s.top_k(3), pl.Series("a", [8, 5, 3])) - assert_series_equal(s.bottom_k(4), pl.Series("a", [1, 2, 3, 5])) + assert_series_equal(s.top_k(3), pl.Series("a", [8, 5, 3]), check_order=False) + assert_series_equal(s.bottom_k(4), pl.Series("a", [3, 2, 1, 5]), check_order=False) # 5886 df = pl.DataFrame( @@ -23,6 +26,7 @@ def test_top_k() -> None: assert_frame_equal( df.select(pl.col("test").top_k(10)), pl.DataFrame({"test": [4, 3, 2, 1]}), + check_row_order=False, ) assert_frame_equal( @@ -31,6 +35,7 @@ def test_top_k() -> None: bottom_k=pl.col("test").bottom_k(pl.col("val").min()), ), pl.DataFrame({"top_k": [4, 3], "bottom_k": [1, 2]}), + check_row_order=False, ) assert_frame_equal( @@ -39,6 +44,7 @@ def test_top_k() -> None: pl.col("bool_val").bottom_k(2).alias("bottom_k"), ), pl.DataFrame({"top_k": [True, True], "bottom_k": [False, False]}), + check_row_order=False, ) assert_frame_equal( @@ -47,6 +53,7 @@ def test_top_k() -> None: pl.col("str_value").bottom_k(2).alias("bottom_k"), ), pl.DataFrame({"top_k": ["d", "c"], "bottom_k": ["a", "b"]}), + check_row_order=False, ) with pytest.raises(pl.ComputeError, match="`k` must be set for `top_k`"): @@ -70,15 +77,18 @@ def test_top_k() -> None: assert_frame_equal( df.top_k(3, by=["a", "b"]), pl.DataFrame({"a": [4, 3, 2], "b": [4, 1, 3]}), + check_row_order=False, ) assert_frame_equal( df.top_k(3, by=["a", "b"], descending=True), pl.DataFrame({"a": [1, 2, 2], "b": [3, 2, 2]}), + check_row_order=False, ) assert_frame_equal( df.bottom_k(4, by=["a", "b"], descending=True), pl.DataFrame({"a": [4, 3, 2, 2], "b": [4, 1, 3, 2]}), + check_row_order=False, ) df2 = pl.DataFrame( @@ -102,6 +112,7 @@ def test_top_k() -> None: "b_top_by_b": [12, 11], } ), + check_row_order=False, ) assert_frame_equal( @@ -117,6 +128,7 @@ def test_top_k() -> None: "b_top_by_b": [7, 8], } ), + check_row_order=False, ) assert_frame_equal( @@ -132,6 +144,7 @@ def test_top_k() -> None: "b_bottom_by_b": [7, 8], } ), + check_row_order=False, ) assert_frame_equal( @@ -151,6 +164,7 @@ def test_top_k() -> None: "b_bottom_by_b": [12, 11], } ), + check_row_order=False, ) assert_frame_equal( @@ -164,6 +178,7 @@ def test_top_k() -> None: "b": [9, 10, 11, 7, 8], } ), + check_row_order=False, ) assert_frame_equal( @@ -177,6 +192,7 @@ def test_top_k() -> None: "b": [12, 10, 11, 8, 7], } ), + check_row_order=False, ) assert_frame_equal( @@ -194,6 +210,7 @@ def test_top_k() -> None: "c_top_by_cb": ["Orange", "Banana"], } ), + check_row_order=False, ) assert_frame_equal( @@ -215,6 +232,7 @@ def test_top_k() -> None: "c_bottom_by_cb": ["Apple", "Apple"], } ), + check_row_order=False, ) assert_frame_equal( @@ -236,6 +254,7 @@ def test_top_k() -> None: "c_top_by_cb": ["Apple", "Apple"], } ), + check_row_order=False, ) assert_frame_equal( @@ -257,6 +276,7 @@ def test_top_k() -> None: "c_bottom_by_cb": ["Orange", "Banana"], } ), + check_row_order=False, ) assert_frame_equal( @@ -278,6 +298,7 @@ def test_top_k() -> None: "c_top_by_cb": ["Orange", "Banana"], } ), + check_row_order=False, ) assert_frame_equal( @@ -299,6 +320,7 @@ def test_top_k() -> None: "c_bottom_by_cb": ["Orange", "Banana"], } ), + check_row_order=False, ) with pytest.raises( @@ -318,9 +340,9 @@ def test_top_k_descending() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) result = df.top_k(1, by=["a", "b"], descending=True) expected = pl.DataFrame({"a": [1], "b": [4]}) - assert_frame_equal(result, expected) + assert_frame_equal(result, expected, check_row_order=False) result = df.top_k(1, by=["a", "b"], descending=[True, True]) - assert_frame_equal(result, expected) + assert_frame_equal(result, expected, check_row_order=False) with pytest.raises( ValueError, match=r"the length of `descending` \(1\) does not match the length of `by` \(2\)", @@ -334,21 +356,16 @@ def test_top_k_9385() -> None: assert result.collect()["b"].to_list() == [False] -def test_top_k_sorted_flag() -> None: - # top-k/bottom-k - df = pl.DataFrame({"foo": [56, 2, 3]}) - assert df.top_k(2, by="foo")["foo"].flags["SORTED_DESC"] - assert df.bottom_k(2, by="foo")["foo"].flags["SORTED_ASC"] - - def test_top_k_empty() -> None: df = pl.DataFrame({"test": []}) assert_frame_equal(df.select([pl.col("test").top_k(2)]), df) -def test_top_k_nulls() -> None: - s = pl.Series([1, 2, 3, None, None]) +@given(s=series(excluded_dtypes=[pl.Null, pl.Struct]), should_sort=booleans()) +def test_top_k_nulls(s: pl.Series, should_sort: bool) -> None: + if should_sort: + s = s.sort() valid_count = s.len() - s.null_count() result = s.top_k(valid_count) @@ -358,14 +375,14 @@ def test_top_k_nulls() -> None: assert result.null_count() == s.null_count() result = s.top_k(s.len() * 2) - assert_series_equal(result.sort(), s.sort()) + assert_series_equal(result, s, check_order=False) + +@given(s=series(excluded_dtypes=[pl.Null, pl.Struct]), should_sort=booleans()) +def test_bottom_k_nulls(s: pl.Series, should_sort: bool) -> None: + if should_sort: + s = s.sort() -@pytest.mark.xfail( - reason="Currently bugged, see: https://github.com/pola-rs/polars/issues/16748" -) -def test_bottom_k_nulls() -> None: - s = pl.Series([1, 2, 3, None, None]) valid_count = s.len() - s.null_count() result = s.bottom_k(valid_count) @@ -375,4 +392,4 @@ def test_bottom_k_nulls() -> None: assert result.null_count() == s.null_count() result = s.bottom_k(s.len() * 2) - assert_series_equal(result.sort(), s.sort()) + assert_series_equal(result, s, check_order=False)