Skip to content

Commit

Permalink
fix(rust): bubble error when no available bitrepr
Browse files Browse the repository at this point in the history
This PR refactors the code surrounding bit representations and properly bubbles
errors up if no bit representation is defined.

This resolves the panic in #14826, but does not implement the wanted behavior.

I am not sure a test case is useful here as the behavior should be added in the end.
  • Loading branch information
coastalwhite committed Jun 21, 2024
1 parent 8a6bf4b commit 4e26830
Show file tree
Hide file tree
Showing 19 changed files with 339 additions and 236 deletions.
77 changes: 45 additions & 32 deletions crates/polars-core/src/chunked_array/ops/bit_repr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use arrow::buffer::Buffer;

use crate::prelude::*;
use crate::series::BitRepr;

/// Reinterprets the type of a [`ChunkedArray`]. T and U must have the same size
/// and alignment.
Expand Down Expand Up @@ -103,41 +104,41 @@ impl<T> ToBitRepr for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn bit_repr_is_large() -> bool {
std::mem::size_of::<T::Native>() == 8
}
fn to_bit_repr(&self) -> BitRepr {
let is_large = std::mem::size_of::<T::Native>() == 8;

fn bit_repr_large(&self) -> UInt64Chunked {
if std::mem::size_of::<T::Native>() == 8 {
if is_large {
if matches!(self.dtype(), DataType::UInt64) {
let ca = self.clone();
// Convince the compiler we are this type. This keeps flags.
return unsafe { std::mem::transmute::<ChunkedArray<T>, UInt64Chunked>(ca) };
return BitRepr::Large(unsafe {
std::mem::transmute::<ChunkedArray<T>, UInt64Chunked>(ca)
});
}
reinterpret_chunked_array(self)
} else {
unreachable!()
}
}

fn bit_repr_small(&self) -> UInt32Chunked {
if std::mem::size_of::<T::Native>() == 4 {
if matches!(self.dtype(), DataType::UInt32) {
let ca = self.clone();
// Convince the compiler we are this type. This preserves flags.
return unsafe { std::mem::transmute::<ChunkedArray<T>, UInt32Chunked>(ca) };
}
reinterpret_chunked_array(self)
BitRepr::Large(reinterpret_chunked_array(self))
} else {
// SAFETY: an unchecked cast to uint32 (which has no invariants) is
// always sound.
unsafe {
self.cast_unchecked(&DataType::UInt32)
.unwrap()
.u32()
.unwrap()
.clone()
}
BitRepr::Small(if std::mem::size_of::<T::Native>() == 4 {
if matches!(self.dtype(), DataType::UInt32) {
let ca = self.clone();
// Convince the compiler we are this type. This preserves flags.
return BitRepr::Small(unsafe {
std::mem::transmute::<ChunkedArray<T>, UInt32Chunked>(ca)
});
}

reinterpret_chunked_array(self)
} else {
// SAFETY: an unchecked cast to uint32 (which has no invariants) is
// always sound.
unsafe {
self.cast_unchecked(&DataType::UInt32)
.unwrap()
.u32()
.unwrap()
.clone()
}
})
}
}
}
Expand All @@ -160,7 +161,10 @@ impl Reinterpret for Int64Chunked {
}

fn reinterpret_unsigned(&self) -> Series {
self.bit_repr_large().into_series()
let BitRepr::Large(b) = self.to_bit_repr() else {
unreachable!()
};
b.into_series()
}
}

Expand All @@ -183,7 +187,10 @@ impl Reinterpret for Int32Chunked {
}

fn reinterpret_unsigned(&self) -> Series {
self.bit_repr_small().into_series()
let BitRepr::Small(b) = self.to_bit_repr() else {
unreachable!()
};
b.into_series()
}
}

Expand Down Expand Up @@ -250,7 +257,10 @@ impl Float32Chunked {
where
F: Fn(&Series) -> Series,
{
let s = self.bit_repr_small().into_series();
let BitRepr::Small(s) = self.to_bit_repr() else {
unreachable!()
};
let s = s.into_series();
let out = f(&s);
let out = out.u32().unwrap();
out._reinterpret_float().into()
Expand All @@ -261,7 +271,10 @@ impl Float64Chunked {
where
F: Fn(&Series) -> Series,
{
let s = self.bit_repr_large().into_series();
let BitRepr::Large(s) = self.to_bit_repr() else {
unreachable!()
};
let s = s.into_series();
let out = f(&s);
let out = out.u64().unwrap();
out._reinterpret_float().into()
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use serde::{Deserialize, Serialize};
pub use sort::options::*;

use crate::chunked_array::cast::CastOptions;
use crate::series::IsSorted;
use crate::series::{BitRepr, IsSorted};
#[cfg(feature = "reinterpret")]
pub trait Reinterpret {
fn reinterpret_signed(&self) -> Series {
Expand All @@ -59,10 +59,7 @@ pub trait Reinterpret {
/// This is useful in hashing context and reduces no.
/// of compiled code paths.
pub(crate) trait ToBitRepr {
fn bit_repr_is_large() -> bool;

fn bit_repr_large(&self) -> UInt64Chunked;
fn bit_repr_small(&self) -> UInt32Chunked;
fn to_bit_repr(&self) -> BitRepr;
}

pub trait ChunkAnyValue {
Expand Down
9 changes: 7 additions & 2 deletions crates/polars-core/src/frame/group_by/into_groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::*;
use crate::chunked_array::cast::CastOptions;
use crate::config::verbose;
use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered;
use crate::series::BitRepr;
use crate::utils::flatten::flatten_par;

/// Used to create the tuples for a group_by operation.
Expand Down Expand Up @@ -163,11 +164,15 @@ where
num_groups_proxy(ca, multithreaded, sorted)
},
DataType::Int64 => {
let ca = self.bit_repr_large();
let BitRepr::Large(ca) = self.to_bit_repr() else {
unreachable!()
};
num_groups_proxy(&ca, multithreaded, sorted)
},
DataType::Int32 => {
let ca = self.bit_repr_small();
let BitRepr::Small(ca) = self.to_bit_repr() else {
unreachable!()
};
num_groups_proxy(&ca, multithreaded, sorted)
},
DataType::Float64 => {
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-core/src/series/implementations/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,7 @@ impl SeriesTrait for SeriesWrap<CategoricalChunked> {
}

impl private::PrivateSeriesNumeric for SeriesWrap<CategoricalChunked> {
fn bit_repr_is_large(&self) -> bool {
false
}
fn bit_repr_small(&self) -> UInt32Chunked {
self.0.physical().clone()
fn bit_repr(&self) -> Option<BitRepr> {
Some(BitRepr::Small(self.0.physical().clone()))
}
}
12 changes: 2 additions & 10 deletions crates/polars-core/src/series/implementations/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,7 @@ impl SeriesTrait for SeriesWrap<DateChunked> {
}

impl private::PrivateSeriesNumeric for SeriesWrap<DateChunked> {
fn bit_repr_is_large(&self) -> bool {
false
}

fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
}

fn bit_repr_small(&self) -> UInt32Chunked {
self.0.bit_repr_small()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}
7 changes: 2 additions & 5 deletions crates/polars-core/src/series/implementations/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@ unsafe impl IntoSeries for DatetimeChunked {
}

impl private::PrivateSeriesNumeric for SeriesWrap<DatetimeChunked> {
fn bit_repr_is_large(&self) -> bool {
true
}
fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}

Expand Down
6 changes: 5 additions & 1 deletion crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ unsafe impl IntoSeries for DecimalChunked {
}
}

impl private::PrivateSeriesNumeric for SeriesWrap<DecimalChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<DecimalChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl SeriesWrap<DecimalChunked> {
fn apply_physical_to_s<F: Fn(&Int128Chunked) -> Int128Chunked>(&self, f: F) -> Series {
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@ unsafe impl IntoSeries for DurationChunked {
}

impl private::PrivateSeriesNumeric for SeriesWrap<DurationChunked> {
fn bit_repr_is_large(&self) -> bool {
true
}
fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}

Expand Down
50 changes: 32 additions & 18 deletions crates/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,33 +484,47 @@ impl_dyn_series!(Int32Chunked);
impl_dyn_series!(Int64Chunked);

impl<T: PolarsNumericType> private::PrivateSeriesNumeric for SeriesWrap<ChunkedArray<T>> {
fn bit_repr_is_large(&self) -> bool {
ChunkedArray::<T>::bit_repr_is_large()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
}

impl private::PrivateSeriesNumeric for SeriesWrap<StringChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
fn bit_repr_small(&self) -> UInt32Chunked {
self.0.bit_repr_small()
}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryOffsetChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}
impl private::PrivateSeriesNumeric for SeriesWrap<ListChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl private::PrivateSeriesNumeric for SeriesWrap<StringChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<BinaryOffsetChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<ListChunked> {}
#[cfg(feature = "dtype-array")]
impl private::PrivateSeriesNumeric for SeriesWrap<ArrayChunked> {}
impl private::PrivateSeriesNumeric for SeriesWrap<BooleanChunked> {
fn bit_repr_is_large(&self) -> bool {
false
impl private::PrivateSeriesNumeric for SeriesWrap<ArrayChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
fn bit_repr_small(&self) -> UInt32Chunked {
self.0
}
impl private::PrivateSeriesNumeric for SeriesWrap<BooleanChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
let repr = self
.0
.cast_with_options(&DataType::UInt32, CastOptions::NonStrict)
.unwrap()
.u32()
.unwrap()
.clone()
.clone();

Some(BitRepr::Small(repr))
}
}
7 changes: 5 additions & 2 deletions crates/polars-core/src/series/implementations/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ impl NullChunked {
}
}
impl PrivateSeriesNumeric for NullChunked {
fn bit_repr_small(&self) -> UInt32Chunked {
UInt32Chunked::full_null(self.name.as_ref(), self.len())
fn bit_repr(&self) -> Option<BitRepr> {
Some(BitRepr::Small(UInt32Chunked::full_null(
self.name.as_ref(),
self.len(),
)))
}
}

Expand Down
8 changes: 6 additions & 2 deletions crates/polars-core/src/series/implementations/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ use std::borrow::Cow;

use ahash::RandomState;

use super::MetadataFlags;
use super::{BitRepr, MetadataFlags};
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::object::PolarsObjectSafe;
use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner};
use crate::prelude::*;
use crate::series::implementations::SeriesWrap;
use crate::series::private::{PrivateSeries, PrivateSeriesNumeric};

impl<T: PolarsObject> PrivateSeriesNumeric for SeriesWrap<ObjectChunked<T>> {}
impl<T: PolarsObject> PrivateSeriesNumeric for SeriesWrap<ObjectChunked<T>> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl<T> PrivateSeries for SeriesWrap<ObjectChunked<T>>
where
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-core/src/series/implementations/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ unsafe impl IntoSeries for StructChunked {
}
}

impl PrivateSeriesNumeric for SeriesWrap<StructChunked> {}
impl PrivateSeriesNumeric for SeriesWrap<StructChunked> {
fn bit_repr(&self) -> Option<BitRepr> {
None
}
}

impl private::PrivateSeries for SeriesWrap<StructChunked> {
fn compute_len(&mut self) {
Expand Down
12 changes: 2 additions & 10 deletions crates/polars-core/src/series/implementations/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,7 @@ impl SeriesTrait for SeriesWrap<TimeChunked> {
}

impl private::PrivateSeriesNumeric for SeriesWrap<TimeChunked> {
fn bit_repr_is_large(&self) -> bool {
true
}

fn bit_repr_large(&self) -> UInt64Chunked {
self.0.bit_repr_large()
}

fn bit_repr_small(&self) -> UInt32Chunked {
self.0.bit_repr_small()
fn bit_repr(&self) -> Option<BitRepr> {
Some(self.0.to_bit_repr())
}
}
Loading

0 comments on commit 4e26830

Please sign in to comment.