From 82d92de2b4e2879b622d37a1d51466ac88451a07 Mon Sep 17 00:00:00 2001 From: Paul Dicker Date: Sun, 15 Apr 2018 08:53:08 +0200 Subject: [PATCH] Add rayon support --- Cargo.toml | 1 + src/distributions/mod.rs | 48 +++++++-- src/distributions/rayon.rs | 212 +++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 4 files changed, 255 insertions(+), 8 deletions(-) create mode 100644 src/distributions/rayon.rs diff --git a/Cargo.toml b/Cargo.toml index 27b0aadc3d9..10565b1e82d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ members = ["rand_core"] [dependencies] rand_core = { path="rand_core", default-features = false } log = { version = "0.4", optional = true } +rayon = { version = "1", optional = true } serde = { version = "1", optional = true } serde_derive = { version = "1", optional = true } diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index ae259843ac4..33a64ed2be6 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -24,6 +24,10 @@ //! [`Standard`]: struct.Standard.html use Rng; +#[cfg(feature = "rayon")] +use SeedableRng; +#[cfg(feature = "rayon")] +use distributions::rayon::ParallelDistIter; pub use self::other::Alphanumeric; pub use self::range::Range; @@ -49,6 +53,8 @@ pub mod exponential; pub mod poisson; #[cfg(feature = "std")] pub mod binomial; +#[cfg(feature = "rayon")] +mod rayon; mod float; mod integer; @@ -169,8 +175,8 @@ pub trait Distribution { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter<'a, R: Rng>(&'a self, rng: &'a mut R) - -> DistIter<'a, Self, R, T> where Self: Sized + fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T> + where Self: Sized, R: Rng { DistIter { distr: self, @@ -178,8 +184,25 @@ pub trait Distribution { phantom: ::core::marker::PhantomData, } } + + /// Create a parallel iterator. + #[cfg(feature = "rayon")] + fn sample_par_iter<'a, R>(&'a self, rng: &mut R, amount: usize) + -> ParallelDistIter<'a, Self, R, T> + where Self: Sized, + R: Rng + SeedableRng, + { + ParallelDistIter::new(self, rng, amount) + } +} + +impl<'a, T, D: Distribution> Distribution for &'a D { + fn sample(&self, rng: &mut R) -> T { + (*self).sample(rng) + } } + /// An iterator that generates random values of `T` with distribution `D`, /// using `R` as the source of randomness. /// @@ -189,7 +212,7 @@ pub trait Distribution { /// [`Distribution`]: trait.Distribution.html /// [`sample_iter`]: trait.Distribution.html#method.sample_iter #[derive(Debug)] -pub struct DistIter<'a, D, R, T> where D: Distribution + 'a, R: Rng + 'a { +pub struct DistIter<'a, D: 'a, R: 'a, T> { distr: &'a D, rng: &'a mut R, phantom: ::core::marker::PhantomData, @@ -206,11 +229,6 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T> } } -impl<'a, T, D: Distribution> Distribution for &'a D { - fn sample(&self, rng: &mut R) -> T { - (*self).sample(rng) - } -} /// A generic random value distribution. Generates values for various types /// with numerically uniform distribution. @@ -620,4 +638,18 @@ mod tests { let results: Vec<_> = distr.sample_iter(&mut rng).take(100).collect(); println!("{:?}", results); } + + #[cfg(all(feature="std", feature="rayon"))] + #[test] + fn test_distributions_par_iter() { + use distributions::Range; + use rayon::iter::ParallelIterator; + use NewRng; + use prng::XorShiftRng; // *EXTREMELY* bad choice! + let mut rng = XorShiftRng::new(); + let range = Range::new(100, 200); + let results: Vec<_> = range.sample_par_iter(&mut rng, 1000).collect(); + println!("{:?}", results); + panic!(); + } } diff --git a/src/distributions/rayon.rs b/src/distributions/rayon.rs new file mode 100644 index 00000000000..21f0ba05c0d --- /dev/null +++ b/src/distributions/rayon.rs @@ -0,0 +1,212 @@ +// Copyright 2018 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// https://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Parallel iterator to sample from distributions. + +use {Rng, SeedableRng}; +use distributions::Distribution; + +use rayon::iter::plumbing::{Consumer, Producer, ProducerCallback, UnindexedConsumer, bridge}; +use rayon::iter::{ParallelIterator, IndexedParallelIterator}; + + +/// An iterator that generates random values of `T` with distribution `D`, +/// using `R` as the source of randomness. +/// +/// This `struct` is created by the [`par_sample_iter`] method on +/// [`Distribution`]. See its documentation for more. +/// +/// [`Distribution`]: trait.Distribution.html +/// [`sample_iter`]: trait.Distribution.html#method.sample_iter +#[cfg(feature = "rayon")] +#[derive(Debug)] +pub struct ParallelDistIter<'a, D: 'a, R, T> { + distr: &'a D, + rng: R, + amount: usize, + phantom: ::core::marker::PhantomData, +} + +impl<'a, D, R, T> ParallelDistIter<'a, D, R, T> { + pub fn new(distr: &'a D, rng: &mut R, amount: usize) + -> ParallelDistIter<'a, D, R, T> + where D: Distribution, + R: Rng + SeedableRng, + { + ParallelDistIter { + distr, + rng: R::from_rng(rng).unwrap(), + amount, + phantom: ::core::marker::PhantomData, + } + } +} + +#[cfg(feature = "rayon")] +impl<'a, D, R, T> ParallelIterator for ParallelDistIter<'a, D, R, T> + where D: Distribution + Send + Sync, + R: Rng + SeedableRng + Send, + T: Send, +{ + type Item = T; + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: UnindexedConsumer + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.amount) + } +} + +#[cfg(feature = "rayon")] +impl<'a, D, R, T> IndexedParallelIterator for ParallelDistIter<'a, D, R, T> + where D: Distribution + Send + Sync, + R: Rng + SeedableRng + Send, + T: Send, +{ + fn drive(self, consumer: C) -> C::Result + where C: Consumer + { + bridge(self, consumer) + } + + fn len(&self) -> usize { + self.amount + } + + fn with_producer(self, callback: CB) -> CB::Output + where CB: ProducerCallback + { + callback.callback( + DistProducer { + distr: self.distr.clone(), + amount: self.amount, + rng: self.rng, + phantom: ::core::marker::PhantomData, + } + ) + } +} + +/// FIXME +#[cfg(feature = "rayon")] +#[derive(Debug)] +pub struct DistProducer<'a, D: 'a, R, T> { + distr: &'a D, + rng: R, + amount: usize, + phantom: ::core::marker::PhantomData, +} + +/// This method is intented to be used by fast and relatively simple PRNGs used +/// for simulations etc. While it will also work with cryptographic RNGs, that +/// is not optimal. +/// +/// Every time `rayon` splits the work in two to create parallel tasks, one new +/// PRNG is created. The original PRNG is used to seed the new one using +/// `SeedableRng::from_rng`. **Important**: Not all RNG algorithms support this! +/// Notably the low-quality plain Xorshift, the current default for `SmallRng`, +/// will simply clone itself using this method instead of seeding the split off +/// RNG well. Consider using something like PCG or Xoroshiro128+. +/// +/// It is hard to predict what will happen to the statistical quality of PRNGs +/// when they are split off many times, and only very short runs are used. We +/// limit the minimum number of items that should be used of the PRNG to at +/// least 100 to hopefully keep similar statistical properties as one PRNG used +/// continuously. +#[cfg(feature = "rayon")] +impl<'a, D, R, T> Producer for DistProducer<'a, D, R, T> + where D: Distribution + Send + Sync, + R: Rng + SeedableRng + Send, + T: Send, +{ + type Item = T; + type IntoIter = BoundedDistIter<'a, D, R, T>; + fn into_iter(self) -> Self::IntoIter { + BoundedDistIter { + distr: self.distr, + amount: self.amount, + rng: self.rng, + phantom: ::core::marker::PhantomData, + } + } + + fn split_at(mut self, index: usize) -> (Self, Self) { + assert!(index <= self.amount); + // Create a new PRNG of the same type, by seeding it with this PRNG. + // `from_rng` should never fail. + let new = DistProducer { + distr: self.distr, + amount: self.amount - index, + rng: R::from_rng(&mut self.rng).unwrap(), + phantom: ::core::marker::PhantomData, + }; + self.amount = index; + (self, new) + } + + fn min_len(&self) -> usize { + 100 + } +} + +/// FIXME +#[cfg(feature = "rayon")] +#[derive(Debug)] +pub struct BoundedDistIter<'a, D: 'a, R, T> { + distr: &'a D, + rng: R, + amount: usize, + phantom: ::core::marker::PhantomData, +} + +#[cfg(feature = "rayon")] +impl<'a, D, R, T> Iterator for BoundedDistIter<'a, D, R, T> + where D: Distribution, R: Rng +{ + type Item = T; + + #[inline(always)] + fn next(&mut self) -> Option { + if self.amount > 0 { + self.amount -= 1; + Some(self.distr.sample(&mut self.rng)) + } else { + None + } + } +} + +#[cfg(feature = "rayon")] +impl<'a, D, R, T> DoubleEndedIterator for BoundedDistIter<'a, D, R, T> + where D: Distribution, R: Rng +{ + #[inline(always)] + fn next_back(&mut self) -> Option { + if self.amount > 0 { + self.amount -= 1; + Some(self.distr.sample(&mut self.rng)) + } else { + None + } + } +} + +#[cfg(feature = "rayon")] +impl<'a, D, R, T> ExactSizeIterator for BoundedDistIter<'a, D, R, T> + where D: Distribution, R: Rng +{ + fn len(&self) -> usize { + self.amount + } +} diff --git a/src/lib.rs b/src/lib.rs index a572e8670c6..fcc1242ca9a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -182,6 +182,8 @@ #[cfg(feature="std")] extern crate std as core; #[cfg(all(feature = "alloc", not(feature="std")))] extern crate alloc; +#[cfg(feature = "rayon")] extern crate rayon; + #[cfg(test)] #[cfg(feature="serde1")] extern crate bincode; #[cfg(feature="serde1")] extern crate serde; #[cfg(feature="serde1")] #[macro_use] extern crate serde_derive;