diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 8fda0319fef..b0df597542a 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -53,10 +53,17 @@ //! Those methods should include an assert to check the range is valid (i.e. //! `low < high`). The example below merely wraps another back-end. //! +//! The `new`, `new_inclusive` and `sample_single` functions use arguments of +//! type Borrow in order to support passing in values by reference or by +//! value. In the implementation of these functions, you can choose to simply +//! use the reference returned by `Borrow::borrow`, or you can choose to copy +//! or clone the value, whatever is appropriate for your type. +//! //! ``` //! use rand::prelude::*; //! use rand::distributions::uniform::{Uniform, SampleUniform, //! UniformSampler, UniformFloat}; +//! use std::borrow::Borrow; //! //! struct MyF32(f32); //! @@ -67,12 +74,18 @@ //! //! impl UniformSampler for UniformMyF32 { //! type X = MyF32; -//! fn new(low: Self::X, high: Self::X) -> Self { +//! fn new(low: B1, high: B2) -> Self +//! where B1: Borrow + Sized, +//! B2: Borrow + Sized +//! { //! UniformMyF32 { -//! inner: UniformFloat::::new(low.0, high.0), +//! inner: UniformFloat::::new(low.borrow().0, high.borrow().0), //! } //! } -//! fn new_inclusive(low: Self::X, high: Self::X) -> Self { +//! fn new_inclusive(low: B1, high: B2) -> Self +//! where B1: Borrow + Sized, +//! B2: Borrow + Sized +//! { //! UniformSampler::new(low, high) //! } //! fn sample(&self, rng: &mut R) -> Self::X { @@ -99,6 +112,7 @@ #[cfg(feature = "std")] use std::time::Duration; +use core::borrow::Borrow; use Rng; use distributions::Distribution; @@ -155,13 +169,19 @@ pub struct Uniform { impl Uniform { /// Create a new `Uniform` instance which samples uniformly from the half /// open range `[low, high)` (excluding `high`). Panics if `low >= high`. - pub fn new(low: X, high: X) -> Uniform { + pub fn new(low: B1, high: B2) -> Uniform + where B1: Borrow + Sized, + B2: Borrow + Sized + { Uniform { inner: X::Sampler::new(low, high) } } /// Create a new `Uniform` instance which samples uniformly from the closed /// range `[low, high]` (inclusive). Panics if `low > high`. - pub fn new_inclusive(low: X, high: X) -> Uniform { + pub fn new_inclusive(low: B1, high: B2) -> Uniform + where B1: Borrow + Sized, + B2: Borrow + Sized + { Uniform { inner: X::Sampler::new_inclusive(low, high) } } } @@ -206,14 +226,18 @@ pub trait UniformSampler: Sized { /// /// Usually users should not call this directly but instead use /// `Uniform::new`, which asserts that `low < high` before calling this. - fn new(low: Self::X, high: Self::X) -> Self; + fn new(low: B1, high: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized; /// Construct self, with inclusive bounds `[low, high]`. /// /// Usually users should not call this directly but instead use /// `Uniform::new_inclusive`, which asserts that `low <= high` before /// calling this. - fn new_inclusive(low: Self::X, high: Self::X) -> Self; + fn new_inclusive(low: B1, high: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized; /// Sample a value. fn sample(&self, rng: &mut R) -> Self::X; @@ -229,8 +253,10 @@ pub trait UniformSampler: Sized { /// sampling only a single value from the specified range. The default /// implementation simply calls `UniformSampler::new` then `sample` on the /// result. - fn sample_single(low: Self::X, high: Self::X, rng: &mut R) + fn sample_single(low: B1, high: B2, rng: &mut R) -> Self::X + where B1: Borrow + Sized, + B2: Borrow + Sized { let uniform: Self = UniformSampler::new(low, high); uniform.sample(rng) @@ -311,14 +337,24 @@ macro_rules! uniform_int_impl { #[inline] // if the range is constant, this helps LLVM to do the // calculations at compile-time. - fn new(low: Self::X, high: Self::X) -> Self { + fn new(low_b: B1, high_b: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low < high, "Uniform::new called with `low >= high`"); UniformSampler::new_inclusive(low, high - 1) } #[inline] // if the range is constant, this helps LLVM to do the // calculations at compile-time. - fn new_inclusive(low: Self::X, high: Self::X) -> Self { + fn new_inclusive(low_b: B1, high_b: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low <= high, "Uniform::new_inclusive called with `low > high`"); let unsigned_max = ::core::$unsigned::MAX; @@ -362,10 +398,13 @@ macro_rules! uniform_int_impl { } } - fn sample_single(low: Self::X, - high: Self::X, - rng: &mut R) -> Self::X + fn sample_single(low_b: B1, high_b: B2, rng: &mut R) + -> Self::X + where B1: Borrow + Sized, + B2: Borrow + Sized { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low < high, "Uniform::sample_single called with low >= high"); let range = high.wrapping_sub(low) as $unsigned as $u_large; @@ -532,7 +571,12 @@ macro_rules! uniform_float_impl { impl UniformSampler for UniformFloat<$ty> { type X = $ty; - fn new(low: Self::X, high: Self::X) -> Self { + fn new(low_b: B1, high_b: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low < high, "Uniform::new called with `low >= high`"); let scale = high - low; let offset = low - scale; @@ -542,7 +586,12 @@ macro_rules! uniform_float_impl { } } - fn new_inclusive(low: Self::X, high: Self::X) -> Self { + fn new_inclusive(low_b: B1, high_b: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low <= high, "Uniform::new_inclusive called with `low > high`"); let scale = high - low; @@ -565,9 +614,13 @@ macro_rules! uniform_float_impl { value1_2 * self.scale + self.offset } - fn sample_single(low: Self::X, - high: Self::X, - rng: &mut R) -> Self::X { + fn sample_single(low_b: B1, high_b: B2, rng: &mut R) + -> Self::X + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low < high, "Uniform::sample_single called with low >= high"); let scale = high - low; @@ -624,13 +677,23 @@ impl UniformSampler for UniformDuration { type X = Duration; #[inline] - fn new(low: Duration, high: Duration) -> UniformDuration { + fn new(low_b: B1, high_b: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low < high, "Uniform::new called with `low >= high`"); UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) } #[inline] - fn new_inclusive(low: Duration, high: Duration) -> UniformDuration { + fn new_inclusive(low_b: B1, high_b: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); assert!(low <= high, "Uniform::new_inclusive called with `low > high`"); let size = high - low; let nanos = size @@ -750,6 +813,18 @@ mod tests { assert!(low <= v && v <= high); } + let my_uniform = Uniform::new(&low, high); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!(low <= v && v < high); + } + + let my_uniform = Uniform::new_inclusive(&low, &high); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!(low <= v && v <= high); + } + for _ in 0..1000 { let v: $ty = rng.gen_range(low, high); assert!(low <= v && v < high); @@ -809,6 +884,7 @@ mod tests { #[test] fn test_custom_uniform() { + use core::borrow::Borrow; #[derive(Clone, Copy, PartialEq, PartialOrd)] struct MyF32 { x: f32, @@ -819,12 +895,18 @@ mod tests { } impl UniformSampler for UniformMyF32 { type X = MyF32; - fn new(low: Self::X, high: Self::X) -> Self { + fn new(low: B1, high: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { UniformMyF32 { - inner: UniformFloat::::new(low.x, high.x), + inner: UniformFloat::::new(low.borrow().x, high.borrow().x), } } - fn new_inclusive(low: Self::X, high: Self::X) -> Self { + fn new_inclusive(low: B1, high: B2) -> Self + where B1: Borrow + Sized, + B2: Borrow + Sized + { UniformSampler::new(low, high) } fn sample(&self, rng: &mut R) -> Self::X { diff --git a/src/lib.rs b/src/lib.rs index 6d6d762b3cb..9152846d23b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -300,10 +300,10 @@ pub mod isaac { use core::{mem, slice}; +use core::borrow::Borrow; use distributions::{Distribution, Standard}; use distributions::uniform::{SampleUniform, UniformSampler}; - /// An automatically-implemented extension trait on [`RngCore`] providing high-level /// generic methods for sampling values and other convenience methods. /// @@ -387,7 +387,9 @@ pub trait Rng: RngCore { /// ``` /// /// [`Uniform`]: distributions/uniform/struct.Uniform.html - fn gen_range(&mut self, low: T, high: T) -> T { + fn gen_range(&mut self, low: B1, high: B2) -> T + where B1: Borrow + Sized, + B2: Borrow + Sized { T::Sampler::sample_single(low, high, self) } @@ -935,19 +937,19 @@ mod test { fn test_gen_range() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-3, 42); - assert!(a >= -3 && a < 42); - assert_eq!(r.gen_range(0, 1), 0); - assert_eq!(r.gen_range(-12, -11), -12); - } - - for _ in 0..1000 { - let a = r.gen_range(10, 42); - assert!(a >= 10 && a < 42); - assert_eq!(r.gen_range(0, 1), 0); + let a = r.gen_range(-4711, 17); + assert!(a >= -4711 && a < 17); + let a = r.gen_range(-3i8, 42); + assert!(a >= -3i8 && a < 42i8); + let a = r.gen_range(&10u16, 99); + assert!(a >= 10u16 && a < 99u16); + let a = r.gen_range(-100i32, &2000); + assert!(a >= -100i32 && a < 2000i32); + + assert_eq!(r.gen_range(0u32, 1), 0u32); + assert_eq!(r.gen_range(-12i64, -11), -12i64); assert_eq!(r.gen_range(3_000_000, 3_000_001), 3_000_000); } - } #[test]