Skip to content

Commit

Permalink
feat: move Enum/Categorical categories to binview (#13882)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 22, 2024
1 parent 6b1fcaa commit 3dc9018
Show file tree
Hide file tree
Showing 32 changed files with 496 additions and 171 deletions.
19 changes: 18 additions & 1 deletion crates/polars-arrow/src/array/binview/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::BinaryViewArrayGeneric;
use crate::array::binview::ViewType;
use crate::array::{ArrayAccessor, ArrayValuesIter};
use crate::array::{ArrayAccessor, ArrayValuesIter, MutableBinaryViewArray};
use crate::bitmap::utils::{BitmapIter, ZipValidity};

unsafe impl<'a, T: ViewType + ?Sized> ArrayAccessor<'a> for BinaryViewArrayGeneric<T> {
Expand Down Expand Up @@ -28,3 +28,20 @@ impl<'a, T: ViewType + ?Sized> IntoIterator for &'a BinaryViewArrayGeneric<T> {
self.iter()
}
}

unsafe impl<'a, T: ViewType + ?Sized> ArrayAccessor<'a> for MutableBinaryViewArray<T> {
type Item = &'a T;

#[inline]
unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item {
self.value_unchecked(index)
}

#[inline]
fn len(&self) -> usize {
self.views().len()
}
}

/// Iterator of values of an [`MutableBinaryViewArray`].
pub type MutableBinaryViewValueIter<'a, T> = ArrayValuesIter<'a, MutableBinaryViewArray<T>>;
15 changes: 15 additions & 0 deletions crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,21 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
self
}
}

pub fn make_mut(self) -> MutableBinaryViewArray<T> {
let views = self.views.make_mut();
let completed_buffers = self.buffers.to_vec();
let validity = self.validity.map(|bitmap| bitmap.make_mut());
MutableBinaryViewArray {
views,
completed_buffers,
in_progress_buffer: vec![],
validity,
phantom: Default::default(),
total_bytes_len: self.total_bytes_len.load(Ordering::Relaxed) as usize,
total_buffer_len: self.total_buffer_len,
}
}
}

impl BinaryViewArray {
Expand Down
64 changes: 56 additions & 8 deletions crates/polars-arrow/src/array/binview/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use polars_error::PolarsResult;
use polars_utils::slice::GetSaferUnchecked;

use crate::array::binview::iterator::MutableBinaryViewValueIter;
use crate::array::binview::view::validate_utf8_only;
use crate::array::binview::{BinaryViewArrayGeneric, ViewType};
use crate::array::{Array, MutableArray};
Expand All @@ -17,15 +18,15 @@ use crate::trusted_len::TrustedLen;
const DEFAULT_BLOCK_SIZE: usize = 8 * 1024;

pub struct MutableBinaryViewArray<T: ViewType + ?Sized> {
views: Vec<u128>,
completed_buffers: Vec<Buffer<u8>>,
in_progress_buffer: Vec<u8>,
validity: Option<MutableBitmap>,
phantom: std::marker::PhantomData<T>,
pub(super) views: Vec<u128>,
pub(super) completed_buffers: Vec<Buffer<u8>>,
pub(super) in_progress_buffer: Vec<u8>,
pub(super) validity: Option<MutableBitmap>,
pub(super) phantom: std::marker::PhantomData<T>,
/// Total bytes length if we would concatenate them all.
total_bytes_len: usize,
pub(super) total_bytes_len: usize,
/// Total bytes in the buffer (excluding remaining capacity)
total_buffer_len: usize,
pub(super) total_buffer_len: usize,
}

impl<T: ViewType + ?Sized> Clone for MutableBinaryViewArray<T> {
Expand Down Expand Up @@ -87,10 +88,16 @@ impl<T: ViewType + ?Sized> MutableBinaryViewArray<T> {
}
}

pub fn views(&mut self) -> &mut Vec<u128> {
#[inline]
pub fn views_mut(&mut self) -> &mut Vec<u128> {
&mut self.views
}

#[inline]
pub fn views(&self) -> &[u128] {
&self.views
}

pub fn validity(&mut self) -> Option<&mut MutableBitmap> {
self.validity.as_mut()
}
Expand Down Expand Up @@ -312,6 +319,47 @@ impl<T: ViewType + ?Sized> MutableBinaryViewArray<T> {
pub fn freeze(self) -> BinaryViewArrayGeneric<T> {
self.into()
}

/// Returns the element at index `i`
/// # Safety
/// Assumes that the `i < self.len`.
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> &T {
let v = *self.views.get_unchecked(i);
let len = v as u32;

// view layout:
// length: 4 bytes
// prefix: 4 bytes
// buffer_index: 4 bytes
// offset: 4 bytes

// inlined layout:
// length: 4 bytes
// data: 12 bytes
let bytes = if len <= 12 {
let ptr = self.views.as_ptr() as *const u8;
std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize)
} else {
let buffer_idx = ((v >> 64) as u32) as usize;
let offset = (v >> 96) as u32;

let data = if buffer_idx == self.completed_buffers.len() {
self.in_progress_buffer.as_slice()
} else {
self.completed_buffers.get_unchecked_release(buffer_idx)
};

let offset = offset as usize;
data.get_unchecked(offset..offset + len as usize)
};
T::from_bytes_unchecked(bytes)
}

/// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity
pub fn values_iter(&self) -> MutableBinaryViewValueIter<T> {
MutableBinaryViewValueIter::new(self)
}
}

impl MutableBinaryViewArray<[u8]> {
Expand Down
30 changes: 29 additions & 1 deletion crates/polars-arrow/src/array/dictionary/typed_iterator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use polars_error::{polars_err, PolarsResult};

use super::DictionaryKey;
use crate::array::{Array, PrimitiveArray, Utf8Array};
use crate::array::{Array, PrimitiveArray, Utf8Array, Utf8ViewArray};
use crate::trusted_len::TrustedLen;
use crate::types::Offset;

Expand Down Expand Up @@ -48,6 +48,34 @@ impl<O: Offset> DictValue for Utf8Array<O> {
}
}

impl DictValue for Utf8ViewArray {
type IterValue<'a> = &'a str;

unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> {
self.value_unchecked(item)
}

fn downcast_values(array: &dyn Array) -> PolarsResult<&Self>
where
Self: Sized,
{
array
.as_any()
.downcast_ref::<Self>()
.ok_or_else(
|| polars_err!(InvalidOperation: "could not convert array to dictionary value"),
)
.map(|arr| {
assert_eq!(
arr.null_count(),
0,
"null values in values not supported in iteration"
);
arr
})
}
}

/// Iterator of values of an `ListArray`.
pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> {
keys: &'a PrimitiveArray<K>,
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-arrow/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,15 @@ impl<T> Buffer<T> {
}
}

impl<T: Clone> Buffer<T> {
pub fn make_mut(self) -> Vec<T> {
match self.into_mut() {
Either::Right(v) => v,
Either::Left(same) => same.as_slice().to_vec(),
}
}
}

impl<T: Zero + Copy> Buffer<T> {
pub fn zeroed(len: usize) -> Self {
vec![T::zero(); len].into()
Expand Down
17 changes: 16 additions & 1 deletion crates/polars-arrow/src/io/ipc/write/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
set_variadic_buffer_counts(counts, array.values().as_ref())
},
ArrowDataType::Dictionary(_, _, _) => {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
set_variadic_buffer_counts(counts, array.values().as_ref())
},
_ => (),
}
}
Expand Down Expand Up @@ -326,6 +333,14 @@ fn dictionary_batch_to_bytes<K: DictionaryKey>(
let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];
let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let mut variadic_buffer_counts = vec![];
set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());

let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {
None
} else {
Some(variadic_buffer_counts)
};

let length = write_dictionary(
array,
Expand All @@ -350,7 +365,7 @@ fn dictionary_batch_to_bytes<K: DictionaryKey>(
nodes: Some(nodes),
buffers: Some(buffers),
compression,
variadic_buffer_counts: None,
variadic_buffer_counts,
})),
is_delta: false,
},
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/pushable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl<T: ViewType + ?Sized> Pushable<&T> for MutableBinaryViewArray<T> {
MutableBinaryViewArray::push_value(self, value);

// And then use that new view to extend
let views = self.views();
let views = self.views_mut();
let view = *views.last().unwrap();

let remaining = additional - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct ListLocalCategoricalChunkedBuilder {
inner: ListPrimitiveChunkedBuilder<UInt32Type>,
idx_lookup: PlHashMap<KeyWrapper, ()>,
ordering: CategoricalOrdering,
categories: MutableUtf8Array<i64>,
categories: MutablePlString,
categories_hash: u128,
}

Expand Down Expand Up @@ -126,7 +126,7 @@ impl ListLocalCategoricalChunkedBuilder {
ListLocalCategoricalChunkedBuilder::get_hash_builder(),
),
ordering,
categories: MutableUtf8Array::with_capacity(capacity),
categories: MutablePlString::with_capacity(capacity),
categories_hash: hash,
}
}
Expand Down Expand Up @@ -206,7 +206,7 @@ impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder {
}

fn finish(&mut self) -> ListChunked {
let categories: Utf8Array<i64> = std::mem::take(&mut self.categories).into();
let categories: Utf8ViewArray = std::mem::take(&mut self.categories).into();
let rev_map = RevMapping::build_local(categories);
let inner_dtype = DataType::Categorical(Some(Arc::new(rev_map)), self.ordering);
let mut ca = self.inner.finish();
Expand Down
20 changes: 10 additions & 10 deletions crates/polars-core/src/chunked_array/comparison/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, Comp
str_compare_function: CompareString,
) -> PolarsResult<BooleanChunked>
where
CompareStringSingle: Fn(&Utf8Array<i64>, &str) -> Bitmap,
CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult<BooleanChunked>,
CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked,
Expand Down Expand Up @@ -273,7 +273,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked {
rhs,
|s1, s2| CategoricalChunked::gt(s1, s2),
UInt32Chunked::gt,
Utf8Array::tot_gt_kernel_broadcast,
Utf8ViewArray::tot_gt_kernel_broadcast,
StringChunked::gt,
)
}
Expand All @@ -284,7 +284,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked {
rhs,
|s1, s2| CategoricalChunked::gt_eq(s1, s2),
UInt32Chunked::gt_eq,
Utf8Array::tot_ge_kernel_broadcast,
Utf8ViewArray::tot_ge_kernel_broadcast,
StringChunked::gt_eq,
)
}
Expand All @@ -295,7 +295,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked {
rhs,
|s1, s2| CategoricalChunked::lt(s1, s2),
UInt32Chunked::lt,
Utf8Array::tot_lt_kernel_broadcast,
Utf8ViewArray::tot_lt_kernel_broadcast,
StringChunked::lt,
)
}
Expand All @@ -306,7 +306,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked {
rhs,
|s1, s2| CategoricalChunked::lt_eq(s1, s2),
UInt32Chunked::lt_eq,
Utf8Array::tot_le_kernel_broadcast,
Utf8ViewArray::tot_le_kernel_broadcast,
StringChunked::lt_eq,
)
}
Expand Down Expand Up @@ -348,7 +348,7 @@ fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>(
str_single_compare_function: CompareStringSingle,
) -> PolarsResult<BooleanChunked>
where
CompareStringSingle: Fn(&Utf8Array<i64>, &str) -> Bitmap,
CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap,
ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked,
{
let rev_map = lhs.get_rev_map();
Expand Down Expand Up @@ -421,7 +421,7 @@ impl ChunkCompare<&str> for CategoricalChunked {
self,
rhs,
UInt32Chunked::gt,
Utf8Array::tot_gt_kernel_broadcast,
Utf8ViewArray::tot_gt_kernel_broadcast,
)
}

Expand All @@ -430,7 +430,7 @@ impl ChunkCompare<&str> for CategoricalChunked {
self,
rhs,
UInt32Chunked::gt_eq,
Utf8Array::tot_ge_kernel_broadcast,
Utf8ViewArray::tot_ge_kernel_broadcast,
)
}

Expand All @@ -439,7 +439,7 @@ impl ChunkCompare<&str> for CategoricalChunked {
self,
rhs,
UInt32Chunked::lt,
Utf8Array::tot_lt_kernel_broadcast,
Utf8ViewArray::tot_lt_kernel_broadcast,
)
}

Expand All @@ -448,7 +448,7 @@ impl ChunkCompare<&str> for CategoricalChunked {
self,
rhs,
UInt32Chunked::lt_eq,
Utf8Array::tot_le_kernel_broadcast,
Utf8ViewArray::tot_le_kernel_broadcast,
)
}
}
Loading

0 comments on commit 3dc9018

Please sign in to comment.