From c999d582d9dd6b4f80bd3ce3d6be00d951b313f6 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 30 May 2023 16:26:34 +0200 Subject: [PATCH 1/3] WIP: decimal arithmetic --- .../src/compute/arithmetics/decimal/add.rs | 53 ++ .../src/compute/arithmetics/decimal/mod.rs | 77 ++ .../src/compute/arithmetics/mod.rs | 2 + polars/polars-arrow/src/compute/arity.rs | 1 + polars/polars-arrow/src/compute/mod.rs | 45 ++ .../src/chunked_array/arithmetic.rs | 678 ------------------ .../src/chunked_array/arithmetic/decimal.rs | 111 +++ .../src/chunked_array/arithmetic/mod.rs | 256 +++++++ .../src/chunked_array/arithmetic/numeric.rs | 395 ++++++++++ polars/polars-core/src/datatypes/dtype.rs | 4 +- 10 files changed, 943 insertions(+), 679 deletions(-) create mode 100644 polars/polars-arrow/src/compute/arithmetics/decimal/add.rs create mode 100644 polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs create mode 100644 polars/polars-arrow/src/compute/arithmetics/mod.rs create mode 100644 polars/polars-arrow/src/compute/arity.rs delete mode 100644 polars/polars-core/src/chunked_array/arithmetic.rs create mode 100644 polars/polars-core/src/chunked_array/arithmetic/decimal.rs create mode 100644 polars/polars-core/src/chunked_array/arithmetic/mod.rs create mode 100644 polars/polars-core/src/chunked_array/arithmetic/numeric.rs diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs new file mode 100644 index 000000000000..fc2ff3e8998a --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs @@ -0,0 +1,53 @@ +//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. +use arrow::{ + array::PrimitiveArray, + compute::{ + arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}, + arity::{binary, binary_checked}, + }, +}; +use arrow::{ + datatypes::DataType, + error::{Error, Result}, +}; +use arrow::compute::arity::unary; +use polars_error::*; +use crate::compute::{binary_mut, unary_mut}; +use crate::utils::combine_validities_and; + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; + +pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PolarsResult> { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let mut overflow = false; + let op = move |a, b| { + let res: i128 = a + b; + if res.abs() > max { + overflow = true + } + res + }; + polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); + + Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) +} + + +pub fn add_scalar(lhs: &PrimitiveArray, rhs: i128, rhs_dtype: &DataType) -> PolarsResult> { + let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype).unwrap(); + + let max = max_value(precision); + let mut overflow = false; + let op = move |a| { + let res: i128 = a + rhs; + if res.abs() > max { + overflow = true + } + res + }; + polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); + + Ok(unary_mut(lhs, op, lhs.data_type().clone())) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs new file mode 100644 index 000000000000..6b0cd14acad6 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs @@ -0,0 +1,77 @@ +use arrow::datatypes::DataType; +use polars_error::{PolarsError, PolarsResult}; + +mod add; + +pub use add::*; + +/// Maximum value that can exist with a selected precision +#[inline] +fn max_value(precision: usize) -> i128 { + 10i128.pow(precision as u32) - 1 +} + +// Calculates the number of digits in a i128 number +fn number_digits(num: i128) -> usize { + let mut num = num.abs(); + let mut digit: i128 = 0; + let base = 10i128; + + while num != 0 { + num /= base; + digit += 1; + } + + digit as usize +} + +fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize)> { + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.to_logical_type(), rhs.to_logical_type()) + { + if lhs_p == rhs_p && lhs_s == rhs_s { + Ok((*lhs_p, *lhs_s)) + } else { + Err(PolarsError::InvalidOperation( + "Arrays must have the same precision and scale".into(), + )) + } + } else { + unreachable!() + } +} + +/// Returns the adjusted precision and scale for the lhs and rhs precision and +/// scale +fn adjusted_precision_scale( + lhs_p: usize, + lhs_s: usize, + rhs_p: usize, + rhs_s: usize, +) -> (usize, usize, usize) { + // The initial new precision and scale is based on the number of digits + // that lhs and rhs number has before and after the point. The max + // number of digits before and after the point will make the last + // precision and scale of the result + + // Digits before/after point + // before after + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + let lhs_digits_before = lhs_p - lhs_s; + let rhs_digits_before = rhs_p - rhs_s; + + let res_digits_before = std::cmp::max(lhs_digits_before, rhs_digits_before); + + let (res_s, diff) = if lhs_s > rhs_s { + (lhs_s, lhs_s - rhs_s) + } else { + (rhs_s, rhs_s - lhs_s) + }; + + let res_p = res_digits_before + res_s; + + (res_p, res_s, diff) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/mod.rs b/polars/polars-arrow/src/compute/arithmetics/mod.rs new file mode 100644 index 000000000000..16dd409300a3 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "dtype-decimal")] +pub mod decimal; \ No newline at end of file diff --git a/polars/polars-arrow/src/compute/arity.rs b/polars/polars-arrow/src/compute/arity.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/polars/polars-arrow/src/compute/arity.rs @@ -0,0 +1 @@ + diff --git a/polars/polars-arrow/src/compute/mod.rs b/polars/polars-arrow/src/compute/mod.rs index ab57198868ec..46521718db88 100644 --- a/polars/polars-arrow/src/compute/mod.rs +++ b/polars/polars-arrow/src/compute/mod.rs @@ -1,3 +1,8 @@ +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; +use arrow::types::NativeType; +use crate::utils::combine_validities_and; + pub mod bitwise; #[cfg(feature = "compute")] pub mod cast; @@ -5,3 +10,43 @@ pub mod cast; pub mod decimal; pub mod take; pub mod tile; +pub mod arithmetics; +pub mod arity; + +#[inline] +pub fn binary_mut( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + mut op: F, +) -> PrimitiveArray + where + T: NativeType, + D: NativeType, + F: FnMut(T, D) -> T, +{ + assert_eq!(lhs.len(), rhs.len()); + let validity = combine_validities_and(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>() + .into(); + + PrimitiveArray::::new(data_type, values, validity) +} + +#[inline] +pub fn unary_mut(array: &PrimitiveArray, mut op: F, data_type: DataType) -> PrimitiveArray + where + I: NativeType, + O: NativeType, + F: FnMut(I) -> O, +{ + let values = array.values().iter().map(|v| op(*v)).collect::>(); + + PrimitiveArray::::new(data_type, values.into(), array.validity().cloned()) +} diff --git a/polars/polars-core/src/chunked_array/arithmetic.rs b/polars/polars-core/src/chunked_array/arithmetic.rs deleted file mode 100644 index cddf5882c3d2..000000000000 --- a/polars/polars-core/src/chunked_array/arithmetic.rs +++ /dev/null @@ -1,678 +0,0 @@ -//! Implementations of arithmetic operations on ChunkedArray's. -use std::ops::{Add, Div, Mul, Rem, Sub}; - -use arrow::array::PrimitiveArray; -use arrow::compute::arithmetics::basic; -#[cfg(feature = "dtype-decimal")] -use arrow::compute::arithmetics::decimal; -use arrow::compute::arity_assign; -use arrow::types::NativeType; -use num_traits::{Num, NumCast, ToPrimitive, Zero}; -use polars_arrow::utils::combine_validities_and; - -use crate::prelude::*; -use crate::series::IsSorted; -use crate::utils::{align_chunks_binary, align_chunks_binary_owned}; - -pub trait ArrayArithmetics -where - Self: NativeType, -{ - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; - fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; - fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; -} - -macro_rules! native_array_arithmetics { - ($ty: ty) => { - impl ArrayArithmetics for $ty - { - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::add(lhs, rhs) - } - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::sub(lhs, rhs) - } - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::mul(lhs, rhs) - } - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::div(lhs, rhs) - } - fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { - basic::div_scalar(lhs, rhs) - } - fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - basic::rem(lhs, rhs) - } - fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { - basic::rem_scalar(lhs, rhs) - } - } - }; - ($($ty:ty),*) => { - $(native_array_arithmetics!($ty);)* - } -} - -native_array_arithmetics!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64); - -#[cfg(feature = "dtype-decimal")] -impl ArrayArithmetics for i128 { - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::add(lhs, rhs) - } - - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::sub(lhs, rhs) - } - - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::mul(lhs, rhs) - } - - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::div(lhs, rhs) - } - - fn div_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - // decimal::div_scalar(lhs, rhs) - todo!("decimal::div_scalar exists, but takes &PrimitiveScalar, not &i128"); - } - - fn rem(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { - unimplemented!("requires support in arrow2 crate") - } - - fn rem_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - unimplemented!("requires support in arrow2 crate") - } -} - -pub(super) fn arithmetic_helper( - lhs: &ChunkedArray, - rhs: &ChunkedArray, - kernel: Kernel, - operation: F, -) -> ChunkedArray -where - T: PolarsNumericType, - Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, - F: Fn(T::Native, T::Native) -> T::Native, -{ - let mut ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => { - let (lhs, rhs) = align_chunks_binary(lhs, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(lhs, rhs)| Box::new(kernel(lhs, rhs)) as ArrayRef) - .collect(); - lhs.copy_with_chunks(chunks, false, false) - } - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs) => lhs.apply(|lhs| operation(lhs, rhs)), - } - } - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs) => rhs.apply(|rhs| operation(lhs, rhs)), - } - } - _ => panic!("Cannot apply operation on arrays of different lengths"), - }; - ca.rename(lhs.name()); - ca -} - -/// This assigns to the owned buffer if the ref count is 1 -fn arithmetic_helper_owned( - mut lhs: ChunkedArray, - mut rhs: ChunkedArray, - kernel: Kernel, - operation: F, -) -> ChunkedArray -where - T: PolarsNumericType, - Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), - F: Fn(T::Native, T::Native) -> T::Native, -{ - let ca = match (lhs.len(), rhs.len()) { - (a, b) if a == b => { - let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs); - // safety, we do no t change the lengths - unsafe { - lhs.downcast_iter_mut() - .zip(rhs.downcast_iter_mut()) - .for_each(|(lhs, rhs)| kernel(lhs, rhs)); - } - lhs.set_sorted_flag(IsSorted::Not); - lhs - } - // broadcast right path - (_, 1) => { - let opt_rhs = rhs.get(0); - match opt_rhs { - None => ChunkedArray::full_null(lhs.name(), lhs.len()), - Some(rhs) => { - lhs.apply_mut(|lhs| operation(lhs, rhs)); - lhs - } - } - } - (1, _) => { - let opt_lhs = lhs.get(0); - match opt_lhs { - None => ChunkedArray::full_null(lhs.name(), rhs.len()), - Some(lhs_val) => { - rhs.apply_mut(|rhs| operation(lhs_val, rhs)); - rhs.rename(lhs.name()); - rhs - } - } - } - _ => panic!("Cannot apply operation on arrays of different lengths"), - }; - ca -} - -// Operands on ChunkedArray & ChunkedArray - -impl Add for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn add(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::add, - |lhs, rhs| lhs + rhs, - ) - } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::div, - |lhs, rhs| lhs / rhs, - ) - } -} - -impl Mul for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::mul, - |lhs, rhs| lhs * rhs, - ) - } -} - -impl Rem for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::rem, - |lhs, rhs| lhs % rhs, - ) - } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: Self) -> Self::Output { - arithmetic_helper( - self, - rhs, - ::sub, - |lhs, rhs| lhs - rhs, - ) - } -} - -impl Add for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a + b), - |lhs, rhs| lhs + rhs, - ) - } -} - -impl Div for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a / b), - |lhs, rhs| lhs / rhs, - ) - } -} - -impl Mul for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a * b), - |lhs, rhs| lhs * rhs, - ) - } -} - -impl Sub for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - arithmetic_helper_owned( - self, - rhs, - |a, b| arity_assign::binary(a, b, |a, b| a - b), - |lhs, rhs| lhs - rhs, - ) - } -} - -impl Rem for ChunkedArray -where - T: PolarsNumericType, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: Self) -> Self::Output { - (&self).rem(&rhs) - } -} - -// Operands on ChunkedArray & Num - -impl Add for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn add(self, rhs: N) -> Self::Output { - let adder: T::Native = NumCast::from(rhs).unwrap(); - let mut out = self.apply(|val| val + adder); - out.set_sorted_flag(self.is_sorted_flag()); - out - } -} - -impl Sub for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: N) -> Self::Output { - let subber: T::Native = NumCast::from(rhs).unwrap(); - let mut out = self.apply(|val| val - subber); - out.set_sorted_flag(self.is_sorted_flag()); - out - } -} - -impl Div for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn div(self, rhs: N) -> Self::Output { - let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); - let mut out = self - .apply_kernel(&|arr| Box::new(::div_scalar(arr, &rhs))); - - if rhs < T::Native::zero() { - out.set_sorted_flag(self.is_sorted_flag().reverse()); - } else { - out.set_sorted_flag(self.is_sorted_flag()); - } - out - } -} - -impl Mul for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn mul(self, rhs: N) -> Self::Output { - // don't set sorted flag as probability of overflow is higher - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - let rhs = ChunkedArray::from_vec("", vec![multiplier]); - self.mul(&rhs) - } -} - -impl Rem for &ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: N) -> Self::Output { - let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); - let rhs = ChunkedArray::from_vec("", vec![rhs]); - self.rem(&rhs) - } -} - -impl Add for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn add(self, rhs: N) -> Self::Output { - (&self).add(rhs) - } -} - -impl Sub for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn sub(self, rhs: N) -> Self::Output { - (&self).sub(rhs) - } -} - -impl Div for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn div(self, rhs: N) -> Self::Output { - (&self).div(rhs) - } -} - -impl Mul for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn mul(mut self, rhs: N) -> Self::Output { - let multiplier: T::Native = NumCast::from(rhs).unwrap(); - self.apply_mut(|val| val * multiplier); - self - } -} - -impl Rem for ChunkedArray -where - T: PolarsNumericType, - N: Num + ToPrimitive, -{ - type Output = ChunkedArray; - - fn rem(self, rhs: N) -> Self::Output { - (&self).rem(rhs) - } -} - -fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { - buf.clear(); - - buf.extend_from_slice(l); - buf.extend_from_slice(r); -} - -impl Add for &Utf8Chunked { - type Output = Utf8Chunked; - - fn add(self, rhs: Self) -> Self::Output { - unsafe { (self.as_binary() + rhs.as_binary()).to_utf8() } - } -} - -impl Add for Utf8Chunked { - type Output = Utf8Chunked; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add(&rhs) - } -} - -impl Add<&str> for &Utf8Chunked { - type Output = Utf8Chunked; - - fn add(self, rhs: &str) -> Self::Output { - unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_utf8() } - } -} - -fn concat_binary(a: &BinaryArray, b: &BinaryArray) -> BinaryArray { - let validity = combine_validities_and(a.validity(), b.validity()); - let mut values = Vec::with_capacity(a.get_values_size() + b.get_values_size()); - let mut offsets = Vec::with_capacity(a.len() + 1); - let mut offset_so_far = 0i64; - offsets.push(offset_so_far); - - for (a, b) in a.values_iter().zip(b.values_iter()) { - values.extend_from_slice(a); - values.extend_from_slice(b); - offset_so_far = values.len() as i64; - offsets.push(offset_so_far) - } - unsafe { BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), validity) } -} - -impl Add for &BinaryChunked { - type Output = BinaryChunked; - - fn add(self, rhs: Self) -> Self::Output { - // broadcasting path rhs - if rhs.len() == 1 { - let rhs = rhs.get(0); - let mut buf = vec![]; - return match rhs { - Some(rhs) => { - self.apply_mut(|s| { - concat_binary_arrs(s, rhs, &mut buf); - let out = buf.as_slice(); - // safety: lifetime is bound to the outer scope and the - // ref is valid for the lifetime of this closure - unsafe { std::mem::transmute::<_, &'static [u8]>(out) } - }) - } - None => BinaryChunked::full_null(self.name(), self.len()), - }; - } - // broadcasting path lhs - if self.len() == 1 { - let lhs = self.get(0); - let mut buf = vec![]; - return match lhs { - Some(lhs) => rhs.apply_mut(|s| { - concat_binary_arrs(lhs, s, &mut buf); - - let out = buf.as_slice(); - // safety: lifetime is bound to the outer scope and the - // ref is valid for the lifetime of this closure - unsafe { std::mem::transmute::<_, &'static [u8]>(out) } - }), - None => BinaryChunked::full_null(self.name(), rhs.len()), - }; - } - - let (lhs, rhs) = align_chunks_binary(self, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(a, b)| Box::new(concat_binary(a, b)) as ArrayRef) - .collect(); - - unsafe { BinaryChunked::from_chunks(self.name(), chunks) } - } -} - -impl Add for BinaryChunked { - type Output = BinaryChunked; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add(&rhs) - } -} - -impl Add<&[u8]> for &BinaryChunked { - type Output = BinaryChunked; - - fn add(self, rhs: &[u8]) -> Self::Output { - let arr = BinaryArray::::from_slice([rhs]); - let rhs = unsafe { BinaryChunked::from_chunks("", vec![Box::new(arr) as ArrayRef]) }; - self.add(&rhs) - } -} - -fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray { - let validity = combine_validities_and(a.validity(), b.validity()); - - let values = a - .values_iter() - .zip(b.values_iter()) - .map(|(a, b)| a as IdxSize + b as IdxSize) - .collect::>(); - PrimitiveArray::from_data_default(values.into(), validity) -} - -impl Add for &BooleanChunked { - type Output = IdxCa; - - fn add(self, rhs: Self) -> Self::Output { - // broadcasting path rhs - if rhs.len() == 1 { - let rhs = rhs.get(0); - return match rhs { - Some(rhs) => self.apply_cast_numeric(|v| v as IdxSize + rhs as IdxSize), - None => IdxCa::full_null(self.name(), self.len()), - }; - } - // broadcasting path lhs - if self.len() == 1 { - return rhs.add(self); - } - let (lhs, rhs) = align_chunks_binary(self, rhs); - let chunks = lhs - .downcast_iter() - .zip(rhs.downcast_iter()) - .map(|(a, b)| Box::new(add_boolean(a, b)) as ArrayRef) - .collect::>(); - - unsafe { IdxCa::from_chunks(self.name(), chunks) } - } -} - -impl Add for BooleanChunked { - type Output = IdxCa; - - fn add(self, rhs: Self) -> Self::Output { - (&self).add(&rhs) - } -} - -#[cfg(test)] -pub(crate) mod test { - use crate::prelude::*; - - pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { - let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); - let a2 = Int32Chunked::new("a", &[4, 5, 6]); - let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); - a1.append(&a2); - (a1, a3) - } - - #[test] - #[allow(clippy::eq_op)] - fn test_chunk_mismatch() { - let (a1, a2) = create_two_chunked(); - // with different chunks - let _ = &a1 + &a2; - let _ = &a1 - &a2; - let _ = &a1 / &a2; - let _ = &a1 * &a2; - - // with same chunks - let _ = &a1 + &a1; - let _ = &a1 - &a1; - let _ = &a1 / &a1; - let _ = &a1 * &a1; - } -} diff --git a/polars/polars-core/src/chunked_array/arithmetic/decimal.rs b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs new file mode 100644 index 000000000000..bf0f26aa3de8 --- /dev/null +++ b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -0,0 +1,111 @@ +use crate::prelude::DecimalChunked; +use polars_arrow::compute::arithmetics::decimal; +use super::*; + +impl ArrayArithmetics for i128 { + fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + decimal::add(lhs, rhs).unwrap() + } + + fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + todo!() + // decimal::sub(lhs, rhs) + } + + fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + todo!() + // decimal::mul(lhs, rhs) + } + + fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + todo!() + // decimal::div(lhs, rhs) + } + + fn div_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { + // decimal::div_scalar(lhs, rhs) + todo!("decimal::div_scalar exists, but takes &PrimitiveScalar, not &i128"); + } + + fn rem(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!("requires support in arrow2 crate") + } + + fn rem_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { + unimplemented!("requires support in arrow2 crate") + } +} + +impl DecimalChunked { + fn arithmetic_helper(&self, rhs: &DecimalChunked, + kernel: Kernel, + operation_lhs: ScalarKernelLhs, + operation_rhs: ScalarKernelRhs + ) -> PolarsResult + where + Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PolarsResult>, + ScalarKernelLhs: Fn(&PrimitiveArray, i128, &ArrowDataType) -> PolarsResult>, + ScalarKernelRhs: Fn(i128, &ArrowDataType, &PrimitiveArray) -> PolarsResult> + + { + let lhs = self; + + let mut ca = match (lhs.len(), rhs.len()) { + (a, b) if a == b => { + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs, rhs)| kernel(lhs, rhs).map(|a| Box::new(a) as ArrayRef)) + .collect::>()?; + lhs.copy_with_chunks(chunks, false, false) + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => ChunkedArray::full_null(lhs.name(), lhs.len()), + Some(rhs_val) => { + let chunks = lhs.downcast_iter().map(|lhs| { + operation_lhs(lhs, rhs_val, &rhs.dtype().to_arrow()).map(|a| Box::new(a) as ArrayRef) + }).collect::>()?; + lhs.copy_with_chunks(chunks, false, false) + }, + } + } + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => ChunkedArray::full_null(lhs.name(), rhs.len()), + Some(lhs_val) => { + let chunks = rhs.downcast_iter().map(|rhs| { + operation_rhs(lhs_val, &lhs.dtype().to_arrow(), rhs).map(|a| Box::new(a) as ArrayRef) + }).collect::>()?; + lhs.copy_with_chunks(chunks, false, false) + + } + } + } + _ => polars_bail!(ComputeError: "Cannot apply operation on arrays of different lengths"), + }; + ca.rename(lhs.name()); + Ok(ca.into_decimal_unchecked(self.precision(), self.scale())) + + } + +} + +fn reversed(lhs_val: i128, lhs_dtype: &ArrowDataType, rhs: &PrimitiveArray, op: Kernel) -> PolarsResult> +where Kernel: Fn(&PrimitiveArray, i128, &ArrowDataType) -> PolarsResult>, +{ + op(rhs, lhs_val, lhs_dtype) +} + +impl Add for &DecimalChunked +{ + type Output = PolarsResult; + + fn add(self, rhs: Self) -> Self::Output { + self.arithmetic_helper(rhs, decimal::add, decimal::add_scalar, |lhs_val, lhs_dtype, rhs| reversed(lhs_val, lhs_dtype, rhs, decimal::add_scalar)) + } +} diff --git a/polars/polars-core/src/chunked_array/arithmetic/mod.rs b/polars/polars-core/src/chunked_array/arithmetic/mod.rs new file mode 100644 index 000000000000..63d8c72915b4 --- /dev/null +++ b/polars/polars-core/src/chunked_array/arithmetic/mod.rs @@ -0,0 +1,256 @@ +//! Implementations of arithmetic operations on ChunkedArray's. +mod numeric; +mod decimal; + +use std::ops::{Add, Div, Mul, Rem, Sub}; + +use arrow::array::PrimitiveArray; +use arrow::compute::arithmetics::basic; +use arrow::compute::arity_assign; +use arrow::types::NativeType; +use num_traits::{Num, NumCast, ToPrimitive, Zero}; +use polars_arrow::utils::combine_validities_and; + +use crate::prelude::*; +use crate::series::IsSorted; +use crate::utils::{align_chunks_binary, align_chunks_binary_owned}; +pub(super) use numeric::arithmetic_helper; + +pub trait ArrayArithmetics +where + Self: NativeType, +{ + fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; + fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray; + fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray; +} + +macro_rules! native_array_arithmetics { + ($ty: ty) => { + impl ArrayArithmetics for $ty + { + fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::add(lhs, rhs) + } + fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::sub(lhs, rhs) + } + fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::mul(lhs, rhs) + } + fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::div(lhs, rhs) + } + fn div_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { + basic::div_scalar(lhs, rhs) + } + fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + basic::rem(lhs, rhs) + } + fn rem_scalar(lhs: &PrimitiveArray, rhs: &Self) -> PrimitiveArray { + basic::rem_scalar(lhs, rhs) + } + } + }; + ($($ty:ty),*) => { + $(native_array_arithmetics!($ty);)* + } +} + +native_array_arithmetics!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64); + + + +fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { + buf.clear(); + + buf.extend_from_slice(l); + buf.extend_from_slice(r); +} + +impl Add for &Utf8Chunked { + type Output = Utf8Chunked; + + fn add(self, rhs: Self) -> Self::Output { + unsafe { (self.as_binary() + rhs.as_binary()).to_utf8() } + } +} + +impl Add for Utf8Chunked { + type Output = Utf8Chunked; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl Add<&str> for &Utf8Chunked { + type Output = Utf8Chunked; + + fn add(self, rhs: &str) -> Self::Output { + unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_utf8() } + } +} + +fn concat_binary(a: &BinaryArray, b: &BinaryArray) -> BinaryArray { + let validity = combine_validities_and(a.validity(), b.validity()); + let mut values = Vec::with_capacity(a.get_values_size() + b.get_values_size()); + let mut offsets = Vec::with_capacity(a.len() + 1); + let mut offset_so_far = 0i64; + offsets.push(offset_so_far); + + for (a, b) in a.values_iter().zip(b.values_iter()) { + values.extend_from_slice(a); + values.extend_from_slice(b); + offset_so_far = values.len() as i64; + offsets.push(offset_so_far) + } + unsafe { BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), validity) } +} + +impl Add for &BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: Self) -> Self::Output { + // broadcasting path rhs + if rhs.len() == 1 { + let rhs = rhs.get(0); + let mut buf = vec![]; + return match rhs { + Some(rhs) => { + self.apply_mut(|s| { + concat_binary_arrs(s, rhs, &mut buf); + let out = buf.as_slice(); + // safety: lifetime is bound to the outer scope and the + // ref is valid for the lifetime of this closure + unsafe { std::mem::transmute::<_, &'static [u8]>(out) } + }) + } + None => BinaryChunked::full_null(self.name(), self.len()), + }; + } + // broadcasting path lhs + if self.len() == 1 { + let lhs = self.get(0); + let mut buf = vec![]; + return match lhs { + Some(lhs) => rhs.apply_mut(|s| { + concat_binary_arrs(lhs, s, &mut buf); + + let out = buf.as_slice(); + // safety: lifetime is bound to the outer scope and the + // ref is valid for the lifetime of this closure + unsafe { std::mem::transmute::<_, &'static [u8]>(out) } + }), + None => BinaryChunked::full_null(self.name(), rhs.len()), + }; + } + + let (lhs, rhs) = align_chunks_binary(self, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(a, b)| Box::new(concat_binary(a, b)) as ArrayRef) + .collect(); + + unsafe { BinaryChunked::from_chunks(self.name(), chunks) } + } +} + +impl Add for BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +impl Add<&[u8]> for &BinaryChunked { + type Output = BinaryChunked; + + fn add(self, rhs: &[u8]) -> Self::Output { + let arr = BinaryArray::::from_slice([rhs]); + let rhs = unsafe { BinaryChunked::from_chunks("", vec![Box::new(arr) as ArrayRef]) }; + self.add(&rhs) + } +} + +fn add_boolean(a: &BooleanArray, b: &BooleanArray) -> PrimitiveArray { + let validity = combine_validities_and(a.validity(), b.validity()); + + let values = a + .values_iter() + .zip(b.values_iter()) + .map(|(a, b)| a as IdxSize + b as IdxSize) + .collect::>(); + PrimitiveArray::from_data_default(values.into(), validity) +} + +impl Add for &BooleanChunked { + type Output = IdxCa; + + fn add(self, rhs: Self) -> Self::Output { + // broadcasting path rhs + if rhs.len() == 1 { + let rhs = rhs.get(0); + return match rhs { + Some(rhs) => self.apply_cast_numeric(|v| v as IdxSize + rhs as IdxSize), + None => IdxCa::full_null(self.name(), self.len()), + }; + } + // broadcasting path lhs + if self.len() == 1 { + return rhs.add(self); + } + let (lhs, rhs) = align_chunks_binary(self, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(a, b)| Box::new(add_boolean(a, b)) as ArrayRef) + .collect::>(); + + unsafe { IdxCa::from_chunks(self.name(), chunks) } + } +} + +impl Add for BooleanChunked { + type Output = IdxCa; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + +#[cfg(test)] +pub(crate) mod test { + use crate::prelude::*; + + pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { + let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); + let a2 = Int32Chunked::new("a", &[4, 5, 6]); + let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); + a1.append(&a2); + (a1, a3) + } + + #[test] + #[allow(clippy::eq_op)] + fn test_chunk_mismatch() { + let (a1, a2) = create_two_chunked(); + // with different chunks + let _ = &a1 + &a2; + let _ = &a1 - &a2; + let _ = &a1 / &a2; + let _ = &a1 * &a2; + + // with same chunks + let _ = &a1 + &a1; + let _ = &a1 - &a1; + let _ = &a1 / &a1; + let _ = &a1 * &a1; + } +} diff --git a/polars/polars-core/src/chunked_array/arithmetic/numeric.rs b/polars/polars-core/src/chunked_array/arithmetic/numeric.rs new file mode 100644 index 000000000000..42cb157b6bf2 --- /dev/null +++ b/polars/polars-core/src/chunked_array/arithmetic/numeric.rs @@ -0,0 +1,395 @@ +use super::*; + +pub(crate) fn arithmetic_helper( + lhs: &ChunkedArray, + rhs: &ChunkedArray, + kernel: Kernel, + operation: F, +) -> ChunkedArray + where + T: PolarsNumericType, + Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, + F: Fn(T::Native, T::Native) -> T::Native, +{ + let mut ca = match (lhs.len(), rhs.len()) { + (a, b) if a == b => { + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + let chunks = lhs + .downcast_iter() + .zip(rhs.downcast_iter()) + .map(|(lhs, rhs)| Box::new(kernel(lhs, rhs)) as ArrayRef) + .collect(); + lhs.copy_with_chunks(chunks, false, false) + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => ChunkedArray::full_null(lhs.name(), lhs.len()), + Some(rhs) => lhs.apply(|lhs| operation(lhs, rhs)), + } + } + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => ChunkedArray::full_null(lhs.name(), rhs.len()), + Some(lhs) => rhs.apply(|rhs| operation(lhs, rhs)), + } + } + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + ca.rename(lhs.name()); + ca +} + +/// This assigns to the owned buffer if the ref count is 1 +fn arithmetic_helper_owned( + mut lhs: ChunkedArray, + mut rhs: ChunkedArray, + kernel: Kernel, + operation: F, +) -> ChunkedArray + where + T: PolarsNumericType, + Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), + F: Fn(T::Native, T::Native) -> T::Native, +{ + let ca = match (lhs.len(), rhs.len()) { + (a, b) if a == b => { + let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs); + // safety, we do no t change the lengths + unsafe { + lhs.downcast_iter_mut() + .zip(rhs.downcast_iter_mut()) + .for_each(|(lhs, rhs)| kernel(lhs, rhs)); + } + lhs.set_sorted_flag(IsSorted::Not); + lhs + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => ChunkedArray::full_null(lhs.name(), lhs.len()), + Some(rhs) => { + lhs.apply_mut(|lhs| operation(lhs, rhs)); + lhs + } + } + } + (1, _) => { + let opt_lhs = lhs.get(0); + match opt_lhs { + None => ChunkedArray::full_null(lhs.name(), rhs.len()), + Some(lhs_val) => { + rhs.apply_mut(|rhs| operation(lhs_val, rhs)); + rhs.rename(lhs.name()); + rhs + } + } + } + _ => panic!("Cannot apply operation on arrays of different lengths"), + }; + ca +} + +// Operands on ChunkedArray & ChunkedArray + +impl Add for &ChunkedArray + where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn add(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::add, + |lhs, rhs| lhs + rhs, + ) + } +} + +impl Div for &ChunkedArray + where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn div(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::div, + |lhs, rhs| lhs / rhs, + ) + } +} + +impl Mul for &ChunkedArray + where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn mul(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::mul, + |lhs, rhs| lhs * rhs, + ) + } +} + +impl Rem for &ChunkedArray + where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::rem, + |lhs, rhs| lhs % rhs, + ) + } +} + +impl Sub for &ChunkedArray + where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn sub(self, rhs: Self) -> Self::Output { + arithmetic_helper( + self, + rhs, + ::sub, + |lhs, rhs| lhs - rhs, + ) + } +} + +impl Add for ChunkedArray + where + T: PolarsNumericType, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a + b), + |lhs, rhs| lhs + rhs, + ) + } +} + +impl Div for ChunkedArray + where + T: PolarsNumericType, +{ + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a / b), + |lhs, rhs| lhs / rhs, + ) + } +} + +impl Mul for ChunkedArray + where + T: PolarsNumericType, +{ + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a * b), + |lhs, rhs| lhs * rhs, + ) + } +} + +impl Sub for ChunkedArray + where + T: PolarsNumericType, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + arithmetic_helper_owned( + self, + rhs, + |a, b| arity_assign::binary(a, b, |a, b| a - b), + |lhs, rhs| lhs - rhs, + ) + } +} + +impl Rem for ChunkedArray + where + T: PolarsNumericType, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: Self) -> Self::Output { + (&self).rem(&rhs) + } +} + +// Operands on ChunkedArray & Num + +impl Add for &ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn add(self, rhs: N) -> Self::Output { + let adder: T::Native = NumCast::from(rhs).unwrap(); + let mut out = self.apply(|val| val + adder); + out.set_sorted_flag(self.is_sorted_flag()); + out + } +} + +impl Sub for &ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn sub(self, rhs: N) -> Self::Output { + let subber: T::Native = NumCast::from(rhs).unwrap(); + let mut out = self.apply(|val| val - subber); + out.set_sorted_flag(self.is_sorted_flag()); + out + } +} + +impl Div for &ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn div(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); + let mut out = self + .apply_kernel(&|arr| Box::new(::div_scalar(arr, &rhs))); + + if rhs < T::Native::zero() { + out.set_sorted_flag(self.is_sorted_flag().reverse()); + } else { + out.set_sorted_flag(self.is_sorted_flag()); + } + out + } +} + +impl Mul for &ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn mul(self, rhs: N) -> Self::Output { + // don't set sorted flag as probability of overflow is higher + let multiplier: T::Native = NumCast::from(rhs).unwrap(); + let rhs = ChunkedArray::from_vec("", vec![multiplier]); + self.mul(&rhs) + } +} + +impl Rem for &ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: N) -> Self::Output { + let rhs: T::Native = NumCast::from(rhs).expect("could not cast"); + let rhs = ChunkedArray::from_vec("", vec![rhs]); + self.rem(&rhs) + } +} + +impl Add for ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn add(self, rhs: N) -> Self::Output { + (&self).add(rhs) + } +} + +impl Sub for ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn sub(self, rhs: N) -> Self::Output { + (&self).sub(rhs) + } +} + +impl Div for ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn div(self, rhs: N) -> Self::Output { + (&self).div(rhs) + } +} + +impl Mul for ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn mul(mut self, rhs: N) -> Self::Output { + let multiplier: T::Native = NumCast::from(rhs).unwrap(); + self.apply_mut(|val| val * multiplier); + self + } +} + +impl Rem for ChunkedArray + where + T: PolarsNumericType, + N: Num + ToPrimitive, +{ + type Output = ChunkedArray; + + fn rem(self, rhs: N) -> Self::Output { + (&self).rem(rhs) + } +} diff --git a/polars/polars-core/src/datatypes/dtype.rs b/polars/polars-core/src/datatypes/dtype.rs index 76b749c85b54..2e9538ddf797 100644 --- a/polars/polars-core/src/datatypes/dtype.rs +++ b/polars/polars-core/src/datatypes/dtype.rs @@ -160,7 +160,7 @@ impl DataType { self.is_numeric() | matches!(self, DataType::Boolean | DataType::Utf8 | DataType::Binary) } - /// Check if this [`DataType`] is a numeric type + /// Check if this [`DataType`] is a numeric type. pub fn is_numeric(&self) -> bool { // allow because it cannot be replaced when object feature is activated #[allow(clippy::match_like_matches_macro)] @@ -181,6 +181,8 @@ impl DataType { DataType::Categorical(_) => false, #[cfg(feature = "dtype-struct")] DataType::Struct(_) => false, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => false, _ => true, } } From 1a35cde35f62d95b083341e48d3d1952b2532af7 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 30 May 2023 18:44:09 +0200 Subject: [PATCH 2/3] dry --- .../src/compute/arithmetics/decimal/add.rs | 61 ++------- .../arithmetics/decimal/commutative.rs | 93 ++++++++++++++ .../src/compute/arithmetics/decimal/mod.rs | 9 ++ .../src/compute/arithmetics/decimal/mul.rs | 16 +++ .../src/compute/arithmetics/decimal/sub.rs | 19 +++ .../src/compute/arithmetics/mod.rs | 2 +- polars/polars-arrow/src/compute/mod.rs | 27 ++-- .../src/chunked_array/arithmetic/decimal.rs | 107 +++++++++------- .../src/chunked_array/arithmetic/mod.rs | 7 +- .../src/chunked_array/arithmetic/numeric.rs | 116 +++++++++--------- 10 files changed, 293 insertions(+), 164 deletions(-) create mode 100644 polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs create mode 100644 polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs create mode 100644 polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs index fc2ff3e8998a..b9c2c6dc7813 100644 --- a/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/add.rs @@ -1,53 +1,16 @@ -//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. -use arrow::{ - array::PrimitiveArray, - compute::{ - arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}, - arity::{binary, binary_checked}, - }, -}; -use arrow::{ - datatypes::DataType, - error::{Error, Result}, -}; -use arrow::compute::arity::unary; -use polars_error::*; -use crate::compute::{binary_mut, unary_mut}; -use crate::utils::combine_validities_and; +use super::*; -use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; - -pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PolarsResult> { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - let mut overflow = false; - let op = move |a, b| { - let res: i128 = a + b; - if res.abs() > max { - overflow = true - } - res - }; - polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); - - Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) +pub fn add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + commutative(lhs, rhs, |a, b| a + b) } - -pub fn add_scalar(lhs: &PrimitiveArray, rhs: i128, rhs_dtype: &DataType) -> PolarsResult> { - let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype).unwrap(); - - let max = max_value(precision); - let mut overflow = false; - let op = move |a| { - let res: i128 = a + rhs; - if res.abs() > max { - overflow = true - } - res - }; - polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); - - Ok(unary_mut(lhs, op, lhs.data_type().clone())) +pub fn add_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, +) -> PolarsResult> { + commutative_scalar(lhs, rhs, rhs_dtype, |a, b| a + b) } diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs new file mode 100644 index 000000000000..aca337284387 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs @@ -0,0 +1,93 @@ +use arrow::array::PrimitiveArray; +use arrow::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; +use arrow::compute::arity::{binary, binary_checked, unary}; +use arrow::datatypes::DataType; +use arrow::error::{Error, Result}; +use polars_error::*; + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::compute::{binary_mut, unary_mut}; +use crate::utils::combine_validities_and; + +pub fn commutative( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + op: F +) -> PolarsResult> + where F: Fn(i128, i128) -> i128 +{ + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let mut overflow = false; + let op = move |a, b| { + let res = op(a, b); + overflow |= res.abs() > max; + res + }; + polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); + + Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) +} + +pub fn commutative_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, + op: F +) -> PolarsResult> + where F: Fn(i128, i128) -> i128 +{ + let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype).unwrap(); + + let max = max_value(precision); + let mut overflow = false; + let op = move |a| { + let res = op(a, rhs); + overflow |= res.abs() > max; + res + }; + polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); + + Ok(unary_mut(lhs, op, lhs.data_type().clone())) +} + +pub fn non_commutative( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + op: F +) -> PolarsResult> + where F: Fn(i128, i128) -> i128 +{ + let op = move |a, b| { + op(a, b) + }; + + Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) +} + +pub fn non_commutative_scalar(lhs: &PrimitiveArray, rhs: i128, + op: F +) -> PolarsResult> + where F: Fn(i128, i128) -> i128 +{ + let op = move |a| { + op(a, rhs) + }; + + Ok(unary_mut(lhs, op, lhs.data_type().clone())) +} + +pub fn non_commutative_scalar_swapped( + lhs: i128, + rhs: &PrimitiveArray, + op: F +) -> PolarsResult> + where F: Fn(i128, i128) -> i128 +{ + let op = move |a| { + op(lhs, a) + }; + + Ok(unary_mut(rhs, op, rhs.data_type().clone())) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs index 6b0cd14acad6..36aae1bbd247 100644 --- a/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs @@ -1,9 +1,18 @@ use arrow::datatypes::DataType; use polars_error::{PolarsError, PolarsResult}; +use arrow::array::PrimitiveArray; +use commutative::{ + commutative_scalar, commutative, non_commutative, non_commutative_scalar_swapped, non_commutative_scalar +}; mod add; +mod sub; +mod mul; +mod commutative; pub use add::*; +pub use sub::*; +pub use mul::*; /// Maximum value that can exist with a selected precision #[inline] diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs new file mode 100644 index 000000000000..f8db20097df7 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs @@ -0,0 +1,16 @@ +use super::*; + +pub fn mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + commutative(lhs, rhs, |a, b| a * b) +} + +pub fn mul_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, +) -> PolarsResult> { + commutative_scalar(lhs, rhs, rhs_dtype, |a, b| a * b) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs new file mode 100644 index 000000000000..4e141a25a086 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs @@ -0,0 +1,19 @@ +use super::*; + +pub fn sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + non_commutative(lhs, rhs, |a, b | a - b) +} + +pub fn sub_scalar(lhs: &PrimitiveArray, rhs: i128) -> PolarsResult> { + non_commutative_scalar(lhs, rhs, |a, b | a - b) +} + +pub fn sub_scalar_swapped( + lhs: i128, + rhs: &PrimitiveArray, +) -> PolarsResult> { + non_commutative_scalar_swapped(lhs, rhs, |a, b | a - b) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/mod.rs b/polars/polars-arrow/src/compute/arithmetics/mod.rs index 16dd409300a3..0abcbaba757a 100644 --- a/polars/polars-arrow/src/compute/arithmetics/mod.rs +++ b/polars/polars-arrow/src/compute/arithmetics/mod.rs @@ -1,2 +1,2 @@ #[cfg(feature = "dtype-decimal")] -pub mod decimal; \ No newline at end of file +pub mod decimal; diff --git a/polars/polars-arrow/src/compute/mod.rs b/polars/polars-arrow/src/compute/mod.rs index 46521718db88..ead240d11755 100644 --- a/polars/polars-arrow/src/compute/mod.rs +++ b/polars/polars-arrow/src/compute/mod.rs @@ -1,8 +1,11 @@ use arrow::array::PrimitiveArray; use arrow::datatypes::DataType; use arrow::types::NativeType; + use crate::utils::combine_validities_and; +pub mod arithmetics; +pub mod arity; pub mod bitwise; #[cfg(feature = "compute")] pub mod cast; @@ -10,8 +13,6 @@ pub mod cast; pub mod decimal; pub mod take; pub mod tile; -pub mod arithmetics; -pub mod arity; #[inline] pub fn binary_mut( @@ -20,10 +21,10 @@ pub fn binary_mut( data_type: DataType, mut op: F, ) -> PrimitiveArray - where - T: NativeType, - D: NativeType, - F: FnMut(T, D) -> T, +where + T: NativeType, + D: NativeType, + F: FnMut(T, D) -> T, { assert_eq!(lhs.len(), rhs.len()); let validity = combine_validities_and(lhs.validity(), rhs.validity()); @@ -40,11 +41,15 @@ pub fn binary_mut( } #[inline] -pub fn unary_mut(array: &PrimitiveArray, mut op: F, data_type: DataType) -> PrimitiveArray - where - I: NativeType, - O: NativeType, - F: FnMut(I) -> O, +pub fn unary_mut( + array: &PrimitiveArray, + mut op: F, + data_type: DataType, +) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: FnMut(I) -> O, { let values = array.values().iter().map(|v| op(*v)).collect::>(); diff --git a/polars/polars-core/src/chunked_array/arithmetic/decimal.rs b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs index bf0f26aa3de8..93dee0ca5792 100644 --- a/polars/polars-core/src/chunked_array/arithmetic/decimal.rs +++ b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -1,30 +1,28 @@ -use crate::prelude::DecimalChunked; use polars_arrow::compute::arithmetics::decimal; + use super::*; +use crate::prelude::DecimalChunked; +// TODO: remove impl ArrayArithmetics for i128 { - fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - decimal::add(lhs, rhs).unwrap() + fn add(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() } - fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - todo!() - // decimal::sub(lhs, rhs) + fn sub(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() } - fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - todo!() - // decimal::mul(lhs, rhs) + fn mul(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() } - fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - todo!() - // decimal::div(lhs, rhs) + fn div(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { + unimplemented!() } fn div_scalar(_lhs: &PrimitiveArray, _rhs: &Self) -> PrimitiveArray { - // decimal::div_scalar(lhs, rhs) - todo!("decimal::div_scalar exists, but takes &PrimitiveScalar, not &i128"); + unimplemented!() } fn rem(_lhs: &PrimitiveArray, _rhs: &PrimitiveArray) -> PrimitiveArray { @@ -37,16 +35,18 @@ impl ArrayArithmetics for i128 { } impl DecimalChunked { - fn arithmetic_helper(&self, rhs: &DecimalChunked, - kernel: Kernel, - operation_lhs: ScalarKernelLhs, - operation_rhs: ScalarKernelRhs + fn arithmetic_helper( + &self, + rhs: &DecimalChunked, + kernel: Kernel, + operation_lhs: ScalarKernelLhs, + operation_rhs: ScalarKernelRhs, ) -> PolarsResult - where - Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PolarsResult>, - ScalarKernelLhs: Fn(&PrimitiveArray, i128, &ArrowDataType) -> PolarsResult>, - ScalarKernelRhs: Fn(i128, &ArrowDataType, &PrimitiveArray) -> PolarsResult> - + where + Kernel: + Fn(&PrimitiveArray, &PrimitiveArray) -> PolarsResult>, + ScalarKernelLhs: Fn(&PrimitiveArray, i128) -> PolarsResult>, + ScalarKernelRhs: Fn(i128, &PrimitiveArray) -> PolarsResult>, { let lhs = self; @@ -66,11 +66,12 @@ impl DecimalChunked { match opt_rhs { None => ChunkedArray::full_null(lhs.name(), lhs.len()), Some(rhs_val) => { - let chunks = lhs.downcast_iter().map(|lhs| { - operation_lhs(lhs, rhs_val, &rhs.dtype().to_arrow()).map(|a| Box::new(a) as ArrayRef) - }).collect::>()?; + let chunks = lhs + .downcast_iter() + .map(|lhs| operation_lhs(lhs, rhs_val).map(|a| Box::new(a) as ArrayRef)) + .collect::>()?; lhs.copy_with_chunks(chunks, false, false) - }, + } } } (1, _) => { @@ -78,34 +79,58 @@ impl DecimalChunked { match opt_lhs { None => ChunkedArray::full_null(lhs.name(), rhs.len()), Some(lhs_val) => { - let chunks = rhs.downcast_iter().map(|rhs| { - operation_rhs(lhs_val, &lhs.dtype().to_arrow(), rhs).map(|a| Box::new(a) as ArrayRef) - }).collect::>()?; + let chunks = rhs + .downcast_iter() + .map(|rhs| operation_rhs(lhs_val, rhs).map(|a| Box::new(a) as ArrayRef)) + .collect::>()?; lhs.copy_with_chunks(chunks, false, false) - } } } - _ => polars_bail!(ComputeError: "Cannot apply operation on arrays of different lengths"), + _ => { + polars_bail!(ComputeError: "Cannot apply operation on arrays of different lengths") + } }; ca.rename(lhs.name()); Ok(ca.into_decimal_unchecked(self.precision(), self.scale())) - } +} +impl Add for &DecimalChunked { + type Output = PolarsResult; + + fn add(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::add, + |lhs, rhs_val| decimal::add_scalar(lhs, rhs_val, &rhs.dtype().to_arrow()), + |lhs_val, rhs| decimal::add_scalar(rhs, lhs_val, &self.dtype().to_arrow()), + ) + } } -fn reversed(lhs_val: i128, lhs_dtype: &ArrowDataType, rhs: &PrimitiveArray, op: Kernel) -> PolarsResult> -where Kernel: Fn(&PrimitiveArray, i128, &ArrowDataType) -> PolarsResult>, -{ - op(rhs, lhs_val, lhs_dtype) +impl Sub for &DecimalChunked { + type Output = PolarsResult; + + fn sub(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::sub, + decimal::sub_scalar, + decimal::sub_scalar_swapped, + ) + } } -impl Add for &DecimalChunked -{ +impl Mul for &DecimalChunked { type Output = PolarsResult; - fn add(self, rhs: Self) -> Self::Output { - self.arithmetic_helper(rhs, decimal::add, decimal::add_scalar, |lhs_val, lhs_dtype, rhs| reversed(lhs_val, lhs_dtype, rhs, decimal::add_scalar)) + fn mul(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::mul, + |lhs, rhs_val| decimal::mul_scalar(lhs, rhs_val, &rhs.dtype().to_arrow()), + |lhs_val, rhs| decimal::mul_scalar(rhs, lhs_val, &self.dtype().to_arrow()), + ) } } diff --git a/polars/polars-core/src/chunked_array/arithmetic/mod.rs b/polars/polars-core/src/chunked_array/arithmetic/mod.rs index 63d8c72915b4..6aa0a37cb310 100644 --- a/polars/polars-core/src/chunked_array/arithmetic/mod.rs +++ b/polars/polars-core/src/chunked_array/arithmetic/mod.rs @@ -1,6 +1,7 @@ //! Implementations of arithmetic operations on ChunkedArray's. -mod numeric; +#[cfg(feature = "dtype-decimal")] mod decimal; +mod numeric; use std::ops::{Add, Div, Mul, Rem, Sub}; @@ -9,12 +10,12 @@ use arrow::compute::arithmetics::basic; use arrow::compute::arity_assign; use arrow::types::NativeType; use num_traits::{Num, NumCast, ToPrimitive, Zero}; +pub(super) use numeric::arithmetic_helper; use polars_arrow::utils::combine_validities_and; use crate::prelude::*; use crate::series::IsSorted; use crate::utils::{align_chunks_binary, align_chunks_binary_owned}; -pub(super) use numeric::arithmetic_helper; pub trait ArrayArithmetics where @@ -63,8 +64,6 @@ macro_rules! native_array_arithmetics { native_array_arithmetics!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64); - - fn concat_binary_arrs(l: &[u8], r: &[u8], buf: &mut Vec) { buf.clear(); diff --git a/polars/polars-core/src/chunked_array/arithmetic/numeric.rs b/polars/polars-core/src/chunked_array/arithmetic/numeric.rs index 42cb157b6bf2..03cb849c2e59 100644 --- a/polars/polars-core/src/chunked_array/arithmetic/numeric.rs +++ b/polars/polars-core/src/chunked_array/arithmetic/numeric.rs @@ -6,10 +6,10 @@ pub(crate) fn arithmetic_helper( kernel: Kernel, operation: F, ) -> ChunkedArray - where - T: PolarsNumericType, - Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, - F: Fn(T::Native, T::Native) -> T::Native, +where + T: PolarsNumericType, + Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, + F: Fn(T::Native, T::Native) -> T::Native, { let mut ca = match (lhs.len(), rhs.len()) { (a, b) if a == b => { @@ -49,10 +49,10 @@ fn arithmetic_helper_owned( kernel: Kernel, operation: F, ) -> ChunkedArray - where - T: PolarsNumericType, - Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), - F: Fn(T::Native, T::Native) -> T::Native, +where + T: PolarsNumericType, + Kernel: Fn(&mut PrimitiveArray, &mut PrimitiveArray), + F: Fn(T::Native, T::Native) -> T::Native, { let ca = match (lhs.len(), rhs.len()) { (a, b) if a == b => { @@ -96,8 +96,8 @@ fn arithmetic_helper_owned( // Operands on ChunkedArray & ChunkedArray impl Add for &ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = ChunkedArray; @@ -112,8 +112,8 @@ impl Add for &ChunkedArray } impl Div for &ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = ChunkedArray; @@ -128,8 +128,8 @@ impl Div for &ChunkedArray } impl Mul for &ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = ChunkedArray; @@ -144,8 +144,8 @@ impl Mul for &ChunkedArray } impl Rem for &ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = ChunkedArray; @@ -160,8 +160,8 @@ impl Rem for &ChunkedArray } impl Sub for &ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = ChunkedArray; @@ -176,8 +176,8 @@ impl Sub for &ChunkedArray } impl Add for ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = Self; @@ -192,8 +192,8 @@ impl Add for ChunkedArray } impl Div for ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = Self; @@ -208,8 +208,8 @@ impl Div for ChunkedArray } impl Mul for ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = Self; @@ -224,8 +224,8 @@ impl Mul for ChunkedArray } impl Sub for ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = Self; @@ -240,8 +240,8 @@ impl Sub for ChunkedArray } impl Rem for ChunkedArray - where - T: PolarsNumericType, +where + T: PolarsNumericType, { type Output = ChunkedArray; @@ -253,9 +253,9 @@ impl Rem for ChunkedArray // Operands on ChunkedArray & Num impl Add for &ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -268,9 +268,9 @@ impl Add for &ChunkedArray } impl Sub for &ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -283,9 +283,9 @@ impl Sub for &ChunkedArray } impl Div for &ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -304,9 +304,9 @@ impl Div for &ChunkedArray } impl Mul for &ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -319,9 +319,9 @@ impl Mul for &ChunkedArray } impl Rem for &ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -333,9 +333,9 @@ impl Rem for &ChunkedArray } impl Add for ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -345,9 +345,9 @@ impl Add for ChunkedArray } impl Sub for ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -357,9 +357,9 @@ impl Sub for ChunkedArray } impl Div for ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -369,9 +369,9 @@ impl Div for ChunkedArray } impl Mul for ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; @@ -383,9 +383,9 @@ impl Mul for ChunkedArray } impl Rem for ChunkedArray - where - T: PolarsNumericType, - N: Num + ToPrimitive, +where + T: PolarsNumericType, + N: Num + ToPrimitive, { type Output = ChunkedArray; From c81814a9f943cbace5d508f39b3cf346c3a8bbe2 Mon Sep 17 00:00:00 2001 From: ritchie Date: Wed, 31 May 2023 09:01:11 +0200 Subject: [PATCH 3/3] wrap-up --- .../arithmetics/decimal/commutative.rs | 58 ++++++++--------- .../src/compute/arithmetics/decimal/div.rs | 43 +++++++++++++ .../src/compute/arithmetics/decimal/mod.rs | 64 +++---------------- .../src/compute/arithmetics/decimal/mul.rs | 21 +++++- .../src/compute/arithmetics/decimal/sub.rs | 6 +- .../src/chunked_array/arithmetic/decimal.rs | 13 ++++ .../src/series/implementations/decimal.rs | 16 +++++ .../src/physical_plan/expressions/binary.rs | 2 + .../tests/unit/datatypes/test_decimal.py | 29 +++++++++ 9 files changed, 161 insertions(+), 91 deletions(-) create mode 100644 polars/polars-arrow/src/compute/arithmetics/decimal/div.rs diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs index aca337284387..304f9770a0f6 100644 --- a/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs @@ -1,79 +1,76 @@ use arrow::array::PrimitiveArray; -use arrow::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; -use arrow::compute::arity::{binary, binary_checked, unary}; use arrow::datatypes::DataType; -use arrow::error::{Error, Result}; use polars_error::*; -use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use super::{get_parameters, max_value}; use crate::compute::{binary_mut, unary_mut}; -use crate::utils::combine_validities_and; pub fn commutative( lhs: &PrimitiveArray, rhs: &PrimitiveArray, - op: F + op: F, ) -> PolarsResult> - where F: Fn(i128, i128) -> i128 +where + F: Fn(i128, i128) -> i128, { let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); let max = max_value(precision); let mut overflow = false; - let op = move |a, b| { + let op = |a, b| { let res = op(a, b); overflow |= res.abs() > max; res }; + let out = binary_mut(lhs, rhs, lhs.data_type().clone(), op); polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); - - Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) + Ok(out) } pub fn commutative_scalar( lhs: &PrimitiveArray, rhs: i128, rhs_dtype: &DataType, - op: F + op: F, ) -> PolarsResult> - where F: Fn(i128, i128) -> i128 +where + F: Fn(i128, i128) -> i128, { let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype).unwrap(); let max = max_value(precision); let mut overflow = false; - let op = move |a| { + let op = |a| { let res = op(a, rhs); overflow |= res.abs() > max; res }; + let out = unary_mut(lhs, op, lhs.data_type().clone()); polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}"); - Ok(unary_mut(lhs, op, lhs.data_type().clone())) + Ok(out) } pub fn non_commutative( lhs: &PrimitiveArray, rhs: &PrimitiveArray, - op: F + op: F, ) -> PolarsResult> - where F: Fn(i128, i128) -> i128 +where + F: Fn(i128, i128) -> i128, { - let op = move |a, b| { - op(a, b) - }; - Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op)) } -pub fn non_commutative_scalar(lhs: &PrimitiveArray, rhs: i128, - op: F +pub fn non_commutative_scalar( + lhs: &PrimitiveArray, + rhs: i128, + op: F, ) -> PolarsResult> - where F: Fn(i128, i128) -> i128 +where + F: Fn(i128, i128) -> i128, { - let op = move |a| { - op(a, rhs) - }; + let op = move |a| op(a, rhs); Ok(unary_mut(lhs, op, lhs.data_type().clone())) } @@ -81,13 +78,12 @@ pub fn non_commutative_scalar(lhs: &PrimitiveArray, rhs: i128, pub fn non_commutative_scalar_swapped( lhs: i128, rhs: &PrimitiveArray, - op: F + op: F, ) -> PolarsResult> - where F: Fn(i128, i128) -> i128 +where + F: Fn(i128, i128) -> i128, { - let op = move |a| { - op(lhs, a) - }; + let op = move |a| op(lhs, a); Ok(unary_mut(rhs, op, rhs.data_type().clone())) } diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/div.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/div.rs new file mode 100644 index 000000000000..94c817eea735 --- /dev/null +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/div.rs @@ -0,0 +1,43 @@ +use super::*; + +#[inline] +fn decimal_div(a: i128, b: i128, scale: i128) -> i128 { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + a * scale / b +} + +pub fn div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; + let scale = 10i128.pow(scale as u32); + non_commutative(lhs, rhs, |a, b| decimal_div(a, b, scale)) +} + +pub fn div_scalar( + lhs: &PrimitiveArray, + rhs: i128, + rhs_dtype: &DataType, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?; + let scale = 10i128.pow(scale as u32); + non_commutative_scalar(lhs, rhs, |a, b| decimal_div(a, b, scale)) +} + +pub fn div_scalar_swapped( + lhs: i128, + lhs_dtype: &DataType, + rhs: &PrimitiveArray, +) -> PolarsResult> { + let (_, scale) = get_parameters(lhs_dtype, rhs.data_type())?; + let scale = 10i128.pow(scale as u32); + non_commutative_scalar_swapped(lhs, rhs, |a, b| decimal_div(a, b, scale)) +} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs index 36aae1bbd247..d74f4ddb8e78 100644 --- a/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs @@ -1,18 +1,21 @@ -use arrow::datatypes::DataType; -use polars_error::{PolarsError, PolarsResult}; use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; use commutative::{ - commutative_scalar, commutative, non_commutative, non_commutative_scalar_swapped, non_commutative_scalar + commutative, commutative_scalar, non_commutative, non_commutative_scalar, + non_commutative_scalar_swapped, }; +use polars_error::{PolarsError, PolarsResult}; mod add; -mod sub; -mod mul; mod commutative; +mod div; +mod mul; +mod sub; pub use add::*; -pub use sub::*; +pub use div::*; pub use mul::*; +pub use sub::*; /// Maximum value that can exist with a selected precision #[inline] @@ -20,20 +23,6 @@ fn max_value(precision: usize) -> i128 { 10i128.pow(precision as u32) - 1 } -// Calculates the number of digits in a i128 number -fn number_digits(num: i128) -> usize { - let mut num = num.abs(); - let mut digit: i128 = 0; - let base = 10i128; - - while num != 0 { - num /= base; - digit += 1; - } - - digit as usize -} - fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize)> { if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = (lhs.to_logical_type(), rhs.to_logical_type()) @@ -49,38 +38,3 @@ fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize) unreachable!() } } - -/// Returns the adjusted precision and scale for the lhs and rhs precision and -/// scale -fn adjusted_precision_scale( - lhs_p: usize, - lhs_s: usize, - rhs_p: usize, - rhs_s: usize, -) -> (usize, usize, usize) { - // The initial new precision and scale is based on the number of digits - // that lhs and rhs number has before and after the point. The max - // number of digits before and after the point will make the last - // precision and scale of the result - - // Digits before/after point - // before after - // 11.1111 -> 5, 4 -> 2 4 - // 11111.01 -> 7, 2 -> 5 2 - // ----------------- - // 11122.1211 -> 9, 4 -> 5 4 - let lhs_digits_before = lhs_p - lhs_s; - let rhs_digits_before = rhs_p - rhs_s; - - let res_digits_before = std::cmp::max(lhs_digits_before, rhs_digits_before); - - let (res_s, diff) = if lhs_s > rhs_s { - (lhs_s, lhs_s - rhs_s) - } else { - (rhs_s, rhs_s - lhs_s) - }; - - let res_p = res_digits_before + res_s; - - (res_p, res_s, diff) -} diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs index f8db20097df7..e8e22e73e2ac 100644 --- a/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs @@ -1,10 +1,25 @@ use super::*; +#[inline] +fn decimal_mul(a: i128, b: i128, scale: i128) -> i128 { + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + a * b / scale +} + pub fn mul( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> PolarsResult> { - commutative(lhs, rhs, |a, b| a * b) + let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?; + let scale = 10i128.pow(scale as u32); + commutative(lhs, rhs, |a, b| decimal_mul(a, b, scale)) } pub fn mul_scalar( @@ -12,5 +27,7 @@ pub fn mul_scalar( rhs: i128, rhs_dtype: &DataType, ) -> PolarsResult> { - commutative_scalar(lhs, rhs, rhs_dtype, |a, b| a * b) + let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?; + let scale = 10i128.pow(scale as u32); + commutative_scalar(lhs, rhs, rhs_dtype, |a, b| decimal_mul(a, b, scale)) } diff --git a/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs b/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs index 4e141a25a086..da67a8593bde 100644 --- a/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs +++ b/polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs @@ -4,16 +4,16 @@ pub fn sub( lhs: &PrimitiveArray, rhs: &PrimitiveArray, ) -> PolarsResult> { - non_commutative(lhs, rhs, |a, b | a - b) + non_commutative(lhs, rhs, |a, b| a - b) } pub fn sub_scalar(lhs: &PrimitiveArray, rhs: i128) -> PolarsResult> { - non_commutative_scalar(lhs, rhs, |a, b | a - b) + non_commutative_scalar(lhs, rhs, |a, b| a - b) } pub fn sub_scalar_swapped( lhs: i128, rhs: &PrimitiveArray, ) -> PolarsResult> { - non_commutative_scalar_swapped(lhs, rhs, |a, b | a - b) + non_commutative_scalar_swapped(lhs, rhs, |a, b| a - b) } diff --git a/polars/polars-core/src/chunked_array/arithmetic/decimal.rs b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs index 93dee0ca5792..f9e2206ecb81 100644 --- a/polars/polars-core/src/chunked_array/arithmetic/decimal.rs +++ b/polars/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -134,3 +134,16 @@ impl Mul for &DecimalChunked { ) } } + +impl Div for &DecimalChunked { + type Output = PolarsResult; + + fn div(self, rhs: Self) -> Self::Output { + self.arithmetic_helper( + rhs, + decimal::div, + |lhs, rhs_val| decimal::div_scalar(lhs, rhs_val, &rhs.dtype().to_arrow()), + |lhs_val, rhs| decimal::div_scalar_swapped(lhs_val, &self.dtype().to_arrow(), rhs), + ) + } +} diff --git a/polars/polars-core/src/series/implementations/decimal.rs b/polars/polars-core/src/series/implementations/decimal.rs index 125bfee5a285..8f5f9bdde833 100644 --- a/polars/polars-core/src/series/implementations/decimal.rs +++ b/polars/polars-core/src/series/implementations/decimal.rs @@ -37,6 +37,22 @@ impl private::PrivateSeries for SeriesWrap { .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) } + fn subtract(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) - rhs).map(|ca| ca.into_series()) + } + fn add_to(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) + rhs).map(|ca| ca.into_series()) + } + fn multiply(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) * rhs).map(|ca| ca.into_series()) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + let rhs = rhs.decimal()?; + ((&self.0) / rhs).map(|ca| ca.into_series()) + } } impl SeriesTrait for SeriesWrap { diff --git a/polars/polars-lazy/src/physical_plan/expressions/binary.rs b/polars/polars-lazy/src/physical_plan/expressions/binary.rs index e04140296c2a..06c39265cb30 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/binary.rs @@ -56,6 +56,8 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu Operator::Multiply => Ok(left * right), Operator::Divide => Ok(left / right), Operator::TrueDivide => match left.dtype() { + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => Ok(left / right), Date | Datetime(_, _) | Float32 | Float64 => Ok(left / right), _ => Ok(&left.cast(&Float64)? / &right.cast(&Float64)?), }, diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 48bea4f6c436..0de98fe4be26 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -134,3 +134,32 @@ def test_read_csv_decimal(monkeypatch: Any) -> None: D("1.10000000000000000000"), D("0.01000000000000000000"), ] + + +def test_decimal_arithmetic() -> None: + df = pl.DataFrame( + { + "a": [D("0.1"), D("10.1"), D("100.01")], + "b": [D("20.1"), D("10.19"), D("39.21")], + } + ) + + out = df.select( + out1=pl.col("a") * pl.col("b"), + out2=pl.col("a") + pl.col("b"), + out3=pl.col("a") / pl.col("b"), + out4=pl.col("a") - pl.col("b"), + ) + assert out.dtypes == [ + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=2), + pl.Decimal(precision=None, scale=2), + ] + + assert out.to_dict(False) == { + "out1": [D("2.01"), D("102.91"), D("3921.39")], + "out2": [D("20.20"), D("20.29"), D("139.22")], + "out3": [D("0.00"), D("0.99"), D("2.55")], + "out4": [D("-20.00"), D("-0.09"), D("60.80")], + }