Skip to content

Commit

Permalink
Alternative implementation for sum_axis
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeMathWalker committed Jan 9, 2019
1 parent b3d2b42 commit 8f95705
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,11 @@ impl<A, S, D> ArrayBase<S, D>
where A: Clone + Zero + Add<Output=A>,
D: RemoveAxis,
{
let n = self.len_of(axis);
let stride = self.strides()[axis.index()];
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
if self.ndim() == 2 && stride == 1 {
// contiguous along the axis we are summing
let ax = axis.index();
for (i, elt) in enumerate(&mut res) {
*elt = self.index_axis(Axis(1 - ax), i).sum();
}
res
} else {
numeric_util::array_pairwise_sum(
(0..n).map(|i| self.index_axis(axis, i)),
|| res.clone()
)
}
let mut out = Array::zeros(self.dim.remove_axis(axis));
Zip::from(&mut out)
.and(self.lanes(axis))
.apply(|out, lane| *out = lane.sum());
out
}

/// Return mean along `axis`.
Expand Down

0 comments on commit 8f95705

Please sign in to comment.