Skip to content

Commit

Permalink
Add math_helpers module and various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pitdicker committed Jun 22, 2018
1 parent 209836f commit 4dab1e3
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 165 deletions.
103 changes: 22 additions & 81 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
use core::mem;
use Rng;
use distributions::{Distribution, Standard};
use distributions::math_helpers::CastFromInt;
#[cfg(feature="simd_support")]
use core::simd::*;

Expand Down Expand Up @@ -85,70 +86,7 @@ pub(crate) trait IntoFloat {
}

macro_rules! float_impls {
($ty:ty, $uty:ty, $fraction_bits:expr, $exponent_bias:expr) => {
impl IntoFloat for $uty {
type F = $ty;
#[inline(always)]
fn into_float_with_exponent(self, exponent: i32) -> $ty {
// The exponent is encoded using an offset-binary representation
let exponent_bits =
(($exponent_bias + exponent) as $uty) << $fraction_bits;
unsafe { mem::transmute(self | exponent_bits) }
}
}

impl Distribution<$ty> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
// Multiply-based method; 24/53 random bits; [0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$ty>() * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $uty << precision) as $ty);

let value: $uty = rng.gen();
scale * (value >> (float_size - precision)) as $ty
}
}

impl Distribution<$ty> for OpenClosed01 {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
// Multiply-based method; 24/53 random bits; (0, 1] interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$ty>() * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $uty << precision) as $ty);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
// Add 1 to shift up; will not overflow because of right-shift:
scale * (value + 1) as $ty
}
}

impl Distribution<$ty> for Open01 {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
// Transmute-based method; 23/52 random bits; (0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
const EPSILON: $ty = 1.0 / (1u64 << $fraction_bits) as $ty;
let float_size = mem::size_of::<$ty>() * 8;

let value: $uty = rng.gen();
let fraction = value >> (float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
}
}
}
}
float_impls! { f32, u32, 23, 127 }
float_impls! { f64, u64, 52, 1023 }


#[cfg(feature="simd_support")]
macro_rules! simd_float_impls {
($ty:ident, $uty:ident, $f_scalar:ty, $u_scalar:ty,
($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty,
$fraction_bits:expr, $exponent_bias:expr) => {
impl IntoFloat for $uty {
type F = $ty;
Expand All @@ -157,7 +95,7 @@ macro_rules! simd_float_impls {
// The exponent is encoded using an offset-binary representation
let exponent_bits: $u_scalar =
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
unsafe { mem::transmute(self | $uty::splat(exponent_bits)) }
$ty::from_bits(self | exponent_bits)
}
}

Expand All @@ -168,11 +106,11 @@ macro_rules! simd_float_impls {
// those are usually more random.
let float_size = mem::size_of::<$f_scalar>() * 8;
let precision = $fraction_bits + 1;
let scale = $ty::splat(1.0 / ((1 as $u_scalar << precision) as $f_scalar));
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = $ty::from(value >> (float_size - precision));
scale * value
let value = value >> (float_size - precision);
scale * $ty::cast_from_int(value)
}
}

Expand All @@ -183,12 +121,12 @@ macro_rules! simd_float_impls {
// those are usually more random.
let float_size = mem::size_of::<$f_scalar>() * 8;
let precision = $fraction_bits + 1;
let scale = $ty::splat(1.0 / ((1 as $u_scalar << precision) as $f_scalar));
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
// Add 1 to shift up; will not overflow because of right-shift:
let value = $ty::from((value >> (float_size - precision)) + 1);
scale * value
scale * $ty::cast_from_int(value + 1)
}
}

Expand All @@ -197,32 +135,35 @@ macro_rules! simd_float_impls {
// Transmute-based method; 23/52 random bits; (0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
const EPSILON: $f_scalar = 1.0 / (1u64 << $fraction_bits) as $f_scalar;
use core::$f_scalar::EPSILON;
let float_size = mem::size_of::<$f_scalar>() * 8;

let value: $uty = rng.gen();
let fraction = value >> (float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
}
}
}
}

float_impls! { f32, u32, f32, u32, 23, 127 }
float_impls! { f64, u64, f64, u64, 52, 1023 }

#[cfg(feature="simd_support")]
simd_float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
simd_float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
simd_float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
simd_float_impls! { f32x16, u32x16, f32, u32, 23, 127 }
float_impls! { f32x16, u32x16, f32, u32, 23, 127 }

#[cfg(feature="simd_support")]
simd_float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
#[cfg(feature="simd_support")]
simd_float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
#[cfg(feature="simd_support")]
simd_float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }


#[cfg(test)]
Expand Down Expand Up @@ -267,7 +208,7 @@ mod tests {
}
}
}
test_f32! { f32_edge_cases, f32, 0.0, ::core::f32::EPSILON }
test_f32! { f32_edge_cases, f32, 0.0, EPSILON32 }
#[cfg(feature="simd_support")]
test_f32! { f32x2_edge_cases, f32x2, f32x2::splat(0.0), f32x2::splat(EPSILON32) }
#[cfg(feature="simd_support")]
Expand Down
161 changes: 161 additions & 0 deletions src/distributions/math_helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Math helper functions

#[cfg(feature="simd_support")]
use core::simd::*;


pub trait WideningMultiply<RHS = Self> {
type Output;

fn wmul(self, x: RHS) -> Self::Output;
}

macro_rules! wmul_impl {
($ty:ty, $wide:ty, $shift:expr) => {
impl WideningMultiply for $ty {
type Output = ($ty, $ty);

#[inline(always)]
fn wmul(self, x: $ty) -> Self::Output {
let tmp = (self as $wide) * (x as $wide);
((tmp >> $shift) as $ty, tmp as $ty)
}
}
}
}
wmul_impl! { u8, u16, 8 }
wmul_impl! { u16, u32, 16 }
wmul_impl! { u32, u64, 32 }
#[cfg(feature = "i128_support")]
wmul_impl! { u64, u128, 64 }

// This code is a translation of the __mulddi3 function in LLVM's
// compiler-rt. It is an optimised variant of the common method
// `(a + b) * (c + d) = ac + ad + bc + bd`.
//
// For some reason LLVM can optimise the C version very well, but
// keeps shuffling registers in this Rust translation.
macro_rules! wmul_impl_large {
($ty:ty, $half:expr) => {
impl WideningMultiply for $ty {
type Output = ($ty, $ty);

#[inline(always)]
fn wmul(self, b: $ty) -> Self::Output {
const LOWER_MASK: $ty = !0 >> $half;
let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
let mut t = low >> $half;
low &= LOWER_MASK;
t += (self >> $half).wrapping_mul(b & LOWER_MASK);
low += (t & LOWER_MASK) << $half;
let mut high = t >> $half;
t = low >> $half;
low &= LOWER_MASK;
t += (b >> $half).wrapping_mul(self & LOWER_MASK);
low += (t & LOWER_MASK) << $half;
high += t >> $half;
high += (self >> $half).wrapping_mul(b >> $half);

(high, low)
}
}
}
}
#[cfg(not(feature = "i128_support"))]
wmul_impl_large! { u64, 32 }
#[cfg(feature = "i128_support")]
wmul_impl_large! { u128, 64 }

macro_rules! wmul_impl_usize {
($ty:ty) => {
impl WideningMultiply for usize {
type Output = (usize, usize);

#[inline(always)]
fn wmul(self, x: usize) -> Self::Output {
let (high, low) = (self as $ty).wmul(x as $ty);
(high as usize, low as usize)
}
}
}
}
#[cfg(target_pointer_width = "32")]
wmul_impl_usize! { u32 }
#[cfg(target_pointer_width = "64")]
wmul_impl_usize! { u64 }


pub trait CastFromInt<T> {
fn cast_from_int(i: T) -> Self;
}

impl CastFromInt<u32> for f32 {
fn cast_from_int(i: u32) -> Self { i as f32 }
}

impl CastFromInt<u64> for f64 {
fn cast_from_int(i: u64) -> Self { i as f64 }
}

#[cfg(feature="simd_support")]
macro_rules! simd_float_from_int {
($ty:ident, $uty:ident) => {
impl CastFromInt<$uty> for $ty {
fn cast_from_int(i: $uty) -> Self { $ty::from(i) }
}
}
}

#[cfg(feature="simd_support")] simd_float_from_int! { f32x2, u32x2 }
#[cfg(feature="simd_support")] simd_float_from_int! { f32x4, u32x4 }
#[cfg(feature="simd_support")] simd_float_from_int! { f32x8, u32x8 }
#[cfg(feature="simd_support")] simd_float_from_int! { f32x16, u32x16 }
#[cfg(feature="simd_support")] simd_float_from_int! { f64x2, u64x2 }
#[cfg(feature="simd_support")] simd_float_from_int! { f64x4, u64x4 }
#[cfg(feature="simd_support")] simd_float_from_int! { f64x8, u64x8 }


/// `PartialOrd` for vectors compares lexicographically. We want natural order.
/// Only the comparison functions we need are implemented.
pub trait NaturalCompare {
fn cmp_lt(self, other: Self) -> bool;
fn cmp_le(self, other: Self) -> bool;
}

impl NaturalCompare for f32 {
fn cmp_lt(self, other: Self) -> bool { self < other }
fn cmp_le(self, other: Self) -> bool { self <= other }
}

impl NaturalCompare for f64 {
fn cmp_lt(self, other: Self) -> bool { self < other }
fn cmp_le(self, other: Self) -> bool { self <= other }
}

#[cfg(feature="simd_support")]
macro_rules! simd_less_then {
($ty:ident) => {
impl NaturalCompare for $ty {
fn cmp_lt(self, other: Self) -> bool { self.lt(other).all() }
fn cmp_le(self, other: Self) -> bool { self.le(other).all() }
}
}
}

#[cfg(feature="simd_support")] simd_less_then! { f32x2 }
#[cfg(feature="simd_support")] simd_less_then! { f32x4 }
#[cfg(feature="simd_support")] simd_less_then! { f32x8 }
#[cfg(feature="simd_support")] simd_less_then! { f32x16 }
#[cfg(feature="simd_support")] simd_less_then! { f64x2 }
#[cfg(feature="simd_support")] simd_less_then! { f64x4 }
#[cfg(feature="simd_support")] simd_less_then! { f64x8 }
1 change: 1 addition & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ mod float;
mod integer;
#[cfg(feature="std")]
mod log_gamma;
mod math_helpers;
mod other;
#[cfg(feature="std")]
mod ziggurat_tables;
Expand Down
Loading

0 comments on commit 4dab1e3

Please sign in to comment.