Skip to content

Commit

Permalink
Implements and tests product_axis.
Browse files Browse the repository at this point in the history
  • Loading branch information
akern40 authored and adamreichold committed May 19, 2024
1 parent bde682a commit 17a628e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
38 changes: 38 additions & 0 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#[cfg(feature = "std")]
use num_traits::Float;
use num_traits::One;
use num_traits::{FromPrimitive, Zero};
use std::ops::{Add, Div, Mul};

Expand Down Expand Up @@ -253,6 +254,43 @@ where
}
}

/// Return product along `axis`.
///
/// The product of an empty array is 1.
///
/// ```
/// use ndarray::{aview0, aview1, arr2, Axis};
///
/// let a = arr2(&[[1., 2., 3.],
/// [4., 5., 6.]]);
///
/// assert!(
/// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
/// a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
///
/// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
/// );
/// ```
///
/// **Panics** if `axis` is out of bounds.
#[track_caller]
pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
where
A: Clone + One + Mul<Output = A>,
D: RemoveAxis,
{
let min_stride_axis = self.dim.min_stride_axis(&self.strides);
if axis == min_stride_axis {
crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
} else {
let mut res = Array::ones(self.raw_dim().remove_axis(axis));
for subview in self.axis_iter(axis) {
res = res * &subview;
}
res
}
}

/// Return mean along `axis`.
///
/// Return `None` if the length of the axis is zero.
Expand Down
13 changes: 11 additions & 2 deletions tests/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,36 @@ fn test_mean_with_array_of_floats()
}

#[test]
fn sum_mean()
fn sum_mean_prod()
{
let a: Array2<f64> = arr2(&[[1., 2.], [3., 4.]]);
assert_eq!(a.sum_axis(Axis(0)), arr1(&[4., 6.]));
assert_eq!(a.sum_axis(Axis(1)), arr1(&[3., 7.]));
assert_eq!(a.product_axis(Axis(0)), arr1(&[3., 8.]));
assert_eq!(a.product_axis(Axis(1)), arr1(&[2., 12.]));
assert_eq!(a.mean_axis(Axis(0)), Some(arr1(&[2., 3.])));
assert_eq!(a.mean_axis(Axis(1)), Some(arr1(&[1.5, 3.5])));
assert_eq!(a.sum_axis(Axis(1)).sum_axis(Axis(0)), arr0(10.));
assert_eq!(a.product_axis(Axis(1)).product_axis(Axis(0)), arr0(24.));
assert_eq!(a.view().mean_axis(Axis(1)).unwrap(), aview1(&[1.5, 3.5]));
assert_eq!(a.sum(), 10.);
}

#[test]
fn sum_mean_empty()
fn sum_mean_prod_empty()
{
assert_eq!(Array3::<f32>::ones((2, 0, 3)).sum(), 0.);
assert_eq!(Array3::<f32>::ones((2, 0, 3)).product(), 1.);
assert_eq!(Array1::<f32>::ones(0).sum_axis(Axis(0)), arr0(0.));
assert_eq!(Array1::<f32>::ones(0).product_axis(Axis(0)), arr0(1.));
assert_eq!(
Array3::<f32>::ones((2, 0, 3)).sum_axis(Axis(1)),
Array::zeros((2, 3)),
);
assert_eq!(
Array3::<f32>::ones((2, 0, 3)).product_axis(Axis(1)),
Array::ones((2, 3)),
);
let a = Array1::<f32>::ones(0).mean_axis(Axis(0));
assert_eq!(a, None);
let a = Array3::<f32>::ones((2, 0, 3)).mean_axis(Axis(1));
Expand Down

0 comments on commit 17a628e

Please sign in to comment.