From f306b5f2812652a2763ae86520cce18328454366 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 18:11:37 -0500 Subject: [PATCH 1/7] Fix min_stride_axis to prefer axes with length > 1 --- src/dimension/dimension_trait.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index f8357f069..ba42ead07 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -291,8 +291,8 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default + indices } - /// Compute the minimum stride axis (absolute value), under the constraint - /// that the length of the axis is > 1; + /// Compute the minimum stride axis (absolute value), preferring axes with + /// length > 1. #[doc(hidden)] fn min_stride_axis(&self, strides: &Self) -> Axis { let n = match self.ndim() { @@ -301,7 +301,7 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default + n => n, }; axes_of(self, strides) - .rev() + .filter(|ax| ax.len() > 1) .min_by_key(|ax| ax.stride().abs()) .map_or(Axis(n - 1), |ax| ax.axis()) } @@ -588,9 +588,9 @@ impl Dimension for Dim<[Ix; 2]> { #[inline] fn min_stride_axis(&self, strides: &Self) -> Axis { - let s = get!(strides, 0) as Ixs; - let t = get!(strides, 1) as Ixs; - if s.abs() < t.abs() { + let s = (get!(strides, 0) as isize).abs(); + let t = (get!(strides, 1) as isize).abs(); + if s < t && get!(self, 0) > 1 { Axis(0) } else { Axis(1) From b7951df25e396e3d8cfd9b1b674331314c937c45 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 18:34:13 -0500 Subject: [PATCH 2/7] Specialize min_stride_axis for Ix3 --- src/dimension/dimension_trait.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index ba42ead07..038e5473c 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -697,6 +697,23 @@ impl Dimension for Dim<[Ix; 3]> { Some(Ix3(i, j, k)) } + #[inline] + fn min_stride_axis(&self, strides: &Self) -> Axis { + let s = (get!(strides, 0) as isize).abs(); + let t = (get!(strides, 1) as isize).abs(); + let u = (get!(strides, 2) as isize).abs(); + let (argmin, min) = if t < u && get!(self, 1) > 1 { + (Axis(1), t) + } else { + (Axis(2), u) + }; + if s < min && get!(self, 0) > 1 { + Axis(0) + } else { + argmin + } + } + /// Self is an index, return the stride offset #[inline] fn stride_offset(index: &Self, strides: &Self) -> isize { From 3326de4b8c80aaa4cbccc9068ede3dcc92d6036d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 18:09:09 -0500 Subject: [PATCH 3/7] Enable min_stride_axis as pub(crate) method --- src/impl_methods.rs | 46 +++++++++++++++++++++++++++++++++++++++++---- tests/dimension.rs | 31 ------------------------------ 2 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index de74af795..02e81736a 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1619,12 +1619,11 @@ where axes_of(&self.dim, &self.strides) } - /* - /// Return the axis with the least stride (by absolute value) - pub fn min_stride_axis(&self) -> Axis { + /// Return the axis with the least stride (by absolute value), + /// preferring axes with len > 1. + pub(crate) fn min_stride_axis(&self) -> Axis { self.dim.min_stride_axis(&self.strides) } - */ /// Return the axis with the greatest stride (by absolute value), /// preferring axes with len > 1. @@ -2103,3 +2102,42 @@ where }) } } + +#[cfg(test)] +mod tests { + use crate::prelude::*; + + #[test] + fn min_stride_axis() { + let a = Array1::::zeros(10); + assert_eq!(a.min_stride_axis(), Axis(0)); + + let a = Array2::::zeros((3, 3)); + assert_eq!(a.min_stride_axis(), Axis(1)); + assert_eq!(a.t().min_stride_axis(), Axis(0)); + + let a = ArrayD::::zeros(vec![3, 3]); + assert_eq!(a.min_stride_axis(), Axis(1)); + assert_eq!(a.t().min_stride_axis(), Axis(0)); + + let min_axis = a.axes().min_by_key(|t| t.2.abs()).unwrap().axis(); + assert_eq!(min_axis, Axis(1)); + + let mut b = ArrayD::::zeros(vec![2, 3, 4, 5]); + assert_eq!(b.min_stride_axis(), Axis(3)); + for ax in 0..3 { + b.swap_axes(3, ax); + assert_eq!(b.min_stride_axis(), Axis(ax)); + b.swap_axes(3, ax); + } + let mut v = b.view(); + v.collapse_axis(Axis(3), 0); + assert_eq!(v.min_stride_axis(), Axis(2)); + + let a = Array2::::zeros((3, 3)); + let v = a.broadcast((8, 3, 3)).unwrap(); + assert_eq!(v.min_stride_axis(), Axis(0)); + let v2 = a.broadcast((1, 3, 3)).unwrap(); + assert_eq!(v2.min_stride_axis(), Axis(2)); + } +} diff --git a/tests/dimension.rs b/tests/dimension.rs index c76b8d7ad..11d6b3e0f 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -132,37 +132,6 @@ fn fastest_varying_order() { type ArrayF32 = Array; -/* -#[test] -fn min_stride_axis() { - let a = ArrayF32::zeros(10); - assert_eq!(a.min_stride_axis(), Axis(0)); - - let a = ArrayF32::zeros((3, 3)); - assert_eq!(a.min_stride_axis(), Axis(1)); - assert_eq!(a.t().min_stride_axis(), Axis(0)); - - let a = ArrayF32::zeros(vec![3, 3]); - assert_eq!(a.min_stride_axis(), Axis(1)); - assert_eq!(a.t().min_stride_axis(), Axis(0)); - - let min_axis = a.axes().min_by_key(|t| t.2.abs()).unwrap().axis(); - assert_eq!(min_axis, Axis(1)); - - let mut b = ArrayF32::zeros(vec![2, 3, 4, 5]); - assert_eq!(b.min_stride_axis(), Axis(3)); - for ax in 0..3 { - b.swap_axes(3, ax); - assert_eq!(b.min_stride_axis(), Axis(ax)); - b.swap_axes(3, ax); - } - - let a = ArrayF32::zeros((3, 3)); - let v = a.broadcast((8, 3, 3)).unwrap(); - assert_eq!(v.min_stride_axis(), Axis(0)); -} -*/ - #[test] fn max_stride_axis() { let a = ArrayF32::zeros(10); From 65b6046b0d7e5e1060b81b04d564decde8dc5e66 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 18:20:12 -0500 Subject: [PATCH 4/7] Simplify fold to use min_stride_axis This causes no change in performance according to the relevant benchmarks in `bench1`. --- src/impl_methods.rs | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 02e81736a..2e2c632e9 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1853,25 +1853,11 @@ where } else { let mut v = self.view(); // put the narrowest axis at the last position - match v.ndim() { - 0 | 1 => {} - 2 => { - if self.len_of(Axis(1)) <= 1 - || self.len_of(Axis(0)) > 1 - && self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs() - { - v.swap_axes(0, 1); - } - } - n => { - let last = n - 1; - let narrow_axis = v - .axes() - .filter(|ax| ax.len() > 1) - .min_by_key(|ax| ax.stride().abs()) - .map_or(last, |ax| ax.axis().index()); - v.swap_axes(last, narrow_axis); - } + let n = v.ndim(); + if n > 1 { + let last = n - 1; + let narrow_axis = self.min_stride_axis(); + v.swap_axes(last, narrow_axis.index()); } v.into_elements_base().fold(init, f) } From b0b391aa878a2839a8952fe1d5c6bb8fff6ced2d Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 18:31:30 -0500 Subject: [PATCH 5/7] Improve performance of sum in certain cases This improves the performance of `sum` when an axis is contiguous but the array as a whole is not contiguous. --- src/numeric/impl_numeric.rs | 13 ++++++++++--- src/numeric_util.rs | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index b8c393372..cf2c0eafd 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -32,10 +32,17 @@ impl ArrayBase where A: Clone + Add + num_traits::Zero, { if let Some(slc) = self.as_slice_memory_order() { - numeric_util::pairwise_sum(&slc) - } else { - numeric_util::iterator_pairwise_sum(self.iter()) + return numeric_util::pairwise_sum(&slc); + } + if self.ndim() > 1 { + let ax = self.dim.min_stride_axis(&self.strides); + if self.len_of(ax) >= numeric_util::UNROLL_SIZE && self.stride_of(ax) == 1 { + let partial_sums: Vec<_> = + self.lanes(ax).into_iter().map(|lane| lane.sum()).collect(); + return numeric_util::pure_pairwise_sum(&partial_sums); + } } + numeric_util::iterator_pairwise_sum(self.iter()) } /// Return the sum of all elements in the array. diff --git a/src/numeric_util.rs b/src/numeric_util.rs index 2557f2c68..855a1b803 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -79,7 +79,7 @@ where /// An implementation of pairwise summation for a vector slice that never /// switches to the naive sum algorithm. -fn pure_pairwise_sum(v: &[A]) -> A +pub(crate) fn pure_pairwise_sum(v: &[A]) -> A where A: Clone + Add + Zero, { From 7f04e6fa56b945330278b3e9332b05611165d378 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 13:40:39 -0500 Subject: [PATCH 6/7] Update quickcheck and use quickcheck_macros --- Cargo.toml | 3 +- src/dimension/mod.rs | 157 +++++++++++++++++++++---------------------- src/lib.rs | 2 + src/numeric_util.rs | 10 ++- 4 files changed, 86 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bdab7df52..ce3b563c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,8 @@ serde = { version = "1.0", optional = true } [dev-dependencies] defmac = "0.2" -quickcheck = { version = "0.7.2", default-features = false } +quickcheck = { version = "0.8.1", default-features = false } +quickcheck_macros = "0.8" rawpointer = "0.1" rand = "0.5.5" diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index aebcfc662..719acc678 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -629,7 +629,8 @@ mod test { use crate::error::{from_kind, ErrorKind}; use crate::slice::Slice; use num_integer::gcd; - use quickcheck::{quickcheck, TestResult}; + use quickcheck::TestResult; + use quickcheck_macros::quickcheck; #[test] fn slice_indexing_uncommon_strides() { @@ -738,30 +739,29 @@ mod test { can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err(); } - quickcheck! { - fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec, dim: Vec) -> bool { - let dim = IxDyn(&dim); - let result = can_index_slice_not_custom(&data, &dim); - if dim.size_checked().is_none() { - // Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`. - result.is_err() - } else { - result == can_index_slice(&data, &dim, &dim.default_strides()) && - result == can_index_slice(&data, &dim, &dim.fortran_strides()) - } + #[quickcheck] + fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec, dim: Vec) -> bool { + let dim = IxDyn(&dim); + let result = can_index_slice_not_custom(&data, &dim); + if dim.size_checked().is_none() { + // Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`. + result.is_err() + } else { + result == can_index_slice(&data, &dim, &dim.default_strides()) && + result == can_index_slice(&data, &dim, &dim.fortran_strides()) } } - quickcheck! { - fn extended_gcd_solves_eq(a: isize, b: isize) -> bool { - let (g, (x, y)) = extended_gcd(a, b); - a * x + b * y == g - } + #[quickcheck] + fn extended_gcd_solves_eq(a: isize, b: isize) -> bool { + let (g, (x, y)) = extended_gcd(a, b); + a * x + b * y == g + } - fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool { - let (g, _) = extended_gcd(a, b); - g == gcd(a, b) - } + #[quickcheck] + fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool { + let (g, _) = extended_gcd(a, b); + g == gcd(a, b) } #[test] @@ -773,73 +773,72 @@ mod test { assert_eq!(extended_gcd(-5, 0), (5, (-1, 0))); } - quickcheck! { - fn solve_linear_diophantine_eq_solution_existence( - a: isize, b: isize, c: isize - ) -> TestResult { - if a == 0 || b == 0 { - TestResult::discard() - } else { - TestResult::from_bool( - (c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some() - ) - } + #[quickcheck] + fn solve_linear_diophantine_eq_solution_existence( + a: isize, b: isize, c: isize + ) -> TestResult { + if a == 0 || b == 0 { + TestResult::discard() + } else { + TestResult::from_bool( + (c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some() + ) } + } - fn solve_linear_diophantine_eq_correct_solution( - a: isize, b: isize, c: isize, t: isize - ) -> TestResult { - if a == 0 || b == 0 { - TestResult::discard() - } else { - match solve_linear_diophantine_eq(a, b, c) { - Some((x0, xd)) => { - let x = x0 + xd * t; - let y = (c - a * x) / b; - TestResult::from_bool(a * x + b * y == c) - } - None => TestResult::discard(), + #[quickcheck] + fn solve_linear_diophantine_eq_correct_solution( + a: isize, b: isize, c: isize, t: isize + ) -> TestResult { + if a == 0 || b == 0 { + TestResult::discard() + } else { + match solve_linear_diophantine_eq(a, b, c) { + Some((x0, xd)) => { + let x = x0 + xd * t; + let y = (c - a * x) / b; + TestResult::from_bool(a * x + b * y == c) } + None => TestResult::discard(), } } } - quickcheck! { - fn arith_seq_intersect_correct( - first1: isize, len1: isize, step1: isize, - first2: isize, len2: isize, step2: isize - ) -> TestResult { - use std::cmp; + #[quickcheck] + fn arith_seq_intersect_correct( + first1: isize, len1: isize, step1: isize, + first2: isize, len2: isize, step2: isize + ) -> TestResult { + use std::cmp; - if len1 == 0 || len2 == 0 { - // This case is impossible to reach in `arith_seq_intersect()` - // because the `min*` and `max*` arguments are inclusive. - return TestResult::discard(); - } - let len1 = len1.abs(); - let len2 = len2.abs(); - - // Convert to `min*` and `max*` arguments for `arith_seq_intersect()`. - let last1 = first1 + step1 * (len1 - 1); - let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1)); - let last2 = first2 + step2 * (len2 - 1); - let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2)); - - // Naively determine if the sequences intersect. - let seq1: Vec<_> = (0..len1) - .map(|n| first1 + step1 * n) - .collect(); - let intersects = (0..len2) - .map(|n| first2 + step2 * n) - .any(|elem2| seq1.contains(&elem2)); - - TestResult::from_bool( - arith_seq_intersect( - (min1, max1, if step1 == 0 { 1 } else { step1 }), - (min2, max2, if step2 == 0 { 1 } else { step2 }) - ) == intersects - ) + if len1 == 0 || len2 == 0 { + // This case is impossible to reach in `arith_seq_intersect()` + // because the `min*` and `max*` arguments are inclusive. + return TestResult::discard(); } + let len1 = len1.abs(); + let len2 = len2.abs(); + + // Convert to `min*` and `max*` arguments for `arith_seq_intersect()`. + let last1 = first1 + step1 * (len1 - 1); + let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1)); + let last2 = first2 + step2 * (len2 - 1); + let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2)); + + // Naively determine if the sequences intersect. + let seq1: Vec<_> = (0..len1) + .map(|n| first1 + step1 * n) + .collect(); + let intersects = (0..len2) + .map(|n| first2 + step2 * n) + .any(|elem2| seq1.contains(&elem2)); + + TestResult::from_bool( + arith_seq_intersect( + (min1, max1, if step1 == 0 { 1 } else { step1 }), + (min2, max2, if step2 == 0 { 1 } else { step2 }) + ) == intersects + ) } #[test] diff --git a/src/lib.rs b/src/lib.rs index 39f102e6f..39402bf56 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,6 +105,8 @@ extern crate num_integer; #[cfg(test)] extern crate quickcheck; #[cfg(test)] +extern crate quickcheck_macros; +#[cfg(test)] extern crate rand; #[cfg(feature = "docs")] diff --git a/src/numeric_util.rs b/src/numeric_util.rs index 855a1b803..68be6fa14 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -215,14 +215,12 @@ pub fn unrolled_eq(xs: &[A], ys: &[A]) -> bool #[cfg(test)] mod tests { - use quickcheck::quickcheck; + use quickcheck_macros::quickcheck; use std::num::Wrapping; use super::iterator_pairwise_sum; - quickcheck! { - fn iterator_pairwise_sum_is_correct(xs: Vec) -> bool { - let xs: Vec<_> = xs.into_iter().map(|x| Wrapping(x)).collect(); - iterator_pairwise_sum(xs.iter()) == xs.iter().sum() - } + #[quickcheck] + fn iterator_pairwise_sum_is_correct(xs: Vec>) -> bool { + iterator_pairwise_sum(xs.iter()) == xs.iter().sum() } } From 1ed1a638e7ec86c37de36a30c21bdd12f92940c5 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 3 Feb 2019 19:07:03 -0500 Subject: [PATCH 7/7] Clarify capacity calculation in iterator_pairwise_sum --- src/numeric_util.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numeric_util.rs b/src/numeric_util.rs index 68be6fa14..f6540fb92 100644 --- a/src/numeric_util.rs +++ b/src/numeric_util.rs @@ -62,7 +62,7 @@ where A: Clone + Add + Zero, { let (len, _) = iter.size_hint(); - let cap = len.saturating_sub(1) / NAIVE_SUM_THRESHOLD + 1; // ceiling of division + let cap = len / NAIVE_SUM_THRESHOLD + if len % NAIVE_SUM_THRESHOLD != 0 { 1 } else { 0 }; let mut partial_sums = Vec::with_capacity(cap); let (_, last_sum) = iter.fold((0, A::zero()), |(count, partial_sum), x| { if count < NAIVE_SUM_THRESHOLD {