Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support for decimal 32 and 64
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Mar 8, 2022
1 parent 6ed716e commit 568f7f3
Show file tree
Hide file tree
Showing 42 changed files with 1,100 additions and 588 deletions.
4 changes: 2 additions & 2 deletions guide/src/high_level.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ The following arrays are supported:

* `NullArray` (just holds nulls)
* `BooleanArray` (booleans)
* `PrimitiveArray<T>` (for ints, floats)
* `PrimitiveArray<T>` (for ints, floats, decimal)
* `Utf8Array<i32>` and `Utf8Array<i64>` (for strings)
* `BinaryArray<i32>` and `BinaryArray<i64>` (for opaque binaries)
* `FixedSizeBinaryArray` (like `BinaryArray`, but fixed size)
Expand Down Expand Up @@ -124,7 +124,7 @@ There is a one to one relationship between each variant of `PhysicalType` (an en
an each implementation of `Array` (a struct):

| `PhysicalType` | `Array` |
|-------------------|------------------------|
| ----------------- | ---------------------- |
| `Primitive(_)` | `PrimitiveArray<_>` |
| `Binary` | `BinaryArray<i32>` |
| `LargeBinary` | `BinaryArray<i64>` |
Expand Down
25 changes: 23 additions & 2 deletions src/array/primitive/fmt.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::{Debug, Formatter, Result, Write};

use crate::array::Array;
use crate::datatypes::{IntervalUnit, TimeUnit};
use crate::datatypes::{DecimalType, IntervalUnit, TimeUnit};
use crate::types::{days_ms, months_days_ns};

use super::super::super::temporal_conversions;
Expand Down Expand Up @@ -104,7 +104,27 @@ pub fn get_write_value<'a, T: NativeType, F: Write>(
Duration(TimeUnit::Millisecond) => dyn_primitive!(array, i64, |x| format!("{}ms", x)),
Duration(TimeUnit::Microsecond) => dyn_primitive!(array, i64, |x| format!("{}us", x)),
Duration(TimeUnit::Nanosecond) => dyn_primitive!(array, i64, |x| format!("{}ns", x)),
Decimal(_, scale) => {
Decimal(DecimalType::Int32, _, scale) => {
// The number 999.99 has a precision of 5 and scale of 2
let scale = *scale as u32;
let display = move |x| {
let base = x / 10i32.pow(scale);
let decimals = x - base * 10i32.pow(scale);
format!("{}.{}", base, decimals)
};
dyn_primitive!(array, i32, display)
}
Decimal(DecimalType::Int64, _, scale) => {
// The number 999.99 has a precision of 5 and scale of 2
let scale = *scale as u32;
let display = move |x| {
let base = x / 10i64.pow(scale);
let decimals = x - base * 10i64.pow(scale);
format!("{}.{}", base, decimals)
};
dyn_primitive!(array, i64, display)
}
Decimal(DecimalType::Int128, _, scale) => {
// The number 999.99 has a precision of 5 and scale of 2
let scale = *scale as u32;
let display = move |x| {
Expand All @@ -114,6 +134,7 @@ pub fn get_write_value<'a, T: NativeType, F: Write>(
};
dyn_primitive!(array, i128, display)
}

_ => unreachable!(),
}
}
Expand Down
125 changes: 65 additions & 60 deletions src/compute/arithmetics/decimal/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
arity::{binary, binary_checked},
utils::{check_same_len, combine_validities},
},
datatypes::DecimalType,
};
use crate::{
datatypes::DataType,
Expand All @@ -26,13 +27,13 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
/// ```
/// use arrow2::compute::arithmetics::decimal::add;
/// use arrow2::array::PrimitiveArray;
/// use arrow2::datatypes::DataType;
/// use arrow2::datatypes::{DataType, DecimalType};
///
/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2));
/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2));
/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
///
/// let result = add(&a, &b);
/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(5, 2));
/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
///
/// assert_eq!(result, expected);
/// ```
Expand Down Expand Up @@ -65,13 +66,13 @@ pub fn add(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveA
/// ```
/// use arrow2::compute::arithmetics::decimal::saturating_add;
/// use arrow2::array::PrimitiveArray;
/// use arrow2::datatypes::DataType;
/// use arrow2::datatypes::{DataType, DecimalType};
///
/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2));
/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2));
/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
///
/// let result = saturating_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2));
/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
///
/// assert_eq!(result, expected);
/// ```
Expand Down Expand Up @@ -109,13 +110,13 @@ pub fn saturating_add(
/// ```
/// use arrow2::compute::arithmetics::decimal::checked_add;
/// use arrow2::array::PrimitiveArray;
/// use arrow2::datatypes::DataType;
/// use arrow2::datatypes::{DataType, DecimalType};
///
/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2));
/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2));
/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
///
/// let result = checked_add(&a, &b);
/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2));
/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2));
///
/// assert_eq!(result, expected);
/// ```
Expand Down Expand Up @@ -173,12 +174,12 @@ impl ArraySaturatingAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
/// ```
/// use arrow2::compute::arithmetics::decimal::adaptive_add;
/// use arrow2::array::PrimitiveArray;
/// use arrow2::datatypes::DataType;
/// use arrow2::datatypes::{DataType, DecimalType};
///
/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2));
/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3));
/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2));
/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(DecimalType::Int128, 8, 3));
/// let result = adaptive_add(&a, &b).unwrap();
/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3));
/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(DecimalType::Int128, 8, 3));
///
/// assert_eq!(result, expected);
/// ```
Expand All @@ -188,51 +189,55 @@ pub fn adaptive_add(
) -> Result<PrimitiveArray<i128>> {
check_same_len(lhs, rhs)?;

if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
let (lhs_p, lhs_s, rhs_p, rhs_s) = if let (
DataType::Decimal(DecimalType::Int128, lhs_p, lhs_s),
DataType::Decimal(DecimalType::Int128, rhs_p, rhs_s),
) = (lhs.data_type(), rhs.data_type())
{
// The resulting precision is mutable because it could change while
// looping through the iterator
let (mut res_p, res_s, diff) = adjusted_precision_scale(*lhs_p, *lhs_s, *rhs_p, *rhs_s);

let shift = 10i128.pow(diff as u32);
let mut max = max_value(res_p);

let iter = lhs.values().iter().zip(rhs.values().iter()).map(|(l, r)| {
// Based on the array's scales one of the arguments in the sum has to be shifted
// to the left to match the final scale
let res = if lhs_s > rhs_s {
l + r * shift
} else {
l * shift + r
};

// The precision of the resulting array will change if one of the
// sums during the iteration produces a value bigger than the
// possible value for the initial precision

// 99.9999 -> 6, 4
// 00.0001 -> 6, 4
// -----------------
// 100.0000 -> 7, 4
if res.abs() > max {
res_p = number_digits(res);
max = max_value(res_p);
}
res
});
let values = Buffer::from_trusted_len_iter(iter);

let validity = combine_validities(lhs.validity(), rhs.validity());

Ok(PrimitiveArray::<i128>::new(
DataType::Decimal(res_p, res_s),
values,
validity,
))
(*lhs_p, *lhs_s, *rhs_p, *rhs_s)
} else {
Err(ArrowError::InvalidArgumentError(
return Err(ArrowError::InvalidArgumentError(
"Incorrect data type for the array".to_string(),
))
}
));
};

// The resulting precision is mutable because it could change while
// looping through the iterator
let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s);

let shift = 10i128.pow(diff as u32);
let mut max = max_value(res_p);

let iter = lhs.values().iter().zip(rhs.values().iter()).map(|(l, r)| {
// Based on the array's scales one of the arguments in the sum has to be shifted
// to the left to match the final scale
let res = if lhs_s > rhs_s {
l + r * shift
} else {
l * shift + r
};

// The precision of the resulting array will change if one of the
// sums during the iteration produces a value bigger than the
// possible value for the initial precision

// 99.9999 -> 6, 4
// 00.0001 -> 6, 4
// -----------------
// 100.0000 -> 7, 4
if res.abs() > max {
res_p = number_digits(res);
max = max_value(res_p);
}
res
});
let values = Buffer::from_trusted_len_iter(iter);

let validity = combine_validities(lhs.validity(), rhs.validity());

Ok(PrimitiveArray::<i128>::new(
DataType::Decimal(DecimalType::Int128, res_p, res_s),
values,
validity,
))
}
Loading

0 comments on commit 568f7f3

Please sign in to comment.