Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust): bubble error when no available bitrepr #17116

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 45 additions & 32 deletions crates/polars-core/src/chunked_array/ops/bit_repr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use arrow::buffer::Buffer;

use crate::prelude::*;
use crate::series::BitRepr;

/// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size
/// and alignment.
Expand Down Expand Up @@ -103,41 +104,41 @@ impl<T> ToBitRepr for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn bit_repr_is_large() -> bool {
std::mem::size_of::<T::Native>() == 8
}
fn to_bit_repr(&self) -> BitRepr {
let is_large = std::mem::size_of::<T::Native>() == 8;

fn bit_repr_large(&self) -> UInt64Chunked {
if std::mem::size_of::<T::Native>() == 8 {
if is_large {
if matches!(self.dtype(), DataType::UInt64) {
let ca = self.clone();
// Convince the compiler we are this type. This keeps flags.
return unsafe { std::mem::transmute::<ChunkedArray<T>, UInt64Chunked>(ca) };
return BitRepr::Large(unsafe {
std::mem::transmute::<ChunkedArray<T>, UInt64Chunked>(ca)
});
}
reinterpret_chunked_array(self)
} else {
unreachable!()
}
}

fn bit_repr_small(&self) -> UInt32Chunked {
if std::mem::size_of::<T::Native>() == 4 {
if matches!(self.dtype(), DataType::UInt32) {
let ca = self.clone();
// Convince the compiler we are this type. This preserves flags.
return unsafe { std::mem::transmute::<ChunkedArray<T>, UInt32Chunked>(ca) };
}
reinterpret_chunked_array(self)
BitRepr::Large(reinterpret_chunked_array(self))
} else {
// SAFETY: an unchecked cast to uint32 (which has no invariants) is
// always sound.
unsafe {
self.cast_unchecked(&DataType::UInt32)
.unwrap()
.u32()
.unwrap()
.clone()
}
BitRepr::Small(if std::mem::size_of::<T::Native>() == 4 {
if matches!(self.dtype(), DataType::UInt32) {
let ca = self.clone();
// Convince the compiler we are this type. This preserves flags.
return BitRepr::Small(unsafe {
std::mem::transmute::<ChunkedArray<T>, UInt32Chunked>(ca)
});
}

reinterpret_chunked_array(self)
} else {
// SAFETY: an unchecked cast to uint32 (which has no invariants) is
// always sound.
unsafe {
self.cast_unchecked(&DataType::UInt32)
.unwrap()
.u32()
.unwrap()
.clone()
}
})
}
}
}
Expand All @@ -160,7 +161,10 @@ impl Reinterpret for Int64Chunked {
}

fn reinterpret_unsigned(&self) -> Series {
self.bit_repr_large().into_series()
let BitRepr::Large(b) = self.to_bit_repr() else {
unreachable!()
};
b.into_series()
}
}

Expand All @@ -183,7 +187,10 @@ impl Reinterpret for Int32Chunked {
}

fn reinterpret_unsigned(&self) -> Series {
self.bit_repr_small().into_series()
let BitRepr::Small(b) = self.to_bit_repr() else {
unreachable!()
};
b.into_series()
}
}

Expand Down Expand Up @@ -250,7 +257,10 @@ impl Float32Chunked {
where
F: Fn(&Series) -> Series,
{
let s = self.bit_repr_small().into_series();
let BitRepr::Small(s) = self.to_bit_repr() else {
unreachable!()
};
let s = s.into_series();
let out = f(&s);
let out = out.u32().unwrap();
out._reinterpret_float().into()
Expand All @@ -261,7 +271,10 @@ impl Float64Chunked {
where
F: Fn(&Series) -> Series,
{
let s = self.bit_repr_large().into_series();
let BitRepr::Large(s) = self.to_bit_repr() else {
unreachable!()
};
let s = s.into_series();
let out = f(&s);
let out = out.u64().unwrap();
out._reinterpret_float().into()
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use serde::{Deserialize, Serialize};
pub use sort::options::*;

use crate::chunked_array::cast::CastOptions;
use crate::series::IsSorted;
use crate::series::{BitRepr, IsSorted};
#[cfg(feature = "reinterpret")]
pub trait Reinterpret {
fn reinterpret_signed(&self) -> Series {
Expand All @@ -59,10 +59,7 @@ pub trait Reinterpret {
/// This is useful in hashing context and reduces no.
/// of compiled code paths.
pub(crate) trait ToBitRepr {
fn bit_repr_is_large() -> bool;

fn bit_repr_large(&self) -> UInt64Chunked;
fn bit_repr_small(&self) -> UInt32Chunked;
fn to_bit_repr(&self) -> BitRepr;
}

pub trait ChunkAnyValue {
Expand Down
9 changes: 7 additions & 2 deletions crates/polars-core/src/frame/group_by/into_groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::*;
use crate::chunked_array::cast::CastOptions;
use crate::config::verbose;
use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered;
use crate::series::BitRepr;
use crate::utils::flatten::flatten_par;

/// Used to create the tuples for a group_by operation.
Expand Down Expand Up @@ -163,11 +164,15 @@ where
num_groups_proxy(ca, multithreaded, sorted)
},
DataType::Int64 => {
let ca = self.bit_repr_large();
let BitRepr::Large(ca) = self.to_bit_repr() else {
unreachable!()
};
num_groups_proxy(&ca, multithreaded, sorted)
},
DataType::Int32 => {
let ca = self.bit_repr_small();
let BitRepr::Small(ca) = self.to_bit_repr() else {
unreachable!()
};
num_groups_proxy(&ca, multithreaded, sorted)
},
DataType::Float64 => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,7 @@ impl SeriesTrait for SeriesWrap<CategoricalChunked> {
}

impl private::PrivateSeriesNumeric for SeriesWrap<CategoricalChunked> {
fn bit_repr_is_large(&self) -> bool {
false
}
fn bit_repr_small(&self) -> UInt32Chunked {
self.0.physical().clone()
fn bit_repr(&self) -> Option<BitRepr> {
Some(BitRepr::Small(self.0.physical().clone()))
}
}
12 changes: 2 additions & 10 deletions crates/polars-core/src/series/implementations/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,7 @@ impl SeriesTrait for SeriesWrap<DateChunked> {
}

impl private::PrivateSeriesNumeric for SeriesWrap<DateChunked> {
fn bit_repr_is_large(&self) -> bool {
false
}

fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
}

fn bit_repr_small(&self) -> UInt32Chunked {
self.0.bit_repr_small()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}
7 changes: 2 additions & 5 deletions crates/polars-core/src/series/implementations/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ unsafe impl IntoSeries for DatetimeChunked {
}

impl private::PrivateSeriesNumeric for SeriesWrap<DatetimeChunked> {
fn bit_repr_is_large(&self) -> bool {
true
}
fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}

Expand Down
6 changes: 5 additions & 1 deletion crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ unsafe impl IntoSeries for DecimalChunked {
}
}

impl private::PrivateSeriesNumeric for SeriesWrap<DecimalChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<DecimalChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl SeriesWrap<DecimalChunked> {
fn apply_physical_to_s<F: Fn(&Int128Chunked) -> Int128Chunked>(&self, f: F) -> Series {
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@ unsafe impl IntoSeries for DurationChunked {
}

impl private::PrivateSeriesNumeric for SeriesWrap<DurationChunked> {
fn bit_repr_is_large(&self) -> bool {
true
}
fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}

Expand Down
50 changes: 32 additions & 18 deletions crates/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,33 +484,47 @@ impl_dyn_series!(Int32Chunked);
impl_dyn_series!(Int64Chunked);

impl<T: PolarsNumericType> private::PrivateSeriesNumeric for SeriesWrap<ChunkedArray<T>> {
fn bit_repr_is_large(&self) -> bool {
ChunkedArray::<T>::bit_repr_is_large()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
}

impl private::PrivateSeriesNumeric for SeriesWrap<StringChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
fn bit_repr_small(&self) -> UInt32Chunked {
self.0.bit_repr_small()
}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryOffsetChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}
impl private::PrivateSeriesNumeric for SeriesWrap<ListChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl private::PrivateSeriesNumeric for SeriesWrap<StringChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryOffsetChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<ListChunked> {}
#[cfg(feature = "dtype-array")]
impl private::PrivateSeriesNumeric for SeriesWrap<ArrayChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<BooleanChunked> {
fn bit_repr_is_large(&self) -> bool {
false
impl private::PrivateSeriesNumeric for SeriesWrap<ArrayChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
fn bit_repr_small(&self) -> UInt32Chunked {
self.0
}
impl private::PrivateSeriesNumeric for SeriesWrap<BooleanChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
let repr = self
.0
.cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
.unwrap()
.u32()
.unwrap()
.clone()
.clone();

Some(BitRepr::Small(repr))
}
}
7 changes: 5 additions & 2 deletions crates/polars-core/src/series/implementations/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ impl NullChunked {
}
}
impl PrivateSeriesNumeric for NullChunked {
fn bit_repr_small(&self) -> UInt32Chunked {
UInt32Chunked::full_null(self.name.as_ref(), self.len())
fn bit_repr(&self) -> Option<BitRepr> {
Some(BitRepr::Small(UInt32Chunked::full_null(
self.name.as_ref(),
self.len(),
)))
}
}

Expand Down
8 changes: 6 additions & 2 deletions crates/polars-core/src/series/implementations/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ use std::borrow::Cow;

use ahash::RandomState;

use super::MetadataFlags;
use super::{BitRepr, MetadataFlags};
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::object::PolarsObjectSafe;
use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner};
use crate::prelude::*;
use crate::series::implementations::SeriesWrap;
use crate::series::private::{PrivateSeries, PrivateSeriesNumeric};

impl<T: PolarsObject> PrivateSeriesNumeric for SeriesWrap<ObjectChunked<T>> {}
impl<T: PolarsObject> PrivateSeriesNumeric for SeriesWrap<ObjectChunked<T>> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl<T> PrivateSeries for SeriesWrap<ObjectChunked<T>>
where
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-core/src/series/implementations/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ unsafe impl IntoSeries for StructChunked {
}
}

impl PrivateSeriesNumeric for SeriesWrap<StructChunked> {}
impl PrivateSeriesNumeric for SeriesWrap<StructChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl private::PrivateSeries for SeriesWrap<StructChunked> {
fn compute_len(&mut self) {
Expand Down
12 changes: 2 additions & 10 deletions crates/polars-core/src/series/implementations/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,7 @@ impl SeriesTrait for SeriesWrap<TimeChunked> {
}

impl private::PrivateSeriesNumeric for SeriesWrap<TimeChunked> {
fn bit_repr_is_large(&self) -> bool {
true
}

fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
}

fn bit_repr_small(&self) -> UInt32Chunked {
self.0.bit_repr_small()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}
Loading
Loading