Skip to content

Commit

Permalink
Add argmin/max_skipnan and indexed_fold_skipnan (#33)
Browse files Browse the repository at this point in the history
* Implement argmin_skipnan

* Implement argmax_skipnan

* Loosen the rule for argmin max related methods

* Make returning code clearer

* Add quickcheck for argmin_skipnan, argmax_skipnan

* Use `fold` instead of `for`

* Add indexed_fold_skipnan to MaybeNanExt

* Impl argmin/max_skipnan using indexed_fold_skipnan

* Fix argmin/max_skipnan quickcheck tests

The old tests were incorrect because `min`/`max` return `None` when
there are *any* NaN values (or the array is empty), while
`argmin/max_skipnan` should return `None` only when *all* the values
are NaNs (or the array is empty).

This wasn't caught earlier because the `quickcheck::Arbitrary`
implementation for `f32` generates only finite values. To make sure
the behavior with NaN values is properly tested, the element type in
the test has been changed to `Option<i32>`.

* Replace min/max.map with if for clarity

* Add () to make the match clearer
  • Loading branch information
phungleson authored and jturner314 committed Mar 25, 2019
1 parent 7df0728 commit d838ee7
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/maybe_nan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ where
A: 'a,
F: FnMut(B, &'a A::NotNan) -> B;

/// Traverse the non-NaN elements and their indices and apply a fold,
/// returning the resulting value.
///
/// Elements are visited in arbitrary order.
fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, f: F) -> B
where
A: 'a,
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B;

/// Visit each non-NaN element in the array by calling `f` on each element.
///
/// Elements are visited in arbitrary order.
Expand Down Expand Up @@ -302,6 +311,20 @@ where
})
}

fn indexed_fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B
where
A: 'a,
F: FnMut(B, (D::Pattern, &'a A::NotNan)) -> B,
{
self.indexed_iter().fold(init, |acc, (idx, elem)| {
if let Some(not_nan) = elem.try_as_not_nan() {
f(acc, (idx, not_nan))
} else {
acc
}
})
}

fn visit_skipnan<'a, F>(&'a self, mut f: F)
where
A: 'a,
Expand Down
98 changes: 98 additions & 0 deletions src/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,33 @@ where
where
A: PartialOrd;

/// Finds the index of the minimum value of the array skipping NaN values.
///
/// Returns `None` if the array is empty or none of the values in the array
/// are non-NaN values.
///
/// Even if there are multiple (equal) elements that are minima, only one
/// index is returned. (Which one is returned is unspecified and may depend
/// on the memory layout of the array.)
///
/// # Example
///
/// ```
/// extern crate ndarray;
/// extern crate ndarray_stats;
///
/// use ndarray::array;
/// use ndarray_stats::QuantileExt;
///
/// let a = array![[::std::f64::NAN, 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmin_skipnan(), Some((1, 1)));
/// ```
fn argmin_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord;

/// Finds the elementwise minimum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
Expand Down Expand Up @@ -269,6 +296,33 @@ where
where
A: PartialOrd;

/// Finds the index of the maximum value of the array skipping NaN values.
///
/// Returns `None` if the array is empty or none of the values in the array
/// are non-NaN values.
///
/// Even if there are multiple (equal) elements that are maxima, only one
/// index is returned. (Which one is returned is unspecified and may depend
/// on the memory layout of the array.)
///
/// # Example
///
/// ```
/// extern crate ndarray;
/// extern crate ndarray_stats;
///
/// use ndarray::array;
/// use ndarray_stats::QuantileExt;
///
/// let a = array![[::std::f64::NAN, 3., 5.],
/// [2., 0., 6.]];
/// assert_eq!(a.argmax_skipnan(), Some((1, 2)));
/// ```
fn argmax_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord;

/// Finds the elementwise maximum of the array.
///
/// Returns `None` if any of the pairwise orderings tested by the function
Expand Down Expand Up @@ -369,6 +423,28 @@ where
Some(current_pattern_min)
}

fn argmin_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord,
{
let mut pattern_min = D::zeros(self.ndim()).into_pattern();
let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| {
Some(match current_min {
Some(m) if (m <= elem) => m,
_ => {
pattern_min = pattern;
elem
}
})
});
if min.is_some() {
Some(pattern_min)
} else {
None
}
}

fn min(&self) -> Option<&A>
where
A: PartialOrd,
Expand Down Expand Up @@ -411,6 +487,28 @@ where
Some(current_pattern_max)
}

fn argmax_skipnan(&self) -> Option<D::Pattern>
where
A: MaybeNan,
A::NotNan: Ord,
{
let mut pattern_max = D::zeros(self.ndim()).into_pattern();
let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| {
Some(match current_max {
Some(m) if m >= elem => m,
_ => {
pattern_max = pattern;
elem
}
})
});
if max.is_some() {
Some(pattern_max)
} else {
None
}
}

fn max(&self) -> Option<&A>
where
A: PartialOrd,
Expand Down
65 changes: 65 additions & 0 deletions tests/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,37 @@ quickcheck! {
}
}

#[test]
fn test_argmin_skipnan() {
let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmin_skipnan(), Some((1, 1)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin_skipnan(), Some((0, 0)));

let a = array![[::std::f64::NAN, 5., 3.], [2., ::std::f64::NAN, 6.]];
assert_eq!(a.argmin_skipnan(), Some((1, 0)));

let a: Array2<f64> = array![[], []];
assert_eq!(a.argmin_skipnan(), None);

let a = arr2(&[[::std::f64::NAN; 2]; 2]);
assert_eq!(a.argmin_skipnan(), None);
}

quickcheck! {
fn argmin_skipnan_matches_min_skipnan(data: Vec<Option<i32>>) -> bool {
let a = Array1::from(data);
let min = a.min_skipnan();
let argmin = a.argmin_skipnan();
if min.is_none() {
argmin == None
} else {
a[argmin.unwrap()] == *min
}
}
}

#[test]
fn test_min() {
let a = array![[1, 5, 3], [2, 0, 6]];
Expand Down Expand Up @@ -81,6 +112,40 @@ quickcheck! {
}
}

#[test]
fn test_argmax_skipnan() {
let a = array![[1., 5., 3.], [2., 0., 6.]];
assert_eq!(a.argmax_skipnan(), Some((1, 2)));

let a = array![[1., 5., 3.], [2., ::std::f64::NAN, ::std::f64::NAN]];
assert_eq!(a.argmax_skipnan(), Some((0, 1)));

let a = array![
[::std::f64::NAN, ::std::f64::NAN, 3.],
[2., ::std::f64::NAN, 6.]
];
assert_eq!(a.argmax_skipnan(), Some((1, 2)));

let a: Array2<f64> = array![[], []];
assert_eq!(a.argmax_skipnan(), None);

let a = arr2(&[[::std::f64::NAN; 2]; 2]);
assert_eq!(a.argmax_skipnan(), None);
}

quickcheck! {
fn argmax_skipnan_matches_max_skipnan(data: Vec<Option<i32>>) -> bool {
let a = Array1::from(data);
let max = a.max_skipnan();
let argmax = a.argmax_skipnan();
if max.is_none() {
argmax == None
} else {
a[argmax.unwrap()] == *max
}
}
}

#[test]
fn test_max() {
let a = array![[1, 5, 7], [2, 0, 6]];
Expand Down

0 comments on commit d838ee7

Please sign in to comment.