diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 63cf1c397..8d2b9193b 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -34,8 +34,7 @@ use crate::rand::rngs::SmallRng; use crate::rand::seq::index; use crate::rand::{thread_rng, Rng, SeedableRng}; -use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder}; -use ndarray::{ArrayBase, DataOwned, Dimension}; +use ndarray::{Array, Axis, Data, DataMut, RemoveAxis, ShapeBuilder, ArrayBase, DataOwned, Dimension}; #[cfg(feature = "quickcheck")] use quickcheck::{Arbitrary, Gen}; @@ -64,7 +63,7 @@ pub mod rand_distr { /// [`.random_using()`](#tymethod.random_using). pub trait RandomExt where - S: DataOwned, + S: Data, D: Dimension, { /// Create an array with shape `dim` with elements drawn from @@ -87,6 +86,7 @@ where /// # } fn random(shape: Sh, distribution: IdS) -> ArrayBase where + S: DataOwned, IdS: Distribution, Sh: ShapeBuilder; @@ -116,6 +116,7 @@ where /// # } fn random_using(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase where + S: DataOwned, IdS: Distribution, R: Rng + ?Sized, Sh: ShapeBuilder; @@ -225,17 +226,93 @@ where R: Rng + ?Sized, A: Copy, D: RemoveAxis; + + /// Shuffle `self`'s slices along `axis`. + /// + /// It uses [Fisher-Yates shuffling algorithm](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle). + /// + /// ***Panics*** when creation of the RNG fails. + /// ``` + /// use ndarray::{array, Axis}; + /// use ndarray_rand::RandomExt; + /// + /// # fn main() { + /// let mut a = array![ + /// [1., 2., 3.], + /// [4., 5., 6.], + /// [7., 8., 9.], + /// [10., 11., 12.], + /// ]; + /// // Let's shuffle `a`'s columns! + /// // Shuffling modifies the array in place, nothing is returned + /// a.shuffle_axis_inplace(Axis(1)); + /// println!("{:?}", a); + /// // Example Output: + /// // [ + /// // [1., 3., 2.], + /// // [4., 6., 5.], + /// // [7., 9., 8.], + /// // [10., 12., 11.], + /// // ] + /// # } + /// ``` + fn shuffle_axis_inplace(&mut self, axis: Axis) + where + D: RemoveAxis, + S: DataMut; + + /// Shuffle `self`'s slices along `axis` using the specified random number generator `rng`. + /// + /// It uses [Fisher-Yates shuffling algorithm](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle). + /// + /// ***Panics*** when creation of the RNG fails. + /// ``` + /// use ndarray::{array, Axis}; + /// use ndarray_rand::RandomExt; + /// use ndarray_rand::rand::SeedableRng; + /// use rand_isaac::isaac64::Isaac64Rng; + /// + /// # fn main() { + /// // Get a seeded random number generator for reproducibility (Isaac64 algorithm) + /// let seed = 42; + /// let mut rng = Isaac64Rng::seed_from_u64(seed); + /// + /// let mut a = array![ + /// [1., 2., 3.], + /// [4., 5., 6.], + /// [7., 8., 9.], + /// [10., 11., 12.], + /// ]; + /// // Let's shuffle `a`'s rows! + /// // Shuffling modifies the array in place, nothing is returned + /// a.shuffle_axis_inplace_using(Axis(0), &mut rng); + /// println!("{:?}", a); + /// // Example Output: + /// // [ + /// // [7., 8., 9.], + /// // [4., 5., 6.], + /// // [10., 11., 12.], + /// // [1., 2., 3.], + /// // ] + /// # } + /// ``` + fn shuffle_axis_inplace_using(&mut self, axis: Axis, rng: &mut R) + where + R: Rng + ?Sized, + D: RemoveAxis, + S: DataMut; } impl RandomExt for ArrayBase where - S: DataOwned, + S: Data, D: Dimension, { fn random(shape: Sh, dist: IdS) -> ArrayBase where IdS: Distribution, Sh: ShapeBuilder, + S: DataOwned, { Self::random_using(shape, dist, &mut get_rng()) } @@ -245,6 +322,7 @@ where IdS: Distribution, R: Rng + ?Sized, Sh: ShapeBuilder, + S: DataOwned, { Self::from_shape_simple_fn(shape, move || dist.sample(rng)) } @@ -280,6 +358,41 @@ where }; self.select(axis, &indices) } + + fn shuffle_axis_inplace(&mut self, axis: Axis) + where + D: RemoveAxis, + S: DataMut, + { + self.shuffle_axis_inplace_using(axis, &mut get_rng()) + } + + fn shuffle_axis_inplace_using(&mut self, axis: Axis, rng: &mut R) + where + R: Rng + ?Sized, + D: RemoveAxis, + S: DataMut, + { + for i in (1..self.len_of(axis)).rev() { + // Invariant: elements with index > i have been locked in place. + let j = rng.gen_range(0, i + 1); + + if i != j { + // Swap the two slices along `axis` + let slice1 = self.index_axis(axis, i); + let slice2 = self.index_axis(axis, j); + + for (x, y) in slice1.iter().zip(slice2.iter()) { + // Swap the two elements. + let ptr1 = x as *const A as *mut A; + let ptr2 = y as *const A as *mut A; + unsafe { + std::ptr::swap(ptr1, ptr2); + } + } + } + } + } } /// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index f7860ac12..7e0c27364 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -125,3 +125,20 @@ fn sampling_with_replacement_from_a_zero_length_axis_should_panic() { let a = Array::random((0, n), Uniform::new(0., 2.)); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement); } + +quickcheck! { + fn shuffling_works(m: usize, n: usize) -> bool { + let a = Array::random((m, n), Uniform::new(0., 2.)); + + // Get a clone of `a` and shuffle it in place + let mut results = vec![]; + for &axis in &[Axis(0), Axis(1)] { + let mut b = a.clone(); + b.shuffle_axis_inplace(axis); + + let result = b.axis_iter(axis).all(|lane| is_subset(&a, &lane, axis)); + results.push(result) + } + results.into_iter().all(|p| p) + } +}