Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic SIMD support #523

Merged
merged 6 commits into from
Jun 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ appveyor = { repository = "alexcrichton/rand" }

[features]
default = ["std" ] # without "std" rand uses libcore
nightly = ["i128_support"] # enables all features requiring nightly rust
nightly = ["i128_support", "simd_support"] # enables all features requiring nightly rust
std = ["rand_core/std", "alloc", "libc", "winapi", "cloudabi", "fuchsia-zircon"]
alloc = ["rand_core/alloc"] # enables Vec and Box support (without std)
i128_support = [] # enables i128 and u128 support
simd_support = [] # enables SIMD support
serde1 = ["serde", "serde_derive", "rand_core/serde1"] # enables serialization for PRNGs

[workspace]
Expand Down
167 changes: 110 additions & 57 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
use core::mem;
use Rng;
use distributions::{Distribution, Standard};
use distributions::utils::CastFromInt;
#[cfg(feature="simd_support")]
use core::simd::*;

/// A distribution to sample floating point numbers uniformly in the half-open
/// interval `(0, 1]`, i.e. including 1 but not 0.
Expand Down Expand Up @@ -83,15 +86,16 @@ pub(crate) trait IntoFloat {
}

macro_rules! float_impls {
($ty:ty, $uty:ty, $fraction_bits:expr, $exponent_bias:expr) => {
($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty,
$fraction_bits:expr, $exponent_bias:expr) => {
impl IntoFloat for $uty {
type F = $ty;
#[inline(always)]
fn into_float_with_exponent(self, exponent: i32) -> $ty {
// The exponent is encoded using an offset-binary representation
let exponent_bits =
(($exponent_bias + exponent) as $uty) << $fraction_bits;
unsafe { mem::transmute(self | exponent_bits) }
let exponent_bits: $u_scalar =
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
$ty::from_bits(self | exponent_bits)
}
}

Expand All @@ -100,12 +104,13 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; [0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$ty>() * 8;
let float_size = mem::size_of::<$f_scalar>() * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $uty << precision) as $ty);
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
scale * (value >> (float_size - precision)) as $ty
let value = value >> (float_size - precision);
scale * $ty::cast_from_int(value)
}
}

Expand All @@ -114,14 +119,14 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; (0, 1] interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$ty>() * 8;
let float_size = mem::size_of::<$f_scalar>() * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $uty << precision) as $ty);
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
// Add 1 to shift up; will not overflow because of right-shift:
scale * (value + 1) as $ty
scale * $ty::cast_from_int(value + 1)
}
}

Expand All @@ -130,8 +135,8 @@ macro_rules! float_impls {
// Transmute-based method; 23/52 random bits; (0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
const EPSILON: $ty = 1.0 / (1u64 << $fraction_bits) as $ty;
let float_size = mem::size_of::<$ty>() * 8;
use core::$f_scalar::EPSILON;
let float_size = mem::size_of::<$f_scalar>() * 8;

let value: $uty = rng.gen();
let fraction = value >> (float_size - $fraction_bits);
Expand All @@ -140,67 +145,115 @@ macro_rules! float_impls {
}
}
}
float_impls! { f32, u32, 23, 127 }
float_impls! { f64, u64, 52, 1023 }

float_impls! { f32, u32, f32, u32, 23, 127 }
float_impls! { f64, u64, f64, u64, 52, 1023 }

#[cfg(feature="simd_support")]
float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
float_impls! { f32x16, u32x16, f32, u32, 23, 127 }

#[cfg(feature="simd_support")]
float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
#[cfg(feature="simd_support")]
float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
#[cfg(feature="simd_support")]
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }


#[cfg(test)]
mod tests {
use Rng;
use distributions::{Open01, OpenClosed01};
use rngs::mock::StepRng;
#[cfg(feature="simd_support")]
use core::simd::*;

const EPSILON32: f32 = ::core::f32::EPSILON;
const EPSILON64: f64 = ::core::f64::EPSILON;

#[test]
fn standard_fp_edge_cases() {
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<f32>(), 0.0);
assert_eq!(zeros.gen::<f64>(), 0.0);

let mut one32 = StepRng::new(1 << 8, 0);
assert_eq!(one32.gen::<f32>(), EPSILON32 / 2.0);

let mut one64 = StepRng::new(1 << 11, 0);
assert_eq!(one64.gen::<f64>(), EPSILON64 / 2.0);

let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<f32>(), 1.0 - EPSILON32 / 2.0);
assert_eq!(max.gen::<f64>(), 1.0 - EPSILON64 / 2.0);
}
macro_rules! test_f32 {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);

#[test]
fn openclosed01_edge_cases() {
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<f32, _>(OpenClosed01), 0.0 + EPSILON32 / 2.0);
assert_eq!(zeros.sample::<f64, _>(OpenClosed01), 0.0 + EPSILON64 / 2.0);
// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01),
0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);

let mut one32 = StepRng::new(1 << 8, 0);
assert_eq!(one32.sample::<f32, _>(OpenClosed01), EPSILON32);

let mut one64 = StepRng::new(1 << 11, 0);
assert_eq!(one64.sample::<f64, _>(OpenClosed01), EPSILON64);

let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<f32, _>(OpenClosed01), 1.0);
assert_eq!(max.sample::<f64, _>(OpenClosed01), 1.0);
// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
}
}
}
test_f32! { f32_edge_cases, f32, 0.0, EPSILON32 }
#[cfg(feature="simd_support")]
test_f32! { f32x2_edge_cases, f32x2, f32x2::splat(0.0), f32x2::splat(EPSILON32) }
#[cfg(feature="simd_support")]
test_f32! { f32x4_edge_cases, f32x4, f32x4::splat(0.0), f32x4::splat(EPSILON32) }
#[cfg(feature="simd_support")]
test_f32! { f32x8_edge_cases, f32x8, f32x8::splat(0.0), f32x8::splat(EPSILON32) }
#[cfg(feature="simd_support")]
test_f32! { f32x16_edge_cases, f32x16, f32x16::splat(0.0), f32x16::splat(EPSILON32) }

#[test]
fn open01_edge_cases() {
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<f32, _>(Open01), 0.0 + EPSILON32 / 2.0);
assert_eq!(zeros.sample::<f64, _>(Open01), 0.0 + EPSILON64 / 2.0);
macro_rules! test_f64 {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);

let mut one32 = StepRng::new(1 << 9, 0);
assert_eq!(one32.sample::<f32, _>(Open01), EPSILON32 / 2.0 * 3.0);
// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01),
0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);

let mut one64 = StepRng::new(1 << 12, 0);
assert_eq!(one64.sample::<f64, _>(Open01), EPSILON64 / 2.0 * 3.0);

let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<f32, _>(Open01), 1.0 - EPSILON32 / 2.0);
assert_eq!(max.sample::<f64, _>(Open01), 1.0 - EPSILON64 / 2.0);
// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 12, 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
}
}
}
test_f64! { f64_edge_cases, f64, 0.0, EPSILON64 }
#[cfg(feature="simd_support")]
test_f64! { f64x2_edge_cases, f64x2, f64x2::splat(0.0), f64x2::splat(EPSILON64) }
#[cfg(feature="simd_support")]
test_f64! { f64x4_edge_cases, f64x4, f64x4::splat(0.0), f64x4::splat(EPSILON64) }
#[cfg(feature="simd_support")]
test_f64! { f64x8_edge_cases, f64x8, f64x8::splat(0.0), f64x8::splat(EPSILON64) }
}
35 changes: 35 additions & 0 deletions src/distributions/integer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

use {Rng};
use distributions::{Distribution, Standard};
#[cfg(feature="simd_support")]
use core::simd::*;

impl Distribution<u8> for Standard {
#[inline]
Expand Down Expand Up @@ -84,6 +86,39 @@ impl_int_from_uint! { i64, u64 }
#[cfg(feature = "i128_support")] impl_int_from_uint! { i128, u128 }
impl_int_from_uint! { isize, usize }

#[cfg(feature="simd_support")]
macro_rules! simd_impl {
($bits:expr,) => {};
($bits:expr, $ty:ty, $($ty_more:ty,)*) => {
simd_impl!($bits, $($ty_more,)*);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat usage of recursive macros.

But why do we need to pass $bits here instead of just using mem::size_of? It seems like an unnecessary risk of underfill/overfill.


impl Distribution<$ty> for Standard {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
let mut vec = Default::default();
unsafe {
let ptr = &mut vec;
let b_ptr = &mut *(ptr as *mut $ty as *mut [u8; $bits/8]);
rng.fill_bytes(b_ptr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look portable to me. Elsewhere we've made an effort to keep things portable; I don't think this needs to be an exception?

Unfortunately it doesn't look like the SIMD types support to_le. @TheIronBorn is this what you mean about using swap_bytes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that’s exactly where we’d use it

}
vec
}
}
}
}

#[cfg(feature="simd_support")]
simd_impl!(16, u8x2, i8x2,);
#[cfg(feature="simd_support")]
simd_impl!(32, u8x4, i8x4, u16x2, i16x2,);
#[cfg(feature="simd_support")]
simd_impl!(64, u8x8, i8x8, u16x4, i16x4, u32x2, i32x2,);
#[cfg(feature="simd_support")]
simd_impl!(128, u8x16, i8x16, u16x8, i16x8, u32x4, i32x4, u64x2, i64x2,);
#[cfg(feature="simd_support")]
simd_impl!(256, u8x32, i8x32, u16x16, i16x16, u32x8, i32x8, u64x4, i64x4,);
#[cfg(feature="simd_support")]
simd_impl!(512, u8x64, i8x64, u16x32, i16x32, u32x16, i32x16, u64x8, i64x8,);

#[cfg(test)]
mod tests {
Expand Down
1 change: 1 addition & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ mod integer;
#[cfg(feature="std")]
mod log_gamma;
mod other;
mod utils;
#[cfg(feature="std")]
mod ziggurat_tables;
#[cfg(feature="std")]
Expand Down
Loading