Skip to content

Commit

Permalink
Add rayon support
Browse files Browse the repository at this point in the history
  • Loading branch information
pitdicker committed Apr 15, 2018
1 parent 95ea68c commit 82d92de
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ members = ["rand_core"]
[dependencies]
rand_core = { path="rand_core", default-features = false }
log = { version = "0.4", optional = true }
rayon = { version = "1", optional = true }
serde = { version = "1", optional = true }
serde_derive = { version = "1", optional = true }

Expand Down
48 changes: 40 additions & 8 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
//! [`Standard`]: struct.Standard.html

use Rng;
#[cfg(feature = "rayon")]
use SeedableRng;
#[cfg(feature = "rayon")]
use distributions::rayon::ParallelDistIter;

pub use self::other::Alphanumeric;
pub use self::range::Range;
Expand All @@ -49,6 +53,8 @@ pub mod exponential;
pub mod poisson;
#[cfg(feature = "std")]
pub mod binomial;
#[cfg(feature = "rayon")]
mod rayon;

mod float;
mod integer;
Expand Down Expand Up @@ -169,17 +175,34 @@ pub trait Distribution<T> {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, R: Rng>(&'a self, rng: &'a mut R)
-> DistIter<'a, Self, R, T> where Self: Sized
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
where Self: Sized, R: Rng
{
DistIter {
distr: self,
rng: rng,
phantom: ::core::marker::PhantomData,
}
}

/// Create a parallel iterator.
#[cfg(feature = "rayon")]
fn sample_par_iter<'a, R>(&'a self, rng: &mut R, amount: usize)
-> ParallelDistIter<'a, Self, R, T>
where Self: Sized,
R: Rng + SeedableRng,
{
ParallelDistIter::new(self, rng, amount)
}
}

impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
(*self).sample(rng)
}
}


/// An iterator that generates random values of `T` with distribution `D`,
/// using `R` as the source of randomness.
///
Expand All @@ -189,7 +212,7 @@ pub trait Distribution<T> {
/// [`Distribution`]: trait.Distribution.html
/// [`sample_iter`]: trait.Distribution.html#method.sample_iter
#[derive(Debug)]
pub struct DistIter<'a, D, R, T> where D: Distribution<T> + 'a, R: Rng + 'a {
pub struct DistIter<'a, D: 'a, R: 'a, T> {
distr: &'a D,
rng: &'a mut R,
phantom: ::core::marker::PhantomData<T>,
Expand All @@ -206,11 +229,6 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
}
}

impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
(*self).sample(rng)
}
}

/// A generic random value distribution. Generates values for various types
/// with numerically uniform distribution.
Expand Down Expand Up @@ -620,4 +638,18 @@ mod tests {
let results: Vec<_> = distr.sample_iter(&mut rng).take(100).collect();
println!("{:?}", results);
}

#[cfg(all(feature="std", feature="rayon"))]
#[test]
fn test_distributions_par_iter() {
use distributions::Range;
use rayon::iter::ParallelIterator;
use NewRng;
use prng::XorShiftRng; // *EXTREMELY* bad choice!
let mut rng = XorShiftRng::new();
let range = Range::new(100, 200);
let results: Vec<_> = range.sample_par_iter(&mut rng, 1000).collect();
println!("{:?}", results);
panic!();
}
}
212 changes: 212 additions & 0 deletions src/distributions/rayon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Parallel iterator to sample from distributions.

use {Rng, SeedableRng};
use distributions::Distribution;

use rayon::iter::plumbing::{Consumer, Producer, ProducerCallback, UnindexedConsumer, bridge};
use rayon::iter::{ParallelIterator, IndexedParallelIterator};


/// An iterator that generates random values of `T` with distribution `D`,
/// using `R` as the source of randomness.
///
/// This `struct` is created by the [`par_sample_iter`] method on
/// [`Distribution`]. See its documentation for more.
///
/// [`Distribution`]: trait.Distribution.html
/// [`sample_iter`]: trait.Distribution.html#method.sample_iter
#[cfg(feature = "rayon")]
#[derive(Debug)]
pub struct ParallelDistIter<'a, D: 'a, R, T> {
distr: &'a D,
rng: R,
amount: usize,
phantom: ::core::marker::PhantomData<T>,
}

impl<'a, D, R, T> ParallelDistIter<'a, D, R, T> {
pub fn new(distr: &'a D, rng: &mut R, amount: usize)
-> ParallelDistIter<'a, D, R, T>
where D: Distribution<T>,
R: Rng + SeedableRng,
{
ParallelDistIter {
distr,
rng: R::from_rng(rng).unwrap(),
amount,
phantom: ::core::marker::PhantomData,
}
}
}

#[cfg(feature = "rayon")]
impl<'a, D, R, T> ParallelIterator for ParallelDistIter<'a, D, R, T>
where D: Distribution<T> + Send + Sync,
R: Rng + SeedableRng + Send,
T: Send,
{
type Item = T;

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where C: UnindexedConsumer<Self::Item>
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.amount)
}
}

#[cfg(feature = "rayon")]
impl<'a, D, R, T> IndexedParallelIterator for ParallelDistIter<'a, D, R, T>
where D: Distribution<T> + Send + Sync,
R: Rng + SeedableRng + Send,
T: Send,
{
fn drive<C>(self, consumer: C) -> C::Result
where C: Consumer<Self::Item>
{
bridge(self, consumer)
}

fn len(&self) -> usize {
self.amount
}

fn with_producer<CB>(self, callback: CB) -> CB::Output
where CB: ProducerCallback<Self::Item>
{
callback.callback(
DistProducer {
distr: self.distr.clone(),
amount: self.amount,
rng: self.rng,
phantom: ::core::marker::PhantomData,
}
)
}
}

/// FIXME
#[cfg(feature = "rayon")]
#[derive(Debug)]
pub struct DistProducer<'a, D: 'a, R, T> {
distr: &'a D,
rng: R,
amount: usize,
phantom: ::core::marker::PhantomData<T>,
}

/// This method is intented to be used by fast and relatively simple PRNGs used
/// for simulations etc. While it will also work with cryptographic RNGs, that
/// is not optimal.
///
/// Every time `rayon` splits the work in two to create parallel tasks, one new
/// PRNG is created. The original PRNG is used to seed the new one using
/// `SeedableRng::from_rng`. **Important**: Not all RNG algorithms support this!
/// Notably the low-quality plain Xorshift, the current default for `SmallRng`,
/// will simply clone itself using this method instead of seeding the split off
/// RNG well. Consider using something like PCG or Xoroshiro128+.
///
/// It is hard to predict what will happen to the statistical quality of PRNGs
/// when they are split off many times, and only very short runs are used. We
/// limit the minimum number of items that should be used of the PRNG to at
/// least 100 to hopefully keep similar statistical properties as one PRNG used
/// continuously.
#[cfg(feature = "rayon")]
impl<'a, D, R, T> Producer for DistProducer<'a, D, R, T>
where D: Distribution<T> + Send + Sync,
R: Rng + SeedableRng + Send,
T: Send,
{
type Item = T;
type IntoIter = BoundedDistIter<'a, D, R, T>;
fn into_iter(self) -> Self::IntoIter {
BoundedDistIter {
distr: self.distr,
amount: self.amount,
rng: self.rng,
phantom: ::core::marker::PhantomData,
}
}

fn split_at(mut self, index: usize) -> (Self, Self) {
assert!(index <= self.amount);
// Create a new PRNG of the same type, by seeding it with this PRNG.
// `from_rng` should never fail.
let new = DistProducer {
distr: self.distr,
amount: self.amount - index,
rng: R::from_rng(&mut self.rng).unwrap(),
phantom: ::core::marker::PhantomData,
};
self.amount = index;
(self, new)
}

fn min_len(&self) -> usize {
100
}
}

/// FIXME
#[cfg(feature = "rayon")]
#[derive(Debug)]
pub struct BoundedDistIter<'a, D: 'a, R, T> {
distr: &'a D,
rng: R,
amount: usize,
phantom: ::core::marker::PhantomData<T>,
}

#[cfg(feature = "rayon")]
impl<'a, D, R, T> Iterator for BoundedDistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng
{
type Item = T;

#[inline(always)]
fn next(&mut self) -> Option<T> {
if self.amount > 0 {
self.amount -= 1;
Some(self.distr.sample(&mut self.rng))
} else {
None
}
}
}

#[cfg(feature = "rayon")]
impl<'a, D, R, T> DoubleEndedIterator for BoundedDistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng
{
#[inline(always)]
fn next_back(&mut self) -> Option<T> {
if self.amount > 0 {
self.amount -= 1;
Some(self.distr.sample(&mut self.rng))
} else {
None
}
}
}

#[cfg(feature = "rayon")]
impl<'a, D, R, T> ExactSizeIterator for BoundedDistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng
{
fn len(&self) -> usize {
self.amount
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@
#[cfg(feature="std")] extern crate std as core;
#[cfg(all(feature = "alloc", not(feature="std")))] extern crate alloc;

#[cfg(feature = "rayon")] extern crate rayon;

#[cfg(test)] #[cfg(feature="serde1")] extern crate bincode;
#[cfg(feature="serde1")] extern crate serde;
#[cfg(feature="serde1")] #[macro_use] extern crate serde_derive;
Expand Down

0 comments on commit 82d92de

Please sign in to comment.