From 2a1c9bf2799b780c79a4dc331305c40ba5bead59 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 22 Mar 2023 18:53:31 +0000 Subject: [PATCH] Add PrimitiveArray::new (#3879) --- arrow-arith/src/arity.rs | 26 ++----- arrow-array/src/array/primitive_array.rs | 95 ++++++++++++++---------- arrow-buffer/src/buffer/scalar.rs | 7 ++ 3 files changed, 67 insertions(+), 61 deletions(-) diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index 0a8815cc8059..782c8270cf85 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -23,25 +23,10 @@ use arrow_array::types::ArrowDictionaryKeyType; use arrow_array::*; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::{Buffer, MutableBuffer}; -use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_data::ArrayData; use arrow_schema::ArrowError; use std::sync::Arc; -#[inline] -unsafe fn build_primitive_array( - len: usize, - buffer: Buffer, - nulls: Option, -) -> PrimitiveArray { - PrimitiveArray::from( - ArrayDataBuilder::new(O::DATA_TYPE) - .len(len) - .nulls(nulls) - .buffers(vec![buffer]) - .build_unchecked(), - ) -} - /// See [`PrimitiveArray::unary`] pub fn unary(array: &PrimitiveArray, op: F) -> PrimitiveArray where @@ -209,7 +194,6 @@ where "Cannot perform binary operation on arrays of different length".to_string(), )); } - let len = a.len(); if a.is_empty() { return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); @@ -224,8 +208,7 @@ where // Soundness // `values` is an iterator with a known size from a PrimitiveArray let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; - - Ok(unsafe { build_primitive_array(len, buffer, nulls) }) + Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls)) } /// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating @@ -328,7 +311,8 @@ where Ok::<_, ArrowError>(()) })?; - Ok(unsafe { build_primitive_array(len, buffer.finish(), Some(nulls)) }) + let values = buffer.finish().into(); + Ok(PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls))) } } @@ -412,7 +396,7 @@ where buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?); }; } - Ok(unsafe { build_primitive_array(len, buffer.into(), None) }) + Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), None)) } /// This intentional inline(never) attribute helps LLVM optimize the loop. diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 241e2a051197..9fbbe4dd96d4 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -29,7 +29,7 @@ use arrow_buffer::{ i256, ArrowNativeType, BooleanBuffer, Buffer, NullBuffer, ScalarBuffer, }; use arrow_data::bit_iterator::try_for_each_valid_idx; -use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType}; use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime}; use half::f16; @@ -251,19 +251,58 @@ pub struct PrimitiveArray { /// Underlying ArrayData data: ArrayData, /// Values data - raw_values: ScalarBuffer, + values: ScalarBuffer, } impl Clone for PrimitiveArray { fn clone(&self) -> Self { Self { data: self.data.clone(), - raw_values: self.raw_values.clone(), + values: self.values.clone(), } } } impl PrimitiveArray { + /// Create a new [`PrimitiveArray`] from the provided data_type, values, nulls + /// + /// # Panics + /// + /// Panics if: + /// - `values.len() != nulls.len()` + /// - `!Self::is_compatible(data_type)` + pub fn new( + data_type: DataType, + values: ScalarBuffer, + nulls: Option, + ) -> Self { + Self::assert_compatible(&data_type); + if let Some(n) = nulls.as_ref() { + assert_eq!(values.len(), n.len()); + } + + // TODO: Don't store ArrayData inside arrays (#3880) + let data = unsafe { + ArrayData::builder(data_type) + .len(values.len()) + .nulls(nulls) + .buffers(vec![values.inner().clone()]) + .build_unchecked() + }; + + Self { data, values } + } + + /// Asserts that `data_type` is compatible with `Self` + fn assert_compatible(data_type: &DataType) { + assert!( + Self::is_compatible(data_type), + "PrimitiveArray expected ArrayData with type {} got {}", + T::DATA_TYPE, + data_type + ); + } + /// Returns the length of this array. #[inline] pub fn len(&self) -> usize { @@ -278,7 +317,7 @@ impl PrimitiveArray { /// Returns the values of this array #[inline] pub fn values(&self) -> &ScalarBuffer { - &self.raw_values + &self.values } /// Returns a new primitive array builder @@ -308,7 +347,7 @@ impl PrimitiveArray { /// caller must ensure that the passed in offset is less than the array len() #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> T::Native { - *self.raw_values.get_unchecked(i) + *self.values.get_unchecked(i) } /// Returns the primitive value at index `i`. @@ -346,7 +385,7 @@ impl PrimitiveArray { pub fn from_value(value: T::Native, count: usize) -> Self { unsafe { let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value)); - build_primitive_array(count, val_buf, None) + Self::new(T::DATA_TYPE, val_buf.into(), None) } } @@ -422,7 +461,6 @@ impl PrimitiveArray { F: Fn(T::Native) -> O::Native, { let data = self.data(); - let len = self.len(); let nulls = data.nulls().cloned(); let values = self.values().iter().map(|v| op(*v)); @@ -432,7 +470,7 @@ impl PrimitiveArray { // Soundness // `values` is an iterator with a known size because arrays are sized. let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; - unsafe { build_primitive_array(len, buffer, nulls) } + PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls) } /// Applies an unary and infallible function to a mutable primitive array. @@ -495,7 +533,8 @@ impl PrimitiveArray { None => (0..len).try_for_each(f)?, } - Ok(unsafe { build_primitive_array(len, buffer.finish(), nulls) }) + let values = buffer.finish().into(); + Ok(PrimitiveArray::new(O::DATA_TYPE, values, nulls)) } /// Applies an unary and fallible function to all valid values in a mutable primitive array. @@ -579,13 +618,9 @@ impl PrimitiveArray { }); let nulls = BooleanBuffer::new(null_builder.finish(), 0, len); - unsafe { - build_primitive_array( - len, - buffer.finish(), - Some(NullBuffer::new_unchecked(nulls, out_null_count)), - ) - } + let values = buffer.finish().into(); + let nulls = unsafe { NullBuffer::new_unchecked(nulls, out_null_count) }; + PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls)) } /// Returns `PrimitiveBuilder` of this primitive array for mutating its values if the underlying @@ -599,7 +634,7 @@ impl PrimitiveArray { .slice_with_length(self.data.offset() * element_len, len * element_len); drop(self.data); - drop(self.raw_values); + drop(self.values); let try_mutable_null_buffer = match null_bit_buffer { None => Ok(None), @@ -647,21 +682,6 @@ impl PrimitiveArray { } } -#[inline] -unsafe fn build_primitive_array( - len: usize, - buffer: Buffer, - nulls: Option, -) -> PrimitiveArray { - PrimitiveArray::from( - ArrayDataBuilder::new(O::DATA_TYPE) - .len(len) - .buffers(vec![buffer]) - .nulls(nulls) - .build_unchecked(), - ) -} - impl From> for ArrayData { fn from(array: PrimitiveArray) -> Self { array.data @@ -1052,21 +1072,16 @@ impl PrimitiveArray { /// Constructs a `PrimitiveArray` from an array data reference. impl From for PrimitiveArray { fn from(data: ArrayData) -> Self { - assert!( - Self::is_compatible(data.data_type()), - "PrimitiveArray expected ArrayData with type {} got {}", - T::DATA_TYPE, - data.data_type() - ); + Self::assert_compatible(data.data_type()); assert_eq!( data.buffers().len(), 1, "PrimitiveArray data should contain a single buffer only (values buffer)" ); - let raw_values = + let values = ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()); - Self { data, raw_values } + Self { data, values } } } diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 4c16a736b10b..1a4680111bd1 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -17,6 +17,7 @@ use crate::buffer::Buffer; use crate::native::ArrowNativeType; +use crate::MutableBuffer; use std::fmt::Formatter; use std::marker::PhantomData; use std::ops::Deref; @@ -96,6 +97,12 @@ impl AsRef<[T]> for ScalarBuffer { } } +impl From for ScalarBuffer { + fn from(value: MutableBuffer) -> Self { + Buffer::from(value).into() + } +} + impl From for ScalarBuffer { fn from(buffer: Buffer) -> Self { let align = std::mem::align_of::();