diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 18b4bb554a76..576e9b430726 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -133,7 +133,9 @@ pub(crate) fn create_extension> + TrustedLen, T: Si #[cfg(test)] mod test { use std::fmt::{Display, Formatter}; + use std::hash::{Hash, Hasher}; + use polars_utils::total_ord::TotalHash; use polars_utils::unitvec; use super::*; @@ -151,6 +153,15 @@ mod test { } } + impl TotalHash for Foo { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } + } + impl Display for Foo { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index 65fb98b4d96e..9f17a1d1b434 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -4,6 +4,7 @@ use std::hash::Hash; use arrow::bitmap::utils::{BitmapIter, ZipValidity}; use arrow::bitmap::Bitmap; +use polars_utils::total_ord::TotalHash; use crate::prelude::*; @@ -36,7 +37,7 @@ pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display { /// Values need to implement this so that they can be stored into a Series and DataFrame pub trait PolarsObject: - Any + Debug + Clone + Send + Sync + Default + Display + Hash + PartialEq + Eq + TotalEq + Any + Debug + Clone + Send + Sync + Default + Display + Hash + TotalHash + PartialEq + Eq + TotalEq { /// This should be used as type information. Consider this a part of the type system. fn type_name() -> &'static str; diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index 34e6946f7e7f..bf5d8b037582 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -1,6 +1,7 @@ use std::hash::Hash; use arrow::bitmap::MutableBitmap; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; #[cfg(feature = "object")] use crate::datatypes::ObjectType; @@ -60,12 +61,13 @@ impl ChunkUnique> for ObjectChunked { fn arg_unique(a: impl Iterator, capacity: usize) -> Vec where - T: Hash + Eq, + T: ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut set = PlHashSet::new(); let mut unique = Vec::with_capacity(capacity); a.enumerate().for_each(|(idx, val)| { - if set.insert(val) { + if set.insert(val.to_total_ord()) { unique.push(idx as IdxSize) } }); @@ -83,8 +85,9 @@ macro_rules! arg_unique_ca { impl ChunkUnique for ChunkedArray where - T: PolarsIntegerType, - T::Native: Hash + Eq + Ord, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Ord, ChunkedArray: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, { fn unique(&self) -> PolarsResult { @@ -96,25 +99,23 @@ where IsSorted::Ascending | IsSorted::Descending => { if self.null_count() > 0 { let mut arr = MutablePrimitiveArray::with_capacity(self.len()); - let mut iter = self.into_iter(); - let mut last = None; - if let Some(val) = iter.next() { - last = val; - arr.push(val) - }; + if !self.is_empty() { + let mut iter = self.iter(); + let last = iter.next().unwrap(); + arr.push(last); + let mut last = last.to_total_ord(); - #[allow(clippy::unnecessary_filter_map)] - let to_extend = iter.filter_map(|opt_val| { - if opt_val != last { - last = opt_val; - Some(opt_val) - } else { - None - } - }); + let to_extend = iter.filter(|opt_val| { + let opt_val_tot_ord = opt_val.to_total_ord(); + let out = opt_val_tot_ord != last; + last = opt_val_tot_ord; + out + }); + + arr.extend(to_extend); + } - arr.extend(to_extend); let arr: PrimitiveArray = arr.into(); Ok(ChunkedArray::with_chunk(self.name(), arr)) } else { @@ -142,15 +143,18 @@ where IsSorted::Ascending | IsSorted::Descending => { if self.null_count() > 0 { let mut count = 0; - let mut iter = self.into_iter(); - let mut last = None; - if let Some(val) = iter.next() { - last = val; - count += 1; - }; + if self.is_empty() { + return Ok(count); + } + + let mut iter = self.iter(); + let mut last = iter.next().unwrap().to_total_ord(); + + count += 1; iter.for_each(|opt_val| { + let opt_val = opt_val.to_total_ord(); if opt_val != last { last = opt_val; count += 1; @@ -254,30 +258,6 @@ impl ChunkUnique for BooleanChunked { } } -impl ChunkUnique for Float32Chunked { - fn unique(&self) -> PolarsResult> { - let ca = self.bit_repr_small(); - let ca = ca.unique()?; - Ok(ca._reinterpret_float()) - } - - fn arg_unique(&self) -> PolarsResult { - self.bit_repr_small().arg_unique() - } -} - -impl ChunkUnique for Float64Chunked { - fn unique(&self) -> PolarsResult> { - let ca = self.bit_repr_large(); - let ca = ca.unique()?; - Ok(ca._reinterpret_float()) - } - - fn arg_unique(&self) -> PolarsResult { - self.bit_repr_large().arg_unique() - } -} - #[cfg(test)] mod test { use crate::prelude::*; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 3f47896d5d87..4e4ee99f90c8 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -10,6 +10,7 @@ use polars_utils::format_smartstring; use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "dtype-categorical")] use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::ToTotalOrd; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; @@ -893,8 +894,8 @@ impl AnyValue<'_> { (Int16(l), Int16(r)) => *l == *r, (Int32(l), Int32(r)) => *l == *r, (Int64(l), Int64(r)) => *l == *r, - (Float32(l), Float32(r)) => *l == *r, - (Float64(l), Float64(r)) => *l == *r, + (Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(), + (Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(), (String(l), String(r)) => l == r, (String(l), StringOwned(r)) => l == r, (StringOwned(l), String(r)) => l == r, @@ -978,8 +979,8 @@ impl PartialOrd for AnyValue<'_> { (Int16(l), Int16(r)) => l.partial_cmp(r), (Int32(l), Int32(r)) => l.partial_cmp(r), (Int64(l), Int64(r)) => l.partial_cmp(r), - (Float32(l), Float32(r)) => l.partial_cmp(r), - (Float64(l), Float64(r)) => l.partial_cmp(r), + (Float32(l), Float32(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), + (Float64(l), Float64(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), (String(l), String(r)) => l.partial_cmp(*r), (Binary(l), Binary(r)) => l.partial_cmp(*r), _ => None, diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index 796b5c2f33d1..96624473f786 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -6,6 +6,7 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use polars_utils::unitvec; use rayon::prelude::*; @@ -144,12 +145,15 @@ fn finish_group_order_vecs( pub(crate) fn group_by(a: impl Iterator, sorted: bool) -> GroupsProxy where - T: Hash + Eq, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let init_size = get_init_size(); - let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); + let mut hash_tbl: PlHashMap = + PlHashMap::with_capacity(init_size); let mut cnt = 0; a.for_each(|k| { + let k = k.to_total_ord(); let idx = cnt; cnt += 1; let entry = hash_tbl.entry(k); @@ -188,7 +192,8 @@ pub(crate) fn group_by_threaded_slice( sorted: bool, ) -> GroupsProxy where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash, IntoSlice: AsRef<[T]> + Send + Sync, { let init_size = get_init_size(); @@ -200,7 +205,7 @@ where (0..n_partitions) .into_par_iter() .map(|thread_no| { - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); let mut offset = 0; @@ -211,17 +216,18 @@ where let mut cnt = 0; keys.iter().for_each(|k| { + let k = k.to_total_ord(); let idx = cnt + offset; cnt += 1; if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) { let hash = hasher.hash_one(k); - let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, k); + let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, &k); match entry { RawEntryMut::Vacant(entry) => { let tuples = unitvec![idx]; - entry.insert_with_hasher(hash, *k, (idx, tuples), |k| { + entry.insert_with_hasher(hash, k, (idx, tuples), |k| { hasher.hash_one(k) }); }, @@ -252,7 +258,8 @@ pub(crate) fn group_by_threaded_iter( where I: IntoIterator + Send + Sync + Clone, I::IntoIter: ExactSizeIterator, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash, { let init_size = get_init_size(); @@ -263,7 +270,7 @@ where (0..n_partitions) .into_par_iter() .map(|thread_no| { - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); let mut offset = 0; @@ -274,6 +281,7 @@ where let mut cnt = 0; keys.for_each(|k| { + let k = k.to_total_ord(); let idx = cnt + offset; cnt += 1; diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index 8cedb4c6044d..2d1c14caacd4 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -2,6 +2,7 @@ use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter; use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups}; use arrow::legacy::prelude::*; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; use crate::config::verbose; @@ -25,9 +26,9 @@ fn group_multithreaded(ca: &ChunkedArray) -> bool { fn num_groups_proxy(ca: &ChunkedArray, multithreaded: bool, sorted: bool) -> GroupsProxy where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Copy + Send + DirtyHash, { if multithreaded && group_multithreaded(ca) { let n_partitions = _set_partition_size(); @@ -163,14 +164,28 @@ where }; num_groups_proxy(ca, multithreaded, sorted) }, - DataType::Int64 | DataType::Float64 => { + DataType::Int64 => { let ca = self.bit_repr_large(); num_groups_proxy(&ca, multithreaded, sorted) }, - DataType::Int32 | DataType::Float32 => { + DataType::Int32 => { let ca = self.bit_repr_small(); num_groups_proxy(&ca, multithreaded, sorted) }, + DataType::Float64 => { + // convince the compiler that we are this type. + let ca: &Float64Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + DataType::Float32 => { + // convince the compiler that we are this type. + let ca: &Float32Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))] DataType::Int8 => { // convince the compiler that we are this type. diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 809e300bcd01..f6dd8d7fa9a4 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -38,7 +38,10 @@ fn prepare_dataframe_unsorted(by: &[Series]) -> DataFrame { _ => { if s.dtype().to_physical().is_numeric() { let s = s.to_physical_repr(); - if s.bit_repr_is_large() { + + if s.dtype().is_float() { + s.into_owned().into_series() + } else if s.bit_repr_is_large() { s.bit_repr_large().into_series() } else { s.bit_repr_small().into_series() diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 4b882bb2ce5e..403123b5c71b 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -1,6 +1,7 @@ use arrow::bitmap::utils::get_bit_unchecked; #[cfg(feature = "group_by_list")] use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use rayon::prelude::*; use xxhash_rust::xxh3::xxh3_64_with_seed; @@ -67,10 +68,11 @@ fn insert_null_hash(chunks: &[ArrayRef], random_state: RandomState, buf: &mut Ve }); } -fn integer_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) +fn numeric_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to @@ -89,16 +91,17 @@ where .as_slice() .iter() .copied() - .map(|v| random_state.hash_one(v)), + .map(|v| random_state.hash_one(v.to_total_ord())), ); }); insert_null_hash(&ca.chunks, random_state, buf) } -fn integer_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) +fn numeric_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { let null_h = get_null_hash_value(&random_state); @@ -111,7 +114,7 @@ where .iter() .zip(&mut hashes[offset..]) .for_each(|(v, h)| { - *h = folded_multiply(random_state.hash_one(v) ^ *h, MULTIPLE); + *h = folded_multiply(random_state.hash_one(v.to_total_ord()) ^ *h, MULTIPLE); }), _ => { let validity = arr.validity().unwrap(); @@ -121,7 +124,7 @@ where .zip(&mut hashes[offset..]) .zip(arr.values().as_slice()) .for_each(|((valid, h), l)| { - let lh = random_state.hash_one(l); + let lh = random_state.hash_one(l.to_total_ord()); let to_hash = [null_h, lh][valid as usize]; // inlined from ahash. This ensures we combine with the previous state @@ -133,11 +136,11 @@ where }); } -macro_rules! vec_hash_int { +macro_rules! vec_hash_numeric { ($ca:ident) => { impl VecHash for $ca { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - integer_vec_hash(self, random_state, buf); + numeric_vec_hash(self, random_state, buf); Ok(()) } @@ -146,21 +149,23 @@ macro_rules! vec_hash_int { random_state: RandomState, hashes: &mut [u64], ) -> PolarsResult<()> { - integer_vec_hash_combine(self, random_state, hashes); + numeric_vec_hash_combine(self, random_state, hashes); Ok(()) } } }; } -vec_hash_int!(Int64Chunked); -vec_hash_int!(Int32Chunked); -vec_hash_int!(Int16Chunked); -vec_hash_int!(Int8Chunked); -vec_hash_int!(UInt64Chunked); -vec_hash_int!(UInt32Chunked); -vec_hash_int!(UInt16Chunked); -vec_hash_int!(UInt8Chunked); +vec_hash_numeric!(Int64Chunked); +vec_hash_numeric!(Int32Chunked); +vec_hash_numeric!(Int16Chunked); +vec_hash_numeric!(Int8Chunked); +vec_hash_numeric!(UInt64Chunked); +vec_hash_numeric!(UInt32Chunked); +vec_hash_numeric!(UInt16Chunked); +vec_hash_numeric!(UInt8Chunked); +vec_hash_numeric!(Float64Chunked); +vec_hash_numeric!(Float32Chunked); impl VecHash for StringChunked { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { @@ -370,30 +375,6 @@ impl VecHash for BooleanChunked { } } -impl VecHash for Float32Chunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.bit_repr_small().vec_hash(random_state, buf)?; - Ok(()) - } - - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.bit_repr_small() - .vec_hash_combine(random_state, hashes)?; - Ok(()) - } -} -impl VecHash for Float64Chunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.bit_repr_large().vec_hash(random_state, buf)?; - Ok(()) - } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.bit_repr_large() - .vec_hash_combine(random_state, hashes)?; - Ok(()) - } -} - #[cfg(feature = "group_by_list")] impl VecHash for ListChunked { fn vec_hash(&self, _random_state: RandomState, _buf: &mut Vec) -> PolarsResult<()> { diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs index 6a107d595a48..2c8f5e65ca73 100644 --- a/crates/polars-core/src/utils/series.rs +++ b/crates/polars-core/src/utils/series.rs @@ -2,7 +2,7 @@ use crate::prelude::*; use crate::series::unstable::UnstableSeries; use crate::series::IsSorted; -/// Transform to physical type and coerce floating point and similar sized integer to a bit representation +/// Transform to physical type and coerce similar sized integer to a bit representation /// to reduce compiler bloat pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { s.iter() @@ -11,8 +11,6 @@ pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { match physical.dtype() { DataType::Int64 => physical.bit_repr_large().into_series(), DataType::Int32 => physical.bit_repr_small().into_series(), - DataType::Float32 => physical.bit_repr_small().into_series(), - DataType::Float64 => physical.bit_repr_large().into_series(), _ => physical.into_owned(), } }) diff --git a/crates/polars-ops/src/chunked_array/list/hash.rs b/crates/polars-ops/src/chunked_array/list/hash.rs index 5931753f6ebf..711777d142ed 100644 --- a/crates/polars-ops/src/chunked_array/list/hash.rs +++ b/crates/polars-ops/src/chunked_array/list/hash.rs @@ -4,14 +4,16 @@ use polars_core::export::_boost_hash_combine; use polars_core::export::ahash::{self}; use polars_core::export::rayon::prelude::*; use polars_core::utils::NoNull; -use polars_core::POOL; +use polars_core::{with_match_physical_float_polars_type, POOL}; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; fn hash_agg(ca: &ChunkedArray, random_state: &ahash::RandomState) -> u64 where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to @@ -30,7 +32,7 @@ where for opt_v in arr.iter() { match opt_v { Some(v) => { - let r = random_state.hash_one(v); + let r = random_state.hash_one(v.to_total_ord()); hash_agg = _boost_hash_combine(hash_agg, r); }, None => { @@ -60,7 +62,12 @@ pub(crate) fn hash(ca: &mut ListChunked, build_hasher: ahash::RandomState) -> UI .map(|opt_s: Option| match opt_s { None => null_hash, Some(s) => { - if s.bit_repr_is_large() { + if s.dtype().is_float() { + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + hash_agg(ca, &build_hasher) + }) + } else if s.bit_repr_is_large() { let ca = s.bit_repr_large(); hash_agg(&ca, &build_hasher) } else { diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 9970478184b4..38f1bd2cfad3 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -4,10 +4,11 @@ use ahash::RandomState; use num_traits::Zero; use polars_core::hashing::{_df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; use polars_core::utils::{split_ca, split_df}; -use polars_core::POOL; +use polars_core::{with_match_physical_float_polars_type, POOL}; use polars_utils::abs_diff::AbsDiff; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use rayon::prelude::*; use smartstring::alias::String as SmartString; @@ -71,7 +72,8 @@ fn asof_join_by_numeric( where T: PolarsDataType, S: PolarsNumericType, - S::Native: Hash + Eq + DirtyHash + IsNull, + S::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash + IsNull + Copy, A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, { @@ -109,6 +111,7 @@ where results.push(None); continue; }; + let by_left_k = by_left_k.to_total_ord(); let idx_left = (rel_idx_left + offset) as IdxSize; let Some(left_val) = left_val_arr.get(idx_left as usize) else { results.push(None); @@ -118,7 +121,7 @@ where let group_probe_table = unsafe { hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) }; - let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else { + let Some(right_grp_idxs) = group_probe_table.get(&by_left_k) else { results.push(None); continue; }; @@ -326,7 +329,15 @@ where asof_join_by_binary::(left_by, right_by, left_asof, right_asof, filter) }, _ => { - if left_by_s.bit_repr_is_large() { + if left_by_s.dtype().is_float() { + with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { + let left_by: &ChunkedArray<$T> = left_by_s.as_ref().as_ref().as_ref(); + let right_by: &ChunkedArray<$T> = right_by_s.as_ref().as_ref().as_ref(); + asof_join_by_numeric::( + left_by, right_by, left_asof, right_asof, filter, + )? + }) + } else if left_by_s.bit_repr_is_large() { let left_by = left_by_s.bit_repr_large(); let right_by = right_by_s.bit_repr_large(); asof_join_by_numeric::( diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs index ee92dfdd6c45..72cea0d9589e 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs @@ -2,6 +2,7 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use polars_utils::unitvec; use super::*; @@ -12,9 +13,13 @@ use super::*; // Use a small element per thread threshold for debugging/testing purposes. const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 }; -pub(crate) fn build_tables(keys: Vec, join_nulls: bool) -> Vec> +pub(crate) fn build_tables( + keys: Vec, + join_nulls: bool, +) -> Vec::TotalOrdItem, IdxVec>> where - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Copy + Hash + Eq + DirtyHash + IsNull, I: IntoIterator + Send + Sync + Clone, { // FIXME: change interface to split the input here, instead of taking @@ -28,10 +33,11 @@ where // Don't bother parallelizing anything for small inputs. if num_keys_est < 2 * MIN_ELEMS_PER_THREAD { - let mut hm: PlHashMap = PlHashMap::new(); + let mut hm: PlHashMap = PlHashMap::new(); let mut offset = 0; for it in keys { for k in it { + let k = k.to_total_ord(); if !k.is_null() || join_nulls { hm.entry(k).or_default().push(offset); } @@ -49,6 +55,7 @@ where .map(|key_portion| { let mut partition_sizes = vec![0; n_partitions]; for key in key_portion.clone() { + let key = key.to_total_ord(); let p = hash_to_partition(key.dirty_hash(), n_partitions); unsafe { *partition_sizes.get_unchecked_mut(p) += 1; @@ -85,7 +92,7 @@ where } // Scatter values into partitions. - let mut scatter_keys: Vec = Vec::with_capacity(num_keys); + let mut scatter_keys: Vec = Vec::with_capacity(num_keys); let mut scatter_idxs: Vec = Vec::with_capacity(num_keys); let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) }; let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) }; @@ -96,6 +103,7 @@ where let mut partition_offsets = per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec(); for (i, key) in key_portion.into_iter().enumerate() { + let key = key.to_total_ord(); unsafe { let p = hash_to_partition(key.dirty_hash(), n_partitions); let off = partition_offsets.get_unchecked_mut(p); @@ -124,7 +132,8 @@ where let partition_range = partition_offsets[p]..partition_offsets[p + 1]; let full_size = partition_range.len(); let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64); - let mut hm: PlHashMap = PlHashMap::with_capacity(conservative_size); + let mut hm: PlHashMap = + PlHashMap::with_capacity(conservative_size); unsafe { for i in partition_range { @@ -160,8 +169,6 @@ where pub(super) fn probe_to_offsets(probe: &[I]) -> Vec where I: IntoIterator + Clone, - // ::IntoIter: TrustedLen, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, { probe .iter() diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index dca9e1326097..9001e3780fb2 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -1,7 +1,8 @@ use arrow::array::PrimitiveArray; -use num_traits::NumCast; +use polars_core::with_match_physical_float_polars_type; use polars_utils::hashing::DirtyHash; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; use crate::series::SeriesSealed; @@ -28,6 +29,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_left(lhs, rhs, validate, join_nulls) + }) } else if s_self.bit_repr_is_large() { let lhs = lhs.bit_repr_large(); let rhs = rhs.bit_repr_large(); @@ -58,6 +65,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { } else { hash_join_tuples_left_semi(lhs, rhs) } + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_anti_semi(lhs, rhs, anti) + }) } else if s_self.bit_repr_is_large() { let lhs = lhs.bit_repr_large(); let rhs = rhs.bit_repr_large(); @@ -93,6 +106,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?, !swapped, )) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + group_join_inner::<$T>(lhs, rhs, validate, join_nulls) + }) } else if s_self.bit_repr_is_large() { let lhs = s_self.bit_repr_large(); let rhs = other.bit_repr_large(); @@ -124,6 +143,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { let lhs = lhs.iter().collect::>(); let rhs = rhs.iter().collect::>(); hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + hash_join_outer(lhs, rhs, validate, join_nulls) + }) } else if s_self.bit_repr_is_large() { let lhs = s_self.bit_repr_large(); let rhs = other.bit_repr_large(); @@ -161,7 +186,10 @@ fn group_join_inner( where T: PolarsDataType, for<'a> &'a T::Array: IntoIterator>>, - for<'a> T::Physical<'a>: Hash + Eq + Send + DirtyHash + Copy + Send + Sync + IsNull, + for<'a> T::Physical<'a>: + Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + for<'a> as ToTotalOrd>::TotalOrdItem: + Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { let n_threads = POOL.current_num_threads(); let (a, b, swapped) = det_hash_prone_order!(left, right); @@ -243,9 +271,11 @@ fn num_group_join_left( join_nulls: bool, ) -> PolarsResult where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash + IsNull, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Copy + Hash + Eq + DirtyHash + IsNull, + Option: DirtyHash + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: DirtyHash, { let n_threads = POOL.current_num_threads(); let splitted_a = split_ca(left, n_threads).unwrap(); @@ -300,8 +330,9 @@ fn hash_join_outer( join_nulls: bool, ) -> PolarsResult<(PrimitiveArray, PrimitiveArray)> where - T: PolarsIntegerType + Sync, - T::Native: Eq + Hash + NumCast, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + IsNull, { let (a, b, swapped) = det_hash_prone_order!(ca_in, other); @@ -395,9 +426,11 @@ fn num_group_join_anti_semi( anti: bool, ) -> Vec where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, + Option: DirtyHash + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: DirtyHash, { let n_threads = POOL.current_num_threads(); let splitted_a = split_ca(left, n_threads).unwrap(); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs index bc5e5d4acdce..58bdd286a814 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -4,23 +4,25 @@ use polars_utils::idx_vec::IdxVec; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; pub(super) fn probe_inner( probe: I, - hash_tbls: &[PlHashMap], + hash_tbls: &[PlHashMap<::TotalOrdItem, IdxVec>], results: &mut Vec<(IdxSize, IdxSize)>, local_offset: IdxSize, n_tables: usize, swap_fn: F, ) where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, I: IntoIterator, - // ::IntoIter: TrustedLen, F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize), { probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = idx_a + local_offset; // probe table that contains the hashed value let current_probe_table = @@ -45,8 +47,8 @@ pub(super) fn hash_join_tuples_inner( ) -> PolarsResult<(Vec, Vec)> where I: IntoIterator + Send + Sync + Clone, - // ::IntoIter: TrustedLen, - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { // NOTE: see the left join for more elaborate comments // first we hash one relation diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs index 51956d41585d..7bdbb5dcaade 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs @@ -1,6 +1,7 @@ use polars_core::utils::flatten::flatten_par; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; @@ -112,7 +113,8 @@ pub(super) fn hash_join_tuples_left( where I: IntoIterator, ::IntoIter: Send + Sync + Clone, - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); let build = build.into_iter().map(|i| i.into_iter()).collect::>(); @@ -147,6 +149,7 @@ where let mut result_idx_right = Vec::with_capacity(probe.size_hint().1.unwrap()); probe.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = (idx_a + offset) as IdxSize; // probe table that contains the hashed value let current_probe_table = unsafe { diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs index 33c4a376de87..98bb3c2e377c 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs @@ -3,6 +3,7 @@ use arrow::legacy::utils::CustomIterTools; use polars_utils::hashing::hash_to_partition; use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use polars_utils::unitvec; use super::*; @@ -14,7 +15,8 @@ pub(crate) fn create_hash_and_keys_threaded_vectorized( where I: IntoIterator + Send, I::IntoIter: TrustedLen, - T: Send + Hash + Eq, + T: TotalHash + TotalEq + Send + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let build_hasher = build_hasher.unwrap_or_default(); let hashes = POOL.install(|| { @@ -23,7 +25,7 @@ where .map(|iter| { // create hashes and keys iter.into_iter() - .map(|val| (build_hasher.hash_one(&val), val)) + .map(|val| (build_hasher.hash_one(&val.to_total_ord()), val)) .collect_trusted::>() }) .collect() @@ -33,10 +35,11 @@ where pub(crate) fn prepare_hashed_relation_threaded( iters: Vec, -) -> Vec> +) -> Vec::TotalOrdItem, (bool, IdxVec)>> where I: Iterator + Send + TrustedLen, - T: Send + Hash + Eq + Sync + Copy, + T: Send + Sync + TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let n_partitions = _set_partition_size(); let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None); @@ -50,7 +53,7 @@ where .map(|partition_no| { let build_hasher = build_hasher.clone(); let hashes_and_keys = &hashes_and_keys; - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_hasher(build_hasher); let mut offset = 0; @@ -60,6 +63,7 @@ where .iter() .enumerate() .for_each(|(idx, (h, k))| { + let k = k.to_total_ord(); let idx = idx as IdxSize; // partition hashes by thread no. // So only a part of the hashes go to this hashmap @@ -68,11 +72,11 @@ where let entry = hash_tbl .raw_entry_mut() // uses the key to check equality to find and entry - .from_key_hashed_nocheck(*h, k); + .from_key_hashed_nocheck(*h, &k); match entry { RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(*h, *k, (false, unitvec![idx])); + entry.insert_hashed_nocheck(*h, k, (false, unitvec![idx])); }, RawEntryMut::Occupied(mut entry) => { let (_k, v) = entry.get_key_value_mut(); @@ -94,7 +98,7 @@ where #[allow(clippy::too_many_arguments)] fn probe_outer( probe_hashes: &[Vec<(u64, T)>], - hash_tbls: &mut [PlHashMap], + hash_tbls: &mut [PlHashMap<::TotalOrdItem, (bool, IdxVec)>], results: &mut ( MutablePrimitiveArray, MutablePrimitiveArray, @@ -108,7 +112,8 @@ fn probe_outer( swap_fn_drain: H, join_nulls: bool, ) where - T: Send + Hash + Eq + Sync + Copy + IsNull, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + IsNull, // idx_a, idx_b -> ... F: Fn(IdxSize, IdxSize) -> (Option, Option), // idx_a -> ... @@ -120,6 +125,7 @@ fn probe_outer( let mut idx_a = 0; for probe_hashes in probe_hashes { for (h, key) in probe_hashes { + let key = key.to_total_ord(); let h = *h; // probe table that contains the hashed value let current_probe_table = @@ -127,7 +133,7 @@ fn probe_outer( let entry = current_probe_table .raw_entry_mut() - .from_key_hashed_nocheck(h, key); + .from_key_hashed_nocheck(h, &key); match entry { // match and remove @@ -182,7 +188,8 @@ where J: IntoIterator, ::IntoIter: TrustedLen + Send, ::IntoIter: TrustedLen + Send, - T: Hash + Eq + Copy + Sync + Send + IsNull, + T: Send + Sync + TotalHash + TotalEq + IsNull + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + IsNull, { let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); let build = build.into_iter().map(|i| i.into_iter()).collect::>(); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs index 00ad29499715..c2d5695b8b72 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs @@ -1,11 +1,15 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; /// Only keeps track of membership in right table -pub(super) fn create_probe_table_semi_anti(keys: Vec) -> Vec> +pub(super) fn create_probe_table_semi_anti( + keys: Vec, +) -> Vec::TotalOrdItem>> where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, I: IntoIterator + Copy + Send + Sync, { let n_partitions = _set_partition_size(); @@ -14,9 +18,10 @@ where // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. let par_iter = (0..n_partitions).into_par_iter().map(|partition_no| { - let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); + let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); for keys in &keys { keys.into_iter().for_each(|k| { + let k = k.to_total_ord(); if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { hash_tbl.insert(k); } @@ -35,7 +40,8 @@ fn semi_anti_impl( ) -> impl ParallelIterator where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, { // first we hash one relation let hash_sets = create_probe_table_semi_anti(build); @@ -61,6 +67,7 @@ where let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); probe_iter.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = (idx_a + offset) as IdxSize; // probe table that contains the hashed value let current_probe_table = @@ -83,7 +90,8 @@ where pub(super) fn hash_join_tuples_left_anti(probe: Vec, build: Vec) -> Vec where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, { let par_iter = semi_anti_impl(probe, build) .filter(|tpls| !tpls.1) @@ -94,7 +102,8 @@ where pub(super) fn hash_join_tuples_left_semi(probe: Vec, build: Vec) -> Vec where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, { let par_iter = semi_anti_impl(probe, build) .filter(|tpls| tpls.1) diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index d494c80f6bdf..ac2c3e948e2d 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -3,6 +3,7 @@ use std::hash::Hash; use arrow::legacy::trusted_len::TrustedLenPush; use polars_core::prelude::*; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; @@ -175,23 +176,23 @@ where fn compute_col_idx_numeric(column_agg_physical: &ChunkedArray) -> Vec where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); let mut idx = 0 as IdxSize; let mut out = Vec::with_capacity(column_agg_physical.len()); - for arr in column_agg_physical.downcast_iter() { - for opt_v in arr.into_iter() { - let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { - let old_idx = idx; - idx += 1; - old_idx - }); - // SAFETY: - // we pre-allocated - unsafe { out.push_unchecked(idx) }; - } + for opt_v in column_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + // SAFETY: + // we pre-allocated + unsafe { out.push_unchecked(idx) }; } out } @@ -232,14 +233,22 @@ pub(super) fn compute_col_idx( use DataType::*; let col_locations = match column_agg_physical.dtype() { - Int32 | UInt32 | Float32 => { + Int32 | UInt32 => { let ca = column_agg_physical.bit_repr_small(); compute_col_idx_numeric(&ca) }, - Int64 | UInt64 | Float64 => { + Int64 | UInt64 => { let ca = column_agg_physical.bit_repr_large(); compute_col_idx_numeric(&ca) }, + Float64 => { + let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); + compute_col_idx_numeric(ca) + }, + Float32 => { + let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); + compute_col_idx_numeric(ca) + }, Struct(_) => { let ca = column_agg_physical.struct_().unwrap(); let ca = ca.rows_encode()?; @@ -286,7 +295,8 @@ fn compute_row_index<'a, T>( ) -> (Vec, usize, Option>) where T: PolarsDataType, - T::Physical<'a>: Hash + Eq + Copy, + Option>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, ChunkedArray: FromIterator>>, ChunkedArray: IntoSeries, { @@ -295,26 +305,29 @@ where let mut idx = 0 as IdxSize; let mut row_locations = Vec::with_capacity(index_agg_physical.len()); - for arr in index_agg_physical.downcast_iter() { - for opt_v in arr.iter() { - let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { - let old_idx = idx; - idx += 1; - old_idx - }); + for opt_v in index_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); - // SAFETY: - // we pre-allocated - unsafe { - row_locations.push_unchecked(idx); - } + // SAFETY: + // we pre-allocated + unsafe { + row_locations.push_unchecked(idx); } } let row_index = match count { 0 => { let mut s = row_to_idx .into_iter() - .map(|(k, _)| k) + .map(|(k, _)| { + let out = Option::>::peel_total_ord(k); + let out: Option> = unsafe { std::mem::transmute_copy(&out) }; + out + }) .collect::>() .into_series(); s.rename(&index[0]); @@ -386,14 +399,22 @@ pub(super) fn compute_row_idx( use DataType::*; match index_agg_physical.dtype() { - Int32 | UInt32 | Float32 => { + Int32 | UInt32 => { let ca = index_agg_physical.bit_repr_small(); compute_row_index(index, &ca, count, index_s.dtype()) }, - Int64 | UInt64 | Float64 => { + Int64 | UInt64 => { let ca = index_agg_physical.bit_repr_large(); compute_row_index(index, &ca, count, index_s.dtype()) }, + Float64 => { + let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + Float32 => { + let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, Boolean => { let ca = index_agg_physical.bool().unwrap(); compute_row_index(index, ca, count, index_s.dtype()) diff --git a/crates/polars-ops/src/series/ops/approx_unique.rs b/crates/polars-ops/src/series/ops/approx_unique.rs index fe5d70372395..c1eabb6c20ff 100644 --- a/crates/polars-ops/src/series/ops/approx_unique.rs +++ b/crates/polars-ops/src/series/ops/approx_unique.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; #[cfg(feature = "approx_unique")] use crate::series::ops::approx_algo::HyperLogLog; @@ -9,10 +10,11 @@ use crate::series::ops::approx_algo::HyperLogLog; fn approx_n_unique_ca<'a, T>(ca: &'a ChunkedArray) -> PolarsResult where T: PolarsDataType, - T::Physical<'a>: Hash + Eq, + Option>: TotalHash + TotalEq + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, { let mut hllp = HyperLogLog::new(); - ca.iter().for_each(|item| hllp.add(&item)); + ca.iter().for_each(|item| hllp.add(&item.to_total_ord())); let c = hllp.count() as IdxSize; Ok(Series::new(ca.name(), &[c])) @@ -28,8 +30,12 @@ fn dispatcher(s: &Series) -> PolarsResult { let ca = s.str().unwrap().as_binary(); approx_n_unique_ca(&ca) }, - Float32 => approx_n_unique_ca(&s.bit_repr_small()), - Float64 => approx_n_unique_ca(&s.bit_repr_large()), + Float32 => approx_n_unique_ca(AsRef::>::as_ref( + s.as_ref().as_ref(), + )), + Float64 => approx_n_unique_ca(AsRef::>::as_ref( + s.as_ref().as_ref(), + )), dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); diff --git a/crates/polars-ops/src/series/ops/is_first_distinct.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs index 178c80bb980d..b75ae23dba1f 100644 --- a/crates/polars-ops/src/series/ops/is_first_distinct.rs +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -5,16 +5,18 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::bit_util::*; use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; -use polars_core::with_match_physical_integer_polars_type; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_first_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut unique = PlHashSet::new(); let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { arr.into_iter() - .map(|opt_v| unique.insert(opt_v)) + .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_trusted() }); @@ -126,16 +128,8 @@ pub fn is_first_distinct(s: &Series) -> PolarsResult { let s = s.cast(&Binary).unwrap(); return is_first_distinct(&s); }, - Float32 => { - let ca = s.bit_repr_small(); - is_first_distinct_numeric(&ca) - }, - Float64 => { - let ca = s.bit_repr_large(); - is_first_distinct_numeric(&ca) - }, dt if dt.is_numeric() => { - with_match_physical_integer_polars_type!(s.dtype(), |$T| { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); is_first_distinct_numeric(ca) }) diff --git a/crates/polars-ops/src/series/ops/is_last_distinct.rs b/crates/polars-ops/src/series/ops/is_last_distinct.rs index 84fe94a5c002..40d57d438151 100644 --- a/crates/polars-ops/src/series/ops/is_last_distinct.rs +++ b/crates/polars-ops/src/series/ops/is_last_distinct.rs @@ -5,7 +5,8 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::utils::NoNull; -use polars_core::with_match_physical_integer_polars_type; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; pub fn is_last_distinct(s: &Series) -> PolarsResult { // fast path. @@ -31,16 +32,8 @@ pub fn is_last_distinct(s: &Series) -> PolarsResult { let s = s.cast(&Binary).unwrap(); return is_last_distinct(&s); }, - Float32 => { - let ca = s.bit_repr_small(); - is_last_distinct_numeric(&ca) - }, - Float64 => { - let ca = s.bit_repr_large(); - is_last_distinct_numeric(&ca) - }, dt if dt.is_numeric() => { - with_match_physical_integer_polars_type!(s.dtype(), |$T| { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); is_last_distinct_numeric(ca) }) @@ -131,7 +124,8 @@ fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { fn is_last_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); @@ -139,7 +133,7 @@ where let mut new_ca: BooleanChunked = arr .into_iter() .rev() - .map(|opt_v| unique.insert(opt_v)) + .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_reversed::>() .into_inner(); new_ca.rename(ca.name()); diff --git a/crates/polars-ops/src/series/ops/unique.rs b/crates/polars-ops/src/series/ops/unique.rs index e35847b120a8..3a2d9b5652fe 100644 --- a/crates/polars-ops/src/series/ops/unique.rs +++ b/crates/polars-ops/src/series/ops/unique.rs @@ -3,14 +3,18 @@ use std::hash::Hash; use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; use polars_core::utils::NoNull; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn unique_counts_helper(items: I) -> IdxCa where I: Iterator, - J: Hash + Eq, + J: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut map = PlIndexMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); for item in items { + let item = item.to_total_ord(); map.entry(item) .and_modify(|cnt| { *cnt += 1; @@ -24,13 +28,12 @@ where /// Returns a count of the unique values in the order of appearance. pub fn unique_counts(s: &Series) -> PolarsResult { if s.dtype().to_physical().is_numeric() { - if s.bit_repr_is_large() { - let ca = s.bit_repr_large(); - Ok(unique_counts_helper(ca.iter()).into_series()) - } else { - let ca = s.bit_repr_small(); + let s_physical = s.to_physical_repr(); + + with_match_physical_numeric_polars_type!(s_physical.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s_physical.as_ref().as_ref().as_ref(); Ok(unique_counts_helper(ca.iter()).into_series()) - } + }) } else { match s.dtype() { DataType::String => { diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index b3d3440adc5e..bdabf9d815a2 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -1,4 +1,5 @@ use polars_utils::arena::Arena; +use polars_utils::total_ord::ToTotalOrd; #[cfg(all(feature = "strings", feature = "concat_str"))] use crate::dsl::function_expr::StringFunction; @@ -58,10 +59,10 @@ macro_rules! eval_binary_bool_type { if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { match (lit_left, lit_right) { (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { - Some(AExpr::Literal(LiteralValue::Boolean(x $operand y))) + Some(AExpr::Literal(LiteralValue::Boolean(x.to_total_ord() $operand y.to_total_ord()))) } (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { - Some(AExpr::Literal(LiteralValue::Boolean(x $operand y))) + Some(AExpr::Literal(LiteralValue::Boolean(x.to_total_ord() $operand y.to_total_ord()))) } #[cfg(feature = "dtype-i8")] (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs index 8dac484d5d96..9ef6f091ebf4 100644 --- a/crates/polars-utils/src/total_ord.rs +++ b/crates/polars-utils/src/total_ord.rs @@ -3,6 +3,9 @@ use std::hash::{Hash, Hasher}; use bytemuck::TransparentWrapper; +use crate::hashing::{BytesHash, DirtyHash}; +use crate::nulls::IsNull; + /// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to /// the same value. pub fn canonical_f32(x: f32) -> f32 { @@ -149,6 +152,37 @@ impl Clone for TotalOrdWrap { impl Copy for TotalOrdWrap {} +impl IsNull for TotalOrdWrap { + const HAS_NULLS: bool = T::HAS_NULLS; + type Inner = T::Inner; + + fn is_null(&self) -> bool { + self.0.is_null() + } + + fn unwrap_inner(self) -> Self::Inner { + self.0.unwrap_inner() + } +} + +impl DirtyHash for f32 { + fn dirty_hash(&self) -> u64 { + canonical_f32(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for f64 { + fn dirty_hash(&self) -> u64 { + canonical_f64(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for TotalOrdWrap { + fn dirty_hash(&self) -> u64 { + self.0.dirty_hash() + } +} + macro_rules! impl_trivial_total { ($T: ty) => { impl TotalEq for $T { @@ -402,3 +436,140 @@ impl TotalOrd for (T, U) { .then_with(|| self.1.tot_cmp(&other.1)) } } + +impl<'a> TotalHash for BytesHash<'a> { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state) + } +} + +impl<'a> TotalEq for BytesHash<'a> { + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +/// This elides creating a [`TotalOrdWrap`] for types that don't need it. +pub trait ToTotalOrd { + type TotalOrdItem: Send + Sync; + type SourceItem; + + fn to_total_ord(&self) -> Self::TotalOrdItem; + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem; +} + +macro_rules! impl_to_total_ord_identity { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = $T; + type SourceItem = $T; + + fn to_total_ord(&self) -> Self::TotalOrdItem { + self.clone() + } + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_identity!(bool); +impl_to_total_ord_identity!(u8); +impl_to_total_ord_identity!(u16); +impl_to_total_ord_identity!(u32); +impl_to_total_ord_identity!(u64); +impl_to_total_ord_identity!(u128); +impl_to_total_ord_identity!(usize); +impl_to_total_ord_identity!(i8); +impl_to_total_ord_identity!(i16); +impl_to_total_ord_identity!(i32); +impl_to_total_ord_identity!(i64); +impl_to_total_ord_identity!(i128); +impl_to_total_ord_identity!(isize); +impl_to_total_ord_identity!(char); +impl_to_total_ord_identity!(String); + +macro_rules! impl_to_total_ord_lifetimed_identity { + ($T: ty) => { + impl<'a> ToTotalOrd for &'a $T { + type TotalOrdItem = &'a $T; + type SourceItem = &'a $T; + + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_lifetimed_identity!(str); +impl_to_total_ord_lifetimed_identity!([u8]); + +macro_rules! impl_to_total_ord_wrapped { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = TotalOrdWrap<$T>; + type SourceItem = $T; + + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(self.clone()) + } + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } + } + }; +} + +impl_to_total_ord_wrapped!(f32); +impl_to_total_ord_wrapped!(f64); + +impl ToTotalOrd for Option { + type TotalOrdItem = TotalOrdWrap>; + type SourceItem = Option; + + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + +impl<'a> ToTotalOrd for BytesHash<'a> { + type TotalOrdItem = BytesHash<'a>; + type SourceItem = BytesHash<'a>; + + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } +} + +impl ToTotalOrd for &T { + type TotalOrdItem = T::TotalOrdItem; + type SourceItem = T::SourceItem; + + fn to_total_ord(&self) -> Self::TotalOrdItem { + (*self).to_total_ord() + } + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + T::peel_total_ord(ord_item) + } +} diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index 59da00e0078d..04ae800dc370 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -20,7 +20,7 @@ use polars_core::utils::arrow::types::NativeType; use polars_lazy::prelude::*; #[cfg(feature = "cloud")] use polars_rs::io::cloud::CloudOptions; -use polars_utils::total_ord::TotalEq; +use polars_utils::total_ord::{TotalEq, TotalHash}; use pyo3::basic::CompareOp; use pyo3::conversion::{FromPyObject, IntoPy}; use pyo3::exceptions::{PyTypeError, PyValueError}; @@ -498,6 +498,15 @@ impl TotalEq for ObjectValue { } } +impl TotalHash for ObjectValue { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } +} + impl Display for ObjectValue { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.inner) diff --git a/py-polars/tests/unit/datatypes/test_float.py b/py-polars/tests/unit/datatypes/test_float.py index 62c975b26298..d41ef1390c54 100644 --- a/py-polars/tests/unit/datatypes/test_float.py +++ b/py-polars/tests/unit/datatypes/test_float.py @@ -1,4 +1,7 @@ +import pytest + import polars as pl +from polars.testing import assert_series_equal def test_nan_in_group_by_agg() -> None: @@ -32,3 +35,237 @@ def test_nan_aggregations() -> None: str(df.group_by("b").agg(aggs).to_dict(as_series=False)) == "{'b': [1], 'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}" ) + + +@pytest.mark.parametrize( + ("s", "expect"), + [ + ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ), + pl.Series("x", [None, 0.0, 1.0, float("nan")]), + ), + ( + # No nulls + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + ], + ), + pl.Series("x", [0.0, 1.0, float("nan")]), + ), + ], +) +def test_unique(s: pl.Series, expect: pl.Series) -> None: + out = s.unique() + assert_series_equal(expect, out) + + out = s.n_unique() + assert expect.len() == out + + out = s.gather(s.arg_unique()).sort() + assert_series_equal(expect, out) + + +def test_unique_counts() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + expect = pl.Series("x", [2, 2, 1, 1], dtype=pl.UInt32) + out = s.unique_counts() + assert_series_equal(expect, out) + + +def test_hash() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ).hash() + + # check them against each other since hash is not stable + assert s.item(0) == s.item(1) # hash(-0.0) == hash(0.0) + assert s.item(2) == s.item(3) # hash(float('-nan')) == hash(float('nan')) + + +def test_group_by() -> None: + # Test num_groups_proxy + # * -0.0 and 0.0 in same groups + # * -nan and nan in same groups + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + expect = pl.Series("index", [[0, 1], [2, 3], [4], [5]], dtype=pl.List(pl.UInt32)) + + for group_keys in (("x",), ("x", "a")): + for maintain_order in (True, False): + for drop_nulls in (True, False): + out = df + if drop_nulls: + out = out.drop_nulls() + + out = ( + out.group_by(group_keys, maintain_order=maintain_order) + .agg("index") + .sort(pl.col("index").list.get(0)) + .select("index") + .to_series() + ) + + if drop_nulls: + assert_series_equal(expect.head(3), out) + else: + assert_series_equal(expect, out) + + +def test_joins() -> None: + # Test that -0.0 joins with 0.0 and nan joins with nan + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + rhs = ( + pl.Series("x", [0.0, float("nan"), 3.0]) + .to_frame() + .with_columns(a=pl.lit("a"), rhs=True) + ) + + for join_on in ( + # Single and multiple keys + ("x",), + ( + "x", + "a", + ), + ): + how = "left" + expect = pl.Series("rhs", [True, True, True, True, None, None]) + out = df.join(rhs, on=join_on, how=how).sort("index").select("rhs").to_series() + assert_series_equal(expect, out) + + how = "inner" + expect = pl.Series("index", [0, 1, 2, 3], dtype=pl.UInt32) + out = ( + df.join(rhs, on=join_on, how=how).sort("index").select("index").to_series() + ) + assert_series_equal(expect, out) + + how = "outer" + expect = pl.Series("rhs", [True, True, True, True, None, None, True]) + out = ( + df.join(rhs, on=join_on, how=how) + .sort("index", nulls_last=True) + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + how = "semi" + expect = pl.Series("x", [-0.0, 0.0, float("-nan"), float("nan")]) + out = ( + df.join(rhs, on=join_on, how=how) + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + how = "anti" + expect = pl.Series("x", [1.0, None]) + out = ( + df.join(rhs, on=join_on, how=how) + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + # test asof + # note that nans never join because nans are always greater than the other + # side of the comparison (in this case the tolerance) + expect = pl.Series("rhs", [True, True, None, None, None, None]) + out = ( + df.sort("x") + .join_asof(rhs.sort("x"), on="x", tolerance=0) + .sort("index") + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + +def test_first_last_distinct() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + + assert_series_equal( + pl.Series("x", [True, False, True, False, True, True]), s.is_first_distinct() + ) + + assert_series_equal( + pl.Series("x", [False, True, False, True, True, True]), s.is_last_distinct() + )