Skip to content

Commit

Permalink
feat!: Treat float -0.0 == 0.0 and -NaN == NaN in group-by, joins and…
Browse files Browse the repository at this point in the history
… unique
  • Loading branch information
nameexhaustion committed Feb 21, 2024
1 parent f55e480 commit 35e88fc
Show file tree
Hide file tree
Showing 26 changed files with 746 additions and 233 deletions.
11 changes: 11 additions & 0 deletions crates/polars-core/src/chunked_array/object/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ pub(crate) fn create_extension<I: Iterator<Item = Option<T>> + TrustedLen, T: Si
#[cfg(test)]
mod test {
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};

use polars_utils::total_ord::TotalHash;
use polars_utils::unitvec;

use super::*;
Expand All @@ -151,6 +153,15 @@ mod test {
}
}

impl TotalHash for Foo {
fn tot_hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.hash(state);
}
}

impl Display for Foo {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-core/src/chunked_array/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::hash::Hash;

use arrow::bitmap::utils::{BitmapIter, ZipValidity};
use arrow::bitmap::Bitmap;
use polars_utils::total_ord::TotalHash;

use crate::prelude::*;

Expand Down Expand Up @@ -36,7 +37,7 @@ pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display {

/// Values need to implement this so that they can be stored into a Series and DataFrame
pub trait PolarsObject:
Any + Debug + Clone + Send + Sync + Default + Display + Hash + PartialEq + Eq + TotalEq
Any + Debug + Clone + Send + Sync + Default + Display + Hash + TotalHash + PartialEq + Eq + TotalEq
{
/// This should be used as type information. Consider this a part of the type system.
fn type_name() -> &'static str;
Expand Down
80 changes: 30 additions & 50 deletions crates/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::hash::Hash;

use arrow::bitmap::MutableBitmap;
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash};

#[cfg(feature = "object")]
use crate::datatypes::ObjectType;
Expand Down Expand Up @@ -60,12 +61,13 @@ impl<T: PolarsObject> ChunkUnique<ObjectType<T>> for ObjectChunked<T> {

fn arg_unique<T>(a: impl Iterator<Item = T>, capacity: usize) -> Vec<IdxSize>
where
T: Hash + Eq,
T: ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Hash + Eq,
{
let mut set = PlHashSet::new();
let mut unique = Vec::with_capacity(capacity);
a.enumerate().for_each(|(idx, val)| {
if set.insert(val) {
if set.insert(val.to_total_ord()) {
unique.push(idx as IdxSize)
}
});
Expand All @@ -83,8 +85,9 @@ macro_rules! arg_unique_ca {

impl<T> ChunkUnique<T> for ChunkedArray<T>
where
T: PolarsIntegerType,
T::Native: Hash + Eq + Ord,
T: PolarsNumericType,
T::Native: TotalHash + TotalEq + ToTotalOrd,
<T::Native as ToTotalOrd>::TotalOrdItem: Hash + Eq + Ord,
ChunkedArray<T>: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray<T>, Item = BooleanChunked>,
{
fn unique(&self) -> PolarsResult<Self> {
Expand All @@ -96,25 +99,23 @@ where
IsSorted::Ascending | IsSorted::Descending => {
if self.null_count() > 0 {
let mut arr = MutablePrimitiveArray::with_capacity(self.len());
let mut iter = self.into_iter();
let mut last = None;

if let Some(val) = iter.next() {
last = val;
arr.push(val)
};
if !self.is_empty() {
let mut iter = self.iter();
let last = iter.next().unwrap();
arr.push(last);
let mut last = last.to_total_ord();

#[allow(clippy::unnecessary_filter_map)]
let to_extend = iter.filter_map(|opt_val| {
if opt_val != last {
last = opt_val;
Some(opt_val)
} else {
None
}
});
let to_extend = iter.filter(|opt_val| {
let opt_val_tot_ord = opt_val.to_total_ord();
let out = opt_val_tot_ord != last;
last = opt_val_tot_ord;
out
});

arr.extend(to_extend);
}

arr.extend(to_extend);
let arr: PrimitiveArray<T::Native> = arr.into();
Ok(ChunkedArray::with_chunk(self.name(), arr))
} else {
Expand Down Expand Up @@ -142,15 +143,18 @@ where
IsSorted::Ascending | IsSorted::Descending => {
if self.null_count() > 0 {
let mut count = 0;
let mut iter = self.into_iter();
let mut last = None;

if let Some(val) = iter.next() {
last = val;
count += 1;
};
if self.is_empty() {
return Ok(count);
}

let mut iter = self.iter();
let mut last = iter.next().unwrap().to_total_ord();

count += 1;

iter.for_each(|opt_val| {
let opt_val = opt_val.to_total_ord();
if opt_val != last {
last = opt_val;
count += 1;
Expand Down Expand Up @@ -254,30 +258,6 @@ impl ChunkUnique<BooleanType> for BooleanChunked {
}
}

impl ChunkUnique<Float32Type> for Float32Chunked {
fn unique(&self) -> PolarsResult<ChunkedArray<Float32Type>> {
let ca = self.bit_repr_small();
let ca = ca.unique()?;
Ok(ca._reinterpret_float())
}

fn arg_unique(&self) -> PolarsResult<IdxCa> {
self.bit_repr_small().arg_unique()
}
}

impl ChunkUnique<Float64Type> for Float64Chunked {
fn unique(&self) -> PolarsResult<ChunkedArray<Float64Type>> {
let ca = self.bit_repr_large();
let ca = ca.unique()?;
Ok(ca._reinterpret_float())
}

fn arg_unique(&self) -> PolarsResult<IdxCa> {
self.bit_repr_large().arg_unique()
}
}

#[cfg(test)]
mod test {
use crate::prelude::*;
Expand Down
9 changes: 5 additions & 4 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use polars_utils::format_smartstring;
use polars_utils::slice::GetSaferUnchecked;
#[cfg(feature = "dtype-categorical")]
use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::ToTotalOrd;
use polars_utils::unwrap::UnwrapUncheckedRelease;

use super::*;
Expand Down Expand Up @@ -893,8 +894,8 @@ impl AnyValue<'_> {
(Int16(l), Int16(r)) => *l == *r,
(Int32(l), Int32(r)) => *l == *r,
(Int64(l), Int64(r)) => *l == *r,
(Float32(l), Float32(r)) => *l == *r,
(Float64(l), Float64(r)) => *l == *r,
(Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(),
(Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(),
(String(l), String(r)) => l == r,
(String(l), StringOwned(r)) => l == r,
(StringOwned(l), String(r)) => l == r,
Expand Down Expand Up @@ -978,8 +979,8 @@ impl PartialOrd for AnyValue<'_> {
(Int16(l), Int16(r)) => l.partial_cmp(r),
(Int32(l), Int32(r)) => l.partial_cmp(r),
(Int64(l), Int64(r)) => l.partial_cmp(r),
(Float32(l), Float32(r)) => l.partial_cmp(r),
(Float64(l), Float64(r)) => l.partial_cmp(r),
(Float32(l), Float32(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()),
(Float64(l), Float64(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()),
(String(l), String(r)) => l.partial_cmp(*r),
(Binary(l), Binary(r)) => l.partial_cmp(*r),
_ => None,
Expand Down
24 changes: 16 additions & 8 deletions crates/polars-core/src/frame/group_by/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash};
use polars_utils::idx_vec::IdxVec;
use polars_utils::iter::EnumerateIdxTrait;
use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::{ToTotalOrd, TotalHash};
use polars_utils::unitvec;
use rayon::prelude::*;

Expand Down Expand Up @@ -144,12 +145,15 @@ fn finish_group_order_vecs(

pub(crate) fn group_by<T>(a: impl Iterator<Item = T>, sorted: bool) -> GroupsProxy
where
T: Hash + Eq,
T: TotalHash + TotalEq + ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Hash + Eq,
{
let init_size = get_init_size();
let mut hash_tbl: PlHashMap<T, (IdxSize, IdxVec)> = PlHashMap::with_capacity(init_size);
let mut hash_tbl: PlHashMap<T::TotalOrdItem, (IdxSize, IdxVec)> =
PlHashMap::with_capacity(init_size);
let mut cnt = 0;
a.for_each(|k| {
let k = k.to_total_ord();
let idx = cnt;
cnt += 1;
let entry = hash_tbl.entry(k);
Expand Down Expand Up @@ -188,7 +192,8 @@ pub(crate) fn group_by_threaded_slice<T, IntoSlice>(
sorted: bool,
) -> GroupsProxy
where
T: Send + Hash + Eq + Sync + Copy + DirtyHash,
T: TotalHash + TotalEq + ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash,
IntoSlice: AsRef<[T]> + Send + Sync,
{
let init_size = get_init_size();
Expand All @@ -200,7 +205,7 @@ where
(0..n_partitions)
.into_par_iter()
.map(|thread_no| {
let mut hash_tbl: PlHashMap<T, (IdxSize, IdxVec)> =
let mut hash_tbl: PlHashMap<T::TotalOrdItem, (IdxSize, IdxVec)> =
PlHashMap::with_capacity(init_size);

let mut offset = 0;
Expand All @@ -211,17 +216,18 @@ where

let mut cnt = 0;
keys.iter().for_each(|k| {
let k = k.to_total_ord();
let idx = cnt + offset;
cnt += 1;

if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) {
let hash = hasher.hash_one(k);
let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, k);
let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, &k);

match entry {
RawEntryMut::Vacant(entry) => {
let tuples = unitvec![idx];
entry.insert_with_hasher(hash, *k, (idx, tuples), |k| {
entry.insert_with_hasher(hash, k, (idx, tuples), |k| {
hasher.hash_one(k)
});
},
Expand Down Expand Up @@ -252,7 +258,8 @@ pub(crate) fn group_by_threaded_iter<T, I>(
where
I: IntoIterator<Item = T> + Send + Sync + Clone,
I::IntoIter: ExactSizeIterator,
T: Send + Hash + Eq + Sync + Copy + DirtyHash,
T: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
<T as ToTotalOrd>::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash,
{
let init_size = get_init_size();

Expand All @@ -263,7 +270,7 @@ where
(0..n_partitions)
.into_par_iter()
.map(|thread_no| {
let mut hash_tbl: PlHashMap<T, (IdxSize, IdxVec)> =
let mut hash_tbl: PlHashMap<T::TotalOrdItem, (IdxSize, IdxVec)> =
PlHashMap::with_capacity(init_size);

let mut offset = 0;
Expand All @@ -274,6 +281,7 @@ where

let mut cnt = 0;
keys.for_each(|k| {
let k = k.to_total_ord();
let idx = cnt + offset;
cnt += 1;

Expand Down
25 changes: 20 additions & 5 deletions crates/polars-core/src/frame/group_by/into_groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter;
use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups};
use arrow::legacy::prelude::*;
use polars_utils::total_ord::{ToTotalOrd, TotalHash};

use super::*;
use crate::config::verbose;
Expand All @@ -25,9 +26,9 @@ fn group_multithreaded<T: PolarsDataType>(ca: &ChunkedArray<T>) -> bool {

fn num_groups_proxy<T>(ca: &ChunkedArray<T>, multithreaded: bool, sorted: bool) -> GroupsProxy
where
T: PolarsIntegerType,
T::Native: Hash + Eq + Send + DirtyHash,
Option<T::Native>: DirtyHash,
T: PolarsNumericType,
T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd,
<T::Native as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy + Send + DirtyHash,
{
if multithreaded && group_multithreaded(ca) {
let n_partitions = _set_partition_size();
Expand Down Expand Up @@ -163,14 +164,28 @@ where
};
num_groups_proxy(ca, multithreaded, sorted)
},
DataType::Int64 | DataType::Float64 => {
DataType::Int64 => {
let ca = self.bit_repr_large();
num_groups_proxy(&ca, multithreaded, sorted)
},
DataType::Int32 | DataType::Float32 => {
DataType::Int32 => {
let ca = self.bit_repr_small();
num_groups_proxy(&ca, multithreaded, sorted)
},
DataType::Float64 => {
// convince the compiler that we are this type.
let ca: &Float64Chunked = unsafe {
&*(self as *const ChunkedArray<T> as *const ChunkedArray<Float64Type>)
};
num_groups_proxy(ca, multithreaded, sorted)
},
DataType::Float32 => {
// convince the compiler that we are this type.
let ca: &Float32Chunked = unsafe {
&*(self as *const ChunkedArray<T> as *const ChunkedArray<Float32Type>)
};
num_groups_proxy(ca, multithreaded, sorted)
},
#[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))]
DataType::Int8 => {
// convince the compiler that we are this type.
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-core/src/frame/group_by/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ fn prepare_dataframe_unsorted(by: &[Series]) -> DataFrame {
_ => {
if s.dtype().to_physical().is_numeric() {
let s = s.to_physical_repr();
if s.bit_repr_is_large() {

if s.dtype().is_float() {
s.into_owned().into_series()
} else if s.bit_repr_is_large() {
s.bit_repr_large().into_series()
} else {
s.bit_repr_small().into_series()
Expand Down
Loading

0 comments on commit 35e88fc

Please sign in to comment.