Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Basic support for Arrow 128-bit Decimal. #1129

Merged
merged 18 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ def null(cls) -> DataType:
"""Creates the Null DataType: Always the ``Null`` value"""
return cls._from_pydatatype(PyDataType.null())

@classmethod
def decimal128(cls, precision: int, scale: int) -> DataType:
"""Fixed-precision decimal."""
return cls._from_pydatatype(PyDataType.decimal128(precision, scale))

@classmethod
def date(cls) -> DataType:
"""Create a Date DataType: A date with a year, month and day"""
Expand Down Expand Up @@ -314,6 +319,8 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
return cls.bool()
elif pa.types.is_null(arrow_type):
return cls.null()
elif pa.types.is_decimal128(arrow_type):
return cls.decimal128(arrow_type.precision, arrow_type.scale)
elif pa.types.is_date32(arrow_type):
return cls.date()
elif pa.types.is_timestamp(arrow_type):
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/array/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ where
impl<T> Add for &DataArray<T>
where
T: DaftNumericType,
T::Native: basic::NativeArithmetics,
{
type Output = DaftResult<DataArray<T>>;
fn add(self, rhs: Self) -> Self::Output {
Expand All @@ -94,6 +95,7 @@ impl Add for &Utf8Array {
impl<T> Sub for &DataArray<T>
where
T: DaftNumericType,
T::Native: basic::NativeArithmetics,
{
type Output = DaftResult<DataArray<T>>;
fn sub(self, rhs: Self) -> Self::Output {
Expand All @@ -104,6 +106,7 @@ where
impl<T> Mul for &DataArray<T>
where
T: DaftNumericType,
T::Native: basic::NativeArithmetics,
{
type Output = DaftResult<DataArray<T>>;
fn mul(self, rhs: Self) -> Self::Output {
Expand Down Expand Up @@ -148,6 +151,7 @@ where
impl<T> Rem for &DataArray<T>
where
T: DaftNumericType,
T::Native: basic::NativeArithmetics,
{
type Output = DaftResult<DataArray<T>>;
fn rem(self, rhs: Self) -> Self::Output {
Expand Down
5 changes: 3 additions & 2 deletions src/daft-core/src/array/ops/as_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::{
array::DataArray,
datatypes::{
logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
ImageArray, TimestampArray,
},
BinaryArray, BooleanArray, DaftNumericType, FixedSizeListArray, ListArray, StructArray,
Utf8Array,
Expand Down Expand Up @@ -69,6 +69,7 @@ impl_asarrow_dataarray!(StructArray, array::StructArray);
#[cfg(feature = "python")]
impl_asarrow_dataarray!(PythonArray, PseudoArrowArray<pyo3::PyObject>);

impl_asarrow_logicalarray!(Decimal128Array, array::PrimitiveArray<i128>);
impl_asarrow_logicalarray!(DateArray, array::PrimitiveArray<i32>);
impl_asarrow_logicalarray!(DurationArray, array::PrimitiveArray<i64>);
impl_asarrow_logicalarray!(TimestampArray, array::PrimitiveArray<i64>);
Expand Down
106 changes: 84 additions & 22 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use crate::{
array::DataArray,
datatypes::{
logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
LogicalArray, TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
ImageArray, LogicalArray, TimestampArray,
},
DaftLogicalType,
},
datatypes::{DaftArrowBackedType, DataType, Field, Utf8Array},
series::Series,
with_match_arrow_daft_types, with_match_daft_logical_types,
with_match_arrow_daft_types, with_match_daft_logical_primitive_types,
with_match_daft_logical_types,
};
use common_error::{DaftError, DaftResult};

Expand Down Expand Up @@ -54,16 +55,35 @@ where
// Get the result of the Arrow Logical->Target cast.
let result_arrow_array = {
// First, get corresponding Arrow LogicalArray of source DataArray
let source_arrow_array = cast(
to_cast.physical.data(),
&source_arrow_type,
CastOptions {
wrapped: true,
partial: false,
},
)?;
use DataType::*;
let source_arrow_array = match source_dtype {
// Wrapped primitives
Decimal128(..) | Date | Timestamp(..) | Duration(..) => {
with_match_daft_logical_primitive_types!(source_dtype, |$T| {
use arrow2::array::Array;
to_cast
.physical
.data()
.as_any()
.downcast_ref::<arrow2::array::PrimitiveArray<$T>>()
.unwrap()
.clone()
.to(source_arrow_type)
.to_boxed()
})
}
_ => cast(
to_cast.physical.data(),
&source_arrow_type,
CastOptions {
wrapped: true,
partial: false,
},
)?,
};

// Then, cast source Arrow LogicalArray to target Arrow LogicalArray.

cast(
source_arrow_array.as_ref(),
&target_arrow_type,
Expand All @@ -75,19 +95,35 @@ where
};

// If the target type is also Logical, get the Arrow Physical.
let target_physical_type = dtype.to_physical().to_arrow()?;
let result_arrow_physical_array = {
if target_physical_type == target_arrow_type {
result_arrow_array
if dtype.is_logical() {
use DataType::*;
let target_physical_type = dtype.to_physical().to_arrow()?;
match dtype {
// Primitive wrapper types: change the arrow2 array's type field to primitive
Decimal128(..) | Date | Timestamp(..) | Duration(..) => {
with_match_daft_logical_primitive_types!(dtype, |$P| {
use arrow2::array::Array;
result_arrow_array
.as_any()
.downcast_ref::<arrow2::array::PrimitiveArray<$P>>()
.unwrap()
.clone()
.to(target_physical_type)
.to_boxed()
})
}
_ => arrow2::compute::cast::cast(
result_arrow_array.as_ref(),
&target_physical_type,
arrow2::compute::cast::CastOptions {
wrapped: true,
partial: false,
},
)?,
}
} else {
cast(
result_arrow_array.as_ref(),
&target_physical_type,
CastOptions {
wrapped: true,
partial: false,
},
)?
result_arrow_array
}
};

Expand Down Expand Up @@ -244,6 +280,22 @@ impl DateArray {
}
}

pub(super) fn decimal128_to_str(val: i128, _precision: u8, scale: i8) -> String {
if scale < 0 {
unimplemented!();
} else {
let modulus = i128::pow(10, scale as u32);
let integral = val / modulus;
if scale == 0 {
format!("{}", integral)
} else {
let decimals = (val % modulus).abs();
let scale = scale as usize;
format!("{}.{:0scale$}", integral, decimals)
}
}
}

pub(super) fn timestamp_to_str_naive(val: i64, unit: &TimeUnit) -> String {
let chrono_ts = {
arrow2::temporal_conversions::timestamp_to_naive_datetime(val, unit.to_arrow().unwrap())
Expand Down Expand Up @@ -342,6 +394,16 @@ impl DurationArray {
}
}

impl Decimal128Array {
pub fn cast(&self, dtype: &DataType) -> DaftResult<Series> {
match dtype {
#[cfg(feature = "python")]
DataType::Python => cast_logical_to_python_array(self, dtype),
_ => arrow_logical_cast(self, dtype),
}
}
}

#[cfg(feature = "python")]
macro_rules! pycast_then_arrowcast {
($self:expr, $daft_type:expr, $pytype_str:expr) => {
Expand Down
5 changes: 3 additions & 2 deletions src/daft-core/src/array/ops/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
array::DataArray,
datatypes::{
logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
ImageArray, TimestampArray,
},
BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray, ListArray,
NullArray, StructArray, Utf8Array,
Expand Down Expand Up @@ -61,6 +61,7 @@ impl_array_get!(BooleanArray, bool);
impl_array_get!(BinaryArray, &[u8]);
impl_array_get!(ListArray, Box<dyn arrow2::array::Array>);
impl_array_get!(FixedSizeListArray, Box<dyn arrow2::array::Array>);
impl_array_get!(Decimal128Array, i128);
impl_array_get!(DateArray, i32);
impl_array_get!(DurationArray, i64);
impl_array_get!(TimestampArray, i64);
Expand Down
4 changes: 3 additions & 1 deletion src/daft-core/src/array/ops/if_else.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::array::DataArray;
use crate::datatypes::logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray, TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
TimestampArray,
};
use crate::datatypes::{
BinaryArray, BooleanArray, DaftArrowBackedType, DaftNumericType, ExtensionArray, Field,
Expand Down Expand Up @@ -334,6 +335,7 @@ macro_rules! impl_logicalarray_if_else {
};
}

impl_logicalarray_if_else!(Decimal128Array);
impl_logicalarray_if_else!(DateArray);
impl_logicalarray_if_else!(DurationArray);
impl_logicalarray_if_else!(TimestampArray);
Expand Down
20 changes: 18 additions & 2 deletions src/daft-core/src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use crate::{
array::DataArray,
datatypes::{
logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
ImageArray, TimestampArray,
},
BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeListArray,
ImageFormat, ListArray, NullArray, StructArray, Utf8Array,
Expand Down Expand Up @@ -188,6 +188,21 @@ impl TimestampArray {
}
}

impl Decimal128Array {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
|| "None".to_string(),
|val| -> String {
use crate::array::ops::cast::decimal128_to_str;
use crate::datatypes::DataType::Decimal128;
let Decimal128(precision, scale) = &self.field.dtype else { panic!("Wrong dtype for Decimal128Array: {}", self.field.dtype) };
decimal128_to_str(val, *precision as u8, *scale as i8)
}
);
Ok(res)
}
}

// Default implementation of html_value: html escape the str_value.
macro_rules! impl_array_html_value {
($ArrayT:ty) => {
Expand All @@ -210,6 +225,7 @@ impl_array_html_value!(ListArray);
impl_array_html_value!(FixedSizeListArray);
impl_array_html_value!(StructArray);
impl_array_html_value!(ExtensionArray);
impl_array_html_value!(Decimal128Array);
impl_array_html_value!(DateArray);
impl_array_html_value!(DurationArray);
impl_array_html_value!(TimestampArray);
Expand Down
13 changes: 10 additions & 3 deletions src/daft-core/src/array/ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
array::DataArray,
datatypes::{
logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
ImageArray, TimestampArray,
},
BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray,
FixedSizeListArray, Float32Array, Float64Array, ListArray, NullArray, StructArray,
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn build_multi_array_bicompare(
impl<T> DataArray<T>
where
T: DaftIntegerType,
<T as DaftNumericType>::Native: arrow2::types::Index,
<T as DaftNumericType>::Native: Ord,
{
pub fn argsort<I>(&self, descending: bool) -> DaftResult<DataArray<I>>
where
Expand Down Expand Up @@ -594,6 +594,13 @@ impl PythonArray {
}
}

impl Decimal128Array {
pub fn sort(&self, descending: bool) -> DaftResult<Self> {
let new_array = self.physical.sort(descending)?;
Ok(Self::new(self.field.clone(), new_array))
}
}

impl DateArray {
pub fn sort(&self, descending: bool) -> DaftResult<Self> {
let new_array = self.physical.sort(descending)?;
Expand Down
5 changes: 3 additions & 2 deletions src/daft-core/src/array/ops/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
array::DataArray,
datatypes::{
logical::{
DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, ImageArray,
TimestampArray,
DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray,
ImageArray, TimestampArray,
},
BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray,
FixedSizeListArray, ListArray, NullArray, StructArray, Utf8Array,
Expand Down Expand Up @@ -66,6 +66,7 @@ impl_dataarray_take!(FixedSizeListArray);
impl_dataarray_take!(NullArray);
impl_dataarray_take!(StructArray);
impl_dataarray_take!(ExtensionArray);
impl_logicalarray_take!(Decimal128Array);
impl_logicalarray_take!(DateArray);
impl_logicalarray_take!(DurationArray);
impl_logicalarray_take!(TimestampArray);
Expand Down
Loading