Skip to content

Commit

Permalink
feat(rust, python): Decimal arithmetic (pola-rs#9123)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored and c-peters committed Jul 14, 2023
1 parent 2d06f37 commit 12811c8
Show file tree
Hide file tree
Showing 17 changed files with 1,142 additions and 679 deletions.
16 changes: 16 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/decimal/add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use super::*;

pub fn add(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> PolarsResult<PrimitiveArray<i128>> {
commutative(lhs, rhs, |a, b| a + b)
}

pub fn add_scalar(
lhs: &PrimitiveArray<i128>,
rhs: i128,
rhs_dtype: &DataType,
) -> PolarsResult<PrimitiveArray<i128>> {
commutative_scalar(lhs, rhs, rhs_dtype, |a, b| a + b)
}
89 changes: 89 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/decimal/commutative.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use arrow::array::PrimitiveArray;
use arrow::datatypes::DataType;
use polars_error::*;

use super::{get_parameters, max_value};
use crate::compute::{binary_mut, unary_mut};

pub fn commutative<F>(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
op: F,
) -> PolarsResult<PrimitiveArray<i128>>
where
F: Fn(i128, i128) -> i128,
{
let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();

let max = max_value(precision);
let mut overflow = false;
let op = |a, b| {
let res = op(a, b);
overflow |= res.abs() > max;
res
};
let out = binary_mut(lhs, rhs, lhs.data_type().clone(), op);
polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}");
Ok(out)
}

pub fn commutative_scalar<F>(
lhs: &PrimitiveArray<i128>,
rhs: i128,
rhs_dtype: &DataType,
op: F,
) -> PolarsResult<PrimitiveArray<i128>>
where
F: Fn(i128, i128) -> i128,
{
let (precision, _) = get_parameters(lhs.data_type(), rhs_dtype).unwrap();

let max = max_value(precision);
let mut overflow = false;
let op = |a| {
let res = op(a, rhs);
overflow |= res.abs() > max;
res
};
let out = unary_mut(lhs, op, lhs.data_type().clone());
polars_ensure!(!overflow, ComputeError: "Decimal overflowed the allowed precision: {precision}");

Ok(out)
}

pub fn non_commutative<F>(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
op: F,
) -> PolarsResult<PrimitiveArray<i128>>
where
F: Fn(i128, i128) -> i128,
{
Ok(binary_mut(lhs, rhs, lhs.data_type().clone(), op))
}

pub fn non_commutative_scalar<F>(
lhs: &PrimitiveArray<i128>,
rhs: i128,
op: F,
) -> PolarsResult<PrimitiveArray<i128>>
where
F: Fn(i128, i128) -> i128,
{
let op = move |a| op(a, rhs);

Ok(unary_mut(lhs, op, lhs.data_type().clone()))
}

pub fn non_commutative_scalar_swapped<F>(
lhs: i128,
rhs: &PrimitiveArray<i128>,
op: F,
) -> PolarsResult<PrimitiveArray<i128>>
where
F: Fn(i128, i128) -> i128,
{
let op = move |a| op(lhs, a);

Ok(unary_mut(rhs, op, rhs.data_type().clone()))
}
43 changes: 43 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/decimal/div.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use super::*;

#[inline]
fn decimal_div(a: i128, b: i128, scale: i128) -> i128 {
// The division is done using the numbers without scale.
// The dividend is scaled up to maintain precision after the
// division

// 222.222 --> 222222000
// 123.456 --> 123456
// -------- ---------
// 1.800 <-- 1800
a * scale / b
}

pub fn div(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> PolarsResult<PrimitiveArray<i128>> {
let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?;
let scale = 10i128.pow(scale as u32);
non_commutative(lhs, rhs, |a, b| decimal_div(a, b, scale))
}

pub fn div_scalar(
lhs: &PrimitiveArray<i128>,
rhs: i128,
rhs_dtype: &DataType,
) -> PolarsResult<PrimitiveArray<i128>> {
let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?;
let scale = 10i128.pow(scale as u32);
non_commutative_scalar(lhs, rhs, |a, b| decimal_div(a, b, scale))
}

pub fn div_scalar_swapped(
lhs: i128,
lhs_dtype: &DataType,
rhs: &PrimitiveArray<i128>,
) -> PolarsResult<PrimitiveArray<i128>> {
let (_, scale) = get_parameters(lhs_dtype, rhs.data_type())?;
let scale = 10i128.pow(scale as u32);
non_commutative_scalar_swapped(lhs, rhs, |a, b| decimal_div(a, b, scale))
}
40 changes: 40 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/decimal/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use arrow::array::PrimitiveArray;
use arrow::datatypes::DataType;
use commutative::{
commutative, commutative_scalar, non_commutative, non_commutative_scalar,
non_commutative_scalar_swapped,
};
use polars_error::{PolarsError, PolarsResult};

mod add;
mod commutative;
mod div;
mod mul;
mod sub;

pub use add::*;
pub use div::*;
pub use mul::*;
pub use sub::*;

/// Maximum value that can exist with a selected precision
#[inline]
fn max_value(precision: usize) -> i128 {
10i128.pow(precision as u32) - 1
}

fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize)> {
if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.to_logical_type(), rhs.to_logical_type())
{
if lhs_p == rhs_p && lhs_s == rhs_s {
Ok((*lhs_p, *lhs_s))
} else {
Err(PolarsError::InvalidOperation(
"Arrays must have the same precision and scale".into(),
))
}
} else {
unreachable!()
}
}
33 changes: 33 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/decimal/mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use super::*;

#[inline]
fn decimal_mul(a: i128, b: i128, scale: i128) -> i128 {
// The multiplication is done using the numbers without scale.
// The resulting scale of the value has to be corrected by
// dividing by (10^scale)

// 111.111 --> 111111
// 222.222 --> 222222
// -------- -------
// 24691.308 <-- 24691308642
a * b / scale
}

pub fn mul(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> PolarsResult<PrimitiveArray<i128>> {
let (_, scale) = get_parameters(lhs.data_type(), rhs.data_type())?;
let scale = 10i128.pow(scale as u32);
commutative(lhs, rhs, |a, b| decimal_mul(a, b, scale))
}

pub fn mul_scalar(
lhs: &PrimitiveArray<i128>,
rhs: i128,
rhs_dtype: &DataType,
) -> PolarsResult<PrimitiveArray<i128>> {
let (_, scale) = get_parameters(lhs.data_type(), rhs_dtype)?;
let scale = 10i128.pow(scale as u32);
commutative_scalar(lhs, rhs, rhs_dtype, |a, b| decimal_mul(a, b, scale))
}
19 changes: 19 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/decimal/sub.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use super::*;

pub fn sub(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> PolarsResult<PrimitiveArray<i128>> {
non_commutative(lhs, rhs, |a, b| a - b)
}

pub fn sub_scalar(lhs: &PrimitiveArray<i128>, rhs: i128) -> PolarsResult<PrimitiveArray<i128>> {
non_commutative_scalar(lhs, rhs, |a, b| a - b)
}

pub fn sub_scalar_swapped(
lhs: i128,
rhs: &PrimitiveArray<i128>,
) -> PolarsResult<PrimitiveArray<i128>> {
non_commutative_scalar_swapped(lhs, rhs, |a, b| a - b)
}
2 changes: 2 additions & 0 deletions polars/polars-arrow/src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#[cfg(feature = "dtype-decimal")]
pub mod decimal;
1 change: 1 addition & 0 deletions polars/polars-arrow/src/compute/arity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

50 changes: 50 additions & 0 deletions polars/polars-arrow/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,57 @@
use arrow::array::PrimitiveArray;
use arrow::datatypes::DataType;
use arrow::types::NativeType;

use crate::utils::combine_validities_and;

pub mod arithmetics;
pub mod arity;
pub mod bitwise;
#[cfg(feature = "compute")]
pub mod cast;
#[cfg(feature = "dtype-decimal")]
pub mod decimal;
pub mod take;
pub mod tile;

#[inline]
pub fn binary_mut<T, D, F>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<D>,
data_type: DataType,
mut op: F,
) -> PrimitiveArray<T>
where
T: NativeType,
D: NativeType,
F: FnMut(T, D) -> T,
{
assert_eq!(lhs.len(), rhs.len());
let validity = combine_validities_and(lhs.validity(), rhs.validity());

let values = lhs
.values()
.iter()
.zip(rhs.values().iter())
.map(|(l, r)| op(*l, *r))
.collect::<Vec<_>>()
.into();

PrimitiveArray::<T>::new(data_type, values, validity)
}

#[inline]
pub fn unary_mut<I, F, O>(
array: &PrimitiveArray<I>,
mut op: F,
data_type: DataType,
) -> PrimitiveArray<O>
where
I: NativeType,
O: NativeType,
F: FnMut(I) -> O,
{
let values = array.values().iter().map(|v| op(*v)).collect::<Vec<_>>();

PrimitiveArray::<O>::new(data_type, values.into(), array.validity().cloned())
}
Loading

0 comments on commit 12811c8

Please sign in to comment.