From eae40333bd122dc5d74566f142f3351e1e618ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benno=20F=C3=BCnfst=C3=BCck?= Date: Tue, 19 May 2020 15:44:59 +0200 Subject: [PATCH] Add Sum/Product impls for NotNan --- src/lib.rs | 26 ++++++++++++++++++++++++++ tests/test.rs | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 444a99c..41812a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ use core::hash::{Hash, Hasher}; use core::fmt; use core::mem; use core::hint::unreachable_unchecked; +use core::iter::{Sum, Product}; use core::str::FromStr; use num_traits::{Bounded, FromPrimitive, Num, NumCast, One, Signed, ToPrimitive, Zero}; @@ -322,6 +323,19 @@ impl AddAssign for NotNan { } } + +impl Sum for NotNan { + fn sum>>(iter: I) -> Self { + NotNan::new(iter.map(|v| v.0).sum()).expect("Sum resulted in NaN") + } +} + +impl<'a, T: Float + Sum> Sum<&'a NotNan> for NotNan { + fn sum>>(iter: I) -> Self { + iter.map(|v| *v).sum() + } +} + impl Sub for NotNan { type Output = Self; @@ -392,6 +406,18 @@ impl MulAssign for NotNan { } } +impl Product for NotNan { + fn product>>(iter: I) -> Self { + NotNan::new(iter.map(|v| v.0).product()).expect("Product resulted in NaN") + } +} + +impl<'a, T: Float + Product> Product<&'a NotNan> for NotNan { + fn product>>(iter: I) -> Self { + iter.map(|v| *v).product() + } +} + impl Div for NotNan { type Output = Self; diff --git a/tests/test.rs b/tests/test.rs index 1c3c385..9c36594 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -545,3 +545,37 @@ fn ordered_f32_neg() { fn ordered_f64_neg() { assert_eq!(OrderedFloat(-7.0f64), -OrderedFloat(7.0f64)); } + +#[test] +#[should_panic] +fn test_sum_fails_on_nan() { + let a = NotNan::new(std::f32::INFINITY).unwrap(); + let b = NotNan::new(std::f32::NEG_INFINITY).unwrap(); + let _c: NotNan<_> = [a,b].iter().sum(); +} + +#[test] +#[should_panic] +fn test_product_fails_on_nan() { + let a = NotNan::new(std::f32::INFINITY).unwrap(); + let b = NotNan::new(0f32).unwrap(); + let _c: NotNan<_> = [a,b].iter().product(); +} + +#[test] +fn not_nan64_sum_product() { + let a = NotNan::new(2138.1237).unwrap(); + let b = NotNan::new(132f64).unwrap(); + let c = NotNan::new(5.1).unwrap(); + + assert_eq!(std::iter::empty::>().sum::>(), NotNan::new(0f64).unwrap()); + assert_eq!([a].iter().sum::>(), a); + assert_eq!([a,b].iter().sum::>(), a + b); + assert_eq!([a,b,c].iter().sum::>(), a + b + c); + + assert_eq!(std::iter::empty::>().product::>(), NotNan::new(1f64).unwrap()); + assert_eq!([a].iter().product::>(), a); + assert_eq!([a,b].iter().product::>(), a * b); + assert_eq!([a,b,c].iter().product::>(), a * b * c); + +}