diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index ca6f24bbe..e12aaf7e1 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -8,6 +8,7 @@ #[cfg(feature = "std")] use num_traits::Float; +use num_traits::One; use num_traits::{FromPrimitive, Zero}; use std::ops::{Add, Div, Mul}; @@ -253,6 +254,43 @@ where } } + /// Return product along `axis`. + /// + /// The product of an empty array is 1. + /// + /// ``` + /// use ndarray::{aview0, aview1, arr2, Axis}; + /// + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); + /// + /// assert!( + /// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) && + /// a.product_axis(Axis(1)) == aview1(&[6., 120.]) && + /// + /// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.) + /// ); + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn product_axis(&self, axis: Axis) -> Array + where + A: Clone + One + Mul, + D: RemoveAxis, + { + let min_stride_axis = self.dim.min_stride_axis(&self.strides); + if axis == min_stride_axis { + crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product()) + } else { + let mut res = Array::ones(self.raw_dim().remove_axis(axis)); + for subview in self.axis_iter(axis) { + res = res * &subview; + } + res + } + } + /// Return mean along `axis`. /// /// Return `None` if the length of the axis is zero. diff --git a/tests/numeric.rs b/tests/numeric.rs index 4d70d4502..f6de146c9 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -39,27 +39,36 @@ fn test_mean_with_array_of_floats() } #[test] -fn sum_mean() +fn sum_mean_prod() { let a: Array2 = arr2(&[[1., 2.], [3., 4.]]); assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.])); assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.])); + assert_eq!(a.product_axis(Axis(0)), arr1(&[3., 8.])); + assert_eq!(a.product_axis(Axis(1)), arr1(&[2., 12.])); assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.]))); assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5]))); assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.)); + assert_eq!(a.product_axis(Axis(1)).product_axis(Axis(0)), arr0(24.)); assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5])); assert_eq!(a.sum(), 10.); } #[test] -fn sum_mean_empty() +fn sum_mean_prod_empty() { assert_eq!(Array3::::ones((2, 0, 3)).sum(), 0.); + assert_eq!(Array3::::ones((2, 0, 3)).product(), 1.); assert_eq!(Array1::::ones(0).sum_axis(Axis(0)), arr0(0.)); + assert_eq!(Array1::::ones(0).product_axis(Axis(0)), arr0(1.)); assert_eq!( Array3::::ones((2, 0, 3)).sum_axis(Axis(1)), Array::zeros((2, 3)), ); + assert_eq!( + Array3::::ones((2, 0, 3)).product_axis(Axis(1)), + Array::ones((2, 3)), + ); let a = Array1::::ones(0).mean_axis(Axis(0)); assert_eq!(a, None); let a = Array3::::ones((2, 0, 3)).mean_axis(Axis(1));