From 525ceca7ec5c800baeac8131f354d490c2643818 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 21 Jan 2024 10:43:47 +0100 Subject: [PATCH 1/7] feat: move Enum/Categorical categories to binview --- .../src/array/binview/iterator.rs | 19 +++++- crates/polars-arrow/src/array/binview/mod.rs | 15 +++++ .../polars-arrow/src/array/binview/mutable.rs | 64 ++++++++++++++++--- crates/polars-arrow/src/buffer/immutable.rs | 9 +++ crates/polars-arrow/src/pushable.rs | 2 +- .../chunked_array/builder/list/categorical.rs | 6 +- .../chunked_array/comparison/categorical.rs | 20 +++--- .../logical/categorical/builder.rs | 16 ++--- .../logical/categorical/merge.rs | 33 ++-------- .../chunked_array/logical/categorical/mod.rs | 2 +- .../logical/categorical/revmap.rs | 28 ++++---- .../src/chunked_array/ops/compare_inner.rs | 4 +- crates/polars-core/src/datatypes/any_value.rs | 2 +- crates/polars-core/src/datatypes/dtype.rs | 2 +- crates/polars-core/src/series/from.rs | 2 +- 15 files changed, 146 insertions(+), 78 deletions(-) diff --git a/crates/polars-arrow/src/array/binview/iterator.rs b/crates/polars-arrow/src/array/binview/iterator.rs index 5e53fb8fec67..26587d5c1b72 100644 --- a/crates/polars-arrow/src/array/binview/iterator.rs +++ b/crates/polars-arrow/src/array/binview/iterator.rs @@ -1,6 +1,6 @@ use super::BinaryViewArrayGeneric; use crate::array::binview::ViewType; -use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::array::{ArrayAccessor, ArrayValuesIter, MutableBinaryViewArray}; use crate::bitmap::utils::{BitmapIter, ZipValidity}; unsafe impl<'a, T: ViewType + ?Sized> ArrayAccessor<'a> for BinaryViewArrayGeneric { @@ -28,3 +28,20 @@ impl<'a, T: ViewType + ?Sized> IntoIterator for &'a BinaryViewArrayGeneric { self.iter() } } + +unsafe impl<'a, T: ViewType + ?Sized> ArrayAccessor<'a> for MutableBinaryViewArray { + type Item = &'a T; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.views().len() + } +} + +/// Iterator of values of an [`MutableBinaryViewArray`]. +pub type MutableBinaryViewValueIter<'a, T> = ArrayValuesIter<'a, MutableBinaryViewArray>; diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index 44b0de62217f..3da2863bd5eb 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -396,6 +396,21 @@ impl BinaryViewArrayGeneric { self } } + + pub fn make_mut(self) -> MutableBinaryViewArray { + let views = self.views.make_mut(); + let completed_buffers = self.buffers.to_vec(); + let validity = self.validity.map(|bitmap| bitmap.make_mut()); + MutableBinaryViewArray { + views, + completed_buffers, + in_progress_buffer: vec![], + validity, + phantom: Default::default(), + total_bytes_len: self.total_bytes_len.load(Ordering::Relaxed) as usize, + total_buffer_len: self.total_buffer_len, + } + } } impl BinaryViewArray { diff --git a/crates/polars-arrow/src/array/binview/mutable.rs b/crates/polars-arrow/src/array/binview/mutable.rs index d95c24136fff..5894ae5d4270 100644 --- a/crates/polars-arrow/src/array/binview/mutable.rs +++ b/crates/polars-arrow/src/array/binview/mutable.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use polars_error::PolarsResult; use polars_utils::slice::GetSaferUnchecked; +use crate::array::binview::iterator::MutableBinaryViewValueIter; use crate::array::binview::view::validate_utf8_only; use crate::array::binview::{BinaryViewArrayGeneric, ViewType}; use crate::array::{Array, MutableArray}; @@ -17,15 +18,15 @@ use crate::trusted_len::TrustedLen; const DEFAULT_BLOCK_SIZE: usize = 8 * 1024; pub struct MutableBinaryViewArray { - views: Vec, - completed_buffers: Vec>, - in_progress_buffer: Vec, - validity: Option, - phantom: std::marker::PhantomData, + pub(super) views: Vec, + pub(super) completed_buffers: Vec>, + pub(super) in_progress_buffer: Vec, + pub(super) validity: Option, + pub(super) phantom: std::marker::PhantomData, /// Total bytes length if we would concatenate them all. - total_bytes_len: usize, + pub(super) total_bytes_len: usize, /// Total bytes in the buffer (excluding remaining capacity) - total_buffer_len: usize, + pub(super) total_buffer_len: usize, } impl Clone for MutableBinaryViewArray { @@ -87,10 +88,16 @@ impl MutableBinaryViewArray { } } - pub fn views(&mut self) -> &mut Vec { + #[inline] + pub fn views_mut(&mut self) -> &mut Vec { &mut self.views } + #[inline] + pub fn views(&self) -> &[u128] { + &self.views + } + pub fn validity(&mut self) -> Option<&mut MutableBitmap> { self.validity.as_mut() } @@ -312,6 +319,47 @@ impl MutableBinaryViewArray { pub fn freeze(self) -> BinaryViewArrayGeneric { self.into() } + + /// Returns the element at index `i` + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &T { + let v = *self.views.get_unchecked(i); + let len = v as u32; + + // view layout: + // length: 4 bytes + // prefix: 4 bytes + // buffer_index: 4 bytes + // offset: 4 bytes + + // inlined layout: + // length: 4 bytes + // data: 12 bytes + let bytes = if len <= 12 { + let ptr = self.views.as_ptr() as *const u8; + std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize) + } else { + let buffer_idx = ((v >> 64) as u32) as usize; + let offset = (v >> 96) as u32; + + let data = if buffer_idx == self.completed_buffers.len() { + self.in_progress_buffer.as_slice() + } else { + self.completed_buffers.get_unchecked_release(buffer_idx) + }; + + let offset = offset as usize; + data.get_unchecked(offset..offset + len as usize) + }; + T::from_bytes_unchecked(bytes) + } + + /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity + pub fn values_iter(&self) -> MutableBinaryViewValueIter { + MutableBinaryViewValueIter::new(self) + } } impl MutableBinaryViewArray<[u8]> { diff --git a/crates/polars-arrow/src/buffer/immutable.rs b/crates/polars-arrow/src/buffer/immutable.rs index 5371cc71030c..15d7b0935edc 100644 --- a/crates/polars-arrow/src/buffer/immutable.rs +++ b/crates/polars-arrow/src/buffer/immutable.rs @@ -244,6 +244,15 @@ impl Buffer { } } +impl Buffer { + pub fn make_mut(self) -> Vec { + match self.into_mut() { + Either::Right(v) => v, + Either::Left(same) => same.as_slice().to_vec(), + } + } +} + impl Buffer { pub fn zeroed(len: usize) -> Self { vec![T::zero(); len].into() diff --git a/crates/polars-arrow/src/pushable.rs b/crates/polars-arrow/src/pushable.rs index 0688a6956be5..db71d8726a8a 100644 --- a/crates/polars-arrow/src/pushable.rs +++ b/crates/polars-arrow/src/pushable.rs @@ -145,7 +145,7 @@ impl Pushable<&T> for MutableBinaryViewArray { MutableBinaryViewArray::push_value(self, value); // And then use that new view to extend - let views = self.views(); + let views = self.views_mut(); let view = *views.last().unwrap(); let remaining = additional - 1; diff --git a/crates/polars-core/src/chunked_array/builder/list/categorical.rs b/crates/polars-core/src/chunked_array/builder/list/categorical.rs index 586f212241bb..37f1fb03d8ce 100644 --- a/crates/polars-core/src/chunked_array/builder/list/categorical.rs +++ b/crates/polars-core/src/chunked_array/builder/list/categorical.rs @@ -94,7 +94,7 @@ struct ListLocalCategoricalChunkedBuilder { inner: ListPrimitiveChunkedBuilder, idx_lookup: PlHashMap, ordering: CategoricalOrdering, - categories: MutableUtf8Array, + categories: MutablePlString, categories_hash: u128, } @@ -126,7 +126,7 @@ impl ListLocalCategoricalChunkedBuilder { ListLocalCategoricalChunkedBuilder::get_hash_builder(), ), ordering, - categories: MutableUtf8Array::with_capacity(capacity), + categories: MutablePlString::with_capacity(capacity), categories_hash: hash, } } @@ -206,7 +206,7 @@ impl ListBuilderTrait for ListLocalCategoricalChunkedBuilder { } fn finish(&mut self) -> ListChunked { - let categories: Utf8Array = std::mem::take(&mut self.categories).into(); + let categories: Utf8ViewArray = std::mem::take(&mut self.categories).into(); let rev_map = RevMapping::build_local(categories); let inner_dtype = DataType::Categorical(Some(Arc::new(rev_map)), self.ordering); let mut ca = self.inner.finish(); diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index d93114c62a92..df238839aa88 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -193,7 +193,7 @@ fn cat_str_compare_helper<'a, CompareCat, ComparePhys, CompareStringSingle, Comp str_compare_function: CompareString, ) -> PolarsResult where - CompareStringSingle: Fn(&Utf8Array, &str) -> Bitmap, + CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap, ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, CompareCat: Fn(&CategoricalChunked, &CategoricalChunked) -> PolarsResult, CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, @@ -273,7 +273,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { rhs, |s1, s2| CategoricalChunked::gt(s1, s2), UInt32Chunked::gt, - Utf8Array::tot_gt_kernel_broadcast, + Utf8ViewArray::tot_gt_kernel_broadcast, StringChunked::gt, ) } @@ -284,7 +284,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { rhs, |s1, s2| CategoricalChunked::gt_eq(s1, s2), UInt32Chunked::gt_eq, - Utf8Array::tot_ge_kernel_broadcast, + Utf8ViewArray::tot_ge_kernel_broadcast, StringChunked::gt_eq, ) } @@ -295,7 +295,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { rhs, |s1, s2| CategoricalChunked::lt(s1, s2), UInt32Chunked::lt, - Utf8Array::tot_lt_kernel_broadcast, + Utf8ViewArray::tot_lt_kernel_broadcast, StringChunked::lt, ) } @@ -306,7 +306,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { rhs, |s1, s2| CategoricalChunked::lt_eq(s1, s2), UInt32Chunked::lt_eq, - Utf8Array::tot_le_kernel_broadcast, + Utf8ViewArray::tot_le_kernel_broadcast, StringChunked::lt_eq, ) } @@ -348,7 +348,7 @@ fn cat_single_str_compare_helper<'a, ComparePhys, CompareStringSingle>( str_single_compare_function: CompareStringSingle, ) -> PolarsResult where - CompareStringSingle: Fn(&Utf8Array, &str) -> Bitmap, + CompareStringSingle: Fn(&Utf8ViewArray, &str) -> Bitmap, ComparePhys: Fn(&UInt32Chunked, u32) -> BooleanChunked, { let rev_map = lhs.get_rev_map(); @@ -421,7 +421,7 @@ impl ChunkCompare<&str> for CategoricalChunked { self, rhs, UInt32Chunked::gt, - Utf8Array::tot_gt_kernel_broadcast, + Utf8ViewArray::tot_gt_kernel_broadcast, ) } @@ -430,7 +430,7 @@ impl ChunkCompare<&str> for CategoricalChunked { self, rhs, UInt32Chunked::gt_eq, - Utf8Array::tot_ge_kernel_broadcast, + Utf8ViewArray::tot_ge_kernel_broadcast, ) } @@ -439,7 +439,7 @@ impl ChunkCompare<&str> for CategoricalChunked { self, rhs, UInt32Chunked::lt, - Utf8Array::tot_lt_kernel_broadcast, + Utf8ViewArray::tot_lt_kernel_broadcast, ) } @@ -448,7 +448,7 @@ impl ChunkCompare<&str> for CategoricalChunked { self, rhs, UInt32Chunked::lt_eq, - Utf8Array::tot_le_kernel_broadcast, + Utf8ViewArray::tot_le_kernel_broadcast, ) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index d69fbd025d4e..e0dcf6c1d586 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -16,7 +16,7 @@ pub struct CategoricalChunkedBuilder { cat_builder: UInt32Vec, name: String, ordering: CategoricalOrdering, - categories: MutableUtf8Array, + categories: MutablePlString, // hashmap utilized by the local builder local_mapping: PlHashMap, } @@ -27,7 +27,7 @@ impl CategoricalChunkedBuilder { cat_builder: UInt32Vec::with_capacity(capacity), name: name.to_string(), ordering, - categories: MutableUtf8Array::::with_capacity(_HASHMAP_INIT_SIZE), + categories: MutablePlString::with_capacity(_HASHMAP_INIT_SIZE), local_mapping: PlHashMap::with_capacity_and_hasher( capacity / 10, StringCache::get_hash_builder(), @@ -125,7 +125,7 @@ impl CategoricalChunkedBuilder { } } - let categories: Utf8Array = std::mem::take(&mut self.categories).into(); + let categories = std::mem::take(&mut self.categories).freeze(); // we will create a mapping from our local categoricals to global categoricals // and a mapping from global categoricals to our local categoricals @@ -237,7 +237,7 @@ impl CategoricalChunked { let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), _HASHMAP_INIT_SIZE); let mut rev_map = PlHashMap::with_capacity(cap); - let mut str_values = MutableUtf8Array::with_capacities(cap, cap * 24); + let mut str_values = MutablePlString::with_capacity(cap); for arr in cats.downcast_iter() { for cat in arr.into_iter().flatten().copied() { @@ -260,7 +260,7 @@ impl CategoricalChunked { name: &str, keys: impl IntoIterator> + Send, capacity: usize, - values: &Utf8Array, + values: &Utf8ViewArray, ordering: CategoricalOrdering, ) -> Self { // Vec where the index is local and the value is the global index @@ -304,7 +304,7 @@ impl CategoricalChunked { pub(crate) unsafe fn from_keys_and_values_local( name: &str, keys: &PrimitiveArray, - values: &Utf8Array, + values: &Utf8ViewArray, ordering: CategoricalOrdering, ) -> CategoricalChunked { CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -319,7 +319,7 @@ impl CategoricalChunked { pub(crate) unsafe fn from_keys_and_values( name: &str, keys: &PrimitiveArray, - values: &Utf8Array, + values: &Utf8ViewArray, ordering: CategoricalOrdering, ) -> Self { if !using_string_cache() { @@ -339,7 +339,7 @@ impl CategoricalChunked { /// This will error if a string is not in the fixed list of categories pub fn from_string_to_enum( values: &StringChunked, - categories: &Utf8Array, + categories: &Utf8ViewArray, ordering: CategoricalOrdering, ) -> PolarsResult { polars_ensure!(categories.null_count() == 0, ComputeError: "categories can not contain null values"); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs index 964ced1c6a62..bcf03543239d 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -1,39 +1,14 @@ use std::sync::Arc; -use arrow::bitmap::MutableBitmap; -use arrow::offset::Offsets; - use super::*; -fn slots_to_mut(slots: &Utf8Array) -> MutableUtf8Array { - // safety: invariants don't change, just the type - let offset_buf = unsafe { Offsets::new_unchecked(slots.offsets().as_slice().to_vec()) }; - let values_buf = slots.values().as_slice().to_vec(); - - let validity_buf = if let Some(validity) = slots.validity() { - let mut validity_buf = MutableBitmap::new(); - let (b, offset, len) = validity.as_slice(); - validity_buf.extend_from_slice(b, offset, len); - Some(validity_buf) - } else { - None - }; - - // Safety - // all offsets are valid and the u8 data is valid utf8 - unsafe { - MutableUtf8Array::new_unchecked( - ArrowDataType::LargeUtf8, - offset_buf, - values_buf, - validity_buf, - ) - } +fn slots_to_mut(slots: &Utf8ViewArray) -> MutablePlString { + slots.clone().make_mut() } struct State { map: PlHashMap, - slots: MutableUtf8Array, + slots: MutablePlString, } #[derive(Default)] @@ -111,7 +86,7 @@ impl GlobalRevMapMerger { } fn merge_local_rhs_categorical<'a>( - categories: &'a Utf8Array, + categories: &'a Utf8ViewArray, ca_right: &'a CategoricalChunked, ) -> Result<(UInt32Chunked, Arc), PolarsError> { // Counterpart of the GlobalRevmapMerger. diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 9e7a2605a8ad..6ac3b7ab926e 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -125,7 +125,7 @@ impl CategoricalChunked { } // Convert to fixed enum. In case a value is not in the categories return Error - pub fn to_enum(&self, categories: &Utf8Array, hash: u128) -> PolarsResult { + pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> PolarsResult { // Fast paths match self.get_rev_map().as_ref() { RevMapping::Enum(_, cur_hash) if hash == *cur_hash => return Ok(self.clone()), diff --git a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs index 8c4606492ecb..3a1bc44fc9f7 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs @@ -1,4 +1,5 @@ use std::fmt::{Debug, Formatter}; +use std::hash::{BuildHasher, Hash, Hasher}; use ahash::RandomState; use arrow::array::*; @@ -23,11 +24,11 @@ pub enum CategoricalOrdering { pub enum RevMapping { /// Hashmap: maps the indexes from the global cache/categorical array to indexes in the local Utf8Array /// Utf8Array: caches the string values - Global(PlHashMap, Utf8Array, u32), + Global(PlHashMap, Utf8ViewArray, u32), /// Utf8Array: caches the string values and a hash of all values for quick comparison - Local(Utf8Array, u128), + Local(Utf8ViewArray, u128), /// Utf8Array: fixed user defined array of categories which caches the string values - Enum(Utf8Array, u128), + Enum(Utf8ViewArray, u128), } impl Debug for RevMapping { @@ -49,7 +50,7 @@ impl Debug for RevMapping { impl Default for RevMapping { fn default() -> Self { let slice: &[Option<&str>] = &[]; - let cats = Utf8Array::::from(slice); + let cats = Utf8ViewArray::from_slice(slice); if using_string_cache() { let cache = &mut crate::STRING_CACHE.lock_map(); let id = cache.uuid; @@ -76,26 +77,29 @@ impl RevMapping { } /// Get the categories in this [`RevMapping`] - pub fn get_categories(&self) -> &Utf8Array { + pub fn get_categories(&self) -> &Utf8ViewArray { match self { Self::Global(_, a, _) => a, Self::Local(a, _) | Self::Enum(a, _) => a, } } - fn build_hash(categories: &Utf8Array) -> u128 { - let hash_builder = RandomState::with_seed(0); - let value_hash = hash_builder.hash_one(categories.values().as_slice()); - let offset_hash = hash_builder.hash_one(categories.offsets().as_slice()); - (value_hash as u128) << 64 | (offset_hash as u128) + fn build_hash(categories: &Utf8ViewArray) -> u128 { + // TODO! we must also validate the cases of duplicates! + let mut hb = RandomState::with_seed(0).build_hasher(); + categories.values_iter().for_each(|val| { + val.hash(&mut hb); + }); + let hash = hb.finish(); + (hash as u128) << 64 | (categories.total_bytes_len() as u128) } - pub fn build_enum(categories: Utf8Array) -> Self { + pub fn build_enum(categories: Utf8ViewArray) -> Self { let hash = Self::build_hash(&categories); Self::Enum(categories, hash) } - pub fn build_local(categories: Utf8Array) -> Self { + pub fn build_local(categories: Utf8ViewArray) -> Self { let hash = Self::build_hash(&categories); Self::Local(categories, hash) } diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index b58fa981f32b..d508a881f34f 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -122,7 +122,7 @@ where #[cfg(feature = "dtype-categorical")] struct LocalCategorical<'a> { - rev_map: &'a Utf8Array, + rev_map: &'a Utf8ViewArray, cats: &'a UInt32Chunked, } @@ -138,7 +138,7 @@ impl<'a> GetInner for LocalCategorical<'a> { #[cfg(feature = "dtype-categorical")] struct GlobalCategorical<'a> { p1: &'a PlHashMap, - p2: &'a Utf8Array, + p2: &'a Utf8ViewArray, cats: &'a UInt32Chunked, } diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 49d3119cf6c0..cb060fb2c380 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -72,7 +72,7 @@ pub enum AnyValue<'a> { #[cfg(feature = "dtype-categorical")] // If syncptr is_null the data is in the rev-map // otherwise it is in the array pointer - Categorical(u32, &'a RevMapping, SyncPtr>), + Categorical(u32, &'a RevMapping, SyncPtr), /// Nested type, contains arrays that are filled with one of the datatypes. List(Series), #[cfg(feature = "dtype-array")] diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index d33ad4c077d8..bfe6299eba67 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -491,7 +491,7 @@ pub(crate) fn can_extend_dtype(left: &DataType, right: &DataType) -> PolarsResul } #[cfg(feature = "dtype-categorical")] -pub fn create_enum_data_type(categories: Utf8Array) -> DataType { +pub fn create_enum_data_type(categories: Utf8ViewArray) -> DataType { let rev_map = RevMapping::build_enum(categories.clone()); DataType::Categorical(Some(Arc::new(rev_map)), Default::default()) } diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 159e14652c01..086aeafa75ef 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -335,7 +335,7 @@ impl Series { ), }; let keys = keys.as_any().downcast_ref::>().unwrap(); - let values = values.as_any().downcast_ref::>().unwrap(); + let values = values.as_any().downcast_ref::().unwrap(); // Safety // the invariants of an Arrow Dictionary guarantee the keys are in bounds From 4869cff083e54ed746b5a6f1443bae42e6858066 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 21 Jan 2024 11:03:52 +0100 Subject: [PATCH 2/7] serde [skip ci] --- crates/polars-core/src/chunked_array/ops/any_value.rs | 2 +- crates/polars-core/src/datatypes/_serde.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/any_value.rs b/crates/polars-core/src/chunked_array/ops/any_value.rs index 92732bb3422f..82fc93a4227e 100644 --- a/crates/polars-core/src/chunked_array/ops/any_value.rs +++ b/crates/polars-core/src/chunked_array/ops/any_value.rs @@ -141,7 +141,7 @@ impl<'a> AnyValue<'a> { let keys = arr.keys(); let values = arr.values(); let values = - values.as_any().downcast_ref::>().unwrap(); + values.as_any().downcast_ref::().unwrap(); let arr = &*(keys as *const dyn Array as *const UInt32Array); if arr.is_valid_unchecked(idx) { diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index ee3f894ccbf4..5dbaee05e210 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -32,7 +32,7 @@ impl Serialize for DataType { struct Wrap(T); #[cfg(feature = "dtype-categorical")] -impl serde::Serialize for Wrap> { +impl serde::Serialize for Wrap { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -42,7 +42,7 @@ impl serde::Serialize for Wrap> { } #[cfg(feature = "dtype-categorical")] -impl<'de> serde::Deserialize<'de> for Wrap> { +impl<'de> serde::Deserialize<'de> for Wrap { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, @@ -50,7 +50,7 @@ impl<'de> serde::Deserialize<'de> for Wrap> { struct Utf8Visitor; impl<'de> Visitor<'de> for Utf8Visitor { - type Value = Wrap>; + type Value = Wrap; fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { formatter.write_str("Utf8Visitor string sequence.") @@ -60,7 +60,7 @@ impl<'de> serde::Deserialize<'de> for Wrap> { where A: SeqAccess<'de>, { - let mut utf8array = MutableUtf8Array::with_capacity(seq.size_hint().unwrap_or(10)); + let mut utf8array = MutablePlString::with_capacity(seq.size_hint().unwrap_or(10)); while let Some(key) = seq.next_element()? { let key: Option<&str> = key; utf8array.push(key) @@ -107,7 +107,7 @@ enum SerializableDataType { // some logical types we cannot know statically, e.g. Datetime Unknown, #[cfg(feature = "dtype-categorical")] - Categorical(Option>>, CategoricalOrdering), + Categorical(Option>, CategoricalOrdering), #[cfg(feature = "object")] Object(String), } From 60c44e5ba66e8f85045c2a16f4b318d6be1b6851 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 21 Jan 2024 13:05:38 +0100 Subject: [PATCH 3/7] fix parquet --- .../src/array/dictionary/typed_iterator.rs | 30 ++++- .../polars-arrow/src/io/ipc/write/common.rs | 17 ++- .../chunked_array/logical/categorical/from.rs | 110 ++++++++++-------- crates/polars-core/src/datatypes/dtype.rs | 19 ++- crates/polars-core/src/series/from.rs | 7 +- crates/polars-core/src/series/into.rs | 3 +- .../polars-json/src/json/write/serialize.rs | 13 ++- .../src/arrow/write/binview/mod.rs | 2 +- .../src/arrow/write/dictionary.rs | 21 +++- crates/polars-row/src/encode.rs | 6 +- py-polars/src/conversion.rs | 6 +- py-polars/src/datatypes.rs | 4 +- 12 files changed, 162 insertions(+), 76 deletions(-) diff --git a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs index 27aa12c74be5..6a543968b98d 100644 --- a/crates/polars-arrow/src/array/dictionary/typed_iterator.rs +++ b/crates/polars-arrow/src/array/dictionary/typed_iterator.rs @@ -1,7 +1,7 @@ use polars_error::{polars_err, PolarsResult}; use super::DictionaryKey; -use crate::array::{Array, PrimitiveArray, Utf8Array}; +use crate::array::{Array, PrimitiveArray, Utf8Array, Utf8ViewArray}; use crate::trusted_len::TrustedLen; use crate::types::Offset; @@ -48,6 +48,34 @@ impl DictValue for Utf8Array { } } +impl DictValue for Utf8ViewArray { + type IterValue<'a> = &'a str; + + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> { + self.value_unchecked(item) + } + + fn downcast_values(array: &dyn Array) -> PolarsResult<&Self> + where + Self: Sized, + { + array + .as_any() + .downcast_ref::() + .ok_or_else( + || polars_err!(InvalidOperation: "could not convert array to dictionary value"), + ) + .map(|arr| { + assert_eq!( + arr.null_count(), + 0, + "null values in values not supported in iteration" + ); + arr + }) + } +} + /// Iterator of values of an `ListArray`. pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> { keys: &'a PrimitiveArray, diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs index f4b8a1c015e7..1d4375280838 100644 --- a/crates/polars-arrow/src/io/ipc/write/common.rs +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -254,6 +254,13 @@ fn set_variadic_buffer_counts(counts: &mut Vec, array: &dyn Array) { let array = array.as_any().downcast_ref::().unwrap(); set_variadic_buffer_counts(counts, array.values().as_ref()) }, + ArrowDataType::Dictionary(_, _, _) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + set_variadic_buffer_counts(counts, array.values().as_ref()) + }, _ => (), } } @@ -326,6 +333,14 @@ fn dictionary_batch_to_bytes( let mut nodes: Vec = vec![]; let mut buffers: Vec = vec![]; let mut arrow_data: Vec = vec![]; + let mut variadic_buffer_counts = vec![]; + set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref()); + + let variadic_buffer_counts = if variadic_buffer_counts.is_empty() { + None + } else { + Some(variadic_buffer_counts) + }; let length = write_dictionary( array, @@ -350,7 +365,7 @@ fn dictionary_batch_to_bytes( nodes: Some(nodes), buffers: Some(buffers), compression, - variadic_buffer_counts: None, + variadic_buffer_counts, })), is_delta: false, }, diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs index 3bc3394aa2e2..d66170467fed 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -1,27 +1,43 @@ use arrow::array::DictionaryArray; -use arrow::compute::cast::{cast, CastOptions}; +use arrow::compute::cast::{cast, utf8view_to_utf8, CastOptions}; use arrow::datatypes::IntegerType; use super::*; -impl From<&CategoricalChunked> for DictionaryArray { - fn from(ca: &CategoricalChunked) -> Self { - let keys = ca.physical().rechunk(); +fn convert_values(arr: &Utf8ViewArray, pl_flavor: bool) -> ArrayRef { + if pl_flavor { + arr.clone().boxed() + } else { + utf8view_to_utf8::(arr).boxed() + } +} + +impl CategoricalChunked { + pub fn to_arrow(&self, pl_flavor: bool, as_i64: bool) -> ArrayRef { + if as_i64 { + self.to_i64(pl_flavor).boxed() + } else { + self.to_u32(pl_flavor).boxed() + } + } + + fn to_u32(&self, pl_flavor: bool) -> DictionaryArray { + let values_dtype = if pl_flavor { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + let keys = self.physical().rechunk(); let keys = keys.downcast_iter().next().unwrap(); - let map = &**ca.get_rev_map(); - let dtype = ArrowDataType::Dictionary( - IntegerType::UInt32, - Box::new(ArrowDataType::LargeUtf8), - false, - ); + let map = &**self.get_rev_map(); + let dtype = ArrowDataType::Dictionary(IntegerType::UInt32, Box::new(values_dtype), false); match map { RevMapping::Local(arr, _) | RevMapping::Enum(arr, _) => { + let values = convert_values(arr, pl_flavor); + // Safety: // the keys are in bounds - unsafe { - DictionaryArray::try_new_unchecked(dtype, keys.clone(), Box::new(arr.clone())) - .unwrap() - } + unsafe { DictionaryArray::try_new_unchecked(dtype, keys.clone(), values).unwrap() } }, RevMapping::Global(reverse_map, values, _uuid) => { let iter = keys @@ -29,41 +45,44 @@ impl From<&CategoricalChunked> for DictionaryArray { .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap())); let keys = PrimitiveArray::from_trusted_len_iter(iter); + let values = convert_values(values, pl_flavor); + // Safety: // the keys are in bounds - unsafe { - DictionaryArray::try_new_unchecked(dtype, keys, Box::new(values.clone())) - .unwrap() - } + unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } }, } } -} -impl From<&CategoricalChunked> for DictionaryArray { - fn from(ca: &CategoricalChunked) -> Self { - let keys = ca.physical().rechunk(); + + fn to_i64(&self, pl_flavor: bool) -> DictionaryArray { + let values_dtype = if pl_flavor { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + let keys = self.physical().rechunk(); let keys = keys.downcast_iter().next().unwrap(); - let map = &**ca.get_rev_map(); - let dtype = ArrowDataType::Dictionary( - IntegerType::UInt32, - Box::new(ArrowDataType::LargeUtf8), - false, - ); + let map = &**self.get_rev_map(); + let dtype = ArrowDataType::Dictionary(IntegerType::Int64, Box::new(values_dtype), false); match map { - // Safety: - // the keys are in bounds - RevMapping::Local(arr, _) | RevMapping::Enum(arr, _) => unsafe { - DictionaryArray::try_new_unchecked( - dtype, - cast(keys, &ArrowDataType::Int64, CastOptions::unchecked()) - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .clone(), - Box::new(arr.clone()), - ) - .unwrap() + RevMapping::Local(arr, _) | RevMapping::Enum(arr, _) => { + let values = convert_values(arr, pl_flavor); + + // Safety: + // the keys are in bounds + unsafe { + DictionaryArray::try_new_unchecked( + dtype, + cast(keys, &ArrowDataType::Int64, CastOptions::unchecked()) + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap() + .clone(), + values, + ) + .unwrap() + } }, RevMapping::Global(reverse_map, values, _uuid) => { let iter = keys @@ -71,12 +90,11 @@ impl From<&CategoricalChunked> for DictionaryArray { .map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap() as i64)); let keys = PrimitiveArray::from_trusted_len_iter(iter); + let values = convert_values(values, pl_flavor); + // Safety: // the keys are in bounds - unsafe { - DictionaryArray::try_new_unchecked(dtype, keys, Box::new(values.clone())) - .unwrap() - } + unsafe { DictionaryArray::try_new_unchecked(dtype, keys, values).unwrap() } }, } } diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index bfe6299eba67..76cb154da5fa 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -339,11 +339,18 @@ impl DataType { polars_bail!(InvalidOperation: "cannot convert Object dtype data to Arrow") }, #[cfg(feature = "dtype-categorical")] - Categorical(_, _) => Ok(ArrowDataType::Dictionary( - IntegerType::UInt32, - Box::new(ArrowDataType::LargeUtf8), - false, - )), + Categorical(_, _) => { + let values = if pl_flavor { + ArrowDataType::Utf8View + } else { + ArrowDataType::LargeUtf8 + }; + Ok(ArrowDataType::Dictionary( + IntegerType::UInt32, + Box::new(values), + false, + )) + }, #[cfg(feature = "dtype-struct")] Struct(fields) => { let fields = fields.iter().map(|fld| fld.to_arrow(pl_flavor)).collect(); @@ -492,7 +499,7 @@ pub(crate) fn can_extend_dtype(left: &DataType, right: &DataType) -> PolarsResul #[cfg(feature = "dtype-categorical")] pub fn create_enum_data_type(categories: Utf8ViewArray) -> DataType { - let rev_map = RevMapping::build_enum(categories.clone()); + let rev_map = RevMapping::build_enum(categories); DataType::Categorical(Some(Arc::new(rev_map)), Default::default()) } diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 086aeafa75ef..dfc86b100e15 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -290,7 +290,10 @@ impl Series { if !matches!( value_type.as_ref(), - ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Null + ArrowDataType::Utf8 + | ArrowDataType::LargeUtf8 + | ArrowDataType::Utf8View + | ArrowDataType::Null ) { polars_bail!( ComputeError: "only string-like values are supported in dictionaries" @@ -303,7 +306,7 @@ impl Series { let keys = arr.keys(); let keys = cast(keys, &ArrowDataType::UInt32).unwrap(); let values = arr.values(); - let values = cast(&**values, &ArrowDataType::LargeUtf8)?; + let values = cast(&**values, &ArrowDataType::Utf8View)?; (keys, values) }}; } diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs index 10eb3687ff43..a12d440bfbb4 100644 --- a/crates/polars-core/src/series/into.rs +++ b/crates/polars-core/src/series/into.rs @@ -73,8 +73,7 @@ impl Series { ) }; - let arr: DictionaryArray = (&new).into(); - Box::new(arr) as ArrayRef + new.to_arrow(pl_flavor, false) }, #[cfg(feature = "dtype-date")] DataType::Date => cast( diff --git a/crates/polars-json/src/json/write/serialize.rs b/crates/polars-json/src/json/write/serialize.rs index 6347e014c722..77e937b8647f 100644 --- a/crates/polars-json/src/json/write/serialize.rs +++ b/crates/polars-json/src/json/write/serialize.rs @@ -112,12 +112,12 @@ where materialize_serializer(f, array.iter(), offset, take) } -fn dictionary_utf8_serializer<'a, K: DictionaryKey, O: Offset>( +fn dictionary_utf8view_serializer<'a, K: DictionaryKey>( array: &'a DictionaryArray, offset: usize, take: usize, ) -> Box + 'a + Send + Sync> { - let iter = array.iter_typed::>().unwrap().skip(offset); + let iter = array.iter_typed::().unwrap().skip(offset); let f = |x: Option<&str>, buf: &mut Vec| { if let Some(x) = x { utf8::write_str(buf, x).unwrap(); @@ -436,16 +436,17 @@ pub(crate) fn new_serializer<'a>( ArrowDataType::LargeList(_) => { list_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) }, - other @ ArrowDataType::Dictionary(k, v, _) => match (k, &**v) { - (IntegerType::UInt32, ArrowDataType::LargeUtf8) => { + ArrowDataType::Dictionary(k, v, _) => match (k, &**v) { + (IntegerType::UInt32, ArrowDataType::Utf8View) => { let array = array .as_any() .downcast_ref::>() .unwrap(); - dictionary_utf8_serializer::(array, offset, take) + dictionary_utf8view_serializer::(array, offset, take) }, _ => { - todo!("Writing {:?} to JSON", other) + // Not produced by polars + unreachable!() }, }, ArrowDataType::Date32 => date_serializer( diff --git a/crates/polars-parquet/src/arrow/write/binview/mod.rs b/crates/polars-parquet/src/arrow/write/binview/mod.rs index 280e2ff9efb5..5b0ab6102c22 100644 --- a/crates/polars-parquet/src/arrow/write/binview/mod.rs +++ b/crates/polars-parquet/src/arrow/write/binview/mod.rs @@ -1,5 +1,5 @@ mod basic; mod nested; -pub use basic::array_to_page; +pub(crate) use basic::{array_to_page, build_statistics, encode_plain}; pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index 9af6be6550d1..cfc5ad888a84 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, DictionaryArray, DictionaryKey}; +use arrow::array::{Array, DictionaryArray, DictionaryKey, Utf8ViewArray}; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::{ArrowDataType, IntegerType}; use num_traits::ToPrimitive; @@ -13,7 +13,7 @@ use super::fixed_len_bytes::{ use super::primitive::{ build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain, }; -use super::{nested, Nested, WriteOptions}; +use super::{binview, nested, Nested, WriteOptions}; use crate::arrow::read::schema::is_nullable; use crate::arrow::write::{slice_nested_leaf, utils}; use crate::parquet::encoding::hybrid_rle::encode_u32; @@ -278,6 +278,23 @@ pub fn array_to_pages( }; (DictPage::new(buffer, array.len(), false), stats) }, + ArrowDataType::Utf8View => { + let array = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .to_binview(); + let mut buffer = vec![]; + binview::encode_plain(&array, &mut buffer); + + let stats = if options.write_statistics { + Some(binview::build_statistics(&array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, ArrowDataType::LargeBinary => { let values = array.values().as_any().downcast_ref().unwrap(); diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index ce48fd2980b7..f4e30e738b30 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -1,6 +1,6 @@ use arrow::array::{ Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, PrimitiveArray, - StructArray, Utf8Array, Utf8ViewArray, + StructArray, Utf8ViewArray, }; use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; @@ -128,7 +128,7 @@ unsafe fn encode_array(array: &dyn Array, field: &SortField, out: &mut RowsEncod .downcast_ref::>() .unwrap(); let iter = array - .iter_typed::>() + .iter_typed::() .unwrap() .map(|opt_s| opt_s.map(|s| s.as_bytes())); crate::variable::encode_iter(iter, out, field) @@ -225,7 +225,7 @@ pub fn allocate_rows_buf( .downcast_ref::>() .unwrap(); let iter = array - .iter_typed::>() + .iter_typed::() .unwrap() .map(|opt_s| opt_s.map(|s| s.as_bytes())); if processed_count == 0 { diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index d58b0a220e11..1e4641adeaf1 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -16,7 +16,6 @@ use polars_core::prelude::{IndexOrder, QuantileInterpolOptions}; use polars_core::utils::arrow::array::Array; use polars_core::utils::arrow::types::NativeType; use polars_lazy::prelude::*; -use polars_rs::export::arrow; #[cfg(feature = "cloud")] use polars_rs::io::cloud::CloudOptions; use polars_utils::total_ord::TotalEq; @@ -515,9 +514,8 @@ impl FromPyObject<'_> for Wrap { let categories = ob.getattr(intern!(py, "categories")).unwrap(); let s = get_series(categories)?; let ca = s.str().map_err(PyPolarsErr::from)?; - let arr = ca.downcast_iter().next().unwrap(); - let categories = arrow::compute::cast::utf8view_to_utf8(arr); - create_enum_data_type(categories) + let categories = ca.downcast_iter().next().unwrap(); + create_enum_data_type(categories.clone()) }, "Date" => DataType::Date, "Time" => DataType::Time, diff --git a/py-polars/src/datatypes.rs b/py-polars/src/datatypes.rs index 4982adbd9298..11474e4c2b9a 100644 --- a/py-polars/src/datatypes.rs +++ b/py-polars/src/datatypes.rs @@ -1,5 +1,5 @@ use polars::prelude::*; -use polars_core::export::arrow::array::Utf8Array; +use polars_core::utils::arrow::array::Utf8ViewArray; use pyo3::{FromPyObject, PyAny, PyResult}; #[cfg(feature = "object")] @@ -33,7 +33,7 @@ pub(crate) enum PyDataType { Binary, Decimal(Option, usize), Array(usize), - Enum(Utf8Array), + Enum(Utf8ViewArray), } impl From<&DataType> for PyDataType { From e3519e9ba1ab806d00c89beb0c531e6e5a36e12e Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 21 Jan 2024 13:10:42 +0100 Subject: [PATCH 4/7] lint --- .../polars-core/src/chunked_array/logical/categorical/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 6ac3b7ab926e..9d79dc28de8a 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -410,8 +410,8 @@ mod test { let ca = ca.cast(&DataType::Categorical(None, Default::default()))?; let ca = ca.categorical().unwrap(); - let arr: DictionaryArray = (ca).into(); - let s = Series::try_from(("foo", Box::new(arr) as ArrayRef))?; + let arr = ca.to_arrow(true, false); + let s = Series::try_from(("foo", arr))?; assert!(matches!(s.dtype(), &DataType::Categorical(_, _))); assert_eq!(s.null_count(), 1); assert_eq!(s.len(), 6); From 8c50f1b903fecc8faadf38a8cfdd6d691bf2b2d3 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 22 Jan 2024 07:54:50 +0100 Subject: [PATCH 5/7] implement dictionary stringview --- .../read/deserialize/binview/dictionary.rs | 164 ++++++++++++++++++ .../src/arrow/read/deserialize/binview/mod.rs | 6 +- .../src/arrow/read/deserialize/nested.rs | 3 + .../src/arrow/read/deserialize/simple.rs | 3 + 4 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 crates/polars-parquet/src/arrow/read/deserialize/binview/dictionary.rs diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/dictionary.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/dictionary.rs new file mode 100644 index 000000000000..1951467a24ad --- /dev/null +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/dictionary.rs @@ -0,0 +1,164 @@ +use std::collections::VecDeque; + +use arrow::array::{Array, DictionaryArray, DictionaryKey, MutableBinaryViewArray}; +use arrow::bitmap::MutableBitmap; +use arrow::datatypes::{ArrowDataType, PhysicalType}; +use polars_error::PolarsResult; + +use super::super::dictionary::*; +use super::super::utils::MaybeNext; +use super::super::PagesIter; +use crate::arrow::read::deserialize::nested_utils::{InitNested, NestedState}; +use crate::parquet::page::DictPage; +use crate::read::deserialize::binary::utils::BinaryIter; + +/// An iterator adapter over [`PagesIter`] assumed to be encoded as parquet's dictionary-encoded binary representation +#[derive(Debug)] +pub struct DictIter +where + I: PagesIter, + K: DictionaryKey, +{ + iter: I, + data_type: ArrowDataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, +} + +impl DictIter +where + K: DictionaryKey, + I: PagesIter, +{ + pub fn new( + iter: I, + data_type: ArrowDataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + } + } +} + +fn read_dict(data_type: ArrowDataType, dict: &DictPage) -> Box { + let data_type = match data_type { + ArrowDataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + + let values = BinaryIter::new(&dict.buffer).take(dict.num_values); + + let mut data = MutableBinaryViewArray::<[u8]>::with_capacity(dict.num_values); + for item in values { + data.push_value(item) + } + + match data_type.to_physical_type() { + PhysicalType::Utf8View => data.freeze().to_utf8view().unwrap().boxed(), + _ => unreachable!(), + } +} + +impl Iterator for DictIter +where + I: PagesIter, + K: DictionaryKey, +{ + type Item = PolarsResult>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`] +#[derive(Debug)] +pub struct NestedDictIter +where + I: PagesIter, + K: DictionaryKey, +{ + iter: I, + init: Vec, + data_type: ArrowDataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, +} + +impl NestedDictIter +where + I: PagesIter, + K: DictionaryKey, +{ + pub fn new( + iter: I, + init: Vec, + data_type: ArrowDataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + } + } +} + +impl Iterator for NestedDictIter +where + I: PagesIter, + K: DictionaryKey, +{ + type Item = PolarsResult<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + loop { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => return Some(Ok(dict)), + MaybeNext::Some(Err(e)) => return Some(Err(e)), + MaybeNext::None => return None, + MaybeNext::More => continue, + } + } + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs index 11bb82dfcee1..1e93e5ae1e42 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/mod.rs @@ -1,5 +1,7 @@ mod basic; +mod dictionary; mod nested; -pub use basic::BinaryViewArrayIter; -pub use nested::NestedIter; +pub(crate) use basic::BinaryViewArrayIter; +pub(crate) use dictionary::{DictIter, NestedDictIter}; +pub(crate) use nested::NestedIter; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index 0341be876fb9..3289700be95d 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -562,6 +562,9 @@ fn dict_read<'a, K: DictionaryKey, I: 'a + PagesIter>( LargeUtf8 | LargeBinary => primitive(binary::NestedDictIter::::new( iter, init, data_type, num_rows, chunk_size, )), + Utf8View => primitive(binview::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), FixedSizeBinary(_) => primitive(fixed_size_binary::NestedDictIter::::new( iter, init, data_type, num_rows, chunk_size, )), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs index abadab3df26c..d3ca4f7b6e37 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs @@ -639,6 +639,9 @@ fn dict_read<'a, K: DictionaryKey, I: PagesIter + 'a>( (PhysicalType::ByteArray, LargeUtf8 | LargeBinary) => dyn_iter( binary::DictIter::::new(iter, data_type, num_rows, chunk_size), ), + (PhysicalType::ByteArray, BinaryView) => dyn_iter(binview::DictIter::::new( + iter, data_type, num_rows, chunk_size, + )), (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => dyn_iter( fixed_size_binary::DictIter::::new(iter, data_type, num_rows, chunk_size), ), From 63b38340d52f41a312618148b762fc21334483b6 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 22 Jan 2024 07:57:21 +0100 Subject: [PATCH 6/7] read dict as utfview --- .../polars-parquet/src/arrow/read/schema/metadata.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/schema/metadata.rs b/crates/polars-parquet/src/arrow/read/schema/metadata.rs index 0dbcd7829753..5b3dd20725cb 100644 --- a/crates/polars-parquet/src/arrow/read/schema/metadata.rs +++ b/crates/polars-parquet/src/arrow/read/schema/metadata.rs @@ -20,13 +20,13 @@ pub fn read_schema_from_metadata(metadata: &mut Metadata) -> PolarsResult