Skip to content

Commit

Permalink
Merge pull request #916 from dam5h/permute-axis-not-sorting
Browse files Browse the repository at this point in the history
Fix sort-axis example, add test to confirm original error and resolution
  • Loading branch information
bluss authored Feb 11, 2021
2 parents 07853e8 + c27626a commit a6fe82f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 9 deletions.
80 changes: 71 additions & 9 deletions examples/sort-axis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,27 @@ where
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());

if self.is_empty() {
return self;
}

let mut result = Array::uninit(self.dim());

unsafe {
// logically move ownership of all elements from self into result
// the result realizes this ownership at .assume_init() further down
let mut moved_elements = 0;
for i in 0..axis_len {
let perm_i = perm.indices[i];
Zip::from(result.index_axis_mut(axis, perm_i))
.and(self.index_axis(axis, i))
.for_each(|to, from| {
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
}
Zip::from(&perm.indices)
.and(result.axis_iter_mut(axis))
.for_each(|&perm_i, result_pane| {
// possible improvement: use unchecked indexing for `index_axis`
Zip::from(result_pane)
.and(self.index_axis(axis, perm_i))
.for_each(|to, from| {
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
});
debug_assert_eq!(result.len(), moved_elements);
// panic-critical begin: we must not panic
// forget moved array elements but not its vec
Expand All @@ -129,6 +135,7 @@ where
}
}
}

#[cfg(feature = "std")]
fn main() {
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
Expand All @@ -143,5 +150,60 @@ fn main() {
let c = strings.permute_axis(Axis(1), &perm);
println!("{:?}", c);
}

#[cfg(not(feature = "std"))]
fn main() {}

#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permute_axis() {
let a = array![
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[108000.33, 4.],
[107999.45, 5.],
[107999.57, 6.],
[108010.69, 7.],
[107999.81, 8.],
[107999.94, 9.],
[75600.09, 10.],
[75600.21, 11.],
[75601.33, 12.],
[75600.45, 13.],
[75600.58, 14.],
[109000.70, 15.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
];

let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);
let b = a.permute_axis(Axis(0), &perm);
assert_eq!(
b,
array![
[75600.09, 10.],
[75600.21, 11.],
[75600.45, 13.],
[75600.58, 14.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
[75601.33, 12.],
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[107999.45, 5.],
[107999.57, 6.],
[107999.81, 8.],
[107999.94, 9.],
[108000.33, 4.],
[108010.69, 7.],
[109000.70, 15.],
]
);
}
}
1 change: 1 addition & 0 deletions scripts/all-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ cargo test --manifest-path=ndarray-rand/Cargo.toml --no-default-features --verbo
cargo test --manifest-path=ndarray-rand/Cargo.toml --features quickcheck --verbose
cargo test --manifest-path=serialization-tests/Cargo.toml --verbose
cargo test --manifest-path=blas-tests/Cargo.toml --verbose
cargo test --examples
CARGO_TARGET_DIR=target/ cargo test --manifest-path=numeric-tests/Cargo.toml --verbose
([ "$CHANNEL" != "nightly" ] || cargo bench --no-run --verbose --features "$FEATURES")

0 comments on commit a6fe82f

Please sign in to comment.