Skip to content

Commit

Permalink
fix: deduplicate recursive growables (#14264)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 4, 2024
1 parent b9d4714 commit fd781eb
Showing 1 changed file with 54 additions and 20 deletions.
74 changes: 54 additions & 20 deletions crates/polars-arrow/src/array/growable/binview.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use polars_utils::aliases::PlIndexSet;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::unwrap::UnwrapUncheckedRelease;

use super::Growable;
use crate::array::binview::{BinaryViewArrayGeneric, View, ViewType};
use crate::array::growable::utils::{extend_validity, prepare_validity};
Expand All @@ -8,14 +13,35 @@ use crate::bitmap::MutableBitmap;
use crate::buffer::Buffer;
use crate::datatypes::ArrowDataType;

struct BufferKey<'a> {
inner: &'a Buffer<u8>,
}

impl Hash for BufferKey<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.inner.as_ptr() as u64)
}
}

impl PartialEq for BufferKey<'_> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.inner.as_ptr() == other.inner.as_ptr()
}
}

impl Eq for BufferKey<'_> {}

/// Concrete [`Growable`] for the [`BinaryArray`].
pub struct GrowableBinaryViewArray<'a, T: ViewType + ?Sized> {
arrays: Vec<&'a BinaryViewArrayGeneric<T>>,
data_type: ArrowDataType,
validity: Option<MutableBitmap>,
views: Vec<View>,
buffers: Vec<Buffer<u8>>,
buffers_idx_offsets: Vec<u32>,
// We need to use a set/hashmap to deduplicate
// A growable can be called with many chunks from self.
// See: #14201
buffers: PlIndexSet<BufferKey<'a>>,
total_bytes_len: usize,
total_buffer_len: usize,
}
Expand All @@ -37,21 +63,16 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> {
use_validity = true;
};

let mut cum_sum = 0;
let cum_offset = arrays
.iter()
.map(|binview| {
let out = cum_sum;
cum_sum += binview.data_buffers().len() as u32;
out
})
.collect::<Vec<_>>();

let buffers = arrays
.iter()
.flat_map(|array| array.data_buffers().as_ref())
.cloned()
.collect::<Vec<_>>();
.flat_map(|array| {
array
.data_buffers()
.as_ref()
.iter()
.map(|buf| BufferKey { inner: buf })
})
.collect::<PlIndexSet<_>>();
let total_buffer_len = arrays
.iter()
.map(|arr| arr.data_buffers().len())
Expand All @@ -63,7 +84,6 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> {
validity: prepare_validity(use_validity, capacity),
views: Vec::with_capacity(capacity),
buffers,
buffers_idx_offsets: cum_offset,
total_bytes_len: 0,
total_buffer_len,
}
Expand All @@ -77,7 +97,12 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> {
BinaryViewArrayGeneric::<T>::new_unchecked(
self.data_type.clone(),
views.into(),
Arc::from(buffers),
Arc::from(
buffers
.into_iter()
.map(|buf| buf.inner.clone())
.collect::<Vec<_>>(),
),
validity.map(|v| v.into()),
self.total_bytes_len,
self.total_buffer_len,
Expand All @@ -90,6 +115,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> {
/// doesn't check bounds
pub unsafe fn extend_unchecked(&mut self, index: usize, start: usize, len: usize) {
let array = *self.arrays.get_unchecked(index);
let local_buffers = array.data_buffers();

extend_validity(&mut self.validity, array, start, len);

Expand All @@ -102,8 +128,11 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> {
self.total_bytes_len += len;

if len > 12 {
let buffer_idx = *self.buffers_idx_offsets.get_unchecked(index);
view.buffer_idx += buffer_idx;
let buffer = local_buffers.get_unchecked_release(view.buffer_idx as usize);
let key = BufferKey { inner: buffer };
let idx = self.buffers.get_full(&key).unwrap_unchecked_release().0;

view.buffer_idx = idx as u32;
}
view
}));
Expand Down Expand Up @@ -163,7 +192,12 @@ impl<'a, T: ViewType + ?Sized> From<GrowableBinaryViewArray<'a, T>> for BinaryVi
BinaryViewArrayGeneric::<T>::new_unchecked(
val.data_type,
val.views.into(),
Arc::from(val.buffers),
Arc::from(
val.buffers
.into_iter()
.map(|buf| buf.inner.clone())
.collect::<Vec<_>>(),
),
val.validity.map(|v| v.into()),
val.total_bytes_len,
val.total_buffer_len,
Expand Down

0 comments on commit fd781eb

Please sign in to comment.