From deeb242f3832e15833bf21791a53ecf79234553f Mon Sep 17 00:00:00 2001 From: Paul Mason Date: Sun, 16 Jul 2023 17:05:17 -0700 Subject: [PATCH] Fixes issue with truncating implicitly rounding in some cases (#600) --- src/ops/array.rs | 30 ++++----------- tests/decimal_tests.rs | 86 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 38 deletions(-) diff --git a/src/ops/array.rs b/src/ops/array.rs index 2840e4b..ed3fce6 100644 --- a/src/ops/array.rs +++ b/src/ops/array.rs @@ -2,8 +2,13 @@ use crate::constants::{MAX_PRECISION_U32, POWERS_10, U32_MASK}; /// Rescales the given decimal to new scale. /// e.g. with 1.23 and new scale 3 rescale the value to 1.230 -#[inline(always)] +#[inline] pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_scale: u32) { + rescale::(value, value_scale, new_scale); +} + +#[inline(always)] +fn rescale(value: &mut [u32; 3], value_scale: &mut u32, new_scale: u32) { if *value_scale == new_scale { // Nothing to do return; @@ -32,7 +37,7 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_ // Any remainder is discarded if diff > 0 still (i.e. lost precision) remainder = div_by_u32(value, 10); } - if remainder >= 5 { + if ROUND && remainder >= 5 { for part in value.iter_mut() { let digit = u64::from(*part) + 1u64; remainder = if digit > U32_MASK { 1 } else { 0 }; @@ -60,26 +65,7 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_ #[inline] pub(crate) fn truncate_internal(value: &mut [u32; 3], value_scale: &mut u32, desired_scale: u32) { - if *value_scale <= desired_scale { - // Nothing to do, we're already at the desired scale (or less) - return; - } - if is_all_zero(value) { - *value_scale = desired_scale; - return; - } - while *value_scale > desired_scale { - // We're removing precision, so we don't care about handling the remainder - if *value_scale < 10 { - let adjustment = *value_scale - desired_scale; - div_by_u32(value, POWERS_10[adjustment as usize]); - *value_scale = desired_scale; - } else { - div_by_u32(value, POWERS_10[9]); - // Only 9 as this array starts with 1 - *value_scale -= 9; - } - } + rescale::(value, value_scale, desired_scale); } #[cfg(feature = "legacy-ops")] diff --git a/tests/decimal_tests.rs b/tests/decimal_tests.rs index c106ea8..1391cdb 100644 --- a/tests/decimal_tests.rs +++ b/tests/decimal_tests.rs @@ -2674,24 +2674,78 @@ fn it_can_trunc() { #[test] fn it_can_trunc_with_scale() { let cmp = Decimal::from_str("1.2345").unwrap(); - assert_eq!(Decimal::from_str("1.23450").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.234500001").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.23451").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.23454").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.23455").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.23456").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.23459").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("1.234599999").unwrap().trunc_with_scale(4), cmp); + let tests = [ + "1.23450", + "1.234500001", + "1.23451", + "1.23454", + "1.23455", + "1.23456", + "1.23459", + "1.234599999", + ]; + for test in tests { + assert_eq!( + Decimal::from_str(test).unwrap().trunc_with_scale(4), + cmp, + "Original: {}", + test + ); + } let cmp = Decimal::from_str("-1.2345").unwrap(); - assert_eq!(Decimal::from_str("-1.23450").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.234500001").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.23451").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.23454").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.23455").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.23456").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.23459").unwrap().trunc_with_scale(4), cmp); - assert_eq!(Decimal::from_str("-1.234599999").unwrap().trunc_with_scale(4), cmp); + let tests = [ + "-1.23450", + "-1.234500001", + "-1.23451", + "-1.23454", + "-1.23455", + "-1.23456", + "-1.23459", + "-1.234599999", + ]; + for test in tests { + assert_eq!( + Decimal::from_str(test).unwrap().trunc_with_scale(4), + cmp, + "Original: {}", + test + ); + } + + // Complex cases + let cmp = Decimal::from_str("0.5156").unwrap(); + let tests = [ + "0.51560089", + "0.515600893", + "0.5156008936", + "0.51560089369", + "0.515600893691", + "0.5156008936910", + "0.51560089369101", + "0.515600893691016", + "0.5156008936910161", + "0.51560089369101613", + "0.515600893691016134", + "0.5156008936910161349", + "0.51560089369101613494", + "0.515600893691016134941", + "0.5156008936910161349411", + "0.51560089369101613494115", + "0.515600893691016134941151", + "0.5156008936910161349411515", + "0.51560089369101613494115158", + "0.515600893691016134941151581", + "0.5156008936910161349411515818", + ]; + for test in tests { + assert_eq!( + Decimal::from_str(test).unwrap().trunc_with_scale(4), + cmp, + "Original: {}", + test + ); + } } #[test]