Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate rand_distr to num-traits for no_std support #987

Merged
merged 16 commits into from
Aug 1, 2020
Merged
10 changes: 9 additions & 1 deletion rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ travis-ci = { repository = "rust-random/rand" }
appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = "0.7" }
rand = { path = "..", version = "0.7", default-features = false }
num-traits = { version = "0.2", default-features = false, features = ["libm"] }
newpavlov marked this conversation as resolved.
Show resolved Hide resolved

[features]
default = ["std"]
std = ["alloc"]
alloc = []

[dev-dependencies]
rand_pcg = { version = "0.2", path = "../rand_pcg" }
# For inline examples
rand = { path = "..", version = "0.7", default-features = false, features = ["std_rng", "std"] }
# Histogram implementation for testing uniformity
average = "0.10.3"
27 changes: 5 additions & 22 deletions rand_distr/src/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

use crate::{Distribution, Uniform};
use rand::Rng;
use std::{error, fmt};
use core::fmt;

/// The binomial distribution `Binomial(n, p)`.
///
Expand Down Expand Up @@ -53,7 +53,8 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl Binomial {
/// Construct a new `Binomial` with the given shape parameters `n` (number
Expand All @@ -72,7 +73,7 @@ impl Binomial {
/// Convert a `f64` to an `i64`, panicing on overflow.
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
fn f64_to_i64(x: f64) -> i64 {
assert!(x < (::std::i64::MAX as f64));
assert!(x < (core::i64::MAX as f64));
x as i64
}

Expand Down Expand Up @@ -106,7 +107,7 @@ impl Distribution<u64> for Binomial {
// Ranlib uses 30, and GSL uses 14.
const BINV_THRESHOLD: f64 = 10.;

if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (::std::i32::MAX as u64) {
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) {
// Use the BINV algorithm.
let s = p / q;
let a = ((self.n + 1) as f64) * s;
Expand Down Expand Up @@ -338,22 +339,4 @@ mod test {
fn test_binomial_invalid_lambda_neg() {
Binomial::new(20, -10.0).unwrap();
}

#[test]
fn value_stability() {
fn test_samples(n: u64, p: f64, expected: &[u64]) {
let distr = Binomial::new(n, p).unwrap();
let mut rng = crate::test::rng(353);
let mut buf = [0; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}

// We have multiple code paths: np < 10, p > 0.5
test_samples(2, 0.7, &[1, 1, 2, 1]);
test_samples(20, 0.3, &[7, 7, 5, 7]);
test_samples(2000, 0.6, &[1194, 1208, 1192, 1210]);
}
}
41 changes: 23 additions & 18 deletions rand_distr/src/cauchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

//! The Cauchy distribution.

use crate::utils::Float;
use num_traits::{Float, FloatConst};
use crate::{Distribution, Standard};
use rand::Rng;
use std::{error, fmt};
use core::fmt;

/// The Cauchy distribution `Cauchy(median, scale)`.
///
Expand All @@ -32,9 +32,11 @@ use std::{error, fmt};
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Cauchy<N> {
median: N,
scale: N,
pub struct Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
median: F,
scale: F,
}

/// Error type returned from `Cauchy::new`.
Expand All @@ -52,30 +54,31 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl<N: Float> Cauchy<N>
where Standard: Distribution<N>
impl<F> Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
/// Construct a new `Cauchy` with the given shape parameters
/// `median` the peak location and `scale` the scale factor.
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
if !(scale > N::from(0.0)) {
pub fn new(median: F, scale: F) -> Result<Cauchy<F>, Error> {
if !(scale > F::zero()) {
return Err(Error::ScaleTooSmall);
}
Ok(Cauchy { median, scale })
}
}

impl<N: Float> Distribution<N> for Cauchy<N>
where Standard: Distribution<N>
impl<F> Distribution<F> for Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// sample from [0, 1)
let x = Standard.sample(rng);
// get standard cauchy random number
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
let comp_dev = (N::pi() * x).tan();
let comp_dev = (F::PI() * x).tan();
// shift and scale according to parameters
self.median + self.scale * comp_dev
}
Expand Down Expand Up @@ -108,10 +111,12 @@ mod test {
sum += numbers[i];
}
let median = median(&mut numbers);
println!("Cauchy median: {}", median);
#[cfg(feature = "std")]
std::println!("Cauchy median: {}", median);
assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough
let mean = sum / 1000.0;
println!("Cauchy mean: {}", mean);
#[cfg(feature = "std")]
std::println!("Cauchy mean: {}", mean);
// for a Cauchy distribution the mean should not converge
assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough
}
Expand All @@ -130,8 +135,8 @@ mod test {

#[test]
fn value_stability() {
fn gen_samples<N: Float + core::fmt::Debug>(m: N, s: N, buf: &mut [N])
where Standard: Distribution<N> {
fn gen_samples<F: Float + FloatConst + core::fmt::Debug>(m: F, s: F, buf: &mut [F])
where Standard: Distribution<F> {
let distr = Cauchy::new(m, s).unwrap();
let mut rng = crate::test::rng(353);
for x in buf {
Expand Down
89 changes: 41 additions & 48 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
// except according to those terms.

//! The dirichlet distribution.

use crate::utils::Float;
#![cfg(feature = "alloc")]
use num_traits::Float;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use rand::Rng;
use std::{error, fmt};
use core::fmt;
use alloc::{boxed::Box, vec, vec::Vec};

/// The Dirichlet distribution `Dirichlet(alpha)`.
///
Expand All @@ -26,14 +27,20 @@ use std::{error, fmt};
/// use rand::prelude::*;
/// use rand_distr::Dirichlet;
///
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
/// let samples = dirichlet.sample(&mut rand::thread_rng());
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[derive(Clone, Debug)]
pub struct Dirichlet<N> {
pub struct Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Concentration parameters (alpha)
alpha: Vec<N>,
alpha: Box<[F]>,
}

/// Error type returned from `Dirchlet::new`.
Expand All @@ -58,68 +65,70 @@ impl fmt::Display for Error {
}
}

impl error::Error for Error {}
#[cfg(feature = "std")]
impl std::error::Error for Error {}

impl<N: Float> Dirichlet<N>
impl<F> Dirichlet<F>
where
StandardNormal: Distribution<N>,
Exp1: Distribution<N>,
Open01: Distribution<N>,
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
///
/// Requires `alpha.len() >= 2`.
#[inline]
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
let a = alpha.into();
if a.len() < 2 {
pub fn new(alpha: &[F]) -> Result<Dirichlet<F>, Error> {
if alpha.len() < 2 {
return Err(Error::AlphaTooShort);
}
for &ai in &a {
if !(ai > N::from(0.0)) {
for &ai in alpha.iter() {
if !(ai > F::zero()) {
return Err(Error::AlphaTooSmall);
}
}

Ok(Dirichlet { alpha: a })
Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() })
}

/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
if !(alpha > N::from(0.0)) {
pub fn new_with_size(alpha: F, size: usize) -> Result<Dirichlet<F>, Error> {
if !(alpha > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if size < 2 {
return Err(Error::SizeTooSmall);
}
Ok(Dirichlet {
alpha: vec![alpha; size],
alpha: vec![alpha; size].into_boxed_slice(),
})
}
}

impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
impl<F> Distribution<Vec<F>> for Dirichlet<F>
where
StandardNormal: Distribution<N>,
Exp1: Distribution<N>,
Open01: Distribution<N>,
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
let n = self.alpha.len();
let mut samples = vec![N::from(0.0); n];
let mut sum = N::from(0.0);
let mut samples = vec![F::zero(); n];
let mut sum = F::zero();

for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, N::from(1.0)).unwrap();
let g = Gamma::new(a, F::one()).unwrap();
*s = g.sample(rng);
sum += *s;
sum = sum + (*s);
}
let invacc = N::from(1.0) / sum;
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s *= invacc;
*s = (*s)*invacc;
}
samples
}
Expand All @@ -131,7 +140,7 @@ mod test {

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap();
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
Expand Down Expand Up @@ -170,20 +179,4 @@ mod test {
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_size(0.0f64, 2).unwrap();
}

#[test]
fn value_stability() {
let mut rng = crate::test::rng(223);
assert_eq!(
rng.sample(Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap()),
vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
);
assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![
0.17684200044809556,
0.29915953935953055,
0.1832858056608014,
0.1425623503573967,
0.19815030417417595
]);
}
}
Loading