diff --git a/guide/src/high_level.md b/guide/src/high_level.md index 80652038b0f..44ab5831843 100644 --- a/guide/src/high_level.md +++ b/guide/src/high_level.md @@ -80,7 +80,7 @@ The following arrays are supported: * `NullArray` (just holds nulls) * `BooleanArray` (booleans) -* `PrimitiveArray` (for ints, floats) +* `PrimitiveArray` (for ints, floats, decimal) * `Utf8Array` and `Utf8Array` (for strings) * `BinaryArray` and `BinaryArray` (for opaque binaries) * `FixedSizeBinaryArray` (like `BinaryArray`, but fixed size) @@ -124,7 +124,7 @@ There is a one to one relationship between each variant of `PhysicalType` (an en an each implementation of `Array` (a struct): | `PhysicalType` | `Array` | -|-------------------|------------------------| +| ----------------- | ---------------------- | | `Primitive(_)` | `PrimitiveArray<_>` | | `Binary` | `BinaryArray` | | `LargeBinary` | `BinaryArray` | diff --git a/src/array/primitive/fmt.rs b/src/array/primitive/fmt.rs index 5a94b399424..1ee079cbba6 100644 --- a/src/array/primitive/fmt.rs +++ b/src/array/primitive/fmt.rs @@ -1,7 +1,7 @@ use std::fmt::{Debug, Formatter, Result, Write}; use crate::array::Array; -use crate::datatypes::{IntervalUnit, TimeUnit}; +use crate::datatypes::{DecimalType, IntervalUnit, TimeUnit}; use crate::types::{days_ms, months_days_ns}; use super::super::super::temporal_conversions; @@ -104,7 +104,27 @@ pub fn get_write_value<'a, T: NativeType, F: Write>( Duration(TimeUnit::Millisecond) => dyn_primitive!(array, i64, |x| format!("{}ms", x)), Duration(TimeUnit::Microsecond) => dyn_primitive!(array, i64, |x| format!("{}us", x)), Duration(TimeUnit::Nanosecond) => dyn_primitive!(array, i64, |x| format!("{}ns", x)), - Decimal(_, scale) => { + Decimal(DecimalType::Int32, _, scale) => { + // The number 999.99 has a precision of 5 and scale of 2 + let scale = *scale as u32; + let display = move |x| { + let base = x / 10i32.pow(scale); + let decimals = x - base * 10i32.pow(scale); + format!("{}.{}", base, decimals) + }; + dyn_primitive!(array, i32, display) + } + Decimal(DecimalType::Int64, _, scale) => { + // The number 999.99 has a precision of 5 and scale of 2 + let scale = *scale as u32; + let display = move |x| { + let base = x / 10i64.pow(scale); + let decimals = x - base * 10i64.pow(scale); + format!("{}.{}", base, decimals) + }; + dyn_primitive!(array, i64, display) + } + Decimal(DecimalType::Int128, _, scale) => { // The number 999.99 has a precision of 5 and scale of 2 let scale = *scale as u32; let display = move |x| { @@ -114,6 +134,7 @@ pub fn get_write_value<'a, T: NativeType, F: Write>( }; dyn_primitive!(array, i128, display) } + _ => unreachable!(), } } diff --git a/src/compute/aggregate/min_max.rs b/src/compute/aggregate/min_max.rs index 8b78614a1a7..846d41a0c62 100644 --- a/src/compute/aggregate/min_max.rs +++ b/src/compute/aggregate/min_max.rs @@ -1,5 +1,5 @@ use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; -use crate::datatypes::{DataType, IntervalUnit}; +use crate::datatypes::{DataType, DecimalType, IntervalUnit}; use crate::error::{ArrowError, Result}; use crate::scalar::*; use crate::types::simd::*; @@ -394,7 +394,7 @@ pub fn max(array: &dyn Array) -> Result> { DataType::Float16 => unreachable!(), DataType::Float32 => dyn_primitive!(f32, array, max_primitive), DataType::Float64 => dyn_primitive!(f64, array, max_primitive), - DataType::Decimal(_, _) => dyn_primitive!(i128, array, max_primitive), + DataType::Decimal(DecimalType::Int128, _, _) => dyn_primitive!(i128, array, max_primitive), DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), DataType::Binary => dyn_generic!(BinaryArray, BinaryScalar, array, max_binary), @@ -436,7 +436,7 @@ pub fn min(array: &dyn Array) -> Result> { DataType::Float16 => unreachable!(), DataType::Float32 => dyn_primitive!(f32, array, min_primitive), DataType::Float64 => dyn_primitive!(f64, array, min_primitive), - DataType::Decimal(_, _) => dyn_primitive!(i128, array, min_primitive), + DataType::Decimal(DecimalType::Int128, _, _) => dyn_primitive!(i128, array, min_primitive), DataType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), DataType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), DataType::Binary => dyn_generic!(BinaryArray, BinaryScalar, array, min_binary), diff --git a/src/compute/arithmetics/decimal/add.rs b/src/compute/arithmetics/decimal/add.rs index 84bf8f4aead..164e3a4fed7 100644 --- a/src/compute/arithmetics/decimal/add.rs +++ b/src/compute/arithmetics/decimal/add.rs @@ -6,6 +6,7 @@ use crate::{ arity::{binary, binary_checked}, utils::{check_same_len, combine_validities}, }, + datatypes::DecimalType, }; use crate::{ datatypes::DataType, @@ -25,13 +26,13 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// ``` /// use arrow2::compute::arithmetics::decimal::add; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = add(&a, &b); -/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -64,13 +65,13 @@ pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveA /// ``` /// use arrow2::compute::arithmetics::decimal::saturating_add; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = saturating_add(&a, &b); -/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -108,13 +109,13 @@ pub fn saturating_add( /// ``` /// use arrow2::compute::arithmetics::decimal::checked_add; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = checked_add(&a, &b); -/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -172,12 +173,12 @@ impl ArraySaturatingAdd> for PrimitiveArray { /// ``` /// use arrow2::compute::arithmetics::decimal::adaptive_add; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); -/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); +/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); +/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(DecimalType::Int128, 8, 3)); /// let result = adaptive_add(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3)); +/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(DecimalType::Int128, 8, 3)); /// /// assert_eq!(result, expected); /// ``` @@ -188,7 +189,7 @@ pub fn adaptive_add( check_same_len(lhs, rhs)?; let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + if let (DataType::Decimal(_, lhs_p, lhs_s), DataType::Decimal(_, rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) { (*lhs_p, *lhs_s, *rhs_p, *rhs_s) @@ -237,7 +238,7 @@ pub fn adaptive_add( let validity = combine_validities(lhs.validity(), rhs.validity()); Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), + DataType::Decimal(DecimalType::Int128, res_p, res_s), values.into(), validity, )) diff --git a/src/compute/arithmetics/decimal/div.rs b/src/compute/arithmetics/decimal/div.rs index 224209084d5..7e3319e2a0a 100644 --- a/src/compute/arithmetics/decimal/div.rs +++ b/src/compute/arithmetics/decimal/div.rs @@ -8,7 +8,7 @@ use crate::{ arity::{binary, binary_checked, unary}, utils::{check_same_len, combine_validities}, }, - datatypes::DataType, + datatypes::{DataType, DecimalType}, error::{ArrowError, Result}, scalar::{PrimitiveScalar, Scalar}, }; @@ -25,13 +25,13 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// ``` /// use arrow2::compute::arithmetics::decimal::div; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = div(&a, &b); -/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -121,13 +121,13 @@ pub fn div_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> Pr /// ``` /// use arrow2::compute::arithmetics::decimal::saturating_div; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(999_99i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(000_01i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(999_99i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(000_01i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = saturating_div(&a, &b); -/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -170,13 +170,13 @@ pub fn saturating_div( /// ``` /// use arrow2::compute::arithmetics::decimal::checked_div; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = checked_div(&a, &b); -/// let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -232,12 +232,12 @@ impl ArrayCheckedDiv> for PrimitiveArray { /// ``` /// use arrow2::compute::arithmetics::decimal::adaptive_div; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); -/// let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); +/// let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); +/// let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); /// let result = adaptive_div(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(9, 4)); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 9, 4)); /// /// assert_eq!(result, expected); /// ``` @@ -248,7 +248,7 @@ pub fn adaptive_div( check_same_len(lhs, rhs)?; let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + if let (DataType::Decimal(_, lhs_p, lhs_s), DataType::Decimal(_, rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) { (*lhs_p, *lhs_s, *rhs_p, *rhs_s) @@ -302,7 +302,7 @@ pub fn adaptive_div( let validity = combine_validities(lhs.validity(), rhs.validity()); Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), + DataType::Decimal(DecimalType::Int128, res_p, res_s), values.into(), validity, )) diff --git a/src/compute/arithmetics/decimal/mod.rs b/src/compute/arithmetics/decimal/mod.rs index ade2d0cec9c..045df4fb3db 100644 --- a/src/compute/arithmetics/decimal/mod.rs +++ b/src/compute/arithmetics/decimal/mod.rs @@ -36,7 +36,7 @@ fn number_digits(num: i128) -> usize { } fn get_parameters(lhs: &DataType, rhs: &DataType) -> Result<(usize, usize)> { - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + if let (DataType::Decimal(_, lhs_p, lhs_s), DataType::Decimal(_, rhs_p, rhs_s)) = (lhs.to_logical_type(), rhs.to_logical_type()) { if lhs_p == rhs_p && lhs_s == rhs_s { diff --git a/src/compute/arithmetics/decimal/mul.rs b/src/compute/arithmetics/decimal/mul.rs index 42f21368bb8..eefe11da7b6 100644 --- a/src/compute/arithmetics/decimal/mul.rs +++ b/src/compute/arithmetics/decimal/mul.rs @@ -8,7 +8,7 @@ use crate::{ arity::{binary, binary_checked, unary}, utils::{check_same_len, combine_validities}, }, - datatypes::DataType, + datatypes::{DataType, DecimalType}, error::{ArrowError, Result}, scalar::{PrimitiveScalar, Scalar}, }; @@ -24,13 +24,13 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// ``` /// use arrow2::compute::arithmetics::decimal::mul; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(1_00i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(1_00i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = mul(&a, &b); -/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -125,13 +125,13 @@ pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> Pr /// ``` /// use arrow2::compute::arithmetics::decimal::saturating_mul; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = saturating_mul(&a, &b); -/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -175,13 +175,13 @@ pub fn saturating_mul( /// ``` /// use arrow2::compute::arithmetics::decimal::checked_mul; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = checked_mul(&a, &b); -/// let expected = PrimitiveArray::from([None, Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([None, Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -244,12 +244,12 @@ impl ArraySaturatingMul> for PrimitiveArray { /// ``` /// use arrow2::compute::arithmetics::decimal::adaptive_mul; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(11111_0i128), Some(1_0i128)]).to(DataType::Decimal(6, 1)); -/// let b = PrimitiveArray::from([Some(10_002i128), Some(2_000i128)]).to(DataType::Decimal(5, 3)); +/// let a = PrimitiveArray::from([Some(11111_0i128), Some(1_0i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 1)); +/// let b = PrimitiveArray::from([Some(10_002i128), Some(2_000i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 3)); /// let result = adaptive_mul(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(111132_222i128), Some(2_000i128)]).to(DataType::Decimal(9, 3)); +/// let expected = PrimitiveArray::from([Some(111132_222i128), Some(2_000i128)]).to(DataType::Decimal(DecimalType::Int128, 9, 3)); /// /// assert_eq!(result, expected); /// ``` @@ -260,7 +260,7 @@ pub fn adaptive_mul( check_same_len(lhs, rhs)?; let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + if let (DataType::Decimal(_, lhs_p, lhs_s), DataType::Decimal(_, rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) { (*lhs_p, *lhs_s, *rhs_p, *rhs_s) @@ -314,7 +314,7 @@ pub fn adaptive_mul( let validity = combine_validities(lhs.validity(), rhs.validity()); Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), + DataType::Decimal(DecimalType::Int128, res_p, res_s), values.into(), validity, )) diff --git a/src/compute/arithmetics/decimal/sub.rs b/src/compute/arithmetics/decimal/sub.rs index c027dc0c51a..e0968be4b34 100644 --- a/src/compute/arithmetics/decimal/sub.rs +++ b/src/compute/arithmetics/decimal/sub.rs @@ -7,7 +7,7 @@ use crate::{ arity::{binary, binary_checked}, utils::{check_same_len, combine_validities}, }, - datatypes::DataType, + datatypes::{DataType, DecimalType}, error::{ArrowError, Result}, }; @@ -22,13 +22,13 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// ``` /// use arrow2::compute::arithmetics::decimal::sub; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = sub(&a, &b); -/// let expected = PrimitiveArray::from([Some(0i128), Some(-1i128), None, Some(0i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(0i128), Some(-1i128), None, Some(0i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -62,13 +62,13 @@ pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveA /// ``` /// use arrow2::compute::arithmetics::decimal::saturating_sub; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = saturating_sub(&a, &b); -/// let expected = PrimitiveArray::from([Some(-99999i128), Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([Some(-99999i128), Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -128,13 +128,13 @@ impl ArraySaturatingSub> for PrimitiveArray { /// ``` /// use arrow2::compute::arithmetics::decimal::checked_sub; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// let result = checked_sub(&a, &b); -/// let expected = PrimitiveArray::from([None, Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// let expected = PrimitiveArray::from([None, Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); /// /// assert_eq!(result, expected); /// ``` @@ -171,12 +171,12 @@ pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> Pr /// ``` /// use arrow2::compute::arithmetics::decimal::adaptive_sub; /// use arrow2::array::PrimitiveArray; -/// use arrow2::datatypes::DataType; +/// use arrow2::datatypes::{DataType, DecimalType}; /// -/// let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); -/// let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(6, 4)); +/// let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); +/// let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); /// let result = adaptive_sub(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 4)); /// /// assert_eq!(result, expected); /// ``` @@ -187,7 +187,7 @@ pub fn adaptive_sub( check_same_len(lhs, rhs)?; let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + if let (DataType::Decimal(_, lhs_p, lhs_s), DataType::Decimal(_, rhs_p, rhs_s)) = (lhs.data_type(), rhs.data_type()) { (*lhs_p, *lhs_s, *rhs_p, *rhs_s) @@ -237,7 +237,7 @@ pub fn adaptive_sub( let validity = combine_validities(lhs.validity(), rhs.validity()); Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), + DataType::Decimal(DecimalType::Int128, res_p, res_s), values.into(), validity, )) diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index 5766385d8a4..d6ec3ee9800 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -19,7 +19,7 @@ pub mod time; use crate::{ array::{Array, DictionaryArray, PrimitiveArray}, bitmap::Bitmap, - datatypes::{DataType, IntervalUnit, TimeUnit}, + datatypes::{DataType, DecimalType, IntervalUnit, TimeUnit}, scalar::{PrimitiveScalar, Scalar}, }; @@ -55,7 +55,7 @@ macro_rules! arith { (Float32, Float32) => primitive!(lhs, rhs, $op, f32), (Float64, Float64) => primitive!(lhs, rhs, $op, f64), $ ( - (Decimal(_, _), Decimal(_, _)) => { + (Decimal(DecimalType::Int128, _, _), Decimal(DecimalType::Int128, _, _)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); Box::new(decimal::$op_decimal(lhs, rhs)) as Box @@ -148,7 +148,7 @@ macro_rules! arith_scalar { (Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32), (Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64), $ ( - (Decimal(_, _), Decimal(_, _)) => { + (Decimal(DecimalType::Int128, _, _), Decimal(DecimalType::Int128, _, _)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); Box::new(decimal::$op_decimal(lhs, rhs)) as Box @@ -242,7 +242,10 @@ pub fn can_add(lhs: &DataType, rhs: &DataType) -> bool { | (Float64, Float64) | (Float32, Float32) | (Duration(_), Duration(_)) - | (Decimal(_, _), Decimal(_, _)) + | ( + Decimal(DecimalType::Int128, _, _), + Decimal(DecimalType::Int128, _, _) + ) | (Date32, Duration(_)) | (Date64, Duration(_)) | (Time32(TimeUnit::Millisecond), Duration(_)) @@ -303,7 +306,10 @@ pub fn can_sub(lhs: &DataType, rhs: &DataType) -> bool { | (Float64, Float64) | (Float32, Float32) | (Duration(_), Duration(_)) - | (Decimal(_, _), Decimal(_, _)) + | ( + Decimal(DecimalType::Int128, _, _), + Decimal(DecimalType::Int128, _, _) + ) | (Date32, Duration(_)) | (Date64, Duration(_)) | (Time32(TimeUnit::Millisecond), Duration(_)) @@ -347,7 +353,10 @@ pub fn can_mul(lhs: &DataType, rhs: &DataType) -> bool { | (UInt64, UInt64) | (Float64, Float64) | (Float32, Float32) - | (Decimal(_, _), Decimal(_, _)) + | ( + Decimal(DecimalType::Int128, _, _), + Decimal(DecimalType::Int128, _, _) + ) ) } diff --git a/src/compute/cast/decimal_to.rs b/src/compute/cast/decimal_to.rs index 219d99cff1c..58a7aa8d892 100644 --- a/src/compute/cast/decimal_to.rs +++ b/src/compute/cast/decimal_to.rs @@ -2,7 +2,10 @@ use num_traits::{AsPrimitive, Float, NumCast}; use crate::error::Result; use crate::types::NativeType; -use crate::{array::*, datatypes::DataType}; +use crate::{ + array::*, + datatypes::{DataType, DecimalType}, +}; #[inline] fn decimal_to_decimal_impl Option>( @@ -27,8 +30,11 @@ fn decimal_to_decimal_impl Option>( }) }) }); - PrimitiveArray::::from_trusted_len_iter(values) - .to(DataType::Decimal(to_precision, to_scale)) + PrimitiveArray::::from_trusted_len_iter(values).to(DataType::Decimal( + DecimalType::Int128, + to_precision, + to_scale, + )) } /// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow @@ -38,7 +44,7 @@ pub fn decimal_to_decimal( to_scale: usize, ) -> PrimitiveArray { let (from_precision, from_scale) = - if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + if let DataType::Decimal(DecimalType::Int128, p, s) = from.data_type().to_logical_type() { (*p, *s) } else { panic!("internal error: i128 is always a decimal") @@ -46,7 +52,11 @@ pub fn decimal_to_decimal( if to_scale == from_scale && to_precision >= from_precision { // fast path - return from.clone().to(DataType::Decimal(to_precision, to_scale)); + return from.clone().to(DataType::Decimal( + DecimalType::Int128, + to_precision, + to_scale, + )); } // todo: other fast paths include increasing scale and precision by so that // a number will never overflow (validity is preserved) @@ -85,11 +95,12 @@ where T: NativeType + Float, f64: AsPrimitive, { - let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { - (*p, *s) - } else { - panic!("internal error: i128 is always a decimal") - }; + let (_, from_scale) = + if let DataType::Decimal(DecimalType::Int128, p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; let div = 10_f64.powi(from_scale as i32); let values = from @@ -115,11 +126,12 @@ pub fn decimal_to_integer(from: &PrimitiveArray) -> PrimitiveArray where T: NativeType + NumCast, { - let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { - (*p, *s) - } else { - panic!("internal error: i128 is always a decimal") - }; + let (_, from_scale) = + if let DataType::Decimal(DecimalType::Int128, p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; let factor = 10_i128.pow(from_scale as u32); let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor))); diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 71e21fa9a29..73085d8db66 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -145,7 +145,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt8, Int64) => true, (UInt8, Float32) => true, (UInt8, Float64) => true, - (UInt8, Decimal(_, _)) => true, + (UInt8, Decimal(DecimalType::Int128, _, _)) => true, (UInt16, UInt8) => true, (UInt16, UInt32) => true, @@ -156,7 +156,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt16, Int64) => true, (UInt16, Float32) => true, (UInt16, Float64) => true, - (UInt16, Decimal(_, _)) => true, + (UInt16, Decimal(DecimalType::Int128, _, _)) => true, (UInt32, UInt8) => true, (UInt32, UInt16) => true, @@ -167,7 +167,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt32, Int64) => true, (UInt32, Float32) => true, (UInt32, Float64) => true, - (UInt32, Decimal(_, _)) => true, + (UInt32, Decimal(DecimalType::Int128, _, _)) => true, (UInt64, UInt8) => true, (UInt64, UInt16) => true, @@ -178,7 +178,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (UInt64, Int64) => true, (UInt64, Float32) => true, (UInt64, Float64) => true, - (UInt64, Decimal(_, _)) => true, + (UInt64, Decimal(DecimalType::Int128, _, _)) => true, (Int8, UInt8) => true, (Int8, UInt16) => true, @@ -189,7 +189,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int8, Int64) => true, (Int8, Float32) => true, (Int8, Float64) => true, - (Int8, Decimal(_, _)) => true, + (Int8, Decimal(DecimalType::Int128, _, _)) => true, (Int16, UInt8) => true, (Int16, UInt16) => true, @@ -200,7 +200,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int16, Int64) => true, (Int16, Float32) => true, (Int16, Float64) => true, - (Int16, Decimal(_, _)) => true, + (Int16, Decimal(DecimalType::Int128, _, _)) => true, (Int32, UInt8) => true, (Int32, UInt16) => true, @@ -211,7 +211,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int32, Int64) => true, (Int32, Float32) => true, (Int32, Float64) => true, - (Int32, Decimal(_, _)) => true, + (Int32, Decimal(DecimalType::Int128, _, _)) => true, (Int64, UInt8) => true, (Int64, UInt16) => true, @@ -222,7 +222,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Int64, Int32) => true, (Int64, Float32) => true, (Int64, Float64) => true, - (Int64, Decimal(_, _)) => true, + (Int64, Decimal(DecimalType::Int128, _, _)) => true, (Float32, UInt8) => true, (Float32, UInt16) => true, @@ -233,7 +233,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Float32, Int32) => true, (Float32, Int64) => true, (Float32, Float64) => true, - (Float32, Decimal(_, _)) => true, + (Float32, Decimal(DecimalType::Int128, _, _)) => true, (Float64, UInt8) => true, (Float64, UInt16) => true, @@ -244,10 +244,10 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Float64, Int32) => true, (Float64, Int64) => true, (Float64, Float32) => true, - (Float64, Decimal(_, _)) => true, + (Float64, Decimal(DecimalType::Int128, _, _)) => true, ( - Decimal(_, _), + Decimal(DecimalType::Int128, _, _), UInt8 | UInt16 | UInt32 @@ -258,7 +258,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { | Int64 | Float32 | Float64 - | Decimal(_, _), + | Decimal(DecimalType::Int128, _, _), ) => true, // end numeric casts @@ -676,7 +676,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (UInt8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (UInt8, Decimal(DecimalType::Int128, p, s)) => integer_to_decimal_dyn::(array, *p, *s), (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), @@ -687,7 +687,9 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (UInt16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (UInt16, Decimal(DecimalType::Int128, p, s)) => { + integer_to_decimal_dyn::(array, *p, *s) + } (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -698,7 +700,9 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (UInt32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (UInt32, Decimal(DecimalType::Int128, p, s)) => { + integer_to_decimal_dyn::(array, *p, *s) + } (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -709,7 +713,9 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (UInt64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (UInt64, Decimal(DecimalType::Int128, p, s)) => { + integer_to_decimal_dyn::(array, *p, *s) + } (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -720,7 +726,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (Int8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (Int8, Decimal(DecimalType::Int128, p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -731,7 +737,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (Int16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (Int16, Decimal(DecimalType::Int128, p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -742,7 +748,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (Int32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (Int32, Decimal(DecimalType::Int128, p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -753,7 +759,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (Int64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + (Int64, Decimal(DecimalType::Int128, p, s)) => integer_to_decimal_dyn::(array, *p, *s), (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -764,7 +770,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), - (Float32, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + (Float32, Decimal(DecimalType::Int128, p, s)) => float_to_decimal_dyn::(array, *p, *s), (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), @@ -775,19 +781,21 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), - (Float64, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), - - (Decimal(_, _), UInt8) => decimal_to_integer_dyn::(array), - (Decimal(_, _), UInt16) => decimal_to_integer_dyn::(array), - (Decimal(_, _), UInt32) => decimal_to_integer_dyn::(array), - (Decimal(_, _), UInt64) => decimal_to_integer_dyn::(array), - (Decimal(_, _), Int8) => decimal_to_integer_dyn::(array), - (Decimal(_, _), Int16) => decimal_to_integer_dyn::(array), - (Decimal(_, _), Int32) => decimal_to_integer_dyn::(array), - (Decimal(_, _), Int64) => decimal_to_integer_dyn::(array), - (Decimal(_, _), Float32) => decimal_to_float_dyn::(array), - (Decimal(_, _), Float64) => decimal_to_float_dyn::(array), - (Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s), + (Float64, Decimal(DecimalType::Int128, p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Decimal(DecimalType::Int128, _, _), UInt8) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), UInt16) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), UInt32) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), UInt64) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Int8) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Int16) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Int32) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Int64) => decimal_to_integer_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Float32) => decimal_to_float_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Float64) => decimal_to_float_dyn::(array), + (Decimal(DecimalType::Int128, _, _), Decimal(DecimalType::Int128, to_p, to_s)) => { + decimal_to_decimal_dyn(array, *to_p, *to_s) + } // end numeric casts // temporal casts diff --git a/src/compute/cast/primitive_to.rs b/src/compute/cast/primitive_to.rs index fea42348c2c..75dc2671e55 100644 --- a/src/compute/cast/primitive_to.rs +++ b/src/compute/cast/primitive_to.rs @@ -9,7 +9,7 @@ use crate::{ array::*, bitmap::Bitmap, compute::arity::unary, - datatypes::{DataType, TimeUnit}, + datatypes::{DataType, DecimalType, TimeUnit}, temporal_conversions::*, types::NativeType, }; @@ -182,8 +182,11 @@ pub fn integer_to_decimal>( }) }); - PrimitiveArray::::from_trusted_len_iter(values) - .to(DataType::Decimal(to_precision, to_scale)) + PrimitiveArray::::from_trusted_len_iter(values).to(DataType::Decimal( + DecimalType::Int128, + to_precision, + to_scale, + )) } pub(super) fn integer_to_decimal_dyn( @@ -227,8 +230,11 @@ where }) }); - PrimitiveArray::::from_trusted_len_iter(values) - .to(DataType::Decimal(to_precision, to_scale)) + PrimitiveArray::::from_trusted_len_iter(values).to(DataType::Decimal( + DecimalType::Int128, + to_precision, + to_scale, + )) } pub(super) fn float_to_decimal_dyn( diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 9ed47c16a79..b31346fcf2f 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -476,7 +476,7 @@ fn can_partial_eq_and_ord(data_type: &DataType) -> bool { | DataType::Float64 | DataType::Utf8 | DataType::LargeUtf8 - | DataType::Decimal(_, _) + | DataType::Decimal(_, _, _) | DataType::Binary | DataType::LargeBinary ) diff --git a/src/compute/take/mod.rs b/src/compute/take/mod.rs index b9ae790d0fd..b739438057c 100644 --- a/src/compute/take/mod.rs +++ b/src/compute/take/mod.rs @@ -127,7 +127,7 @@ pub fn can_take(data_type: &DataType) -> bool { | DataType::Float16 | DataType::Float32 | DataType::Float64 - | DataType::Decimal(_, _) + | DataType::Decimal(_, _, _) | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index e7e2c9b98e8..30953df059a 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -149,11 +149,11 @@ pub enum DataType { /// /// The `bool` value indicates the `Dictionary` is sorted if set to `true`. Dictionary(IntegerType, Box, bool), - /// Decimal value with precision and scale + /// Decimal value with its physical presentation, precision and scale /// precision is the number of digits in the number and /// scale is the number of decimal places. /// The number 999.99 has a precision of 5 and scale of 2. - Decimal(usize, usize), + Decimal(DecimalType, usize, usize), /// Extension type. Extension(String, Box, Option), } @@ -217,6 +217,28 @@ pub enum IntervalUnit { MonthDayNano, } +/// The decimal representations supported by this crate +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum DecimalType { + /// 32 bit integer + Int32, + /// 64 bit integer + Int64, + /// 128 bit integer + Int128, +} + +impl From for PhysicalType { + fn from(width: DecimalType) -> Self { + match width { + DecimalType::Int32 => PhysicalType::Primitive(PrimitiveType::Int32), + DecimalType::Int64 => PhysicalType::Primitive(PrimitiveType::Int64), + DecimalType::Int128 => PhysicalType::Primitive(PrimitiveType::Int128), + } + } +} + impl DataType { /// the [`PhysicalType`] of this [`DataType`]. pub fn to_physical_type(&self) -> PhysicalType { @@ -232,7 +254,7 @@ impl DataType { Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => { PhysicalType::Primitive(PrimitiveType::Int64) } - Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128), + Decimal(type_, _, _) => (*type_).into(), UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8), UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16), UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32), @@ -298,7 +320,7 @@ impl From for DataType { PrimitiveType::UInt16 => DataType::UInt16, PrimitiveType::UInt32 => DataType::UInt32, PrimitiveType::UInt64 => DataType::UInt64, - PrimitiveType::Int128 => DataType::Decimal(32, 32), + PrimitiveType::Int128 => DataType::Decimal(DecimalType::Int128, 32, 32), PrimitiveType::Float32 => DataType::Float32, PrimitiveType::Float64 => DataType::Float64, PrimitiveType::DaysMs => DataType::Interval(IntervalUnit::DayTime), diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index a65d5568b7a..84b14392058 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -2,7 +2,8 @@ use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr} use crate::{ datatypes::{ - DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, UnionMode, + DataType, DecimalType, Extension, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, + UnionMode, }, error::{ArrowError, Result}, }; @@ -310,30 +311,41 @@ unsafe fn to_data_type(schema: &ArrowSchema) -> Result { DataType::FixedSizeList(Box::new(child), size) } else if parts.len() == 2 && parts[0] == "d" { let parts = parts[1].split(',').collect::>(); - if parts.len() < 2 || parts.len() > 3 { + if parts.len() != 2 && parts.len() != 3 { return Err(ArrowError::OutOfSpec( "Decimal must contain 2 or 3 comma-separated values".to_string(), )); }; - if parts.len() == 3 { - let bit_width = parts[0].parse::().map_err(|_| { - ArrowError::OutOfSpec( - "Decimal bit width is not a valid integer".to_string(), - ) - })?; - if bit_width != 128 { - return Err(ArrowError::OutOfSpec( - "Decimal256 is not supported".to_string(), - )); - } - } + let precision = parts[0].parse::().map_err(|_| { ArrowError::OutOfSpec("Decimal precision is not a valid integer".to_string()) })?; let scale = parts[1].parse::().map_err(|_| { ArrowError::OutOfSpec("Decimal scale is not a valid integer".to_string()) })?; - DataType::Decimal(precision, scale) + + let decimal_type = if parts.len() == 2 { + DecimalType::Int128 + } else if parts.len() == 3 { + let bit_width = parts[2].parse::().map_err(|_| { + ArrowError::OutOfSpec("Decimal bitwidth is not a valid integer".to_string()) + })?; + match bit_width { + 32 => DecimalType::Int32, + 64 => DecimalType::Int64, + 128 => DecimalType::Int128, + _ => { + return Err(ArrowError::OutOfSpec( + "Decimal256 is not supported".to_string(), + )) + } + } + } else { + return Err(ArrowError::OutOfSpec( + "Decimal must contain 2 or 3 comma-separated values".to_string(), + )); + }; + DataType::Decimal(decimal_type, precision, scale) } else if !parts.is_empty() && ((parts[0] == "+us") || (parts[0] == "+ud")) { // union let mode = UnionMode::sparse(parts[0] == "+us"); @@ -415,7 +427,11 @@ fn to_format(data_type: &DataType) -> String { tz.as_ref().map(|x| x.as_ref()).unwrap_or("") ) } - DataType::Decimal(precision, scale) => format!("d:{},{}", precision, scale), + DataType::Decimal(type_, precision, scale) => match type_ { + DecimalType::Int32 => format!("d:{},{},{}", precision, scale, 32), + DecimalType::Int64 => format!("d:{},{},{}", precision, scale, 64), + DecimalType::Int128 => format!("d:{},{}", precision, scale), + }, DataType::List(_) => "+l".to_string(), DataType::LargeList(_) => "+L".to_string(), DataType::Struct(_) => "+s".to_string(), diff --git a/src/io/avro/read/schema.rs b/src/io/avro/read/schema.rs index ff72879d78b..c0e75f71bef 100644 --- a/src/io/avro/read/schema.rs +++ b/src/io/avro/read/schema.rs @@ -76,7 +76,7 @@ fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) -> AvroSchema::Bytes(logical) => match logical { Some(logical) => match logical { avro_schema::BytesLogical::Decimal(precision, scale) => { - DataType::Decimal(*precision, *scale) + DataType::Decimal(DecimalType::Int128, *precision, *scale) } }, None => DataType::Binary, @@ -139,7 +139,7 @@ fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) -> AvroSchema::Fixed(Fixed { size, logical, .. }) => match logical { Some(logical) => match logical { avro_schema::FixedLogical::Decimal(precision, scale) => { - DataType::Decimal(*precision, *scale) + DataType::Decimal(DecimalType::Int128, *precision, *scale) } avro_schema::FixedLogical::Duration => { DataType::Interval(IntervalUnit::MonthDayNano) diff --git a/src/io/avro/write/schema.rs b/src/io/avro/write/schema.rs index 23e0949145a..4be27b47fce 100644 --- a/src/io/avro/write/schema.rs +++ b/src/io/avro/write/schema.rs @@ -55,7 +55,9 @@ fn _type_to_schema(data_type: &DataType) -> Result { AvroSchema::Fixed(fixed) } DataType::FixedSizeBinary(size) => AvroSchema::Fixed(Fixed::new("", *size)), - DataType::Decimal(p, s) => AvroSchema::Bytes(Some(BytesLogical::Decimal(*p, *s))), + DataType::Decimal(DecimalType::Int128, p, s) => { + AvroSchema::Bytes(Some(BytesLogical::Decimal(*p, *s))) + } other => { return Err(ArrowError::NotYetImplemented(format!( "write {:?} to avro", diff --git a/src/io/csv/read_utils.rs b/src/io/csv/read_utils.rs index 5ecdc0e8d5f..39c7988c3b4 100644 --- a/src/io/csv/read_utils.rs +++ b/src/io/csv/read_utils.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use chrono::Datelike; +use lexical_core::FromLexical; // Ideally this trait should not be needed and both `csv` and `csv_async` crates would share // the same `ByteRecord` struct. Unfortunately, they do not and thus we must use generics @@ -15,7 +16,7 @@ use crate::{ datatypes::*, error::{ArrowError, Result}, temporal_conversions, - types::NativeType, + types::{Decimal, NativeType}, }; use super::utils::RFC3339; @@ -26,7 +27,7 @@ fn to_utf8(bytes: &[u8]) -> Option<&str> { } #[inline] -fn deserialize_primitive( +fn deserialize_primitive( rows: &[B], column: usize, datatype: DataType, @@ -34,6 +35,7 @@ fn deserialize_primitive( ) -> Arc where T: NativeType + lexical_core::FromLexical, + B: ByteRecordGeneric, F: Fn(&[u8]) -> Option, { let iter = rows.iter().map(|row| match row.get(column) { @@ -56,13 +58,28 @@ fn significant_bytes(bytes: &[u8]) -> usize { /// Deserializes bytes to a single i128 representing a decimal /// The decimal precision and scale are not checked. #[inline] -fn deserialize_decimal(bytes: &[u8], precision: usize, scale: usize) -> Option { +fn deserialize_decimal( + bytes: &[u8], + precision: usize, + scale: usize, +) -> Option { + let ten = T::one() + + T::one() + + T::one() + + T::one() + + T::one() + + T::one() + + T::one() + + T::one() + + T::one() + + T::one(); + let mut a = bytes.split(|x| *x == b'.'); let lhs = a.next(); let rhs = a.next(); match (lhs, rhs) { - (Some(lhs), Some(rhs)) => lexical_core::parse::(lhs).ok().and_then(|x| { - lexical_core::parse::(rhs) + (Some(lhs), Some(rhs)) => lexical_core::parse::(lhs).ok().and_then(|x| { + lexical_core::parse::(rhs) .ok() .map(|y| (x, lhs, y, rhs)) .and_then(|(lhs, lhs_b, rhs, rhs_b)| { @@ -74,19 +91,19 @@ fn deserialize_decimal(bytes: &[u8], precision: usize, scale: usize) -> Option { if rhs.len() != precision || rhs.len() != scale { return None; } - lexical_core::parse::(rhs).ok() + lexical_core::parse::(rhs).ok() } (Some(lhs), None) => { if lhs.len() != precision || scale != 0 { return None; } - lexical_core::parse::(lhs).ok() + lexical_core::parse::(lhs).ok() } (None, None) => None, } @@ -241,9 +258,22 @@ pub(crate) fn deserialize_column( }) }) } - Decimal(precision, scale) => deserialize_primitive(rows, column, datatype, |x| { - deserialize_decimal(x, precision, scale) - }), + Decimal(DecimalType::Int32, precision, scale) => { + deserialize_primitive::(rows, column, datatype, |x| { + deserialize_decimal(x, precision, scale) + }) + } + Decimal(DecimalType::Int64, precision, scale) => { + deserialize_primitive::(rows, column, datatype, |x| { + deserialize_decimal(x, precision, scale) + }) + } + Decimal(DecimalType::Int128, precision, scale) => { + deserialize_primitive::(rows, column, datatype, |x| { + deserialize_decimal(x, precision, scale) + }) + } + Utf8 => deserialize_utf8::(rows, column), LargeUtf8 => deserialize_utf8::(rows, column), Binary => deserialize_binary::(rows, column), diff --git a/src/io/ipc/read/schema.rs b/src/io/ipc/read/schema.rs index bde3deba1e2..f3dfdb8cd33 100644 --- a/src/io/ipc/read/schema.rs +++ b/src/io/ipc/read/schema.rs @@ -2,8 +2,8 @@ use arrow_format::ipc::planus::ReadAsRoot; use crate::{ datatypes::{ - get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema, - TimeUnit, UnionMode, + get_extension, DataType, DecimalType, Extension, Field, IntegerType, IntervalUnit, + Metadata, Schema, TimeUnit, UnionMode, }, error::{ArrowError, Result}, }; @@ -197,8 +197,18 @@ fn get_data_type( (DataType::Duration(time_unit), IpcField::default()) } Decimal(decimal) => { - let data_type = - DataType::Decimal(decimal.precision()? as usize, decimal.scale()? as usize); + let bit_width = decimal.bit_width()?; + let type_ = match bit_width { + 32 => DecimalType::Int32, + 64 => DecimalType::Int64, + 128 => DecimalType::Int128, + _ => return Err(ArrowError::nyi("Decimal 256 not supported")), + }; + let data_type = DataType::Decimal( + type_, + decimal.precision()? as usize, + decimal.scale()? as usize, + ); (data_type, IpcField::default()) } List(_) => { diff --git a/src/io/ipc/write/schema.rs b/src/io/ipc/write/schema.rs index 0636cb9e6c0..d6c296c581b 100644 --- a/src/io/ipc/write/schema.rs +++ b/src/io/ipc/write/schema.rs @@ -1,7 +1,7 @@ use arrow_format::ipc::planus::Builder; use crate::datatypes::{ - DataType, Field, IntegerType, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode, + DataType, DecimalType, Field, IntegerType, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode, }; use crate::io::ipc::endianess::is_native_little_endian; @@ -192,11 +192,18 @@ fn serialize_type(data_type: &DataType) -> arrow_format::ipc::Type { Float64 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { precision: ipc::Precision::Double, })), - Decimal(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { - precision: *precision as i32, - scale: *scale as i32, - bit_width: 128, - })), + Decimal(type_, precision, scale) => { + let bit_width = match type_ { + DecimalType::Int32 => 32, + DecimalType::Int64 => 64, + DecimalType::Int128 => 128, + }; + ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width, + })) + } Binary => ipc::Type::Binary(Box::new(ipc::Binary {})), LargeBinary => ipc::Type::LargeBinary(Box::new(ipc::LargeBinary {})), Utf8 => ipc::Type::Utf8(Box::new(ipc::Utf8 {})), @@ -281,7 +288,7 @@ fn serialize_children(data_type: &DataType, ipc_field: &IpcField) -> Vec vec![], + | Decimal(_, _, _) => vec![], FixedSizeList(inner, _) | LargeList(inner) | List(inner) | Map(inner, _) => { vec![serialize_field(inner, &ipc_field.fields[0])] } diff --git a/src/io/json_integration/read/schema.rs b/src/io/json_integration/read/schema.rs index 96b2e97b8ce..cc7343553ea 100644 --- a/src/io/json_integration/read/schema.rs +++ b/src/io/json_integration/read/schema.rs @@ -2,6 +2,7 @@ use serde_derive::Deserialize; use serde_json::Value; use crate::{ + datatypes::DecimalType, error::{ArrowError, Result}, io::ipc::IpcField, }; @@ -174,20 +175,34 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { "largeutf8" => LargeUtf8, "decimal" => { // return a list with any type as its child isn't defined in the map - let precision = match item.get("precision") { - Some(p) => Ok(p.as_u64().unwrap() as usize), - None => Err(ArrowError::OutOfSpec( - "Expecting a precision for decimal".to_string(), - )), - }; - let scale = match item.get("scale") { - Some(s) => Ok(s.as_u64().unwrap() as usize), - _ => Err(ArrowError::OutOfSpec( - "Expecting a scale for decimal".to_string(), - )), + let precision = item + .get("precision") + .map(|p| p.as_u64().unwrap() as usize) + .ok_or_else(|| { + ArrowError::OutOfSpec("Expecting a precision for decimal".to_string()) + })?; + let scale = item + .get("scale") + .map(|p| p.as_u64().unwrap() as usize) + .ok_or_else(|| { + ArrowError::OutOfSpec("Expecting a scale for decimal".to_string()) + })?; + let bitwidth = item + .get("bitWidth") + .map(|p| p.as_u64().unwrap() as usize) + .unwrap_or(128); + let type_ = match bitwidth { + 32 => DecimalType::Int32, + 64 => DecimalType::Int64, + 128 => DecimalType::Int128, + _ => { + return Err(ArrowError::OutOfSpec( + "Expecting a bitwidth for decimal".to_string(), + )) + } }; - DataType::Decimal(precision?, scale?) + DataType::Decimal(type_, precision, scale) } "floatingpoint" => match item.get("precision") { Some(p) if p == "HALF" => DataType::Float16, diff --git a/src/io/json_integration/write/schema.rs b/src/io/json_integration/write/schema.rs index f4297c3fb0f..3ff15200083 100644 --- a/src/io/json_integration/write/schema.rs +++ b/src/io/json_integration/write/schema.rs @@ -1,6 +1,6 @@ use serde_json::{json, Map, Value}; -use crate::datatypes::{DataType, Field, IntervalUnit, Metadata, Schema, TimeUnit}; +use crate::datatypes::{DataType, DecimalType, Field, IntervalUnit, Metadata, Schema, TimeUnit}; use crate::io::ipc::IpcField; use crate::io::json_integration::ArrowJsonSchema; @@ -86,8 +86,13 @@ fn serialize_data_type(data_type: &DataType) -> Value { TimeUnit::Nanosecond => "NANOSECOND", }}), DataType::Dictionary(_, _, _) => json!({ "name": "dictionary"}), - DataType::Decimal(precision, scale) => { - json!({"name": "decimal", "precision": precision, "scale": scale}) + DataType::Decimal(decimal, precision, scale) => { + let bitwidth = match decimal { + DecimalType::Int32 => 32, + DecimalType::Int64 => 64, + DecimalType::Int128 => 128, + }; + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": bitwidth}) } DataType::Extension(_, inner_data_type, _) => serialize_data_type(inner_data_type), } diff --git a/src/io/odbc/read/schema.rs b/src/io/odbc/read/schema.rs index dba4c233738..2b9aa9b94c0 100644 --- a/src/io/odbc/read/schema.rs +++ b/src/io/odbc/read/schema.rs @@ -1,4 +1,4 @@ -use crate::datatypes::{DataType, Field, TimeUnit}; +use crate::datatypes::{DataType, DecimalType, Field, TimeUnit}; use crate::error::Result; use super::super::api; @@ -41,7 +41,7 @@ fn column_to_data_type(data_type: &api::DataType) -> DataType { | OdbcDataType::Decimal { precision: p @ 0..=38, scale, - } => DataType::Decimal(*p, (*scale) as usize), + } => DataType::Decimal(DecimalType::Int128, *p, (*scale) as usize), OdbcDataType::Integer => DataType::Int32, OdbcDataType::SmallInt => DataType::Int16, OdbcDataType::Real | OdbcDataType::Float { precision: 0..=24 } => DataType::Float32, diff --git a/src/io/parquet/read/deserialize/simple.rs b/src/io/parquet/read/deserialize/simple.rs index 9544f16bff1..73fe5ab1554 100644 --- a/src/io/parquet/read/deserialize/simple.rs +++ b/src/io/parquet/read/deserialize/simple.rs @@ -9,7 +9,7 @@ use parquet2::{ use crate::{ array::{Array, BinaryArray, DictionaryKey, MutablePrimitiveArray, PrimitiveArray, Utf8Array}, - datatypes::{DataType, IntervalUnit, TimeUnit}, + datatypes::{DataType, DecimalType, IntervalUnit, TimeUnit}, error::{ArrowError, Result}, types::NativeType, }; @@ -130,7 +130,7 @@ pub fn page_iter_to_arrays<'a, I: 'a + DataPages>( FixedSizeBinary(_) => dyn_iter(fixed_size_binary::Iter::new(pages, data_type, chunk_size)), - Decimal(_, _) => match physical_type { + Decimal(DecimalType::Int128, _, _) => match physical_type { PhysicalType::Int32 => dyn_iter(iden(primitive::Iter::new( pages, data_type, diff --git a/src/io/parquet/read/schema/convert.rs b/src/io/parquet/read/schema/convert.rs index ae2d66a1b9b..03179a0c9e6 100644 --- a/src/io/parquet/read/schema/convert.rs +++ b/src/io/parquet/read/schema/convert.rs @@ -7,7 +7,7 @@ use parquet2::schema::{ Repetition, }; -use crate::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use crate::datatypes::{DataType, DecimalType, Field, IntervalUnit, TimeUnit}; /// Converts [`ParquetType`]s to a [`Field`], ignoring parquet fields that do not contain /// any physical column. @@ -33,7 +33,7 @@ fn from_int32( _ => DataType::Int32, }, (Some(LogicalType::DECIMAL(t)), _) => { - DataType::Decimal(t.precision as usize, t.scale as usize) + DataType::Decimal(DecimalType::Int128, t.precision as usize, t.scale as usize) } (Some(LogicalType::DATE(_)), _) => DataType::Date32, (Some(LogicalType::TIME(t)), _) => match t.unit { @@ -52,7 +52,7 @@ fn from_int32( (_, Some(PrimitiveConvertedType::Date)) => DataType::Date32, (_, Some(PrimitiveConvertedType::TimeMillis)) => DataType::Time32(TimeUnit::Millisecond), (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { - DataType::Decimal(*precision as usize, *scale as usize) + DataType::Decimal(DecimalType::Int128, *precision as usize, *scale as usize) } (_, _) => DataType::Int32, } @@ -106,7 +106,7 @@ fn from_int64( _ => DataType::Int64, }, (Some(LogicalType::DECIMAL(t)), _) => { - DataType::Decimal(t.precision as usize, t.scale as usize) + DataType::Decimal(DecimalType::Int128, t.precision as usize, t.scale as usize) } // handle converted types: (_, Some(PrimitiveConvertedType::TimeMicros)) => DataType::Time64(TimeUnit::Microsecond), @@ -119,7 +119,7 @@ fn from_int64( (_, Some(PrimitiveConvertedType::Int64)) => DataType::Int64, (_, Some(PrimitiveConvertedType::Uint64)) => DataType::UInt64, (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { - DataType::Decimal(*precision as usize, *scale as usize) + DataType::Decimal(DecimalType::Int128, *precision as usize, *scale as usize) } (_, _) => DataType::Int64, @@ -150,10 +150,10 @@ fn from_fixed_len_byte_array( ) -> DataType { match (logical_type, converted_type) { (Some(LogicalType::DECIMAL(t)), _) => { - DataType::Decimal(t.precision as usize, t.scale as usize) + DataType::Decimal(DecimalType::Int128, t.precision as usize, t.scale as usize) } (None, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { - DataType::Decimal(*precision as usize, *scale as usize) + DataType::Decimal(DecimalType::Int128, *precision as usize, *scale as usize) } (None, Some(PrimitiveConvertedType::Interval)) => { // There is currently no reliable way of determining which IntervalUnit diff --git a/src/io/parquet/read/statistics/fixlen.rs b/src/io/parquet/read/statistics/fixlen.rs index 2cbe8e97751..5d77f9be53e 100644 --- a/src/io/parquet/read/statistics/fixlen.rs +++ b/src/io/parquet/read/statistics/fixlen.rs @@ -2,7 +2,7 @@ use std::any::Any; use std::convert::{TryFrom, TryInto}; use super::primitive::PrimitiveStatistics; -use crate::datatypes::DataType; +use crate::datatypes::{DataType, DecimalType}; use crate::error::{ArrowError, Result}; use parquet2::{ schema::types::PhysicalType, @@ -104,7 +104,9 @@ pub(super) fn statistics_from_fix_len( ) -> Result> { use DataType::*; Ok(match data_type { - Decimal(_, _) => Box::new(PrimitiveStatistics::::try_from((stats, data_type))?), + Decimal(DecimalType::Int128, _, _) => { + Box::new(PrimitiveStatistics::::try_from((stats, data_type))?) + } FixedSizeBinary(_) => Box::new(FixedLenStatistics::from(stats)), other => { return Err(ArrowError::NotYetImplemented(format!( diff --git a/src/io/parquet/read/statistics/primitive.rs b/src/io/parquet/read/statistics/primitive.rs index 91a630692df..7341b400288 100644 --- a/src/io/parquet/read/statistics/primitive.rs +++ b/src/io/parquet/read/statistics/primitive.rs @@ -1,4 +1,4 @@ -use crate::datatypes::TimeUnit; +use crate::datatypes::{DecimalType, TimeUnit}; use crate::{datatypes::DataType, types::NativeType}; use parquet2::schema::types::{ LogicalType, ParquetType, TimeUnit as ParquetTimeUnit, TimestampType, @@ -69,7 +69,9 @@ pub(super) fn statistics_from_i32( UInt32 => Box::new(PrimitiveStatistics::::from((stats, data_type))), Int8 => Box::new(PrimitiveStatistics::::from((stats, data_type))), Int16 => Box::new(PrimitiveStatistics::::from((stats, data_type))), - Decimal(_, _) => Box::new(PrimitiveStatistics::::from((stats, data_type))), + Decimal(DecimalType::Int128, _, _) => { + Box::new(PrimitiveStatistics::::from((stats, data_type))) + } _ => Box::new(PrimitiveStatistics::::from((stats, data_type))), }) } @@ -126,7 +128,9 @@ pub(super) fn statistics_from_i64( .max_value .map(|x| timestamp(stats.descriptor.type_(), time_unit, x)), }), - Decimal(_, _) => Box::new(PrimitiveStatistics::::from((stats, data_type))), + Decimal(DecimalType::Int128, _, _) => { + Box::new(PrimitiveStatistics::::from((stats, data_type))) + } _ => Box::new(PrimitiveStatistics::::from((stats, data_type))), }) } diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 7c31f27fc52..b7d8a9a195c 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -245,7 +245,12 @@ pub fn array_to_page( options, descriptor, ), - DataType::Decimal(precision, _) => { + DataType::Decimal(type_, precision, _) => { + if *type_ != DecimalType::Int128 { + return Err(ArrowError::nyi( + "Only decimal 128 supported to write to parquet", + )); + } let precision = *precision; let array = array .as_any() diff --git a/src/io/parquet/write/schema.rs b/src/io/parquet/write/schema.rs index ea05e7e19cd..d201833a381 100644 --- a/src/io/parquet/write/schema.rs +++ b/src/io/parquet/write/schema.rs @@ -2,15 +2,15 @@ use parquet2::{ metadata::KeyValue, schema::{ types::{ - DecimalType, IntType, LogicalType, ParquetType, PhysicalType, PrimitiveConvertedType, - TimeType, TimeUnit as ParquetTimeUnit, TimestampType, + DecimalType as ParquetDecimalType, IntType, LogicalType, ParquetType, PhysicalType, + PrimitiveConvertedType, TimeType, TimeUnit as ParquetTimeUnit, TimestampType, }, Repetition, }, }; use crate::{ - datatypes::{DataType, Field, Schema, TimeUnit}, + datatypes::{DataType, DecimalType, Field, Schema, TimeUnit}, error::{ArrowError, Result}, io::ipc::write::default_ipc_fields, io::ipc::write::schema_to_bytes, @@ -290,10 +290,15 @@ pub fn to_parquet_type(field: &Field) -> Result { None, None, )?), - DataType::Decimal(precision, scale) => { + DataType::Decimal(type_, precision, scale) => { + if *type_ != DecimalType::Int128 { + return Err(ArrowError::nyi( + "Only decimal 128 implemented to write to parquet", + )); + } let precision = *precision; let scale = *scale; - let logical_type = Some(LogicalType::DECIMAL(DecimalType { + let logical_type = Some(LogicalType::DECIMAL(ParquetDecimalType { scale: scale as i32, precision: precision as i32, })); diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index e2cf20ee4f7..e0ea225883e 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -34,12 +34,104 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { return false; } +<<<<<<< HEAD use PhysicalType::*; match lhs.data_type().to_physical_type() { Null => dyn_eq!(NullScalar, lhs, rhs), Boolean => dyn_eq!(BooleanScalar, lhs, rhs), Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { dyn_eq!(PrimitiveScalar<$T>, lhs, rhs) +======= + match lhs.data_type() { + DataType::Null => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + lhs == rhs + } + DataType::Boolean => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + lhs == rhs + } + DataType::UInt8 => { + dyn_eq!(u8, lhs, rhs) + } + DataType::UInt16 => { + dyn_eq!(u16, lhs, rhs) + } + DataType::UInt32 => { + dyn_eq!(u32, lhs, rhs) + } + DataType::UInt64 => { + dyn_eq!(u64, lhs, rhs) + } + DataType::Int8 => { + dyn_eq!(i8, lhs, rhs) + } + DataType::Int16 => { + dyn_eq!(i16, lhs, rhs) + } + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + dyn_eq!(i32, lhs, rhs) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => { + dyn_eq!(i64, lhs, rhs) + } + DataType::Decimal(_, _, _) => { + dyn_eq!(i128, lhs, rhs) + } + DataType::Interval(IntervalUnit::DayTime) => { + dyn_eq!(days_ms, lhs, rhs) + } + DataType::Float16 => unreachable!(), + DataType::Float32 => { + dyn_eq!(f32, lhs, rhs) + } + DataType::Float64 => { + dyn_eq!(f64, lhs, rhs) + } + DataType::Utf8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::List(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::LargeList(_) => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs + } + DataType::Dictionary(key_type, _, _) => match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + lhs == rhs +>>>>>>> Added support for decimal 32 and 64 }), Utf8 => dyn_eq!(Utf8Scalar, lhs, rhs), LargeUtf8 => dyn_eq!(Utf8Scalar, lhs, rhs), diff --git a/src/types/mod.rs b/src/types/mod.rs index 93e5b0667ce..b2a10b9bf2d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -82,3 +82,30 @@ mod private { impl Sealed for super::days_ms {} impl Sealed for super::months_days_ns {} } + +/// Trait describing the three [`NativeType`]s that can be used as decimal representations +pub trait Decimal: NativeType + num_traits::Signed + num_traits::Pow { + /// The 10 + fn ten() -> Self; +} + +impl Decimal for i32 { + #[inline] + fn ten() -> Self { + 10 + } +} + +impl Decimal for i64 { + #[inline] + fn ten() -> Self { + 10 + } +} + +impl Decimal for i128 { + #[inline] + fn ten() -> Self { + 10 + } +} diff --git a/tests/it/array/primitive/mod.rs b/tests/it/array/primitive/mod.rs index b2080577eae..560fb89c5dc 100644 --- a/tests/it/array/primitive/mod.rs +++ b/tests/it/array/primitive/mod.rs @@ -242,20 +242,28 @@ fn debug_duration_ns() { } #[test] -fn debug_decimal() { - let array = Int128Array::from(&[Some(12345), None, Some(23456)]).to(DataType::Decimal(5, 2)); +fn debug_decimal128() { + let array = Int128Array::from(&[Some(12345), None, Some(23456)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); assert_eq!( format!("{:?}", array), - "Decimal(5, 2)[123.45, None, 234.56]" + "Decimal(Int128, 5, 2)[123.45, None, 234.56]" ); } #[test] -fn debug_decimal1() { - let array = Int128Array::from(&[Some(12345), None, Some(23456)]).to(DataType::Decimal(5, 1)); +fn debug_decimal32() { + let array = Int32Array::from(&[Some(12345), None, Some(23456)]).to(DataType::Decimal( + DecimalType::Int32, + 5, + 1, + )); assert_eq!( format!("{:?}", array), - "Decimal(5, 1)[1234.5, None, 2345.6]" + "Decimal(Int32, 5, 1)[1234.5, None, 2345.6]" ); } diff --git a/tests/it/compute/arithmetics/decimal/add.rs b/tests/it/compute/arithmetics/decimal/add.rs index 45af77b1519..e3022b25cd8 100644 --- a/tests/it/compute/arithmetics/decimal/add.rs +++ b/tests/it/compute/arithmetics/decimal/add.rs @@ -3,19 +3,19 @@ use arrow2::array::*; use arrow2::compute::arithmetics::decimal::{adaptive_add, add, checked_add, saturating_add}; use arrow2::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; -use arrow2::datatypes::DataType; +use arrow2::datatypes::{DataType, DecimalType}; #[test] fn test_add_normal() { let a = PrimitiveArray::from([Some(11111i128), Some(11100i128), None, Some(22200i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([Some(22222i128), Some(22200i128), None, Some(11100i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = add(&a, &b); let expected = PrimitiveArray::from([Some(33333i128), Some(33300i128), None, Some(33300i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -27,30 +27,31 @@ fn test_add_normal() { #[test] #[should_panic] fn test_add_decimal_wrong_precision() { - let a = PrimitiveArray::from([None]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([None]).to(DataType::Decimal(6, 2)); + let a = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 6, 2)); add(&a, &b); } #[test] #[should_panic(expected = "Overflow in addition presented for precision 5")] fn test_add_panic() { - let a = PrimitiveArray::from([Some(99999i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(1i128)]).to(DataType::Decimal(5, 2)); + let a = + PrimitiveArray::from([Some(99999i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([Some(1i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); let _ = add(&a, &b); } #[test] fn test_add_saturating() { let a = PrimitiveArray::from([Some(11111i128), Some(11100i128), None, Some(22200i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([Some(22222i128), Some(22200i128), None, Some(11100i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = saturating_add(&a, &b); let expected = PrimitiveArray::from([Some(33333i128), Some(33300i128), None, Some(33300i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -67,14 +68,14 @@ fn test_add_saturating_overflow() { Some(99999i128), Some(-99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([ Some(00001i128), Some(00100i128), Some(10000i128), Some(-99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = saturating_add(&a, &b); @@ -84,7 +85,7 @@ fn test_add_saturating_overflow() { Some(99999i128), Some(-99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -96,14 +97,14 @@ fn test_add_saturating_overflow() { #[test] fn test_add_checked() { let a = PrimitiveArray::from([Some(11111i128), Some(11100i128), None, Some(22200i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([Some(22222i128), Some(22200i128), None, Some(11100i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = checked_add(&a, &b); let expected = PrimitiveArray::from([Some(33333i128), Some(33300i128), None, Some(33300i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -114,10 +115,19 @@ fn test_add_checked() { #[test] fn test_add_checked_overflow() { - let a = PrimitiveArray::from([Some(1i128), Some(99999i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(1i128), Some(1i128)]).to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(1i128), Some(99999i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); + let b = PrimitiveArray::from([Some(1i128), Some(1i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); let result = checked_add(&a, &b); - let expected = PrimitiveArray::from([Some(2i128), None]).to(DataType::Decimal(5, 2)); + let expected = + PrimitiveArray::from([Some(2i128), None]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); // Testing trait @@ -131,51 +141,86 @@ fn test_add_adaptive() { // 11111.11 -> 7, 2 // ----------------- // 11122.2211 -> 9, 4 - let a = PrimitiveArray::from([Some(11_1111i128)]).to(DataType::Decimal(6, 4)); - let b = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); + let a = + PrimitiveArray::from([Some(11_1111i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); + let b = + PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); let result = adaptive_add(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(11122_2211i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(11122_2211i128)]).to(DataType::Decimal( + DecimalType::Int128, + 9, + 4, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 4) + ); // 0.1111 -> 5, 4 // 11111.0 -> 6, 1 // ----------------- // 11111.1111 -> 9, 4 - let a = PrimitiveArray::from([Some(1111i128)]).to(DataType::Decimal(5, 4)); - let b = PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(6, 1)); + let a = PrimitiveArray::from([Some(1111i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 4)); + let b = + PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 1)); let result = adaptive_add(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(11111_1111i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(11111_1111i128)]).to(DataType::Decimal( + DecimalType::Int128, + 9, + 4, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 4) + ); // 11111.11 -> 7, 2 // 11111.111 -> 8, 3 // ----------------- // 22222.221 -> 8, 3 - let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); - let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); + let a = + PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); + let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal( + DecimalType::Int128, + 8, + 3, + )); let result = adaptive_add(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3)); + let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal( + DecimalType::Int128, + 8, + 3, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(8, 3)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 8, 3) + ); // 99.9999 -> 6, 4 // 00.0001 -> 6, 4 // ----------------- // 100.0000 -> 7, 4 - let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); - let b = PrimitiveArray::from([Some(00_0001i128)]).to(DataType::Decimal(6, 4)); + let a = + PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); + let b = + PrimitiveArray::from([Some(00_0001i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); let result = adaptive_add(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); + let expected = + PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 4)); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(7, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 7, 4) + ); } diff --git a/tests/it/compute/arithmetics/decimal/div.rs b/tests/it/compute/arithmetics/decimal/div.rs index 55cca8d303d..007ca92c5bf 100644 --- a/tests/it/compute/arithmetics/decimal/div.rs +++ b/tests/it/compute/arithmetics/decimal/div.rs @@ -3,7 +3,7 @@ use arrow2::array::*; use arrow2::compute::arithmetics::decimal::{adaptive_div, checked_div, div, saturating_div}; use arrow2::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; -use arrow2::datatypes::DataType; +use arrow2::datatypes::{DataType, DecimalType}; #[test] fn test_divide_normal() { @@ -19,7 +19,7 @@ fn test_divide_normal() { Some(30_000i128), Some(123_456i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); let b = PrimitiveArray::from([ Some(123_456i128), @@ -29,7 +29,7 @@ fn test_divide_normal() { Some(4_000i128), Some(654_321i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); let result = div(&a, &b); let expected = PrimitiveArray::from([ @@ -40,7 +40,7 @@ fn test_divide_normal() { Some(7_500i128), Some(0_188i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); assert_eq!(result, expected); @@ -52,16 +52,18 @@ fn test_divide_normal() { #[test] #[should_panic] fn test_divide_decimal_wrong_precision() { - let a = PrimitiveArray::from([None]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([None]).to(DataType::Decimal(6, 2)); + let a = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 6, 2)); div(&a, &b); } #[test] #[should_panic(expected = "Overflow in multiplication presented for precision 5")] fn test_divide_panic() { - let a = PrimitiveArray::from([Some(99999i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(000_01i128)]).to(DataType::Decimal(5, 2)); + let a = + PrimitiveArray::from([Some(99999i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = + PrimitiveArray::from([Some(000_01i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); div(&a, &b); } @@ -75,7 +77,7 @@ fn test_divide_saturating() { Some(30_000i128), Some(123_456i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); let b = PrimitiveArray::from([ Some(123_456i128), @@ -85,7 +87,7 @@ fn test_divide_saturating() { Some(4_000i128), Some(654_321i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); let result = saturating_div(&a, &b); let expected = PrimitiveArray::from([ @@ -96,7 +98,7 @@ fn test_divide_saturating() { Some(7_500i128), Some(0_188i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); assert_eq!(result, expected); } @@ -110,7 +112,7 @@ fn test_divide_saturating_overflow() { Some(99999i128), Some(99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([ Some(-00001i128), Some(00001i128), @@ -118,7 +120,7 @@ fn test_divide_saturating_overflow() { Some(-00020i128), Some(00000i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = saturating_div(&a, &b); @@ -129,7 +131,7 @@ fn test_divide_saturating_overflow() { Some(-99999i128), Some(00000i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); } @@ -144,7 +146,7 @@ fn test_divide_checked() { Some(30_000i128), Some(123_456i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); let b = PrimitiveArray::from([ Some(123_456i128), @@ -154,7 +156,7 @@ fn test_divide_checked() { Some(4_000i128), Some(654_321i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); let result = div(&a, &b); let expected = PrimitiveArray::from([ @@ -165,7 +167,7 @@ fn test_divide_checked() { Some(7_500i128), Some(0_188i128), ]) - .to(DataType::Decimal(7, 3)); + .to(DataType::Decimal(DecimalType::Int128, 7, 3)); assert_eq!(result, expected); } @@ -173,12 +175,19 @@ fn test_divide_checked() { #[test] fn test_divide_checked_overflow() { let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]) - .to(DataType::Decimal(5, 2)); - let b = - PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); let result = checked_div(&a, &b); - let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(5, 2)); + let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); assert_eq!(result, expected); @@ -193,38 +202,59 @@ fn test_divide_adaptive() { // 10.0000 -> 6, 4 // ----------------- // 100.0000 -> 9, 4 - let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); - let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); + let a = + PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); + let b = + PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); let result = adaptive_div(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(9, 4)); + let expected = + PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 9, 4)); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 4) + ); // 11111.0 -> 6, 1 // 10.002 -> 5, 3 // ----------------- // 1110.877 -> 8, 3 - let a = PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(6, 1)); - let b = PrimitiveArray::from([Some(10_002i128)]).to(DataType::Decimal(5, 3)); + let a = + PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 1)); + let b = + PrimitiveArray::from([Some(10_002i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 3)); let result = adaptive_div(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(1110_877i128)]).to(DataType::Decimal(8, 3)); + let expected = + PrimitiveArray::from([Some(1110_877i128)]).to(DataType::Decimal(DecimalType::Int128, 8, 3)); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(8, 3)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 8, 3) + ); // 12345.67 -> 7, 2 // 12345.678 -> 8, 3 // ----------------- // 0.999 -> 8, 3 - let a = PrimitiveArray::from([Some(12345_67i128)]).to(DataType::Decimal(7, 2)); - let b = PrimitiveArray::from([Some(12345_678i128)]).to(DataType::Decimal(8, 3)); + let a = + PrimitiveArray::from([Some(12345_67i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); + let b = PrimitiveArray::from([Some(12345_678i128)]).to(DataType::Decimal( + DecimalType::Int128, + 8, + 3, + )); let result = adaptive_div(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(0_999i128)]).to(DataType::Decimal(8, 3)); + let expected = + PrimitiveArray::from([Some(0_999i128)]).to(DataType::Decimal(DecimalType::Int128, 8, 3)); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(8, 3)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 8, 3) + ); } diff --git a/tests/it/compute/arithmetics/decimal/mul.rs b/tests/it/compute/arithmetics/decimal/mul.rs index a4b4a71b257..3b3a087182f 100644 --- a/tests/it/compute/arithmetics/decimal/mul.rs +++ b/tests/it/compute/arithmetics/decimal/mul.rs @@ -3,7 +3,7 @@ use arrow2::array::*; use arrow2::compute::arithmetics::decimal::{adaptive_mul, checked_mul, mul, saturating_mul}; use arrow2::compute::arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul}; -use arrow2::datatypes::DataType; +use arrow2::datatypes::{DataType, DecimalType}; #[test] fn test_multiply_normal() { @@ -19,7 +19,7 @@ fn test_multiply_normal() { Some(30_00i128), Some(123_45i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); let b = PrimitiveArray::from([ Some(222_22i128), @@ -29,7 +29,7 @@ fn test_multiply_normal() { Some(4_00i128), Some(543_21i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); let result = mul(&a, &b); let expected = PrimitiveArray::from([ @@ -40,7 +40,7 @@ fn test_multiply_normal() { Some(120_00i128), Some(67059_27i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); assert_eq!(result, expected); @@ -52,16 +52,18 @@ fn test_multiply_normal() { #[test] #[should_panic] fn test_multiply_decimal_wrong_precision() { - let a = PrimitiveArray::from([None]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([None]).to(DataType::Decimal(6, 2)); + let a = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 6, 2)); mul(&a, &b); } #[test] #[should_panic(expected = "Overflow in multiplication presented for precision 5")] fn test_multiply_panic() { - let a = PrimitiveArray::from([Some(99999i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(100_00i128)]).to(DataType::Decimal(5, 2)); + let a = + PrimitiveArray::from([Some(99999i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = + PrimitiveArray::from([Some(100_00i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); let _ = mul(&a, &b); } @@ -75,7 +77,7 @@ fn test_multiply_saturating() { Some(30_00i128), Some(123_45i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); let b = PrimitiveArray::from([ Some(222_22i128), @@ -85,7 +87,7 @@ fn test_multiply_saturating() { Some(4_00i128), Some(543_21i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); let result = saturating_mul(&a, &b); let expected = PrimitiveArray::from([ @@ -96,7 +98,7 @@ fn test_multiply_saturating() { Some(120_00i128), Some(67059_27i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); assert_eq!(result, expected); @@ -113,14 +115,14 @@ fn test_multiply_saturating_overflow() { Some(99999i128), Some(99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([ Some(-00100i128), Some(01000i128), Some(10000i128), Some(-99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = saturating_mul(&a, &b); @@ -130,7 +132,7 @@ fn test_multiply_saturating_overflow() { Some(99999i128), Some(-99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -149,7 +151,7 @@ fn test_multiply_checked() { Some(30_00i128), Some(123_45i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); let b = PrimitiveArray::from([ Some(222_22i128), @@ -159,7 +161,7 @@ fn test_multiply_checked() { Some(4_00i128), Some(543_21i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); let result = checked_mul(&a, &b); let expected = PrimitiveArray::from([ @@ -170,7 +172,7 @@ fn test_multiply_checked() { Some(120_00i128), Some(67059_27i128), ]) - .to(DataType::Decimal(7, 2)); + .to(DataType::Decimal(DecimalType::Int128, 7, 2)); assert_eq!(result, expected); @@ -181,10 +183,22 @@ fn test_multiply_checked() { #[test] fn test_multiply_checked_overflow() { - let a = PrimitiveArray::from([Some(99999i128), Some(1_00i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(10000i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(99999i128), Some(1_00i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); + let b = PrimitiveArray::from([Some(10000i128), Some(2_00i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); let result = checked_mul(&a, &b); - let expected = PrimitiveArray::from([None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); + let expected = PrimitiveArray::from([None, Some(2_00i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); assert_eq!(result, expected); } @@ -195,38 +209,68 @@ fn test_multiply_adaptive() { // 10.0000 -> 6, 4 // ----------------- // 10000.0000 -> 9, 4 - let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); - let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); + let a = + PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); + let b = + PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); let result = adaptive_mul(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(10000_0000i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(10000_0000i128)]).to(DataType::Decimal( + DecimalType::Int128, + 9, + 4, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 4) + ); // 11111.0 -> 6, 1 // 10.002 -> 5, 3 // ----------------- // 111132.222 -> 9, 3 - let a = PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(6, 1)); - let b = PrimitiveArray::from([Some(10_002i128)]).to(DataType::Decimal(5, 3)); + let a = + PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 1)); + let b = + PrimitiveArray::from([Some(10_002i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 3)); let result = adaptive_mul(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(111132_222i128)]).to(DataType::Decimal(9, 3)); + let expected = PrimitiveArray::from([Some(111132_222i128)]).to(DataType::Decimal( + DecimalType::Int128, + 9, + 3, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 3)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 3) + ); // 12345.67 -> 7, 2 // 12345.678 -> 8, 3 // ----------------- // 152415666.514 -> 11, 3 - let a = PrimitiveArray::from([Some(12345_67i128)]).to(DataType::Decimal(7, 2)); - let b = PrimitiveArray::from([Some(12345_678i128)]).to(DataType::Decimal(8, 3)); + let a = + PrimitiveArray::from([Some(12345_67i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); + let b = PrimitiveArray::from([Some(12345_678i128)]).to(DataType::Decimal( + DecimalType::Int128, + 8, + 3, + )); let result = adaptive_mul(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(152415666_514i128)]).to(DataType::Decimal(12, 3)); + let expected = PrimitiveArray::from([Some(152415666_514i128)]).to(DataType::Decimal( + DecimalType::Int128, + 12, + 3, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(12, 3)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 12, 3) + ); } diff --git a/tests/it/compute/arithmetics/decimal/sub.rs b/tests/it/compute/arithmetics/decimal/sub.rs index 343149a5646..38a234bf59a 100644 --- a/tests/it/compute/arithmetics/decimal/sub.rs +++ b/tests/it/compute/arithmetics/decimal/sub.rs @@ -3,19 +3,19 @@ use arrow2::array::*; use arrow2::compute::arithmetics::decimal::{adaptive_sub, checked_sub, saturating_sub, sub}; use arrow2::compute::arithmetics::{ArrayCheckedSub, ArraySaturatingSub, ArraySub}; -use arrow2::datatypes::DataType; +use arrow2::datatypes::{DataType, DecimalType}; #[test] fn test_subtract_normal() { let a = PrimitiveArray::from([Some(11111i128), Some(22200i128), None, Some(40000i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([Some(22222i128), Some(11100i128), None, Some(11100i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = sub(&a, &b); let expected = PrimitiveArray::from([Some(-11111i128), Some(11100i128), None, Some(28900i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -27,30 +27,31 @@ fn test_subtract_normal() { #[test] #[should_panic] fn test_subtract_decimal_wrong_precision() { - let a = PrimitiveArray::from([None]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([None]).to(DataType::Decimal(6, 2)); + let a = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([None]).to(DataType::Decimal(DecimalType::Int128, 6, 2)); sub(&a, &b); } #[test] #[should_panic(expected = "Overflow in subtract presented for precision 5")] fn test_subtract_panic() { - let a = PrimitiveArray::from([Some(-99999i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(1i128)]).to(DataType::Decimal(5, 2)); + let a = + PrimitiveArray::from([Some(-99999i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); + let b = PrimitiveArray::from([Some(1i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); let _ = sub(&a, &b); } #[test] fn test_subtract_saturating() { let a = PrimitiveArray::from([Some(11111i128), Some(22200i128), None, Some(40000i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([Some(22222i128), Some(11100i128), None, Some(11100i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = saturating_sub(&a, &b); let expected = PrimitiveArray::from([Some(-11111i128), Some(11100i128), None, Some(28900i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -67,14 +68,14 @@ fn test_subtract_saturating_overflow() { Some(-99999i128), Some(99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([ Some(00001i128), Some(00100i128), Some(10000i128), Some(-99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = saturating_sub(&a, &b); @@ -84,7 +85,7 @@ fn test_subtract_saturating_overflow() { Some(-99999i128), Some(99999i128), ]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -96,14 +97,14 @@ fn test_subtract_saturating_overflow() { #[test] fn test_subtract_checked() { let a = PrimitiveArray::from([Some(11111i128), Some(22200i128), None, Some(40000i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let b = PrimitiveArray::from([Some(22222i128), Some(11100i128), None, Some(11100i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); let result = checked_sub(&a, &b); let expected = PrimitiveArray::from([Some(-11111i128), Some(11100i128), None, Some(28900i128)]) - .to(DataType::Decimal(5, 2)); + .to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); @@ -114,10 +115,19 @@ fn test_subtract_checked() { #[test] fn test_subtract_checked_overflow() { - let a = PrimitiveArray::from([Some(4i128), Some(-99999i128)]).to(DataType::Decimal(5, 2)); - let b = PrimitiveArray::from([Some(2i128), Some(1i128)]).to(DataType::Decimal(5, 2)); + let a = PrimitiveArray::from([Some(4i128), Some(-99999i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); + let b = PrimitiveArray::from([Some(2i128), Some(1i128)]).to(DataType::Decimal( + DecimalType::Int128, + 5, + 2, + )); let result = checked_sub(&a, &b); - let expected = PrimitiveArray::from([Some(2i128), None]).to(DataType::Decimal(5, 2)); + let expected = + PrimitiveArray::from([Some(2i128), None]).to(DataType::Decimal(DecimalType::Int128, 5, 2)); assert_eq!(result, expected); } @@ -127,51 +137,86 @@ fn test_subtract_adaptive() { // 11111.11 -> 7, 2 // ------------------ // -11099.9989 -> 9, 4 - let a = PrimitiveArray::from([Some(11_1111i128)]).to(DataType::Decimal(6, 4)); - let b = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); + let a = + PrimitiveArray::from([Some(11_1111i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); + let b = + PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); let result = adaptive_sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(-11099_9989i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(-11099_9989i128)]).to(DataType::Decimal( + DecimalType::Int128, + 9, + 4, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 4) + ); // 11111.0 -> 6, 1 // 0.1111 -> 5, 4 // ----------------- // 11110.8889 -> 9, 4 - let a = PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(6, 1)); - let b = PrimitiveArray::from([Some(1111i128)]).to(DataType::Decimal(5, 4)); + let a = + PrimitiveArray::from([Some(11111_0i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 1)); + let b = PrimitiveArray::from([Some(1111i128)]).to(DataType::Decimal(DecimalType::Int128, 5, 4)); let result = adaptive_sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(11110_8889i128)]).to(DataType::Decimal(9, 4)); + let expected = PrimitiveArray::from([Some(11110_8889i128)]).to(DataType::Decimal( + DecimalType::Int128, + 9, + 4, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(9, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 9, 4) + ); // 11111.11 -> 7, 2 // 11111.111 -> 8, 3 // ----------------- // -00000.001 -> 8, 3 - let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); - let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); + let a = + PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 2)); + let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal( + DecimalType::Int128, + 8, + 3, + )); let result = adaptive_sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(-00000_001i128)]).to(DataType::Decimal(8, 3)); + let expected = PrimitiveArray::from([Some(-00000_001i128)]).to(DataType::Decimal( + DecimalType::Int128, + 8, + 3, + )); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(8, 3)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 8, 3) + ); // 99.9999 -> 6, 4 // -00.0001 -> 6, 4 // ----------------- // 100.0000 -> 7, 4 - let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); - let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(6, 4)); + let a = + PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); + let b = + PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(DecimalType::Int128, 6, 4)); let result = adaptive_sub(&a, &b).unwrap(); - let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); + let expected = + PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(DecimalType::Int128, 7, 4)); assert_eq!(result, expected); - assert_eq!(result.data_type(), &DataType::Decimal(7, 4)); + assert_eq!( + result.data_type(), + &DataType::Decimal(DecimalType::Int128, 7, 4) + ); } diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index fedc7a47dc6..e4c529f8b31 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -243,11 +243,16 @@ fn int32_to_decimal() { // 10 and -10 can be represented with precision 1 and scale 0 let array = Int32Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]); - let b = cast(&array, &DataType::Decimal(1, 0), CastOptions::default()).unwrap(); + let b = cast( + &array, + &DataType::Decimal(DecimalType::Int128, 1, 0), + CastOptions::default(), + ) + .unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); let expected = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) - .to(DataType::Decimal(1, 0)); + .to(DataType::Decimal(DecimalType::Int128, 1, 0)); assert_eq!(c, &expected) } @@ -263,7 +268,12 @@ fn float32_to_decimal() { None, ]); - let b = cast(&array, &DataType::Decimal(10, 2), CastOptions::default()).unwrap(); + let b = cast( + &array, + &DataType::Decimal(DecimalType::Int128, 10, 2), + CastOptions::default(), + ) + .unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); let expected = Int128Array::from(&[ @@ -275,7 +285,7 @@ fn float32_to_decimal() { Some(-10001), None, ]) - .to(DataType::Decimal(10, 2)); + .to(DataType::Decimal(DecimalType::Int128, 10, 2)); assert_eq!(c, &expected) } @@ -284,11 +294,16 @@ fn int32_to_decimal_scaled() { // 10 and -10 can't be represented with precision 1 and scale 1 let array = Int32Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]); - let b = cast(&array, &DataType::Decimal(1, 1), CastOptions::default()).unwrap(); + let b = cast( + &array, + &DataType::Decimal(DecimalType::Int128, 1, 1), + CastOptions::default(), + ) + .unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); - let expected = - Int128Array::from(&[Some(20), None, Some(-20), None, None]).to(DataType::Decimal(1, 1)); + let expected = Int128Array::from(&[Some(20), None, Some(-20), None, None]) + .to(DataType::Decimal(DecimalType::Int128, 1, 1)); assert_eq!(c, &expected) } @@ -296,13 +311,18 @@ fn int32_to_decimal_scaled() { fn decimal_to_decimal() { // increase scale and precision let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) - .to(DataType::Decimal(1, 0)); + .to(DataType::Decimal(DecimalType::Int128, 1, 0)); - let b = cast(&array, &DataType::Decimal(2, 1), CastOptions::default()).unwrap(); + let b = cast( + &array, + &DataType::Decimal(DecimalType::Int128, 2, 1), + CastOptions::default(), + ) + .unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); let expected = Int128Array::from(&[Some(20), Some(100), Some(-20), Some(-100), None]) - .to(DataType::Decimal(2, 1)); + .to(DataType::Decimal(DecimalType::Int128, 2, 1)); assert_eq!(c, &expected) } @@ -311,13 +331,18 @@ fn decimal_to_decimal_scaled() { // decrease precision // 10 and -10 can't be represented with precision 1 and scale 1 let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) - .to(DataType::Decimal(1, 0)); + .to(DataType::Decimal(DecimalType::Int128, 1, 0)); - let b = cast(&array, &DataType::Decimal(1, 1), CastOptions::default()).unwrap(); + let b = cast( + &array, + &DataType::Decimal(DecimalType::Int128, 1, 1), + CastOptions::default(), + ) + .unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); - let expected = - Int128Array::from(&[Some(20), None, Some(-20), None, None]).to(DataType::Decimal(1, 1)); + let expected = Int128Array::from(&[Some(20), None, Some(-20), None, None]) + .to(DataType::Decimal(DecimalType::Int128, 1, 1)); assert_eq!(c, &expected) } @@ -326,20 +351,25 @@ fn decimal_to_decimal_fast() { // increase precision // 10 and -10 can't be represented with precision 1 and scale 1 let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) - .to(DataType::Decimal(1, 1)); + .to(DataType::Decimal(DecimalType::Int128, 1, 1)); - let b = cast(&array, &DataType::Decimal(2, 1), CastOptions::default()).unwrap(); + let b = cast( + &array, + &DataType::Decimal(DecimalType::Int128, 2, 1), + CastOptions::default(), + ) + .unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); let expected = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) - .to(DataType::Decimal(2, 1)); + .to(DataType::Decimal(DecimalType::Int128, 2, 1)); assert_eq!(c, &expected) } #[test] fn decimal_to_float() { let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None]) - .to(DataType::Decimal(2, 1)); + .to(DataType::Decimal(DecimalType::Int128, 2, 1)); let b = cast(&array, &DataType::Float32, CastOptions::default()).unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); @@ -351,7 +381,7 @@ fn decimal_to_float() { #[test] fn decimal_to_integer() { let array = Int128Array::from(&[Some(2), Some(10), Some(-2), Some(-10), None, Some(2560)]) - .to(DataType::Decimal(2, 1)); + .to(DataType::Decimal(DecimalType::Int128, 2, 1)); let b = cast(&array, &DataType::Int8, CastOptions::default()).unwrap(); let c = b.as_any().downcast_ref::>().unwrap(); @@ -458,8 +488,8 @@ fn consistency() { Date32, Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond), - Decimal(1, 2), - Decimal(2, 2), + Decimal(DecimalType::Int128, 1, 2), + Decimal(DecimalType::Int128, 2, 2), Date64, Utf8, LargeUtf8, diff --git a/tests/it/ffi/data.rs b/tests/it/ffi/data.rs index 45c2f66cc1a..a9dfa808815 100644 --- a/tests/it/ffi/data.rs +++ b/tests/it/ffi/data.rs @@ -1,6 +1,6 @@ use arrow2::array::*; use arrow2::bitmap::Bitmap; -use arrow2::datatypes::{DataType, Field, TimeUnit}; +use arrow2::datatypes::{DataType, DecimalType, Field, TimeUnit}; use arrow2::{error::Result, ffi}; use std::collections::BTreeMap; use std::sync::Arc; @@ -87,6 +87,32 @@ fn large_utf8() -> Result<()> { test_round_trip(data) } +#[test] +fn decimal128() -> Result<()> { + let data = Int128Array::from(&[Some(2), None, Some(1), None]); + test_round_trip(data) +} + +#[test] +fn decimal64() -> Result<()> { + let data = Int64Array::from(&[Some(2), None, Some(1), None]).to(DataType::Decimal( + DecimalType::Int64, + 2, + 2, + )); + test_round_trip(data) +} + +#[test] +fn decimal32() -> Result<()> { + let data = Int32Array::from(&[Some(2), None, Some(1), None]).to(DataType::Decimal( + DecimalType::Int32, + 2, + 2, + )); + test_round_trip(data) +} + #[test] fn binary() -> Result<()> { let data = diff --git a/tests/it/io/avro/read.rs b/tests/it/io/avro/read.rs index 5efd42518b6..479d28861ac 100644 --- a/tests/it/io/avro/read.rs +++ b/tests/it/io/avro/read.rs @@ -77,7 +77,11 @@ pub(super) fn schema() -> (AvroSchema, Schema) { DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::Utf8), false), false, ), - Field::new("decimal", DataType::Decimal(18, 5), false), + Field::new( + "decimal", + DataType::Decimal(DecimalType::Int128, 18, 5), + false, + ), ]); (AvroSchema::parse_str(raw_schema).unwrap(), schema) @@ -113,7 +117,7 @@ pub(super) fn data() -> Chunk> { )), Arc::new( PrimitiveArray::::from_slice([12345678i128, -12345678i128]) - .to(DataType::Decimal(18, 5)), + .to(DataType::Decimal(DecimalType::Int128, 18, 5)), ), ]; diff --git a/tests/it/io/csv/read.rs b/tests/it/io/csv/read.rs index 5bf193c1496..cd674386257 100644 --- a/tests/it/io/csv/read.rs +++ b/tests/it/io/csv/read.rs @@ -135,27 +135,45 @@ fn date64() -> Result<()> { #[test] fn decimal() -> Result<()> { - let result = test_deserialize("1.1,\n1.2,\n1.22,\n1.3,\n", DataType::Decimal(2, 1))?; - let expected = - Int128Array::from(&[Some(11), Some(12), None, Some(13)]).to(DataType::Decimal(2, 1)); + let result = test_deserialize( + "1.1,\n1.2,\n1.22,\n1.3,\n", + DataType::Decimal(DecimalType::Int128, 2, 1), + )?; + let expected = Int128Array::from(&[Some(11), Some(12), None, Some(13)]).to(DataType::Decimal( + DecimalType::Int128, + 2, + 1, + )); assert_eq!(expected, result.as_ref()); Ok(()) } #[test] fn decimal_only_scale() -> Result<()> { - let result = test_deserialize("0.01,\n0.12,\n0.222,\n0.13,\n", DataType::Decimal(2, 2))?; - let expected = - Int128Array::from(&[Some(1), Some(12), None, Some(13)]).to(DataType::Decimal(2, 2)); + let result = test_deserialize( + "0.01,\n0.12,\n0.222,\n0.13,\n", + DataType::Decimal(DecimalType::Int128, 2, 2), + )?; + let expected = Int128Array::from(&[Some(1), Some(12), None, Some(13)]).to(DataType::Decimal( + DecimalType::Int128, + 2, + 2, + )); assert_eq!(expected, result.as_ref()); Ok(()) } #[test] fn decimal_only_integer() -> Result<()> { - let result = test_deserialize("1,\n1.0,\n1.1,\n10.0,\n", DataType::Decimal(1, 0))?; - let expected = - Int128Array::from(&[Some(1), Some(1), None, Some(10)]).to(DataType::Decimal(1, 0)); + let result = test_deserialize( + "1,\n1.0,\n1.1,\n10.0,\n", + DataType::Decimal(DecimalType::Int128, 1, 0), + )?; + let expected = Int128Array::from(&[Some(1), Some(1), None, Some(10)]).to(DataType::Decimal( + DecimalType::Int128, + 1, + 0, + )); assert_eq!(expected, result.as_ref()); Ok(()) } diff --git a/tests/it/io/ipc/write/file.rs b/tests/it/io/ipc/write/file.rs index d466c8291bc..90df093de57 100644 --- a/tests/it/io/ipc/write/file.rs +++ b/tests/it/io/ipc/write/file.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow2::array::*; use arrow2::chunk::Chunk; -use arrow2::datatypes::{Field, Schema}; +use arrow2::datatypes::{DataType, DecimalType, Field, Schema}; use arrow2::error::Result; use arrow2::io::ipc::read::{read_file_metadata, FileReader}; use arrow2::io::ipc::{write::*, IpcField}; @@ -365,3 +365,18 @@ fn write_sliced_list() -> Result<()> { let columns = Chunk::try_new(vec![array])?; round_trip(columns, schema, None, None) } + +#[test] +fn write_decimali32() -> Result<()> { + use std::sync::Arc; + let array = Arc::new( + Int32Array::from([Some(1), Some(2), None, Some(4)]).to(DataType::Decimal( + DecimalType::Int32, + 2, + 2, + )), + ) as Arc; + let schema = Schema::from(vec![Field::new("a", array.data_type().clone(), true)]); + let columns = Chunk::try_new(vec![array])?; + round_trip(columns, schema, None, None) +} diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 5e1b59c7488..11b6a49f1d2 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -339,21 +339,33 @@ pub fn pyarrow_nullable(column: &str) -> Box { .iter() .map(|x| x.map(|x| x as i128)) .collect::>(); - Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(9, 0))) + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal( + DecimalType::Int128, + 9, + 0, + ))) } "decimal_18" => { let values = i64_values .iter() .map(|x| x.map(|x| x as i128)) .collect::>(); - Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(18, 0))) + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal( + DecimalType::Int128, + 18, + 0, + ))) } "decimal_26" => { let values = i64_values .iter() .map(|x| x.map(|x| x as i128)) .collect::>(); - Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(26, 0))) + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal( + DecimalType::Int128, + 26, + 0, + ))) } "timestamp_us" => Box::new( PrimitiveArray::::from(i64_values) @@ -419,21 +431,21 @@ pub fn pyarrow_nullable_statistics(column: &str) -> Option> null_count: Some(3), min_value: Some(0i128), max_value: Some(9i128), - data_type: DataType::Decimal(9, 0), + data_type: DataType::Decimal(DecimalType::Int128, 9, 0), }), "decimal_18" => Box::new(PrimitiveStatistics:: { distinct_count: None, null_count: Some(3), min_value: Some(0i128), max_value: Some(9i128), - data_type: DataType::Decimal(18, 0), + data_type: DataType::Decimal(DecimalType::Int128, 18, 0), }), "decimal_26" => Box::new(PrimitiveStatistics:: { distinct_count: None, null_count: Some(3), min_value: Some(0i128), max_value: Some(9i128), - data_type: DataType::Decimal(26, 0), + data_type: DataType::Decimal(DecimalType::Int128, 26, 0), }), "timestamp_us" => Box::new(PrimitiveStatistics:: { data_type: DataType::Timestamp(TimeUnit::Microsecond, None), @@ -488,21 +500,33 @@ pub fn pyarrow_required(column: &str) -> Box { .iter() .map(|x| x.map(|x| x as i128)) .collect::>(); - Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(9, 0))) + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal( + DecimalType::Int128, + 9, + 0, + ))) } "decimal_18" => { let values = i64_values .iter() .map(|x| x.map(|x| x as i128)) .collect::>(); - Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(18, 0))) + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal( + DecimalType::Int128, + 18, + 0, + ))) } "decimal_26" => { let values = i64_values .iter() .map(|x| x.map(|x| x as i128)) .collect::>(); - Box::new(PrimitiveArray::::from(values).to(DataType::Decimal(26, 0))) + Box::new(PrimitiveArray::::from(values).to(DataType::Decimal( + DecimalType::Int128, + 26, + 0, + ))) } _ => unreachable!(), } @@ -534,21 +558,21 @@ pub fn pyarrow_required_statistics(column: &str) -> Option> null_count: Some(0), min_value: Some(0i128), max_value: Some(9i128), - data_type: DataType::Decimal(9, 0), + data_type: DataType::Decimal(DecimalType::Int128, 9, 0), }), "decimal_18" => Box::new(PrimitiveStatistics:: { distinct_count: None, null_count: Some(0), min_value: Some(0i128), max_value: Some(9i128), - data_type: DataType::Decimal(18, 0), + data_type: DataType::Decimal(DecimalType::Int128, 18, 0), }), "decimal_26" => Box::new(PrimitiveStatistics:: { distinct_count: None, null_count: Some(0), min_value: Some(0i128), max_value: Some(9i128), - data_type: DataType::Decimal(26, 0), + data_type: DataType::Decimal(DecimalType::Int128, 26, 0), }), _ => unreachable!(), })