Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pairwise summation #4

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ serde = { version = "1.0", optional = true }

[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "0.7.2", default-features = false }
quickcheck = { version = "0.8.1", default-features = false }
quickcheck_macros = "0.8"
rawpointer = "0.1"
rand = "0.5.5"

Expand Down
29 changes: 23 additions & 6 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
indices
}

/// Compute the minimum stride axis (absolute value), under the constraint
/// that the length of the axis is > 1;
/// Compute the minimum stride axis (absolute value), preferring axes with
/// length > 1.
#[doc(hidden)]
fn min_stride_axis(&self, strides: &Self) -> Axis {
let n = match self.ndim() {
Expand All @@ -301,7 +301,7 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
n => n,
};
axes_of(self, strides)
.rev()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(Axis(n - 1), |ax| ax.axis())
}
Expand Down Expand Up @@ -588,9 +588,9 @@ impl Dimension for Dim<[Ix; 2]> {

#[inline]
fn min_stride_axis(&self, strides: &Self) -> Axis {
let s = get!(strides, 0) as Ixs;
let t = get!(strides, 1) as Ixs;
if s.abs() < t.abs() {
let s = (get!(strides, 0) as isize).abs();
let t = (get!(strides, 1) as isize).abs();
if s < t && get!(self, 0) > 1 {
Axis(0)
} else {
Axis(1)
Expand Down Expand Up @@ -697,6 +697,23 @@ impl Dimension for Dim<[Ix; 3]> {
Some(Ix3(i, j, k))
}

#[inline]
fn min_stride_axis(&self, strides: &Self) -> Axis {
let s = (get!(strides, 0) as isize).abs();
let t = (get!(strides, 1) as isize).abs();
let u = (get!(strides, 2) as isize).abs();
let (argmin, min) = if t < u && get!(self, 1) > 1 {
(Axis(1), t)
} else {
(Axis(2), u)
};
if s < min && get!(self, 0) > 1 {
Axis(0)
} else {
argmin
}
}

/// Self is an index, return the stride offset
#[inline]
fn stride_offset(index: &Self, strides: &Self) -> isize {
Expand Down
157 changes: 78 additions & 79 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ mod test {
use crate::error::{from_kind, ErrorKind};
use crate::slice::Slice;
use num_integer::gcd;
use quickcheck::{quickcheck, TestResult};
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;

#[test]
fn slice_indexing_uncommon_strides() {
Expand Down Expand Up @@ -738,30 +739,29 @@ mod test {
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
}

quickcheck! {
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
let dim = IxDyn(&dim);
let result = can_index_slice_not_custom(&data, &dim);
if dim.size_checked().is_none() {
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
} else {
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides())
}
#[quickcheck]
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
let dim = IxDyn(&dim);
let result = can_index_slice_not_custom(&data, &dim);
if dim.size_checked().is_none() {
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
} else {
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides())
}
}

quickcheck! {
fn extended_gcd_solves_eq(a: isize, b: isize) -> bool {
let (g, (x, y)) = extended_gcd(a, b);
a * x + b * y == g
}
#[quickcheck]
fn extended_gcd_solves_eq(a: isize, b: isize) -> bool {
let (g, (x, y)) = extended_gcd(a, b);
a * x + b * y == g
}

fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool {
let (g, _) = extended_gcd(a, b);
g == gcd(a, b)
}
#[quickcheck]
fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool {
let (g, _) = extended_gcd(a, b);
g == gcd(a, b)
}

#[test]
Expand All @@ -773,73 +773,72 @@ mod test {
assert_eq!(extended_gcd(-5, 0), (5, (-1, 0)));
}

quickcheck! {
fn solve_linear_diophantine_eq_solution_existence(
a: isize, b: isize, c: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
TestResult::from_bool(
(c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
)
}
#[quickcheck]
fn solve_linear_diophantine_eq_solution_existence(
a: isize, b: isize, c: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
TestResult::from_bool(
(c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
)
}
}

fn solve_linear_diophantine_eq_correct_solution(
a: isize, b: isize, c: isize, t: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
match solve_linear_diophantine_eq(a, b, c) {
Some((x0, xd)) => {
let x = x0 + xd * t;
let y = (c - a * x) / b;
TestResult::from_bool(a * x + b * y == c)
}
None => TestResult::discard(),
#[quickcheck]
fn solve_linear_diophantine_eq_correct_solution(
a: isize, b: isize, c: isize, t: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
match solve_linear_diophantine_eq(a, b, c) {
Some((x0, xd)) => {
let x = x0 + xd * t;
let y = (c - a * x) / b;
TestResult::from_bool(a * x + b * y == c)
}
None => TestResult::discard(),
}
}
}

quickcheck! {
fn arith_seq_intersect_correct(
first1: isize, len1: isize, step1: isize,
first2: isize, len2: isize, step2: isize
) -> TestResult {
use std::cmp;
#[quickcheck]
fn arith_seq_intersect_correct(
first1: isize, len1: isize, step1: isize,
first2: isize, len2: isize, step2: isize
) -> TestResult {
use std::cmp;

if len1 == 0 || len2 == 0 {
// This case is impossible to reach in `arith_seq_intersect()`
// because the `min*` and `max*` arguments are inclusive.
return TestResult::discard();
}
let len1 = len1.abs();
let len2 = len2.abs();

// Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
let last1 = first1 + step1 * (len1 - 1);
let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
let last2 = first2 + step2 * (len2 - 1);
let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));

// Naively determine if the sequences intersect.
let seq1: Vec<_> = (0..len1)
.map(|n| first1 + step1 * n)
.collect();
let intersects = (0..len2)
.map(|n| first2 + step2 * n)
.any(|elem2| seq1.contains(&elem2));

TestResult::from_bool(
arith_seq_intersect(
(min1, max1, if step1 == 0 { 1 } else { step1 }),
(min2, max2, if step2 == 0 { 1 } else { step2 })
) == intersects
)
if len1 == 0 || len2 == 0 {
// This case is impossible to reach in `arith_seq_intersect()`
// because the `min*` and `max*` arguments are inclusive.
return TestResult::discard();
}
let len1 = len1.abs();
let len2 = len2.abs();

// Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
let last1 = first1 + step1 * (len1 - 1);
let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
let last2 = first2 + step2 * (len2 - 1);
let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));

// Naively determine if the sequences intersect.
let seq1: Vec<_> = (0..len1)
.map(|n| first1 + step1 * n)
.collect();
let intersects = (0..len2)
.map(|n| first2 + step2 * n)
.any(|elem2| seq1.contains(&elem2));

TestResult::from_bool(
arith_seq_intersect(
(min1, max1, if step1 == 0 { 1 } else { step1 }),
(min2, max2, if step2 == 0 { 1 } else { step2 })
) == intersects
)
}

#[test]
Expand Down
70 changes: 47 additions & 23 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,12 +1619,11 @@ where
axes_of(&self.dim, &self.strides)
}

/*
/// Return the axis with the least stride (by absolute value)
pub fn min_stride_axis(&self) -> Axis {
/// Return the axis with the least stride (by absolute value),
/// preferring axes with len > 1.
pub(crate) fn min_stride_axis(&self) -> Axis {
self.dim.min_stride_axis(&self.strides)
}
*/

/// Return the axis with the greatest stride (by absolute value),
/// preferring axes with len > 1.
Expand Down Expand Up @@ -1854,25 +1853,11 @@ where
} else {
let mut v = self.view();
// put the narrowest axis at the last position
match v.ndim() {
0 | 1 => {}
2 => {
if self.len_of(Axis(1)) <= 1
|| self.len_of(Axis(0)) > 1
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
{
v.swap_axes(0, 1);
}
}
n => {
let last = n - 1;
let narrow_axis = v
.axes()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(last, |ax| ax.axis().index());
v.swap_axes(last, narrow_axis);
}
let n = v.ndim();
if n > 1 {
let last = n - 1;
let narrow_axis = self.min_stride_axis();
v.swap_axes(last, narrow_axis.index());
}
v.into_elements_base().fold(init, f)
}
Expand Down Expand Up @@ -2103,3 +2088,42 @@ where
})
}
}

#[cfg(test)]
mod tests {
use crate::prelude::*;

#[test]
fn min_stride_axis() {
let a = Array1::<u8>::zeros(10);
assert_eq!(a.min_stride_axis(), Axis(0));

let a = Array2::<u8>::zeros((3, 3));
assert_eq!(a.min_stride_axis(), Axis(1));
assert_eq!(a.t().min_stride_axis(), Axis(0));

let a = ArrayD::<u8>::zeros(vec![3, 3]);
assert_eq!(a.min_stride_axis(), Axis(1));
assert_eq!(a.t().min_stride_axis(), Axis(0));

let min_axis = a.axes().min_by_key(|t| t.2.abs()).unwrap().axis();
assert_eq!(min_axis, Axis(1));

let mut b = ArrayD::<u8>::zeros(vec![2, 3, 4, 5]);
assert_eq!(b.min_stride_axis(), Axis(3));
for ax in 0..3 {
b.swap_axes(3, ax);
assert_eq!(b.min_stride_axis(), Axis(ax));
b.swap_axes(3, ax);
}
let mut v = b.view();
v.collapse_axis(Axis(3), 0);
assert_eq!(v.min_stride_axis(), Axis(2));

let a = Array2::<u8>::zeros((3, 3));
let v = a.broadcast((8, 3, 3)).unwrap();
assert_eq!(v.min_stride_axis(), Axis(0));
let v2 = a.broadcast((1, 3, 3)).unwrap();
assert_eq!(v2.min_stride_axis(), Axis(2));
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ extern crate num_integer;
#[cfg(test)]
extern crate quickcheck;
#[cfg(test)]
extern crate quickcheck_macros;
#[cfg(test)]
extern crate rand;

#[cfg(feature = "docs")]
Expand Down
13 changes: 10 additions & 3 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ impl<A, S, D> ArrayBase<S, D>
where A: Clone + Add<Output=A> + num_traits::Zero,
{
if let Some(slc) = self.as_slice_memory_order() {
numeric_util::pairwise_sum(&slc)
} else {
numeric_util::iterator_pairwise_sum(self.iter())
return numeric_util::pairwise_sum(&slc);
}
if self.ndim() > 1 {
let ax = self.dim.min_stride_axis(&self.strides);
if self.len_of(ax) >= numeric_util::UNROLL_SIZE && self.stride_of(ax) == 1 {
let partial_sums: Vec<_> =
self.lanes(ax).into_iter().map(|lane| lane.sum()).collect();
return numeric_util::pure_pairwise_sum(&partial_sums);
}
}
numeric_util::iterator_pairwise_sum(self.iter())
}

/// Return the sum of all elements in the array.
Expand Down
Loading