From 34bd7efc6b5352e62724edfe6cf54edd68de805d Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Fri, 24 Feb 2023 12:41:47 -0500 Subject: [PATCH] Use const generics in Dirichlet --- rand_distr/CHANGELOG.md | 2 + rand_distr/src/dirichlet.rs | 62 +++++++---------------------- rand_distr/tests/value_stability.rs | 6 +-- 3 files changed, 19 insertions(+), 51 deletions(-) diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index d1da24a5a74..7d6c0602e17 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Upgrade Rand - Fix Knuth's method so `Poisson` doesn't return -1.0 for small lambda - Fix `Poisson` distribution instantiation so it return an error if lambda is infinite +- `Dirichlet` now uses `const` generics, which means that its size is required at compile time (#1292) +- The `Dirichlet::new_with_size` constructor was removed (#1292) ## [0.4.3] - 2021-12-30 - Fix `no_std` build (#1208) diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 786cbccd0cc..74fbff39bc7 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -13,7 +13,6 @@ use num_traits::Float; use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; use rand::Rng; use core::fmt; -use alloc::{boxed::Box, vec, vec::Vec}; /// The Dirichlet distribution `Dirichlet(alpha)`. /// @@ -27,14 +26,14 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// use rand::prelude::*; /// use rand_distr::Dirichlet; /// -/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); +/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); /// let samples = dirichlet.sample(&mut rand::thread_rng()); /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Dirichlet +pub struct Dirichlet where F: Float, StandardNormal: Distribution, @@ -42,7 +41,7 @@ where Open01: Distribution, { /// Concentration parameters (alpha) - alpha: Box<[F]>, + alpha: [F; N], } /// Error type returned from `Dirchlet::new`. @@ -72,7 +71,7 @@ impl fmt::Display for Error { #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] impl std::error::Error for Error {} -impl Dirichlet +impl Dirichlet where F: Float, StandardNormal: Distribution, @@ -83,8 +82,8 @@ where /// /// Requires `alpha.len() >= 2`. #[inline] - pub fn new(alpha: &[F]) -> Result, Error> { - if alpha.len() < 2 { + pub fn new(alpha: [F; N]) -> Result, Error> { + if N < 2 { return Err(Error::AlphaTooShort); } for &ai in alpha.iter() { @@ -93,36 +92,19 @@ where } } - Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() }) - } - - /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. - /// - /// Requires `size >= 2`. - #[inline] - pub fn new_with_size(alpha: F, size: usize) -> Result, Error> { - if !(alpha > F::zero()) { - return Err(Error::AlphaTooSmall); - } - if size < 2 { - return Err(Error::SizeTooSmall); - } - Ok(Dirichlet { - alpha: vec![alpha; size].into_boxed_slice(), - }) + Ok(Dirichlet { alpha }) } } -impl Distribution> for Dirichlet +impl Distribution<[F; N]> for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> Vec { - let n = self.alpha.len(); - let mut samples = vec![F::zero(); n]; + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; let mut sum = F::zero(); for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) { @@ -144,23 +126,7 @@ mod test { #[test] fn test_dirichlet() { - let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); - let mut rng = crate::test::rng(221); - let samples = d.sample(&mut rng); - let _: Vec = samples - .into_iter() - .map(|x| { - assert!(x > 0.0); - x - }) - .collect(); - } - - #[test] - fn test_dirichlet_with_param() { - let alpha = 0.5f64; - let size = 2; - let d = Dirichlet::new_with_size(alpha, size).unwrap(); + let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); let _: Vec = samples @@ -175,17 +141,17 @@ mod test { #[test] #[should_panic] fn test_dirichlet_invalid_length() { - Dirichlet::new_with_size(0.5f64, 1).unwrap(); + Dirichlet::new([0.5]).unwrap(); } #[test] #[should_panic] fn test_dirichlet_invalid_alpha() { - Dirichlet::new_with_size(0.0f64, 2).unwrap(); + Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); } #[test] fn dirichlet_distributions_can_be_compared() { - assert_eq!(Dirichlet::new(&[1.0, 2.0]), Dirichlet::new(&[1.0, 2.0])); + assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); } } diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index d3754705db5..4b9490a6581 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -348,10 +348,10 @@ fn weibull_stability() { fn dirichlet_stability() { let mut rng = get_rng(223); assert_eq!( - rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), - vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146] + rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] ); - assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![ + assert_eq!(rng.sample(Dirichlet::new([8.0; 5]).unwrap()), [ 0.17684200044809556, 0.29915953935953055, 0.1832858056608014,