Skip to content

Commit

Permalink
Add AnyDictionary Abstraction and Take ArrayRef in DictionaryArray::w…
Browse files Browse the repository at this point in the history
…ith_values (apache#4707)

* Add AnyDictionary Abstraction

* Review feedback

* Move to AsArray
  • Loading branch information
tustvold authored Aug 17, 2023
1 parent 31c81c5 commit f0200db
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 19 deletions.
8 changes: 6 additions & 2 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op);
Ok(Arc::new(array.with_values(&values)))
Ok(Arc::new(array.with_values(Arc::new(values))))
}

/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
Expand All @@ -105,10 +105,11 @@ where

let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?;
Ok(Arc::new(array.with_values(&values)))
Ok(Arc::new(array.with_values(Arc::new(values))))
}

/// Applies an infallible unary function to an array with primitive values.
#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
where
T: ArrowPrimitiveType,
Expand All @@ -134,6 +135,7 @@ where
}

/// Applies a fallible unary function to an array with primitive values.
#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef, ArrowError>
where
T: ArrowPrimitiveType,
Expand Down Expand Up @@ -436,6 +438,7 @@ mod tests {
use arrow_array::types::*;

#[test]
#[allow(deprecated)]
fn test_unary_f64_slice() {
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]);
Expand All @@ -455,6 +458,7 @@ mod tests {
}

#[test]
#[allow(deprecated)]
fn test_unary_dict_and_unary_dyn() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();
builder.append(5).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion arrow-arith/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ where
downcast_dictionary_array!(
array => {
let values = time_fraction_dyn(array.values(), name, op)?;
Ok(Arc::new(array.with_values(&values)))
Ok(Arc::new(array.with_values(values)))
}
dt => return_compute_error_with!(format!("{name} does not support"), dt),
)
Expand Down
116 changes: 101 additions & 15 deletions arrow-array/src/array/dictionary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// Panics if `values` has a length less than the current values
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::builder::PrimitiveDictionaryBuilder;
/// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor};
/// # use arrow_array::types::{Int32Type, Int8Type};
Expand All @@ -451,7 +452,7 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64);
///
/// // Create a Dict(Int32,
/// let new = dictionary.with_values(&values);
/// let new = dictionary.with_values(Arc::new(values));
///
/// // Verify values are as expected
/// let new_typed = new.downcast_dict::<Int64Array>().unwrap();
Expand All @@ -460,21 +461,18 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// }
/// ```
///
pub fn with_values(&self, values: &dyn Array) -> Self {
pub fn with_values(&self, values: ArrayRef) -> Self {
assert!(values.len() >= self.values.len());

let builder = self
.to_data()
.into_builder()
.data_type(DataType::Dictionary(
Box::new(K::DATA_TYPE),
Box::new(values.data_type().clone()),
))
.child_data(vec![values.to_data()]);

// SAFETY:
// Offsets were valid before and verified length is greater than or equal
Self::from(unsafe { builder.build_unchecked() })
let data_type = DataType::Dictionary(
Box::new(K::DATA_TYPE),
Box::new(values.data_type().clone()),
);
Self {
data_type,
keys: self.keys.clone(),
values,
is_ordered: false,
}
}

/// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating
Expand Down Expand Up @@ -930,6 +928,94 @@ where
}
}

/// A [`DictionaryArray`] with the key type erased
///
/// This can be used to efficiently implement kernels for all possible dictionary
/// keys without needing to create specialized implementations for each key type
///
/// For example
///
/// ```
/// # use arrow_array::*;
/// # use arrow_array::cast::AsArray;
/// # use arrow_array::builder::PrimitiveDictionaryBuilder;
/// # use arrow_array::types::*;
/// # use arrow_schema::ArrowError;
/// # use std::sync::Arc;
///
/// fn to_string(a: &dyn Array) -> Result<ArrayRef, ArrowError> {
/// if let Some(d) = a.as_any_dictionary_opt() {
/// // Recursively handle dictionary input
/// let r = to_string(d.values().as_ref())?;
/// return Ok(d.with_values(r));
/// }
/// downcast_primitive_array! {
/// a => Ok(Arc::new(a.iter().map(|x| x.map(|x| x.to_string())).collect::<StringArray>())),
/// d => Err(ArrowError::InvalidArgumentError(format!("{d:?} not supported")))
/// }
/// }
///
/// let result = to_string(&Int32Array::from(vec![1, 2, 3])).unwrap();
/// let actual = result.as_string::<i32>().iter().map(Option::unwrap).collect::<Vec<_>>();
/// assert_eq!(actual, &["1", "2", "3"]);
///
/// let mut dict = PrimitiveDictionaryBuilder::<Int32Type, UInt16Type>::new();
/// dict.extend([Some(1), Some(1), Some(2), Some(3), Some(2)]);
/// let dict = dict.finish();
///
/// let r = to_string(&dict).unwrap();
/// let r = r.as_dictionary::<Int32Type>().downcast_dict::<StringArray>().unwrap();
/// assert_eq!(r.keys(), dict.keys()); // Keys are the same
///
/// let actual = r.into_iter().map(Option::unwrap).collect::<Vec<_>>();
/// assert_eq!(actual, &["1", "1", "2", "3", "2"]);
/// ```
///
/// See [`AsArray::as_any_dictionary_opt`] and [`AsArray::as_any_dictionary`]
pub trait AnyDictionaryArray: Array {
/// Returns the primitive keys of this dictionary as an [`Array`]
fn keys(&self) -> &dyn Array;

/// Returns the values of this dictionary
fn values(&self) -> &ArrayRef;

/// Returns the keys of this dictionary as usize
///
/// The values for nulls will be arbitrary, but are guaranteed
/// to be in the range `0..self.values.len()`
///
/// # Panic
///
/// Panics if `values.len() == 0`
fn normalized_keys(&self) -> Vec<usize>;

/// Create a new [`DictionaryArray`] replacing `values` with the new values
///
/// See [`DictionaryArray::with_values`]
fn with_values(&self, values: ArrayRef) -> ArrayRef;
}

impl<K: ArrowDictionaryKeyType> AnyDictionaryArray for DictionaryArray<K> {
fn keys(&self) -> &dyn Array {
&self.keys
}

fn values(&self) -> &ArrayRef {
self.values()
}

fn normalized_keys(&self) -> Vec<usize> {
let v_len = self.values().len();
assert_ne!(v_len, 0);
let iter = self.keys().values().iter();
iter.map(|x| x.as_usize().min(v_len)).collect()
}

fn with_values(&self, values: ArrayRef) -> ArrayRef {
Arc::new(self.with_values(values))
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
20 changes: 20 additions & 0 deletions arrow-array/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,14 @@ pub trait AsArray: private::Sealed {
fn as_dictionary<K: ArrowDictionaryKeyType>(&self) -> &DictionaryArray<K> {
self.as_dictionary_opt().expect("dictionary array")
}

/// Downcasts this to a [`AnyDictionaryArray`] returning `None` if not possible
fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray>;

/// Downcasts this to a [`AnyDictionaryArray`] panicking if not possible
fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray {
self.as_any_dictionary_opt().expect("any dictionary array")
}
}

impl private::Sealed for dyn Array + '_ {}
Expand Down Expand Up @@ -874,6 +882,14 @@ impl AsArray for dyn Array + '_ {
) -> Option<&DictionaryArray<K>> {
self.as_any().downcast_ref()
}

fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> {
let array = self;
downcast_dictionary_array! {
array => Some(array),
_ => None
}
}
}

impl private::Sealed for ArrayRef {}
Expand Down Expand Up @@ -915,6 +931,10 @@ impl AsArray for ArrayRef {
) -> Option<&DictionaryArray<K>> {
self.as_ref().as_dictionary_opt()
}

fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> {
self.as_ref().as_any_dictionary_opt()
}
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion arrow-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,7 @@ mod tests {
// Construct dictionary with a timezone
let dict = a.finish();
let values = TimestampNanosecondArray::from(dict.values().to_data());
let dict_with_tz = dict.with_values(&values.with_timezone("+02:00"));
let dict_with_tz = dict.with_values(Arc::new(values.with_timezone("+02:00")));
let d = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Timestamp(
Expand Down

0 comments on commit f0200db

Please sign in to comment.