Skip to content

Commit

Permalink
Make Uniform and its helper traits use arguments of type Borrow<X> ra…
Browse files Browse the repository at this point in the history
…ther than type X.
  • Loading branch information
sicking committed Jun 12, 2018
1 parent ec3d7ef commit 8bec2de
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 36 deletions.
128 changes: 105 additions & 23 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,17 @@
//! Those methods should include an assert to check the range is valid (i.e.
//! `low < high`). The example below merely wraps another back-end.
//!
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
//! type Borrow<X> in order to support passing in values by reference or by
//! value. In the implementation of these functions, you can choose to simply
//! use the reference returned by `Borrow::borrow`, or you can choose to copy
//! or clone the value, whatever is appropriate for your type.
//!
//! ```
//! use rand::prelude::*;
//! use rand::distributions::uniform::{Uniform, SampleUniform,
//! UniformSampler, UniformFloat};
//! use std::borrow::Borrow;
//!
//! struct MyF32(f32);
//!
Expand All @@ -67,12 +74,18 @@
//!
//! impl UniformSampler for UniformMyF32 {
//! type X = MyF32;
//! fn new(low: Self::X, high: Self::X) -> Self {
//! fn new<B1, B2>(low: B1, high: B2) -> Self
//! where B1: Borrow<Self::X> + Sized,
//! B2: Borrow<Self::X> + Sized
//! {
//! UniformMyF32 {
//! inner: UniformFloat::<f32>::new(low.0, high.0),
//! inner: UniformFloat::<f32>::new(low.borrow().0, high.borrow().0),
//! }
//! }
//! fn new_inclusive(low: Self::X, high: Self::X) -> Self {
//! fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
//! where B1: Borrow<Self::X> + Sized,
//! B2: Borrow<Self::X> + Sized
//! {
//! UniformSampler::new(low, high)
//! }
//! fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
Expand All @@ -99,6 +112,7 @@

#[cfg(feature = "std")]
use std::time::Duration;
use core::borrow::Borrow;

use Rng;
use distributions::Distribution;
Expand Down Expand Up @@ -155,13 +169,19 @@ pub struct Uniform<X: SampleUniform> {
impl<X: SampleUniform> Uniform<X> {
/// Create a new `Uniform` instance which samples uniformly from the half
/// open range `[low, high)` (excluding `high`). Panics if `low >= high`.
pub fn new(low: X, high: X) -> Uniform<X> {
pub fn new<B1, B2>(low: B1, high: B2) -> Uniform<X>
where B1: Borrow<X> + Sized,
B2: Borrow<X> + Sized
{
Uniform { inner: X::Sampler::new(low, high) }
}

/// Create a new `Uniform` instance which samples uniformly from the closed
/// range `[low, high]` (inclusive). Panics if `low > high`.
pub fn new_inclusive(low: X, high: X) -> Uniform<X> {
pub fn new_inclusive<B1, B2>(low: B1, high: B2) -> Uniform<X>
where B1: Borrow<X> + Sized,
B2: Borrow<X> + Sized
{
Uniform { inner: X::Sampler::new_inclusive(low, high) }
}
}
Expand Down Expand Up @@ -206,14 +226,18 @@ pub trait UniformSampler: Sized {
///
/// Usually users should not call this directly but instead use
/// `Uniform::new`, which asserts that `low < high` before calling this.
fn new(low: Self::X, high: Self::X) -> Self;
fn new<B1, B2>(low: B1, high: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized;

/// Construct self, with inclusive bounds `[low, high]`.
///
/// Usually users should not call this directly but instead use
/// `Uniform::new_inclusive`, which asserts that `low <= high` before
/// calling this.
fn new_inclusive(low: Self::X, high: Self::X) -> Self;
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized;

/// Sample a value.
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X;
Expand All @@ -229,8 +253,10 @@ pub trait UniformSampler: Sized {
/// sampling only a single value from the specified range. The default
/// implementation simply calls `UniformSampler::new` then `sample` on the
/// result.
fn sample_single<R: Rng + ?Sized>(low: Self::X, high: Self::X, rng: &mut R)
fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R)
-> Self::X
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let uniform: Self = UniformSampler::new(low, high);
uniform.sample(rng)
Expand Down Expand Up @@ -311,14 +337,24 @@ macro_rules! uniform_int_impl {

#[inline] // if the range is constant, this helps LLVM to do the
// calculations at compile-time.
fn new(low: Self::X, high: Self::X) -> Self {
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low < high, "Uniform::new called with `low >= high`");
UniformSampler::new_inclusive(low, high - 1)
}

#[inline] // if the range is constant, this helps LLVM to do the
// calculations at compile-time.
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low <= high,
"Uniform::new_inclusive called with `low > high`");
let unsigned_max = ::core::$unsigned::MAX;
Expand Down Expand Up @@ -362,10 +398,13 @@ macro_rules! uniform_int_impl {
}
}

fn sample_single<R: Rng + ?Sized>(low: Self::X,
high: Self::X,
rng: &mut R) -> Self::X
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R)
-> Self::X
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low < high,
"Uniform::sample_single called with low >= high");
let range = high.wrapping_sub(low) as $unsigned as $u_large;
Expand Down Expand Up @@ -532,7 +571,12 @@ macro_rules! uniform_float_impl {
impl UniformSampler for UniformFloat<$ty> {
type X = $ty;

fn new(low: Self::X, high: Self::X) -> Self {
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low < high, "Uniform::new called with `low >= high`");
let scale = high - low;
let offset = low - scale;
Expand All @@ -542,7 +586,12 @@ macro_rules! uniform_float_impl {
}
}

fn new_inclusive(low: Self::X, high: Self::X) -> Self {
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low <= high,
"Uniform::new_inclusive called with `low > high`");
let scale = high - low;
Expand All @@ -565,9 +614,13 @@ macro_rules! uniform_float_impl {
value1_2 * self.scale + self.offset
}

fn sample_single<R: Rng + ?Sized>(low: Self::X,
high: Self::X,
rng: &mut R) -> Self::X {
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R)
-> Self::X
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low < high,
"Uniform::sample_single called with low >= high");
let scale = high - low;
Expand Down Expand Up @@ -624,13 +677,23 @@ impl UniformSampler for UniformDuration {
type X = Duration;

#[inline]
fn new(low: Duration, high: Duration) -> UniformDuration {
fn new<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low < high, "Uniform::new called with `low >= high`");
UniformDuration::new_inclusive(low, high - Duration::new(0, 1))
}

#[inline]
fn new_inclusive(low: Duration, high: Duration) -> UniformDuration {
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low <= high, "Uniform::new_inclusive called with `low > high`");
let size = high - low;
let nanos = size
Expand Down Expand Up @@ -750,6 +813,18 @@ mod tests {
assert!(low <= v && v <= high);
}

let my_uniform = Uniform::new(&low, high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!(low <= v && v < high);
}

let my_uniform = Uniform::new_inclusive(&low, &high);
for _ in 0..1000 {
let v: $ty = rng.sample(my_uniform);
assert!(low <= v && v <= high);
}

for _ in 0..1000 {
let v: $ty = rng.gen_range(low, high);
assert!(low <= v && v < high);
Expand Down Expand Up @@ -809,6 +884,7 @@ mod tests {

#[test]
fn test_custom_uniform() {
use core::borrow::Borrow;
#[derive(Clone, Copy, PartialEq, PartialOrd)]
struct MyF32 {
x: f32,
Expand All @@ -819,12 +895,18 @@ mod tests {
}
impl UniformSampler for UniformMyF32 {
type X = MyF32;
fn new(low: Self::X, high: Self::X) -> Self {
fn new<B1, B2>(low: B1, high: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
UniformMyF32 {
inner: UniformFloat::<f32>::new(low.x, high.x),
inner: UniformFloat::<f32>::new(low.borrow().x, high.borrow().x),
}
}
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where B1: Borrow<Self::X> + Sized,
B2: Borrow<Self::X> + Sized
{
UniformSampler::new(low, high)
}
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
Expand Down
28 changes: 15 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,10 @@ pub mod isaac {


use core::{mem, slice};
use core::borrow::Borrow;
use distributions::{Distribution, Standard};
use distributions::uniform::{SampleUniform, UniformSampler};


/// An automatically-implemented extension trait on [`RngCore`] providing high-level
/// generic methods for sampling values and other convenience methods.
///
Expand Down Expand Up @@ -387,7 +387,9 @@ pub trait Rng: RngCore {
/// ```
///
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
fn gen_range<T: SampleUniform, B1, B2>(&mut self, low: B1, high: B2) -> T
where B1: Borrow<T> + Sized,
B2: Borrow<T> + Sized {
T::Sampler::sample_single(low, high, self)
}

Expand Down Expand Up @@ -935,19 +937,19 @@ mod test {
fn test_gen_range() {
let mut r = rng(101);
for _ in 0..1000 {
let a = r.gen_range(-3, 42);
assert!(a >= -3 && a < 42);
assert_eq!(r.gen_range(0, 1), 0);
assert_eq!(r.gen_range(-12, -11), -12);
}

for _ in 0..1000 {
let a = r.gen_range(10, 42);
assert!(a >= 10 && a < 42);
assert_eq!(r.gen_range(0, 1), 0);
let a = r.gen_range(-4711, 17);
assert!(a >= -4711 && a < 17);
let a = r.gen_range(-3i8, 42);
assert!(a >= -3i8 && a < 42i8);
let a = r.gen_range(&10u16, 99);
assert!(a >= 10u16 && a < 99u16);
let a = r.gen_range(-100i32, &2000);
assert!(a >= -100i32 && a < 2000i32);

assert_eq!(r.gen_range(0u32, 1), 0u32);
assert_eq!(r.gen_range(-12i64, -11), -12i64);
assert_eq!(r.gen_range(3_000_000, 3_000_001), 3_000_000);
}

}

#[test]
Expand Down

0 comments on commit 8bec2de

Please sign in to comment.