Skip to content

Commit

Permalink
genericize downcast_iter with HasUnderlyingArray
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Aug 15, 2023
1 parent e8a9aab commit f3a90d1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 159 deletions.
169 changes: 10 additions & 159 deletions crates/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use std::marker::PhantomData;

use arrow::array::*;

#[cfg(feature = "object")]
use crate::chunked_array::object::ObjectArray;
use crate::prelude::*;
use crate::utils::index_to_chunked_index;

Expand Down Expand Up @@ -46,18 +44,17 @@ impl<'a, T> Chunks<'a, T> {
}

#[doc(hidden)]
impl<T> ChunkedArray<T>
impl<T: PolarsDataType> ChunkedArray<T>
where
T: PolarsNumericType,
Self: HasUnderlyingArray,
{
pub fn downcast_iter(
&self,
) -> impl Iterator<Item = &PrimitiveArray<T::Native>> + DoubleEndedIterator {
) -> impl Iterator<Item = &<Self as HasUnderlyingArray>::ArrayT> + DoubleEndedIterator {
self.chunks.iter().map(|arr| {
// Safety:
// This should be the array type in PolarsNumericType
// SAFETY: HasUnderlyingArray guarantees this is correct.
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const PrimitiveArray<T::Native>) }
unsafe { &*(arr as *const dyn Array as *const <Self as HasUnderlyingArray>::ArrayT) }
})
}

Expand All @@ -67,165 +64,19 @@ where
/// * the flags (sorted, etc) remain correct.
pub unsafe fn downcast_iter_mut(
&mut self,
) -> impl Iterator<Item = &mut PrimitiveArray<T::Native>> + DoubleEndedIterator {
) -> impl Iterator<Item = &mut <Self as HasUnderlyingArray>::ArrayT> + DoubleEndedIterator {
self.chunks.iter_mut().map(|arr| {
// Safety:
// This should be the array type in PolarsNumericType
// SAFETY: HasUnderlyingArray guarantees this is correct.
let arr = &mut **arr;
&mut *(arr as *mut dyn Array as *mut PrimitiveArray<T::Native>)
&mut *(arr as *mut dyn Array as *mut <Self as HasUnderlyingArray>::ArrayT)
})
}

pub fn downcast_chunks(&self) -> Chunks<'_, PrimitiveArray<T::Native>> {
Chunks::new(&self.chunks)
}

/// Get the index of the chunk and the index of the value in that chunk
#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
return (0, index);
}
index_to_chunked_index(self.downcast_iter().map(|arr| arr.len()), index)
}
}

#[doc(hidden)]
impl BooleanChunked {
pub fn downcast_iter(&self) -> impl Iterator<Item = &BooleanArray> + DoubleEndedIterator {
self.chunks.iter().map(|arr| {
// Safety:
// This should be the array type in BooleanChunked
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const BooleanArray) }
})
}
pub fn downcast_chunks(&self) -> Chunks<'_, BooleanArray> {
Chunks::new(&self.chunks)
}

#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
return (0, index);
}
index_to_chunked_index(self.downcast_iter().map(|arr| arr.len()), index)
}
}

#[doc(hidden)]
impl Utf8Chunked {
pub fn downcast_iter(&self) -> impl Iterator<Item = &Utf8Array<i64>> + DoubleEndedIterator {
// Safety:
// This is the array type that must be in a Utf8Chunked
self.chunks.iter().map(|arr| {
// Safety:
// This should be the array type in Utf8Chunked
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const Utf8Array<i64>) }
})
}
pub fn downcast_chunks(&self) -> Chunks<'_, Utf8Array<i64>> {
Chunks::new(&self.chunks)
}

#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
return (0, index);
}
index_to_chunked_index(self.downcast_iter().map(|arr| arr.len()), index)
}
}

#[doc(hidden)]
impl BinaryChunked {
pub fn downcast_iter(&self) -> impl Iterator<Item = &BinaryArray<i64>> + DoubleEndedIterator {
// Safety:
// This is the array type that must be in a BinaryChunked
self.chunks.iter().map(|arr| {
// Safety:
// This should be the array type in BinaryChunked
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const BinaryArray<i64>) }
})
}
pub fn downcast_chunks(&self) -> Chunks<'_, BinaryArray<i64>> {
Chunks::new(&self.chunks)
}

#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
return (0, index);
}
index_to_chunked_index(self.downcast_iter().map(|arr| arr.len()), index)
}
}

#[doc(hidden)]
impl ListChunked {
pub fn downcast_iter(&self) -> impl Iterator<Item = &ListArray<i64>> + DoubleEndedIterator {
// Safety:
// This is the array type that must be in a ListChunked
self.chunks.iter().map(|arr| {
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const ListArray<i64>) }
})
}
pub fn downcast_chunks(&self) -> Chunks<'_, ListArray<i64>> {
Chunks::new(&self.chunks)
}

#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
return (0, index);
}
index_to_chunked_index(self.downcast_iter().map(|arr| arr.len()), index)
}
}

#[cfg(feature = "dtype-array")]
#[doc(hidden)]
impl ArrayChunked {
pub fn downcast_iter(&self) -> impl Iterator<Item = &FixedSizeListArray> + DoubleEndedIterator {
// Safety:
// This is the array type that must be in a ArrayChunked
self.chunks.iter().map(|arr| {
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const FixedSizeListArray) }
})
}
pub fn downcast_chunks(&self) -> Chunks<'_, FixedSizeListArray> {
Chunks::new(&self.chunks)
}

#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
return (0, index);
}
index_to_chunked_index(self.downcast_iter().map(|arr| arr.len()), index)
}
}

#[cfg(feature = "object")]
#[doc(hidden)]
impl<T> ObjectChunked<T>
where
T: PolarsObject,
{
pub fn downcast_iter(&self) -> impl Iterator<Item = &ObjectArray<T>> + DoubleEndedIterator {
self.chunks.iter().map(|arr| {
let arr = &**arr;
unsafe { &*(arr as *const dyn Array as *const ObjectArray<T>) }
})
}
pub fn downcast_chunks(&self) -> Chunks<'_, ObjectArray<T>> {
pub fn downcast_chunks(&self) -> Chunks<'_, <Self as HasUnderlyingArray>::ArrayT> {
Chunks::new(&self.chunks)
}

/// Get the index of the chunk and the index of the value in that chunk.
#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunks.len() == 1 {
Expand Down
35 changes: 35 additions & 0 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,38 @@ unsafe impl StaticallyMatchesPolarsType<BooleanType> for BooleanArray {}
unsafe impl StaticallyMatchesPolarsType<ListType> for ListArray<i64> {}
#[cfg(feature = "dtype-array")]
unsafe impl StaticallyMatchesPolarsType<FixedSizeListType> for FixedSizeListArray {}

#[doc(hidden)]
pub unsafe trait HasUnderlyingArray {
type ArrayT: Array;
}

unsafe impl<T: PolarsNumericType> HasUnderlyingArray for ChunkedArray<T> {
type ArrayT = PrimitiveArray<T::Native>;
}

unsafe impl HasUnderlyingArray for BooleanChunked {
type ArrayT = BooleanArray;
}

unsafe impl HasUnderlyingArray for Utf8Chunked {
type ArrayT = Utf8Array<i64>;
}

unsafe impl HasUnderlyingArray for BinaryChunked {
type ArrayT = BinaryArray<i64>;
}

unsafe impl HasUnderlyingArray for ListChunked {
type ArrayT = ListArray<i64>;
}

#[cfg(feature = "dtype-array")]
unsafe impl HasUnderlyingArray for ArrayChunked {
type ArrayT = FixedSizeListArray;
}

#[cfg(feature = "object")]
unsafe impl<T: PolarsObject> HasUnderlyingArray for ObjectChunked<T> {
type ArrayT = crate::chunked_array::object::ObjectArray<T>;
}

0 comments on commit f3a90d1

Please sign in to comment.