From 81f3b1520b882857fe2fab662d5cd4496ea5134a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9F=D0=B0=D0=B2=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2=20=5BArtyom=20Pavlov=5D?= Date: Tue, 6 Dec 2022 11:07:55 +0300 Subject: [PATCH] Rework CryptoRng --- benches/generators.rs | 3 +- benches/seq.rs | 2 +- rand_chacha/src/chacha.rs | 102 ++++++----- rand_chacha/src/guts.rs | 7 +- rand_core/src/block.rs | 49 +++--- rand_core/src/error.rs | 9 +- rand_core/src/impls.rs | 4 + rand_core/src/lib.rs | 114 ++++-------- rand_core/src/os.rs | 14 +- rand_distr/src/binomial.rs | 13 +- rand_distr/src/cauchy.rs | 16 +- rand_distr/src/dirichlet.rs | 14 +- rand_distr/src/exponential.rs | 22 ++- rand_distr/src/gamma.rs | 53 +++--- rand_distr/src/geometric.rs | 49 +++--- rand_distr/src/hypergeometric.rs | 200 ++++++++++++++-------- rand_distr/src/inverse_gaussian.rs | 7 +- rand_distr/src/lib.rs | 17 +- rand_distr/src/normal.rs | 34 ++-- rand_distr/src/normal_inverse_gaussian.rs | 15 +- rand_distr/src/pareto.rs | 16 +- rand_distr/src/pert.rs | 16 +- rand_distr/src/poisson.rs | 23 ++- rand_distr/src/skew_normal.rs | 32 ++-- rand_distr/src/triangular.rs | 27 +-- rand_distr/src/unit_ball.rs | 2 +- rand_distr/src/unit_circle.rs | 2 +- rand_distr/src/unit_disc.rs | 2 +- rand_distr/src/unit_sphere.rs | 8 +- rand_distr/src/utils.rs | 10 +- rand_distr/src/weibull.rs | 16 +- rand_distr/src/weighted_alias.rs | 19 +- rand_distr/src/zipf.rs | 55 +++--- rand_distr/tests/sparkline.rs | 18 +- rand_distr/tests/value_stability.rs | 153 ++++++++++++----- rand_pcg/src/pcg128.rs | 15 +- rand_pcg/src/pcg128cm.rs | 8 +- rand_pcg/src/pcg64.rs | 8 +- src/distributions/bernoulli.rs | 6 +- src/distributions/distribution.rs | 9 +- src/distributions/float.rs | 26 ++- src/distributions/integer.rs | 11 +- src/distributions/mod.rs | 14 +- src/distributions/other.rs | 20 +-- src/distributions/uniform.rs | 126 +++++++++----- src/distributions/utils.rs | 2 - src/distributions/weighted.rs | 6 +- src/distributions/weighted_index.rs | 6 +- src/lib.rs | 7 +- src/prelude.rs | 5 +- src/rng.rs | 66 ++----- src/rngs/adapter/mod.rs | 3 +- src/rngs/adapter/read.rs | 28 +-- src/rngs/adapter/reseeding.rs | 50 +++--- src/rngs/mock.rs | 12 +- src/rngs/mod.rs | 15 +- src/rngs/small.rs | 13 +- src/rngs/std.rs | 16 +- src/rngs/thread.rs | 20 +-- src/rngs/xoshiro128plusplus.rs | 22 +-- src/rngs/xoshiro256plusplus.rs | 32 ++-- src/seq/index.rs | 40 +++-- src/seq/mod.rs | 93 ++++++---- 63 files changed, 981 insertions(+), 841 deletions(-) diff --git a/benches/generators.rs b/benches/generators.rs index 96fa302b6a0..02c91c6de7c 100644 --- a/benches/generators.rs +++ b/benches/generators.rs @@ -21,7 +21,7 @@ use rand::prelude::*; use rand::rngs::adapter::ReseedingRng; use rand::rngs::{mock::StepRng, OsRng}; use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; -use rand_pcg::{Pcg32, Pcg64, Pcg64Mcg, Pcg64Dxsm}; +use rand_pcg::{Pcg32, Pcg64, Pcg64Dxsm, Pcg64Mcg}; macro_rules! gen_bytes { ($fnn:ident, $gen:expr) => { @@ -142,7 +142,6 @@ reseeding_bytes!(reseeding_chacha20_64k, 64); reseeding_bytes!(reseeding_chacha20_256k, 256); reseeding_bytes!(reseeding_chacha20_1M, 1024); - macro_rules! threadrng_uint { ($fnn:ident, $ty:ty) => { #[bench] diff --git a/benches/seq.rs b/benches/seq.rs index 5b3a846f60b..8f5bea87f25 100644 --- a/benches/seq.rs +++ b/benches/seq.rs @@ -13,9 +13,9 @@ extern crate test; use test::Bencher; +use core::mem::size_of; use rand::prelude::*; use rand::seq::*; -use core::mem::size_of; // We force use of 32-bit RNG since seq code is optimised for use with 32-bit // generators on all platforms. diff --git a/rand_chacha/src/chacha.rs b/rand_chacha/src/chacha.rs index ad74b35f62b..161fc7c1874 100644 --- a/rand_chacha/src/chacha.rs +++ b/rand_chacha/src/chacha.rs @@ -13,10 +13,11 @@ use self::core::fmt; use crate::guts::ChaCha; -use rand_core::block::{BlockRng, BlockRngCore}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Serialize, Deserialize, Serializer, Deserializer}; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; // NB. this must remain consistent with some currently hard-coded numbers in this module const BUF_BLOCKS: u8 = 4; @@ -85,6 +86,7 @@ macro_rules! chacha_impl { impl BlockRngCore for $ChaChaXCore { type Item = u32; type Results = Array64; + #[inline] fn generate(&mut self, r: &mut Self::Results) { self.state.refill4($rounds, &mut r.0); @@ -93,13 +95,16 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXCore { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { - $ChaChaXCore { state: ChaCha::new(&seed, &[0u8; 8]) } + $ChaChaXCore { + state: ChaCha::new(&seed, &[0u8; 8]), + } } } - impl CryptoRng for $ChaChaXCore {} + impl CryptoBlockRng for $ChaChaXCore {} /// A cryptographically secure random number generator that uses the ChaCha algorithm. /// @@ -146,6 +151,7 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXRng { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { let core = $ChaChaXCore::from_seed(seed); @@ -160,18 +166,16 @@ macro_rules! chacha_impl { fn next_u32(&mut self) -> u32 { self.rng.next_u32() } + #[inline] fn next_u64(&mut self) -> u64 { self.rng.next_u64() } + #[inline] fn fill_bytes(&mut self, bytes: &mut [u8]) { self.rng.fill_bytes(bytes) } - #[inline] - fn try_fill_bytes(&mut self, bytes: &mut [u8]) -> Result<(), Error> { - self.rng.try_fill_bytes(bytes) - } } impl $ChaChaXRng { @@ -209,11 +213,9 @@ macro_rules! chacha_impl { #[inline] pub fn set_word_pos(&mut self, word_offset: u128) { let block = (word_offset / u128::from(BLOCK_WORDS)) as u64; + self.rng.core.state.set_block_pos(block); self.rng - .core - .state - .set_block_pos(block); - self.rng.generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); + .generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); } /// Set the stream number. @@ -229,10 +231,7 @@ macro_rules! chacha_impl { /// indirectly via `set_word_pos`), but this is not directly supported. #[inline] pub fn set_stream(&mut self, stream: u64) { - self.rng - .core - .state - .set_nonce(stream); + self.rng.core.state.set_nonce(stream); if self.rng.index() != 64 { let wp = self.get_word_pos(); self.set_word_pos(wp); @@ -242,19 +241,13 @@ macro_rules! chacha_impl { /// Get the stream number. #[inline] pub fn get_stream(&self) -> u64 { - self.rng - .core - .state - .get_nonce() + self.rng.core.state.get_nonce() } /// Get the seed. #[inline] pub fn get_seed(&self) -> [u8; 32] { - self.rng - .core - .state - .get_seed() + self.rng.core.state.get_seed() } } @@ -286,22 +279,20 @@ macro_rules! chacha_impl { } #[cfg(feature = "serde1")] impl<'de> Deserialize<'de> for $ChaChaXRng { - fn deserialize(d: D) -> Result where D: Deserializer<'de> { + fn deserialize(d: D) -> Result + where D: Deserializer<'de> { $abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x)) } } mod $abst { - #[cfg(feature = "serde1")] use serde::{Serialize, Deserialize}; + #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; // The abstract state of a ChaCha stream, independent of implementation choices. The // comparison and serialization of this object is considered a semver-covered part of // the API. #[derive(Debug, PartialEq, Eq)] - #[cfg_attr( - feature = "serde1", - derive(Serialize, Deserialize), - )] + #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub(crate) struct $ChaChaXRng { seed: [u8; 32], stream: u64, @@ -331,18 +322,36 @@ macro_rules! chacha_impl { } } } - } + }; } -chacha_impl!(ChaCha20Core, ChaCha20Rng, 10, "ChaCha with 20 rounds", abstract20); -chacha_impl!(ChaCha12Core, ChaCha12Rng, 6, "ChaCha with 12 rounds", abstract12); -chacha_impl!(ChaCha8Core, ChaCha8Rng, 4, "ChaCha with 8 rounds", abstract8); +chacha_impl!( + ChaCha20Core, + ChaCha20Rng, + 10, + "ChaCha with 20 rounds", + abstract20 +); +chacha_impl!( + ChaCha12Core, + ChaCha12Rng, + 6, + "ChaCha with 12 rounds", + abstract12 +); +chacha_impl!( + ChaCha8Core, + ChaCha8Rng, + 4, + "ChaCha with 8 rounds", + abstract8 +); #[cfg(test)] mod test { use rand_core::{RngCore, SeedableRng}; - #[cfg(feature = "serde1")] use super::{ChaCha20Rng, ChaCha12Rng, ChaCha8Rng}; + #[cfg(feature = "serde1")] use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng}; type ChaChaRng = super::ChaCha20Rng; @@ -350,8 +359,8 @@ mod test { #[test] fn test_chacha_serde_roundtrip() { let seed = [ - 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, 0, 0, - 0, 2, 92, + 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, + 0, 0, 0, 2, 92, ]; let mut rng1 = ChaCha20Rng::from_seed(seed); let mut rng2 = ChaCha12Rng::from_seed(seed); @@ -402,7 +411,7 @@ mod test { let mut rng1 = ChaChaRng::from_seed(seed); assert_eq!(rng1.next_u32(), 137206642); - let mut rng2 = ChaChaRng::from_rng(rng1).unwrap(); + let mut rng2 = ChaChaRng::from_rng(rng1); assert_eq!(rng2.next_u32(), 1325750369); } @@ -598,7 +607,7 @@ mod test { #[test] fn test_chacha_word_pos_wrap_exact() { - use super::{BUF_BLOCKS, BLOCK_WORDS}; + use super::{BLOCK_WORDS, BUF_BLOCKS}; let mut rng = ChaChaRng::from_seed(Default::default()); // refilling the buffer in set_word_pos will wrap the block counter to 0 let last_block = (1 << 68) - u128::from(BUF_BLOCKS * BLOCK_WORDS); @@ -626,12 +635,13 @@ mod test { #[test] fn test_trait_objects() { - use rand_core::CryptoRngCore; + use rand_core::CryptoRng; - let rng = &mut ChaChaRng::from_seed(Default::default()) as &mut dyn CryptoRngCore; - let r1 = rng.next_u64(); - let rng: &mut dyn RngCore = rng.as_rngcore(); - let r2 = rng.next_u64(); - assert_ne!(r1, r2); + let mut rng = ChaChaRng::from_seed(Default::default()); + let dyn_rng = &mut rng.clone() as &mut dyn CryptoRng; + let mut box_rng = Box::new(rng.clone()); + let exp = rng.next_u64(); + assert_ne!(exp, dyn_rng.next_u64()); + assert_ne!(exp, box_rng.next_u64()); } } diff --git a/rand_chacha/src/guts.rs b/rand_chacha/src/guts.rs index 797ded6fa73..1ecfea1b25f 100644 --- a/rand_chacha/src/guts.rs +++ b/rand_chacha/src/guts.rs @@ -12,7 +12,9 @@ use ppv_lite86::{dispatch, dispatch_light128}; pub use ppv_lite86::Machine; -use ppv_lite86::{vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector}; +use ppv_lite86::{ + vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector, +}; pub(crate) const BLOCK: usize = 16; pub(crate) const BLOCK64: u64 = BLOCK as u64; @@ -140,7 +142,8 @@ fn add_pos(m: Mach, d: Mach::u32x4, i: u64) -> Mach::u32x4 { #[cfg(target_endian = "little")] fn d0123(m: Mach, d: vec128_storage) -> Mach::u32x4x4 { let d0: Mach::u64x2 = m.unpack(d); - let incr = Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); + let incr = + Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into()) } diff --git a/rand_core/src/block.rs b/rand_core/src/block.rs index a527dda2971..019e2ef6fcf 100644 --- a/rand_core/src/block.rs +++ b/rand_core/src/block.rs @@ -43,7 +43,7 @@ //! } //! } //! -//! // optionally, also implement CryptoRng for MyRngCore +//! // optionally, also implement CryptoBlockRng for MyRngCore //! //! // Final RNG. //! let mut rng = BlockRng::::seed_from_u64(0); @@ -57,8 +57,7 @@ use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks}; use crate::{CryptoRng, Error, RngCore, SeedableRng}; use core::convert::AsRef; use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Deserialize, Serialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A trait for RNGs which do not generate random numbers individually, but in /// blocks (typically `[u32; N]`). This technique is commonly used by @@ -77,6 +76,10 @@ pub trait BlockRngCore { fn generate(&mut self, results: &mut Self::Results); } +/// A marker trait used to indicate that an [`BlockRngCore`] +/// implementation is supposed to produce cryptographically secure random data. +pub trait CryptoBlockRng: BlockRngCore {} + /// A wrapper type implementing [`RngCore`] for some type implementing /// [`BlockRngCore`] with `u32` array buffer; i.e. this can be used to implement /// a full RNG from just a `generate` function. @@ -92,13 +95,13 @@ pub trait BlockRngCore { /// `BlockRng` has heavily optimized implementations of the [`RngCore`] methods /// reading values from the results buffer, as well as /// calling [`BlockRngCore::generate`] directly on the output array when -/// [`fill_bytes`] / [`try_fill_bytes`] is called on a large array. These methods +/// [`fill_bytes`] / [`crypto_fill_bytes`] is called on a large array. These methods /// also handle the bookkeeping of when to generate a new batch of values. /// /// No whole generated `u32` values are thrown away and all values are consumed /// in-order. [`next_u32`] simply takes the next available `u32` value. /// [`next_u64`] is implemented by combining two `u32` values, least -/// significant first. [`fill_bytes`] and [`try_fill_bytes`] consume a whole +/// significant first. [`fill_bytes`] and [`crypto_fill_bytes`] consume a whole /// number of `u32` values, converting each `u32` to a byte slice in /// little-endian order. If the requested byte length is not a multiple of 4, /// some bytes will be discarded. @@ -111,7 +114,7 @@ pub trait BlockRngCore { /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes +/// [`crypto_fill_bytes`]: CryptoRng::crypto_fill_bytes #[derive(Clone)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr( @@ -229,9 +232,11 @@ impl> RngCore for BlockRng { read_len += filled_u8; } } +} +impl + CryptoBlockRng> CryptoRng for BlockRng { #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { self.fill_bytes(dest); Ok(()) } @@ -251,8 +256,8 @@ impl SeedableRng for BlockRng { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: S) -> Self { + Self::new(R::from_rng(rng)) } } @@ -270,14 +275,14 @@ impl SeedableRng for BlockRng { /// then the other half is then consumed, however both [`next_u64`] and /// [`fill_bytes`] discard the rest of any half-consumed `u64`s when called. /// -/// [`fill_bytes`] and [`try_fill_bytes`] consume a whole number of `u64` +/// [`fill_bytes`] and [`crypto_fill_bytes`] consume a whole number of `u64` /// values. If the requested length is not a multiple of 8, some bytes will be /// discarded. /// /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes +/// [`crypto_fill_bytes`]: CryptoRng::crypto_fill_bytes #[derive(Clone)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct BlockRng64 { @@ -395,9 +400,11 @@ impl> RngCore for BlockRng64 { read_len += filled_u8; } } +} +impl + CryptoBlockRng> CryptoRng for BlockRng64 { #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { self.fill_bytes(dest); Ok(()) } @@ -417,17 +424,15 @@ impl SeedableRng for BlockRng64 { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: S) -> Self { + Self::new(R::from_rng(rng)) } } -impl CryptoRng for BlockRng {} - #[cfg(test)] mod test { - use crate::{SeedableRng, RngCore}; use crate::block::{BlockRng, BlockRng64, BlockRngCore}; + use crate::{RngCore, SeedableRng}; #[derive(Debug, Clone)] struct DummyRng { @@ -436,7 +441,6 @@ mod test { impl BlockRngCore for DummyRng { type Item = u32; - type Results = [u32; 16]; fn generate(&mut self, results: &mut Self::Results) { @@ -451,7 +455,9 @@ mod test { type Seed = [u8; 4]; fn from_seed(seed: Self::Seed) -> Self { - DummyRng { counter: u32::from_le_bytes(seed) } + DummyRng { + counter: u32::from_le_bytes(seed), + } } } @@ -486,7 +492,6 @@ mod test { impl BlockRngCore for DummyRng64 { type Item = u64; - type Results = [u64; 8]; fn generate(&mut self, results: &mut Self::Results) { @@ -501,7 +506,9 @@ mod test { type Seed = [u8; 8]; fn from_seed(seed: Self::Seed) -> Self { - DummyRng64 { counter: u64::from_le_bytes(seed) } + DummyRng64 { + counter: u64::from_le_bytes(seed), + } } } diff --git a/rand_core/src/error.rs b/rand_core/src/error.rs index 411896f2c47..46cc1a108ce 100644 --- a/rand_core/src/error.rs +++ b/rand_core/src/error.rs @@ -50,9 +50,7 @@ impl Error { #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] #[inline] pub fn new(err: E) -> Self - where - E: Into>, - { + where E: Into> { Error { inner: err.into() } } @@ -223,6 +221,9 @@ mod test { fn test_error_codes() { // Make sure the values are the same as in `getrandom`. assert_eq!(super::Error::CUSTOM_START, getrandom::Error::CUSTOM_START); - assert_eq!(super::Error::INTERNAL_START, getrandom::Error::INTERNAL_START); + assert_eq!( + super::Error::INTERNAL_START, + getrandom::Error::INTERNAL_START + ); } } diff --git a/rand_core/src/impls.rs b/rand_core/src/impls.rs index 4b7688c5c80..e92b2dbb512 100644 --- a/rand_core/src/impls.rs +++ b/rand_core/src/impls.rs @@ -61,9 +61,11 @@ trait Observable: Copy { } impl Observable for u32 { type Bytes = [u8; 4]; + fn to_le_bytes(self) -> Self::Bytes { self.to_le_bytes() } + fn as_byte_slice(x: &[Self]) -> &[u8] { let ptr = x.as_ptr() as *const u8; let len = x.len() * core::mem::size_of::(); @@ -72,9 +74,11 @@ impl Observable for u32 { } impl Observable for u64 { type Bytes = [u8; 8]; + fn to_le_bytes(self) -> Self::Bytes { self.to_le_bytes() } + fn as_byte_slice(x: &[Self]) -> &[u8] { let ptr = x.as_ptr() as *const u8; let len = x.len() * core::mem::size_of::(); diff --git a/rand_core/src/lib.rs b/rand_core/src/lib.rs index 1234a566c05..14eddbd8a9b 100644 --- a/rand_core/src/lib.rs +++ b/rand_core/src/lib.rs @@ -41,21 +41,19 @@ use core::convert::AsMut; use core::default::Default; -#[cfg(feature = "std")] extern crate std; #[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "std")] extern crate std; #[cfg(feature = "alloc")] use alloc::boxed::Box; pub use error::Error; #[cfg(feature = "getrandom")] pub use os::OsRng; - pub mod block; mod error; pub mod impls; pub mod le; #[cfg(feature = "getrandom")] mod os; - /// The core of a random number generator. /// /// This trait encapsulates the low-level functionality common to all @@ -71,11 +69,6 @@ pub mod le; /// [`next_u32`] and [`next_u64`] methods, implementations may discard some /// random bits for efficiency. /// -/// The [`try_fill_bytes`] method is a variant of [`fill_bytes`] allowing error -/// handling; it is not deemed sufficiently useful to add equivalents for -/// [`next_u32`] or [`next_u64`] since the latter methods are almost always used -/// with algorithmic generators (PRNGs), which are normally infallible. -/// /// Implementers should produce bits uniformly. Pathological RNGs (e.g. always /// returning the same value, or never setting certain bits) can break rejection /// sampling used by random distributions, and also break other RNGs when @@ -110,7 +103,7 @@ pub mod le; /// /// ``` /// #![allow(dead_code)] -/// use rand_core::{RngCore, Error, impls}; +/// use rand_core::{RngCore, impls}; /// /// struct CountingRng(u64); /// @@ -127,15 +120,11 @@ pub mod le; /// fn fill_bytes(&mut self, dest: &mut [u8]) { /// impls::fill_bytes_via_next(self, dest) /// } -/// -/// fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { -/// Ok(self.fill_bytes(dest)) -/// } /// } /// ``` /// /// [`rand`]: https://docs.rs/rand -/// [`try_fill_bytes`]: RngCore::try_fill_bytes +/// [`crypto_fill_bytes`]: RngCore::crypto_fill_bytes /// [`fill_bytes`]: RngCore::fill_bytes /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 @@ -159,7 +148,7 @@ pub trait RngCore { /// RNGs must implement at least one method from this trait directly. In /// the case this method is not implemented directly, it can be implemented /// via [`impls::fill_bytes_via_next`] or - /// via [`RngCore::try_fill_bytes`]; if this generator can + /// via [`CryptoRng::crypto_fill_bytes`]; if this generator can /// fail the implementation must choose how best to handle errors here /// (e.g. panic with a descriptive message or log a warning and retry a few /// times). @@ -169,29 +158,16 @@ pub trait RngCore { /// (e.g. reading past the end of a file that is being used as the /// source of randomness). fn fill_bytes(&mut self, dest: &mut [u8]); - - /// Fill `dest` entirely with random data. - /// - /// This is the only method which allows an RNG to report errors while - /// generating random data thus making this the primary method implemented - /// by external (true) RNGs (e.g. `OsRng`) which can fail. It may be used - /// directly to generate keys and to seed (infallible) PRNGs. - /// - /// Other than error handling, this method is identical to [`RngCore::fill_bytes`]; - /// thus this may be implemented using `Ok(self.fill_bytes(dest))` or - /// `fill_bytes` may be implemented with - /// `self.try_fill_bytes(dest).unwrap()` or more specific error handling. - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error>; } -/// A marker trait used to indicate that an [`RngCore`] or [`BlockRngCore`] -/// implementation is supposed to be cryptographically secure. +/// Trait used to indicate that an [`RngCore`] implementation is supposed +/// to produce cryptographically secure data. /// /// *Cryptographically secure generators*, also known as *CSPRNGs*, should /// satisfy an additional properties over other generators: given the first -/// *k* bits of an algorithm's output -/// sequence, it should not be possible using polynomial-time algorithms to -/// predict the next bit with probability significantly greater than 50%. +/// *k* bits of an algorithm's output sequence, it should not be possible +/// using polynomial-time algorithms to predict the next bit with probability +/// significantly greater than 50%. /// /// Some generators may satisfy an additional property, however this is not /// required by this trait: if the CSPRNG's state is revealed, it should not be @@ -206,34 +182,20 @@ pub trait RngCore { /// weaknesses such as seeding from a weak entropy source or leaking state. /// /// [`BlockRngCore`]: block::BlockRngCore -pub trait CryptoRng {} - -/// An extension trait that is automatically implemented for any type -/// implementing [`RngCore`] and [`CryptoRng`]. -/// -/// It may be used as a trait object, and supports upcasting to [`RngCore`] via -/// the [`CryptoRngCore::as_rngcore`] method. -/// -/// # Example -/// -/// ``` -/// use rand_core::CryptoRngCore; -/// -/// #[allow(unused)] -/// fn make_token(rng: &mut dyn CryptoRngCore) -> [u8; 32] { -/// let mut buf = [0u8; 32]; -/// rng.fill_bytes(&mut buf); -/// buf -/// } -/// ``` -pub trait CryptoRngCore: CryptoRng + RngCore { - /// Upcast to an [`RngCore`] trait object. - fn as_rngcore(&mut self) -> &mut dyn RngCore; -} - -impl CryptoRngCore for T { - fn as_rngcore(&mut self) -> &mut dyn RngCore { - self +pub trait CryptoRng: RngCore { + /// Fill `dest` entirely with cryptographically secure random data. + /// + /// This is the only method which allows an RNG to report errors while + /// generating random data thus making this the primary method implemented + /// by external (true) RNGs (e.g. `OsRng`) which can fail. + /// + /// Other than error handling, this method is identical to [`RngCore::fill_bytes`]; + /// thus this may be implemented using `Ok(self.fill_bytes(dest))` or + /// `fill_bytes` may be implemented with `self.crypto_fill_bytes(dest).unwrap()` + /// or more specific error handling. + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.fill_bytes(dest); + Ok(()) } } @@ -387,10 +349,10 @@ pub trait SeedableRng: Sized { /// (in prior versions this was not required). /// /// [`rand`]: https://docs.rs/rand - fn from_rng(mut rng: R) -> Result { + fn from_rng(mut rng: R) -> Self { let mut seed = Self::Seed::default(); - rng.try_fill_bytes(seed.as_mut())?; - Ok(Self::from_seed(seed)) + rng.fill_bytes(seed.as_mut()); + Self::from_seed(seed) } /// Creates a new instance of the RNG seeded via [`getrandom`]. @@ -436,10 +398,13 @@ impl<'a, R: RngCore + ?Sized> RngCore for &'a mut R { fn fill_bytes(&mut self, dest: &mut [u8]) { (**self).fill_bytes(dest) } +} +// Implement `CryptoRng` for references to a `CryptoRng`. +impl<'a, R: CryptoRng + ?Sized> CryptoRng for &'a mut R { #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + (**self).crypto_fill_bytes(dest) } } @@ -462,28 +427,25 @@ impl RngCore for Box { fn fill_bytes(&mut self, dest: &mut [u8]) { (**self).fill_bytes(dest) } +} +// Implement `CryptoRng` for boxed references to a `CryptoRng`. +#[cfg(feature = "alloc")] +impl CryptoRng for Box { #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + (**self).crypto_fill_bytes(dest) } } #[cfg(feature = "std")] impl std::io::Read for dyn RngCore { fn read(&mut self, buf: &mut [u8]) -> Result { - self.try_fill_bytes(buf)?; + self.fill_bytes(buf); Ok(buf.len()) } } -// Implement `CryptoRng` for references to a `CryptoRng`. -impl<'a, R: CryptoRng + ?Sized> CryptoRng for &'a mut R {} - -// Implement `CryptoRng` for boxed references to a `CryptoRng`. -#[cfg(feature = "alloc")] -impl CryptoRng for Box {} - #[cfg(test)] mod test { use super::*; diff --git a/rand_core/src/os.rs b/rand_core/src/os.rs index b43c9fdaf05..ea6e3f86a08 100644 --- a/rand_core/src/os.rs +++ b/rand_core/src/os.rs @@ -48,7 +48,12 @@ use getrandom::getrandom; #[derive(Clone, Copy, Debug, Default)] pub struct OsRng; -impl CryptoRng for OsRng {} +impl CryptoRng for OsRng { + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + getrandom(dest)?; + Ok(()) + } +} impl RngCore for OsRng { fn next_u32(&mut self) -> u32 { @@ -60,15 +65,10 @@ impl RngCore for OsRng { } fn fill_bytes(&mut self, dest: &mut [u8]) { - if let Err(e) = self.try_fill_bytes(dest) { + if let Err(e) = self.crypto_fill_bytes(dest) { panic!("Error: {}", e); } } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - getrandom(dest)?; - Ok(()) - } } #[test] diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 6dbf7ab7494..d02c46b085d 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -10,11 +10,10 @@ //! The binomial distribution. use crate::{Distribution, Uniform}; -use rand::Rng; -use core::fmt; use core::cmp::Ordering; -#[allow(unused_imports)] -use num_traits::Float; +use core::fmt; +#[allow(unused_imports)] use num_traits::Float; +use rand::Rng; /// The binomial distribution `Binomial(n, p)`. /// @@ -224,7 +223,7 @@ impl Distribution for Binomial { break; } } - }, + } Ordering::Greater => { let mut i = y; loop { @@ -234,8 +233,8 @@ impl Distribution for Binomial { break; } } - }, - Ordering::Equal => {}, + } + Ordering::Equal => {} } if v > f { continue; diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 9aff7e625f4..cb9601514a4 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -9,10 +9,10 @@ //! The Cauchy distribution. -use num_traits::{Float, FloatConst}; use crate::{Distribution, Standard}; -use rand::Rng; use core::fmt; +use num_traits::{Float, FloatConst}; +use rand::Rng; /// The Cauchy distribution `Cauchy(median, scale)`. /// @@ -34,7 +34,9 @@ use core::fmt; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + Standard: Distribution, { median: F, scale: F, @@ -60,7 +62,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + Standard: Distribution, { /// Construct a new `Cauchy` with the given shape parameters /// `median` the peak location and `scale` the scale factor. @@ -73,7 +77,9 @@ where F: Float + FloatConst, Standard: Distribution } impl Distribution for Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + Standard: Distribution, { fn sample(&self, rng: &mut R) -> F { // sample from [0, 1) diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 786cbccd0cc..a6a0b59ffbe 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -9,11 +9,11 @@ //! The dirichlet distribution. #![cfg(feature = "alloc")] -use num_traits::Float; use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; -use rand::Rng; -use core::fmt; use alloc::{boxed::Box, vec, vec::Vec}; +use core::fmt; +use num_traits::Float; +use rand::Rng; /// The Dirichlet distribution `Dirichlet(alpha)`. /// @@ -93,7 +93,9 @@ where } } - Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() }) + Ok(Dirichlet { + alpha: alpha.to_vec().into_boxed_slice(), + }) } /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. @@ -128,11 +130,11 @@ where for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) { let g = Gamma::new(a, F::one()).unwrap(); *s = g.sample(rng); - sum = sum + (*s); + sum = sum + (*s); } let invacc = F::one() / sum; for s in samples.iter_mut() { - *s = (*s)*invacc; + *s = (*s) * invacc; } samples } diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index e3d2a8d1cf6..6c930ea9451 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -10,10 +10,10 @@ //! The exponential distribution. use crate::utils::ziggurat; -use num_traits::Float; use crate::{ziggurat_tables, Distribution}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; /// Samples floating-point numbers according to the exponential distribution, /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or @@ -94,7 +94,9 @@ impl Distribution for Exp1 { #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { /// `lambda` stored as `1/lambda`, since this is what we scale by. lambda_inverse: F, @@ -120,16 +122,18 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { /// Construct a new `Exp` with the given shape parameter /// `lambda`. - /// + /// /// # Remarks - /// + /// /// For custom types `N` implementing the [`Float`] trait, /// the case `lambda = 0` is handled as follows: each sample corresponds - /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types + /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types /// yield infinity, since `1 / 0 = infinity`. #[inline] pub fn new(lambda: F) -> Result, Error> { @@ -143,7 +147,9 @@ where F: Float, Exp1: Distribution } impl Distribution for Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { fn sample(&self, rng: &mut R) -> F { rng.sample(Exp1) * self.lambda_inverse diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 1a575bd6a9f..3e3c8c7cf3e 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -17,12 +17,11 @@ use self::ChiSquaredRepr::*; use self::GammaRepr::*; use crate::normal::StandardNormal; -use num_traits::Float; use crate::{Distribution, Exp, Exp1, Open01}; -use rand::Rng; use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -566,7 +565,9 @@ where F: Float, Open01: Distribution, { - a: F, b: F, switched_params: bool, + a: F, + b: F, + switched_params: bool, algorithm: BetaAlgorithm, } @@ -618,15 +619,15 @@ where if a > F::one() { // Algorithm BB let alpha = a + b; - let beta = ((alpha - F::from(2.).unwrap()) - / (F::from(2.).unwrap()*a*b - alpha)).sqrt(); + let beta = + ((alpha - F::from(2.).unwrap()) / (F::from(2.).unwrap() * a * b - alpha)).sqrt(); let gamma = a + F::one() / beta; Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BB(BB { - alpha, beta, gamma, - }) + a, + b, + switched_params, + algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }), }) } else { // Algorithm BC @@ -637,16 +638,21 @@ where let beta = F::one() / b; let delta = F::one() + a - b; let kappa1 = delta - * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b) - / (a*beta - F::from(14. / 18.).unwrap()); + * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b) + / (a * beta - F::from(14. / 18.).unwrap()); let kappa2 = F::from(0.25).unwrap() - + (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b; + + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b; Ok(Beta { - a, b, switched_params, + a, + b, + switched_params, algorithm: BetaAlgorithm::BC(BC { - alpha, beta, kappa1, kappa2, - }) + alpha, + beta, + kappa1, + kappa2, + }), }) } } @@ -667,12 +673,11 @@ where let u2 = rng.sample(Open01); let v = algo.beta * (u1 / (F::one() - u1)).ln(); w = self.a * v.exp(); - let z = u1*u1 * u2; + let z = u1 * u1 * u2; let r = algo.gamma * v - F::from(4.).unwrap().ln(); let s = self.a + r - w; // 2. - if s + F::one() + F::from(5.).unwrap().ln() - >= F::from(5.).unwrap() * z { + if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z { break; } // 3. @@ -685,7 +690,7 @@ where break; } } - }, + } BetaAlgorithm::BC(algo) => { loop { let z; @@ -716,11 +721,13 @@ where let v = algo.beta * (u1 / (F::one() - u1)).ln(); w = self.a * v.exp(); if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) - - F::from(4.).unwrap().ln() < z.ln()) { + - F::from(4.).unwrap().ln() + < z.ln()) + { break; }; } - }, + } }; // 5. for BB, 6. for BC if !self.switched_params { diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs index 3ea8b8f3e13..552e931942b 100644 --- a/rand_distr/src/geometric.rs +++ b/rand_distr/src/geometric.rs @@ -1,20 +1,19 @@ //! The geometric distribution. use crate::Distribution; -use rand::Rng; use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; +#[allow(unused_imports)] use num_traits::Float; +use rand::Rng; /// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`. -/// +/// /// This is the probability distribution of the number of failures before the /// first success in a series of Bernoulli trials. It has the density function /// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success /// on each trial. -/// +/// /// This is the discrete analogue of the [exponential distribution](crate::Exp). -/// +/// /// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised /// implementation for `p = 0.5`. /// @@ -29,11 +28,10 @@ use num_traits::Float; /// ``` #[derive(Copy, Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Geometric -{ +pub struct Geometric { p: f64, pi: f64, - k: u64 + k: u64, } /// Error type returned from `Geometric::new`. @@ -46,7 +44,9 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution", + Error::InvalidProbability => { + "p is NaN or outside the interval [0, 1] in geometric distribution" + } }) } } @@ -80,21 +80,24 @@ impl Geometric { } } -impl Distribution for Geometric -{ +impl Distribution for Geometric { fn sample(&self, rng: &mut R) -> u64 { if self.p >= 2.0 / 3.0 { // use the trivial algorithm: let mut failures = 0; loop { let u = rng.gen::(); - if u <= self.p { break; } + if u <= self.p { + break; + } failures += 1; } return failures; } - - if self.p == 0.0 { return core::u64::MAX; } + + if self.p == 0.0 { + return core::u64::MAX; + } let Geometric { p, pi, k } = *self; @@ -116,7 +119,7 @@ impl Distribution for Geometric // Use rejection sampling for the remainder M from Geo(p) % 2^k: // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M - // NOTE: The paper suggests using bitwise sampling here, which is + // NOTE: The paper suggests using bitwise sampling here, which is // currently unsupported, but should improve performance by requiring // fewer iterations on average. ~ October 28, 2020 let m = loop { @@ -126,7 +129,7 @@ impl Distribution for Geometric } else { (1.0 - p).powf(m as f64) }; - + let u = rng.gen::(); if u < p_reject { break m; @@ -140,16 +143,16 @@ impl Distribution for Geometric /// Samples integers according to the geometric distribution with success /// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`, /// but faster. -/// +/// /// See [`Geometric`](crate::Geometric) for the general geometric distribution. -/// +/// /// Implemented via iterated [Rng::gen::().leading_zeros()]. -/// +/// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::StandardGeometric; -/// +/// /// let v = StandardGeometric.sample(&mut thread_rng()); /// println!("{} is from a Geometric(0.5) distribution", v); /// ``` @@ -163,7 +166,9 @@ impl Distribution for StandardGeometric { loop { let x = rng.gen::().leading_zeros() as u64; result += x; - if x < 64 { break; } + if x < 64 { + break; + } } result } diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs index 4761450360d..71b8aa90b7e 100644 --- a/rand_distr/src/hypergeometric.rs +++ b/rand_distr/src/hypergeometric.rs @@ -1,17 +1,19 @@ //! The hypergeometric distribution. use crate::Distribution; -use rand::Rng; -use rand::distributions::uniform::Uniform; use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; +#[allow(unused_imports)] use num_traits::Float; +use rand::distributions::uniform::Uniform; +use rand::Rng; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] enum SamplingMethod { - InverseTransform{ initial_p: f64, initial_x: i64 }, - RejectionAcceptance{ + InverseTransform { + initial_p: f64, + initial_x: i64, + }, + RejectionAcceptance { m: f64, a: f64, lambda_l: f64, @@ -20,24 +22,24 @@ enum SamplingMethod { x_r: f64, p1: f64, p2: f64, - p3: f64 + p3: f64, }, } /// The hypergeometric distribution `Hypergeometric(N, K, n)`. -/// +/// /// This is the distribution of successes in samples of size `n` drawn without /// replacement from a population of size `N` containing `K` success states. /// It has the density function: /// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, /// where `binomial(a, b) = a! / (b! * (a - b)!)`. -/// +/// /// The [binomial distribution](crate::Binomial) is the analogous distribution /// for sampling with replacement. It is a good approximation when the population /// size is much larger than the sample size. -/// +/// /// # Example -/// +/// /// ``` /// use rand_distr::{Distribution, Hypergeometric}; /// @@ -70,9 +72,15 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution", - Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution", - Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution", + Error::PopulationTooLarge => { + "total_population_size is too large causing underflow in geometric distribution" + } + Error::ProbabilityTooLarge => { + "population_with_feature > total_population_size in geometric distribution" + } + Error::SampleSizeTooLarge => { + "sample_size > total_population_size in geometric distribution" + } }) } } @@ -97,20 +105,20 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, if i <= min_top { result *= i as f64; } - + if i <= min_bottom { result /= i as f64; } - + if i <= max_top { result *= i as f64; } - + if i <= max_bottom { result /= i as f64; } } - + result } @@ -126,7 +134,9 @@ impl Hypergeometric { /// `K = population_with_feature`, /// `n = sample_size`. #[allow(clippy::many_single_char_names)] // Same names as in the reference. - pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result { + pub fn new( + total_population_size: u64, population_with_feature: u64, sample_size: u64, + ) -> Result { if population_with_feature > total_population_size { return Err(Error::ProbabilityTooLarge); } @@ -151,7 +161,7 @@ impl Hypergeometric { }; // when sampling more than half the total population, take the smaller // group as sampled instead (we can then return n1-x instead). - // + // // Note: the boundary condition given in the paper is `sample_size < n / 2`; // we're deviating here, because when n is even, it doesn't matter whether // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching @@ -167,7 +177,7 @@ impl Hypergeometric { // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`, // where `M` is the mode of the distribution. // Use algorithm HIN for the remaining parameter space. - // + // // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer // generation of hypergeometric random variates. // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145 @@ -176,21 +186,30 @@ impl Hypergeometric { let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor(); let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD { let (initial_p, initial_x) = if k < n2 { - (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0) + ( + fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), + 0, + ) } else { - (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64) + ( + fraction_of_products_of_factorials((n1, k), (n, k - n2)), + (k - n2) as i64, + ) }; if initial_p <= 0.0 || !initial_p.is_finite() { return Err(Error::PopulationTooLarge); } - SamplingMethod::InverseTransform { initial_p, initial_x } + SamplingMethod::InverseTransform { + initial_p, + initial_x, + } } else { - let a = ln_of_factorial(m) + - ln_of_factorial(n1 as f64 - m) + - ln_of_factorial(k as f64 - m) + - ln_of_factorial((n2 - k) as f64 + m); + let a = ln_of_factorial(m) + + ln_of_factorial(n1 as f64 - m) + + ln_of_factorial(k as f64 - m) + + ln_of_factorial((n2 - k) as f64 + m); let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64; let denominator = (n - 1) as f64 * n as f64 * n as f64; @@ -199,17 +218,19 @@ impl Hypergeometric { let x_l = m - d + 0.5; let x_r = m + d + 0.5; - let k_l = f64::exp(a - - ln_of_factorial(x_l) - - ln_of_factorial(n1 as f64 - x_l) - - ln_of_factorial(k as f64 - x_l) - - ln_of_factorial((n2 - k) as f64 + x_l)); - let k_r = f64::exp(a - - ln_of_factorial(x_r - 1.0) - - ln_of_factorial(n1 as f64 - x_r + 1.0) - - ln_of_factorial(k as f64 - x_r + 1.0) - - ln_of_factorial((n2 - k) as f64 + x_r - 1.0)); - + let k_l = f64::exp( + a - ln_of_factorial(x_l) + - ln_of_factorial(n1 as f64 - x_l) + - ln_of_factorial(k as f64 - x_l) + - ln_of_factorial((n2 - k) as f64 + x_l), + ); + let k_r = f64::exp( + a - ln_of_factorial(x_r - 1.0) + - ln_of_factorial(n1 as f64 - x_r + 1.0) + - ln_of_factorial(k as f64 - x_r + 1.0) + - ln_of_factorial((n2 - k) as f64 + x_r - 1.0), + ); + let numerator = x_l * ((n2 - k) as f64 + x_l); let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0); let lambda_l = -((numerator / denominator).ln()); @@ -225,11 +246,26 @@ impl Hypergeometric { let p3 = p2 + k_r / lambda_r; SamplingMethod::RejectionAcceptance { - m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 + m, + a, + lambda_l, + lambda_r, + x_l, + x_r, + p1, + p2, + p3, } }; - Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method }) + Ok(Hypergeometric { + n1, + n2, + k, + offset_x, + sign_x, + sampling_method, + }) } } @@ -238,25 +274,46 @@ impl Distribution for Hypergeometric { fn sample(&self, rng: &mut R) -> u64 { use SamplingMethod::*; - let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self; + let Hypergeometric { + n1, + n2, + k, + sign_x, + offset_x, + sampling_method, + } = *self; let x = match sampling_method { - InverseTransform { initial_p: mut p, initial_x: mut x } => { + InverseTransform { + initial_p: mut p, + initial_x: mut x, + } => { let mut u = rng.gen::(); - while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense + while u > p && x < k as i64 { + // the paper erroneously uses `until n < p`, which doesn't make any sense u -= p; p *= ((n1 as i64 - x as i64) * (k as i64 - x as i64)) as f64; p /= ((x as i64 + 1) * (n2 as i64 - k as i64 + 1 + x as i64)) as f64; x += 1; } x - }, - RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => { + } + RejectionAcceptance { + m, + a, + lambda_l, + lambda_r, + x_l, + x_r, + p1, + p2, + p3, + } => { let distr_region_select = Uniform::new(0.0, p3); loop { let (y, v) = loop { let u = distr_region_select.sample(rng); let v = rng.gen::(); // for the accept/reject decision - + if u <= p1 { // Region 1, central bell let y = (x_l + u).floor(); @@ -277,7 +334,7 @@ impl Distribution for Hypergeometric { } } }; - + // Step 4: Acceptance/Rejection Comparison if m < 100.0 || y <= 50.0 { // Step 4.1: evaluate f(y) via recursive relationship @@ -293,8 +350,10 @@ impl Distribution for Hypergeometric { f /= (n1 - i) as f64 * (k - i) as f64; } } - - if v <= f { break y as i64; } + + if v <= f { + break y as i64; + } } else { // Step 4.2: Squeezing let y1 = y + 1.0; @@ -307,24 +366,24 @@ impl Distribution for Hypergeometric { let t = ym / yk; let e = -ym / nk; let g = yn * yk / (y1 * nk) - 1.0; - let dg = if g < 0.0 { - 1.0 + g - } else { - 1.0 - }; + let dg = if g < 0.0 { 1.0 + g } else { 1.0 }; let gu = g * (1.0 + g * (-0.5 + g / 3.0)); let gl = gu - g.powi(4) / (4.0 * dg); let xm = m + 0.5; let xn = n1 as f64 - m + 0.5; let xk = k as f64 - m + 0.5; let nm = n2 as f64 - k as f64 + xm; - let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + - xn * s * (1.0 + s * (-0.5 + s / 3.0)) + - xk * t * (1.0 + t * (-0.5 + t / 3.0)) + - nm * e * (1.0 + e * (-0.5 + e / 3.0)) + - y * gu - m * gl + 0.0034; + let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + + xn * s * (1.0 + s * (-0.5 + s / 3.0)) + + xk * t * (1.0 + t * (-0.5 + t / 3.0)) + + nm * e * (1.0 + e * (-0.5 + e / 3.0)) + + y * gu + - m * gl + + 0.0034; let av = v.ln(); - if av > ub { continue; } + if av > ub { + continue; + } let dr = if r < 0.0 { xm * r.powi(4) / (1.0 + r) } else { @@ -345,17 +404,17 @@ impl Distribution for Hypergeometric { } else { nm * e.powi(4) }; - - if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 { + + if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 { break y as i64; } - + // Step 4.3: Final Acceptance/Rejection Test - let av_critical = a - - ln_of_factorial(y) - - ln_of_factorial(n1 as f64 - y) - - ln_of_factorial(k as f64 - y) - - ln_of_factorial((n2 - k) as f64 + y); + let av_critical = a + - ln_of_factorial(y) + - ln_of_factorial(n1 as f64 - y) + - ln_of_factorial(k as f64 - y) + - ln_of_factorial((n2 - k) as f64 + y); if v.ln() <= av_critical { break y as i64; } @@ -380,8 +439,7 @@ mod test { assert!(Hypergeometric::new(100, 10, 5).is_ok()); } - fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) - { + fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) { let distr = Hypergeometric::new(n, k, s).unwrap(); let expected_mean = s as f64 * k as f64 / n as f64; diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index ba845fd1505..dc178bba4e0 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -1,7 +1,7 @@ use crate::{Distribution, Standard, StandardNormal}; +use core::fmt; use num_traits::Float; use rand::Rng; -use core::fmt; /// Error type returned from `InverseGaussian::new` #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -112,6 +112,9 @@ mod tests { #[test] fn inverse_gaussian_distributions_can_be_compared() { - assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0)); + assert_eq!( + InverseGaussian::new(1.0, 2.0), + InverseGaussian::new(1.0, 2.0) + ); } } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 6d8d81bd2f3..8263823b3b0 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -81,15 +81,12 @@ //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution -#[cfg(feature = "alloc")] -extern crate alloc; +#[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "std")] -extern crate std; +#[cfg(feature = "std")] extern crate std; // This is used for doc links: -#[allow(unused)] -use rand::Rng; +#[allow(unused)] use rand::Rng; pub use rand::distributions::{ uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, @@ -173,10 +170,14 @@ mod test { macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); - assert!(diff <= $prec, + assert!( + diff <= $prec, "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ (left: `{}`, right: `{}`)", - diff, $prec, $a, $b + diff, + $prec, + $a, + $b ); }; } diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index b3b801dfed9..8cb21e9ef88 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -10,10 +10,10 @@ //! The normal and derived distributions. use crate::utils::ziggurat; -use num_traits::Float; use crate::{ziggurat_tables, Distribution, Open01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; /// Samples floating-point numbers according to the normal distribution /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to @@ -115,7 +115,9 @@ impl Distribution for StandardNormal { #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { mean: F, std_dev: F, @@ -144,7 +146,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { /// Construct, from mean and standard deviation /// @@ -204,14 +208,15 @@ where F: Float, StandardNormal: Distribution } impl Distribution for Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { fn sample(&self, rng: &mut R) -> F { self.from_zscore(rng.sample(StandardNormal)) } } - /// The log-normal distribution `ln N(mean, std_dev**2)`. /// /// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)` @@ -230,13 +235,17 @@ where F: Float, StandardNormal: Distribution #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { norm: Normal, } impl LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { /// Construct, from (log-space) mean and standard deviation /// @@ -307,7 +316,9 @@ where F: Float, StandardNormal: Distribution } impl Distribution for LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -348,7 +359,10 @@ mod tests { #[test] fn test_log_normal_cv() { let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap(); - assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0)); + assert_eq!( + (lnorm.norm.mean, lnorm.norm.std_dev), + (-core::f64::INFINITY, 0.0) + ); let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap(); assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0)); diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index b1ba588ac8d..ddbc47531e4 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -1,7 +1,7 @@ use crate::{Distribution, InverseGaussian, Standard, StandardNormal}; +use core::fmt; use num_traits::Float; use rand::Rng; -use core::fmt; /// Error type returned from `NormalInverseGaussian::new` #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -15,8 +15,12 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::AlphaNegativeOrNull => "alpha <= 0 or is NaN in normal inverse Gaussian distribution", - Error::AbsoluteBetaNotLessThanAlpha => "|beta| >= alpha or is NaN in normal inverse Gaussian distribution", + Error::AlphaNegativeOrNull => { + "alpha <= 0 or is NaN in normal inverse Gaussian distribution" + } + Error::AbsoluteBetaNotLessThanAlpha => { + "|beta| >= alpha or is NaN in normal inverse Gaussian distribution" + } }) } } @@ -105,6 +109,9 @@ mod tests { #[test] fn normal_inverse_gaussian_distributions_can_be_compared() { - assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0)); + assert_eq!( + NormalInverseGaussian::new(1.0, 2.0), + NormalInverseGaussian::new(1.0, 2.0) + ); } } diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index 25c8e0537dd..d8869bf1f01 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -8,10 +8,10 @@ //! The Pareto distribution. -use num_traits::Float; use crate::{Distribution, OpenClosed01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; /// Samples floating-point numbers according to the Pareto distribution /// @@ -26,7 +26,9 @@ use core::fmt; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { scale: F, inv_neg_shape: F, @@ -55,7 +57,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { /// Construct a new Pareto distribution with given `scale` and `shape`. /// @@ -78,7 +82,9 @@ where F: Float, OpenClosed01: Distribution } impl Distribution for Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { fn sample(&self, rng: &mut R) -> F { let u: F = OpenClosed01.sample(rng); diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index db89fff7bfb..a4bf812bb3c 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -7,10 +7,10 @@ // except according to those terms. //! The PERT distribution. -use num_traits::Float; use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; /// The PERT distribution. /// @@ -129,20 +129,12 @@ mod test { #[test] fn test_pert() { - for &(min, max, mode) in &[ - (-1., 1., 0.), - (1., 2., 1.), - (5., 25., 25.), - ] { + for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] { let _distr = Pert::new(min, max, mode).unwrap(); // TODO: test correctness } - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { + for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { assert!(Pert::new(min, max, mode).is_err()); } } diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 8b9bffd020e..bbc11422d8e 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -9,10 +9,10 @@ //! The Poisson distribution. -use num_traits::{Float, FloatConst}; use crate::{Cauchy, Distribution, Standard}; -use rand::Rng; use core::fmt; +use num_traits::{Float, FloatConst}; +use rand::Rng; /// The Poisson distribution `Poisson(lambda)`. /// @@ -31,7 +31,9 @@ use core::fmt; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Poisson -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + Standard: Distribution, { lambda: F, // precalculated values @@ -61,7 +63,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Poisson -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + Standard: Distribution, { /// Construct a new `Poisson` with the given shape parameter /// `lambda`. @@ -81,7 +85,9 @@ where F: Float + FloatConst, Standard: Distribution } impl Distribution for Poisson -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + Standard: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -92,7 +98,7 @@ where F: Float + FloatConst, Standard: Distribution let mut result = F::zero(); let mut p = F::one(); while p > self.exp_lambda { - p = p*rng.gen::(); + p = p * rng.gen::(); result = result + F::one(); } result - F::one() @@ -147,8 +153,7 @@ mod test { use super::*; fn test_poisson_avg_gen(lambda: F, tol: F) - where Standard: Distribution - { + where Standard: Distribution { let poisson = Poisson::new(lambda).unwrap(); let mut rng = crate::test::rng(123); let mut sum = F::zero(); @@ -183,4 +188,4 @@ mod test { fn poisson_distributions_can_be_compared() { assert_eq!(Poisson::new(1.0), Poisson::new(1.0)); } -} \ No newline at end of file +} diff --git a/rand_distr/src/skew_normal.rs b/rand_distr/src/skew_normal.rs index 146b4ead125..87546c1bf56 100644 --- a/rand_distr/src/skew_normal.rs +++ b/rand_distr/src/skew_normal.rs @@ -204,21 +204,18 @@ mod tests { #[test] fn skew_normal_value_stability() { - test_samples( - SkewNormal::new(0.0, 1.0, 0.0).unwrap(), - 0f32, - &[-0.11844189, 0.781378, 0.06563994, -1.1932899], - ); - test_samples( - SkewNormal::new(0.0, 1.0, 0.0).unwrap(), - 0f64, - &[ - -0.11844188827977231, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ], - ); + test_samples(SkewNormal::new(0.0, 1.0, 0.0).unwrap(), 0f32, &[ + -0.11844189, + 0.781378, + 0.06563994, + -1.1932899, + ]); + test_samples(SkewNormal::new(0.0, 1.0, 0.0).unwrap(), 0f64, &[ + -0.11844188827977231, + 0.7813779637772346, + 0.06563993969580051, + -1.1932899004186373, + ]); test_samples( SkewNormal::new(core::f64::INFINITY, 1.0, 0.0).unwrap(), 0f64, @@ -256,6 +253,9 @@ mod tests { #[test] fn skew_normal_distributions_can_be_compared() { - assert_eq!(SkewNormal::new(1.0, 2.0, 3.0), SkewNormal::new(1.0, 2.0, 3.0)); + assert_eq!( + SkewNormal::new(1.0, 2.0, 3.0), + SkewNormal::new(1.0, 2.0, 3.0) + ); } } diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index eef7d190133..7b6d0199b04 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -7,10 +7,10 @@ // except according to those terms. //! The triangular distribution. -use num_traits::Float; use crate::{Distribution, Standard}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; /// The triangular distribution. /// @@ -34,7 +34,9 @@ use core::fmt; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Triangular -where F: Float, Standard: Distribution +where + F: Float, + Standard: Distribution, { min: F, max: F, @@ -66,7 +68,9 @@ impl fmt::Display for TriangularError { impl std::error::Error for TriangularError {} impl Triangular -where F: Float, Standard: Distribution +where + F: Float, + Standard: Distribution, { /// Set up the Triangular distribution with defined `min`, `max` and `mode`. #[inline] @@ -82,7 +86,9 @@ where F: Float, Standard: Distribution } impl Distribution for Triangular -where F: Float, Standard: Distribution +where + F: Float, + Standard: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -122,17 +128,16 @@ mod test { assert_eq!(distr.sample(&mut half_rng), median); } - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { + for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { assert!(Triangular::new(min, max, mode).is_err()); } } #[test] fn triangular_distributions_can_be_compared() { - assert_eq!(Triangular::new(1.0, 3.0, 2.0), Triangular::new(1.0, 3.0, 2.0)); + assert_eq!( + Triangular::new(1.0, 3.0, 2.0), + Triangular::new(1.0, 3.0, 2.0) + ); } } diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs index 8a4b4fbf3d1..91a13fc3aee 100644 --- a/rand_distr/src/unit_ball.rs +++ b/rand_distr/src/unit_ball.rs @@ -6,8 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the unit ball (surface and interior) in three diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 24a06f3f4de..ea4fe84fddb 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -6,8 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the edge of the unit circle in two dimensions. diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs index 937c1d01b84..2d7147b38d9 100644 --- a/rand_distr/src/unit_disc.rs +++ b/rand_distr/src/unit_disc.rs @@ -6,8 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the unit disc in two dimensions. diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index 2b299239f49..a20e8662e45 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -6,8 +6,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the surface of the unit sphere in three dimensions. @@ -42,7 +42,11 @@ impl Distribution<[F; 3]> for UnitSphere { continue; } let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); - return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum]; + return [ + x1 * factor, + x2 * factor, + F::from(1.).unwrap() - F::from(2.).unwrap() * sum, + ]; } } } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index 4638e3623d2..ea7c4381951 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -9,9 +9,9 @@ //! Math helper functions use crate::ziggurat_tables; +use num_traits::Float; use rand::distributions::hidden_export::IntoFloat; use rand::Rng; -use num_traits::Float; /// Calculates ln(gamma(x)) (natural logarithm of the gamma /// function) using the Lanczos approximation. @@ -72,12 +72,8 @@ pub(crate) fn log_gamma(x: F) -> F { // size from force-inlining. #[inline(always)] pub(crate) fn ziggurat( - rng: &mut R, - symmetric: bool, - x_tab: ziggurat_tables::ZigTable, - f_tab: ziggurat_tables::ZigTable, - mut pdf: P, - mut zero_case: Z + rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, + f_tab: ziggurat_tables::ZigTable, mut pdf: P, mut zero_case: Z, ) -> f64 where P: FnMut(f64) -> f64, diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index fe45eff6613..ca619e3b8c8 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -8,10 +8,10 @@ //! The Weibull distribution. -use num_traits::Float; use crate::{Distribution, OpenClosed01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; /// Samples floating-point numbers according to the Weibull distribution /// @@ -26,7 +26,9 @@ use core::fmt; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { inv_shape: F, scale: F, @@ -55,7 +57,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { /// Construct a new `Weibull` distribution with given `scale` and `shape`. pub fn new(scale: F, shape: F) -> Result, Error> { @@ -73,7 +77,9 @@ where F: Float, OpenClosed01: Distribution } impl Distribution for Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { fn sample(&self, rng: &mut R) -> F { let x: F = rng.sample(OpenClosed01); diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted_alias.rs index 582a4dd9ba8..4671f33d425 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted_alias.rs @@ -11,13 +11,12 @@ use super::WeightedError; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use alloc::{boxed::Box, vec, vec::Vec}; use core::fmt; use core::iter::Sum; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use rand::Rng; -use alloc::{boxed::Box, vec, vec::Vec}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling to pick a discretely selected item. /// @@ -67,8 +66,14 @@ use serde::{Serialize, Deserialize}; /// [`Uniform::sample`]: Distribution::sample #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))] -#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))] +#[cfg_attr( + feature = "serde1", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde1", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] pub struct WeightedAliasIndex { aliases: Box<[u32]>, no_alias_odds: Box<[W]>, @@ -500,7 +505,9 @@ mod test { #[test] fn value_stability() { - fn test_samples(weights: Vec, buf: &mut [usize], expected: &[usize]) { + fn test_samples( + weights: Vec, buf: &mut [usize], expected: &[usize], + ) { assert_eq!(buf.len(), expected.len()); let distr = WeightedAliasIndex::new(weights).unwrap(); let mut rng = crate::test::rng(0x9c9fa0b0580a7031); diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs index e15b6cdd197..58b2b4f7c68 100644 --- a/rand_distr/src/zipf.rs +++ b/rand_distr/src/zipf.rs @@ -8,10 +8,10 @@ //! The Zeta and related distributions. -use num_traits::Float; use crate::{Distribution, Standard}; -use rand::{Rng, distributions::OpenClosed01}; use core::fmt; +use num_traits::Float; +use rand::{distributions::OpenClosed01, Rng}; /// Samples integers according to the [zeta distribution]. /// @@ -48,7 +48,10 @@ use core::fmt; /// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8 #[derive(Clone, Copy, Debug, PartialEq)] pub struct Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution +where + F: Float, + Standard: Distribution, + OpenClosed01: Distribution, { a_minus_1: F, b: F, @@ -74,7 +77,10 @@ impl fmt::Display for ZetaError { impl std::error::Error for ZetaError {} impl Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution +where + F: Float, + Standard: Distribution, + OpenClosed01: Distribution, { /// Construct a new `Zeta` distribution with given `a` parameter. #[inline] @@ -92,7 +98,10 @@ where F: Float, Standard: Distribution, OpenClosed01: Distribution } impl Distribution for Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution +where + F: Float, + Standard: Distribution, + OpenClosed01: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -144,7 +153,10 @@ where F: Float, Standard: Distribution, OpenClosed01: Distribution /// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa #[derive(Clone, Copy, Debug, PartialEq)] pub struct Zipf -where F: Float, Standard: Distribution { +where + F: Float, + Standard: Distribution, +{ s: F, t: F, q: F, @@ -173,7 +185,10 @@ impl fmt::Display for ZipfError { impl std::error::Error for ZipfError {} impl Zipf -where F: Float, Standard: Distribution { +where + F: Float, + Standard: Distribution, +{ /// Construct a new `Zipf` distribution for a set with `n` elements and a /// frequency rank exponent `s`. /// @@ -186,7 +201,7 @@ where F: Float, Standard: Distribution { if n < 1 { return Err(ZipfError::NTooSmall); } - let n = F::from(n).unwrap(); // This does not fail. + let n = F::from(n).unwrap(); // This does not fail. let q = if s != F::one() { // Make sure to calculate the division only once. F::one() / (F::one() - s) @@ -200,9 +215,7 @@ where F: Float, Standard: Distribution { F::one() + n.ln() }; debug_assert!(t > F::zero()); - Ok(Zipf { - s, t, q - }) + Ok(Zipf { s, t, q }) } /// Inverse cumulative density function @@ -221,7 +234,9 @@ where F: Float, Standard: Distribution { } impl Distribution for Zipf -where F: Float, Standard: Distribution +where + F: Float, + Standard: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -293,12 +308,8 @@ mod tests { #[test] fn zeta_value_stability() { - test_samples(Zeta::new(1.5).unwrap(), 0f32, &[ - 1.0, 2.0, 1.0, 1.0, - ]); - test_samples(Zeta::new(2.0).unwrap(), 0f64, &[ - 2.0, 1.0, 1.0, 1.0, - ]); + test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]); + test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]); } #[test] @@ -363,12 +374,8 @@ mod tests { #[test] fn zipf_value_stability() { - test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[ - 10.0, 2.0, 6.0, 7.0 - ]); - test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[ - 1.0, 2.0, 3.0, 2.0 - ]); + test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]); + test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]); } #[test] diff --git a/rand_distr/tests/sparkline.rs b/rand_distr/tests/sparkline.rs index 6ba48ba886e..4ae14bed400 100644 --- a/rand_distr/tests/sparkline.rs +++ b/rand_distr/tests/sparkline.rs @@ -16,7 +16,7 @@ pub fn render_u64(data: &[u64], buffer: &mut String) { match data.len() { 0 => { return; - }, + } 1 => { if data[0] == 0 { buffer.push(TICKS[0]); @@ -24,8 +24,8 @@ pub fn render_u64(data: &[u64], buffer: &mut String) { buffer.push(TICKS[N - 1]); } return; - }, - _ => {}, + } + _ => {} } let max = data.iter().max().unwrap(); let min = data.iter().min().unwrap(); @@ -56,7 +56,7 @@ pub fn render_f64(data: &[f64], buffer: &mut String) { match data.len() { 0 => { return; - }, + } 1 => { if data[0] == 0. { buffer.push(TICKS[0]); @@ -64,16 +64,14 @@ pub fn render_f64(data: &[f64], buffer: &mut String) { buffer.push(TICKS[N - 1]); } return; - }, - _ => {}, + } + _ => {} } for x in data { assert!(x.is_finite(), "can only render finite values"); } - let max = data.iter().fold( - core::f64::NEG_INFINITY, |a, &b| a.max(b)); - let min = data.iter().fold( - core::f64::INFINITY, |a, &b| a.min(b)); + let max = data.iter().fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)); + let min = data.iter().fold(core::f64::INFINITY, |a, &b| a.min(b)); let scale = ((N - 1) as f64) / (max - min); for x in data { let tick = ((x - min) * scale) as usize; diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index d3754705db5..58cad5ce6fa 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -53,9 +53,7 @@ impl ApproxEq for [T; 3] { } } -fn test_samples>( - seed: u64, distr: D, expected: &[F], -) { +fn test_samples>(seed: u64, distr: D, expected: &[F]) { let mut rng = get_rng(seed); for val in expected { let x = rng.sample(&distr); @@ -68,35 +66,61 @@ fn binomial_stability() { // We have multiple code paths: np < 10, p > 0.5 test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); - test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]); + test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[ + 1194, 1208, 1192, 1210, + ]); } #[test] fn geometric_stability() { test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]); - + test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]); - test_samples(464, Geometric::new(0.05).unwrap(), &[24, 51, 81, 67, 27, 11, 7, 6]); - test_samples(464, Geometric::new(0.95).unwrap(), &[0, 0, 0, 0, 1, 0, 0, 0]); + test_samples(464, Geometric::new(0.05).unwrap(), &[ + 24, 51, 81, 67, 27, 11, 7, 6, + ]); + test_samples(464, Geometric::new(0.95).unwrap(), &[ + 0, 0, 0, 0, 1, 0, 0, 0, + ]); // expect non-random behaviour for series of pre-determined trials - test_samples(464, Geometric::new(0.0).unwrap(), &[u64::max_value(); 100][..]); + test_samples( + 464, + Geometric::new(0.0).unwrap(), + &[u64::max_value(); 100][..], + ); test_samples(464, Geometric::new(1.0).unwrap(), &[0; 100][..]); } #[test] fn hypergeometric_stability() { // We have multiple code paths based on the distribution's mode and sample_size - test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[4, 3, 2, 2, 3, 2, 3, 1]); // Algorithm HIN - test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[23, 27, 26, 27, 22, 24, 31, 22]); // Algorithm H2PE + test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[ + 4, 3, 2, 2, 3, 2, 3, 1, + ]); // Algorithm HIN + test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[ + 23, 27, 26, 27, 22, 24, 31, 22, + ]); // Algorithm H2PE } #[test] fn unit_ball_stability() { test_samples(2, UnitBall, &[ - [0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706], - [0.10588569388223945, -0.4734350111375454, -0.7392104908825501], - [0.11060237642041049, -0.16065642822852677, -0.8444043930440075] + [ + 0.018035709265959987f64, + -0.4348771383120438, + -0.07982762085055706, + ], + [ + 0.10588569388223945, + -0.4734350111375454, + -0.7392104908825501, + ], + [ + 0.11060237642041049, + -0.16065642822852677, + -0.8444043930440075, + ], ]); } @@ -112,8 +136,16 @@ fn unit_circle_stability() { #[test] fn unit_sphere_stability() { test_samples(2, UnitSphere, &[ - [0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027], - [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], + [ + 0.03247542860231647f64, + -0.7830477442152738, + 0.6211131755296027, + ], + [ + -0.09978440840914075, + 0.9706650829833128, + -0.21875184231323952, + ], [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], ]); } @@ -130,7 +162,10 @@ fn unit_disc_stability() { #[test] fn pareto_stability() { test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[ - 1.0423688f32, 2.1235929, 4.132709, 1.4679428, + 1.0423688f32, + 2.1235929, + 4.132709, + 1.4679428, ]); test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[ 9.019295276219136f64, @@ -144,10 +179,11 @@ fn pareto_stability() { fn poisson_stability() { test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); - test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]); + test_samples(223, Poisson::new(27.0).unwrap(), &[ + 28.0f32, 32.0, 36.0, 36.0, + ]); } - #[test] fn triangular_stability() { test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[ @@ -159,11 +195,13 @@ fn triangular_stability() { ]); } - #[test] fn normal_inverse_gaussian_stability() { test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6568966f32, 1.3744819, 2.216063, 0.11488572, + 0.6568966f32, + 1.3744819, + 2.216063, + 0.11488572, ]); test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ 0.6838707059642927f64, @@ -187,8 +225,11 @@ fn pert_stability() { #[test] fn inverse_gaussian_stability() { - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[ - 0.9339157f32, 1.108113, 0.50864697, 0.39849377, + test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ + 0.9339157f32, + 1.108113, + 0.50864697, + 0.39849377, ]); test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ 1.0707604954722476f64, @@ -202,10 +243,16 @@ fn inverse_gaussian_stability() { fn gamma_stability() { // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[ - 5.398085f32, 9.162783, 0.2300583, 1.7235851, + 5.398085f32, + 9.162783, + 0.2300583, + 1.7235851, ]); test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[ - 0.5051203f32, 0.9048302, 3.095812, 1.8566116, + 0.5051203f32, + 0.9048302, + 3.095812, + 1.8566116, ]); test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[ 7.783878094584059f64, @@ -228,15 +275,24 @@ fn gamma_stability() { 0.00000002291755769542258, ]); test_samples(223, ChiSquared::new(10.0).unwrap(), &[ - 12.693656f32, 6.812016, 11.082001, 12.436167, + 12.693656f32, + 6.812016, + 11.082001, + 12.436167, ]); // FisherF has same special cases as ChiSquared on each param test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[ - 0.32283646f32, 0.048049655, 0.0788893, 1.817178, + 0.32283646f32, + 0.048049655, + 0.0788893, + 1.817178, ]); test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[ - 0.29925257f32, 3.4392934, 9.567652, 0.020074, + 0.29925257f32, + 3.4392934, + 9.567652, + 0.020074, ]); test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[ 3.3196593155045124f64, @@ -247,7 +303,10 @@ fn gamma_stability() { // StudentT has same special cases as ChiSquared test_samples(223, StudentT::new(1.0).unwrap(), &[ - 0.54703987f32, -1.8545331, 3.093162, -0.14168274, + 0.54703987f32, + -1.8545331, + 3.093162, + -0.14168274, ]); test_samples(223, StudentT::new(1.1).unwrap(), &[ 0.7729195887949754f64, @@ -276,9 +335,7 @@ fn gamma_stability() { #[test] fn exponential_stability() { - test_samples(223, Exp1, &[ - 1.079617f32, 1.8325565, 0.04601166, 0.34471703, - ]); + test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]); test_samples(223, Exp1, &[ 1.0796170642388276f64, 1.8325565304274, @@ -287,7 +344,10 @@ fn exponential_stability() { ]); test_samples(223, Exp::new(2.0).unwrap(), &[ - 0.5398085f32, 0.91627824, 0.02300583, 0.17235851, + 0.5398085f32, + 0.91627824, + 0.02300583, + 0.17235851, ]); test_samples(223, Exp::new(1.0).unwrap(), &[ 1.0796170642388276f64, @@ -300,7 +360,10 @@ fn exponential_stability() { #[test] fn normal_stability() { test_samples(213, StandardNormal, &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, + -0.11844189f32, + 0.781378, + 0.06563994, + -1.1932899, ]); test_samples(213, StandardNormal, &[ -0.11844188827977231f64, @@ -310,7 +373,10 @@ fn normal_stability() { ]); test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, + -0.11844189f32, + 0.781378, + 0.06563994, + -1.1932899, ]); test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[ 1.940779055860114f64, @@ -320,7 +386,10 @@ fn normal_stability() { ]); test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[ - 0.88830346f32, 2.1844804, 1.0678421, 0.30322206, + 0.88830346f32, + 2.1844804, + 1.0678421, + 0.30322206, ]); test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[ 6.964174338639032f64, @@ -333,7 +402,10 @@ fn normal_stability() { #[test] fn weibull_stability() { test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[ - 0.041495778f32, 0.7531094, 1.4189332, 0.38386202, + 0.041495778f32, + 0.7531094, + 1.4189332, + 0.38386202, ]); test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[ 1.1343478702739669f64, @@ -347,10 +419,11 @@ fn weibull_stability() { #[test] fn dirichlet_stability() { let mut rng = get_rng(223); - assert_eq!( - rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), - vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146] - ); + assert_eq!(rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), vec![ + 0.12941567177708177, + 0.4702121891675036, + 0.4003721390554146 + ]); assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![ 0.17684200044809556, 0.29915953935953055, diff --git a/rand_pcg/src/pcg128.rs b/rand_pcg/src/pcg128.rs index df2025dc444..5f02f041371 100644 --- a/rand_pcg/src/pcg128.rs +++ b/rand_pcg/src/pcg128.rs @@ -14,7 +14,7 @@ const MULTIPLIER: u128 = 0x2360_ED05_1FC6_5DA4_4385_DF64_9FCC_F645; use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; +use rand_core::{impls, le, RngCore, SeedableRng}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A PCG random number generator (XSL RR 128/64 (LCG) variant). @@ -151,15 +151,8 @@ impl RngCore for Lcg128Xsl64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } - /// A PCG random number generator (XSL 128/64 (MCG) variant). /// /// Permuted Congruential Generator with 128-bit state, internal Multiplicative @@ -261,12 +254,6 @@ impl RngCore for Mcg128Xsl64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[inline(always)] diff --git a/rand_pcg/src/pcg128cm.rs b/rand_pcg/src/pcg128cm.rs index 7ac5187e4e0..374e05ba6a9 100644 --- a/rand_pcg/src/pcg128cm.rs +++ b/rand_pcg/src/pcg128cm.rs @@ -14,7 +14,7 @@ const MULTIPLIER: u64 = 15750249268501108917; use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; +use rand_core::{impls, le, RngCore, SeedableRng}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A PCG random number generator (CM DXSM 128/64 (LCG) variant). @@ -157,12 +157,6 @@ impl RngCore for Lcg128CmDxsm64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[inline(always)] diff --git a/rand_pcg/src/pcg64.rs b/rand_pcg/src/pcg64.rs index 365f1c0b117..0b6864a42f3 100644 --- a/rand_pcg/src/pcg64.rs +++ b/rand_pcg/src/pcg64.rs @@ -11,7 +11,7 @@ //! PCG random number generators use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; +use rand_core::{impls, le, RngCore, SeedableRng}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; // This is the default multiplier used by PCG for 64-bit state. @@ -160,10 +160,4 @@ impl RngCore for Lcg64Xsh32 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs index 676b79a5c10..a4677061879 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distributions/bernoulli.rs @@ -12,8 +12,7 @@ use crate::distributions::Distribution; use crate::Rng; use core::{fmt, u64}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// The Bernoulli distribution. /// @@ -151,7 +150,8 @@ mod test { #[cfg(feature = "serde1")] fn test_serializing_deserializing_bernoulli() { let coin_flip = Bernoulli::new(0.5).unwrap(); - let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); + let de_coin_flip: Bernoulli = + bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); assert_eq!(coin_flip.p_int, de_coin_flip.p_int); } diff --git a/src/distributions/distribution.rs b/src/distributions/distribution.rs index c5cf6a607b4..ec7676505cf 100644 --- a/src/distributions/distribution.rs +++ b/src/distributions/distribution.rs @@ -10,9 +10,8 @@ //! Distribution trait and associates use crate::Rng; +#[cfg(feature = "alloc")] use alloc::string::String; use core::iter; -#[cfg(feature = "alloc")] -use alloc::string::String; /// Types (distributions) that can be used to create a random instance of `T`. /// @@ -236,9 +235,7 @@ mod tests { #[test] fn test_make_an_iter() { - fn ten_dice_rolls_other_than_five( - rng: &mut R, - ) -> impl Iterator + '_ { + fn ten_dice_rolls_other_than_five(rng: &mut R) -> impl Iterator + '_ { Uniform::new_inclusive(1, 6) .sample_iter(rng) .filter(|x| *x != 5) @@ -257,8 +254,8 @@ mod tests { #[test] #[cfg(feature = "alloc")] fn test_dist_string() { - use core::str; use crate::distributions::{Alphanumeric, DistString, Standard}; + use core::str; let mut rng = crate::test::rng(213); let s1 = Alphanumeric.sample_string(&mut rng, 20); diff --git a/src/distributions/float.rs b/src/distributions/float.rs index 54aebad4dc5..9b65b89b4ad 100644 --- a/src/distributions/float.rs +++ b/src/distributions/float.rs @@ -8,14 +8,13 @@ //! Basic floating-point number distributions -use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils}; +use crate::distributions::utils::{FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; use crate::distributions::{Distribution, Standard}; use crate::Rng; use core::mem; #[cfg(feature = "simd_support")] use core::simd::*; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A distribution to sample floating point numbers uniformly in the half-open /// interval `(0, 1]`, i.e. including 1 but not 0. @@ -72,7 +71,6 @@ pub struct OpenClosed01; #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Open01; - // This trait is needed by both this lib and rand_distr hence is a hidden export #[doc(hidden)] pub trait IntoFloat { @@ -204,9 +202,15 @@ mod tests { let mut zeros = StepRng::new(0, 0); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0)); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -246,9 +250,15 @@ mod tests { let mut zeros = StepRng::new(0, 0); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 12, 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0)); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } diff --git a/src/distributions/integer.rs b/src/distributions/integer.rs index 418eea9ff13..ab762aa2e83 100644 --- a/src/distributions/integer.rs +++ b/src/distributions/integer.rs @@ -12,16 +12,13 @@ use crate::distributions::{Distribution, Standard}; use crate::Rng; #[cfg(all(target_arch = "x86", feature = "simd_support"))] use core::arch::x86::__m512i; -#[cfg(target_arch = "x86")] -use core::arch::x86::{__m128i, __m256i}; +#[cfg(target_arch = "x86")] use core::arch::x86::{__m128i, __m256i}; #[cfg(all(target_arch = "x86_64", feature = "simd_support"))] use core::arch::x86_64::__m512i; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::{__m128i, __m256i}; -use core::num::{NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize, - NonZeroU128}; -#[cfg(feature = "simd_support")] use core::simd::*; +#[cfg(target_arch = "x86_64")] use core::arch::x86_64::{__m128i, __m256i}; use core::mem; +use core::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize}; +#[cfg(feature = "simd_support")] use core::simd::*; impl Distribution for Standard { #[inline] diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index a923f879d22..e979db8dd7d 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -100,8 +100,7 @@ mod integer; mod other; mod slice; mod utils; -#[cfg(feature = "alloc")] -mod weighted_index; +#[cfg(feature = "alloc")] mod weighted_index; #[doc(hidden)] pub mod hidden_export { @@ -117,19 +116,16 @@ pub mod uniform; pub mod weighted; pub use self::bernoulli::{Bernoulli, BernoulliError}; -pub use self::distribution::{Distribution, DistIter, DistMap}; -#[cfg(feature = "alloc")] -pub use self::distribution::DistString; +#[cfg(feature = "alloc")] pub use self::distribution::DistString; +pub use self::distribution::{DistIter, DistMap, Distribution}; pub use self::float::{Open01, OpenClosed01}; pub use self::other::Alphanumeric; pub use self::slice::Slice; -#[doc(inline)] -pub use self::uniform::Uniform; +#[doc(inline)] pub use self::uniform::Uniform; #[cfg(feature = "alloc")] pub use self::weighted_index::{WeightedError, WeightedIndex}; -#[allow(unused)] -use crate::Rng; +#[allow(unused)] use crate::Rng; /// A generic random value distribution, implemented for many primitive types. /// Usually generates values with a numerically uniform distribution, and with a diff --git a/src/distributions/other.rs b/src/distributions/other.rs index 4cb31086734..ea635a37b96 100644 --- a/src/distributions/other.rs +++ b/src/distributions/other.rs @@ -8,22 +8,17 @@ //! The implementations of the `Standard` distribution for other built-in types. +#[cfg(feature = "alloc")] use alloc::string::String; use core::char; use core::num::Wrapping; -#[cfg(feature = "alloc")] -use alloc::string::String; +#[cfg(feature = "alloc")] use crate::distributions::DistString; use crate::distributions::{Distribution, Standard, Uniform}; -#[cfg(feature = "alloc")] -use crate::distributions::DistString; use crate::Rng; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; use core::mem::{self, MaybeUninit}; -#[cfg(feature = "simd_support")] -use core::simd::*; - +#[cfg(feature = "simd_support")] use core::simd::*; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; // ----- Sampling distributions ----- @@ -69,7 +64,6 @@ use core::simd::*; #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Alphanumeric; - // ----- Implementations of distributions ----- impl Distribution for Standard { @@ -273,7 +267,6 @@ where Standard: Distribution } } - #[cfg(test)] mod tests { use super::*; @@ -312,9 +305,8 @@ mod tests { let mut incorrect = false; for _ in 0..100 { let c: char = rng.sample(Alphanumeric).into(); - incorrect |= !(('0'..='9').contains(&c) || - ('A'..='Z').contains(&c) || - ('a'..='z').contains(&c) ); + incorrect |= + !(('0'..='9').contains(&c) || ('A'..='Z').contains(&c) || ('a'..='z').contains(&c)); } assert!(!incorrect); } diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index a7b4cb1a777..642fc01436a 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -106,14 +106,15 @@ //! [`UniformDuration`]: crate::distributions::uniform::UniformDuration //! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow -use core::time::Duration; use core::ops::{Range, RangeInclusive}; +use core::time::Duration; use crate::distributions::float::IntoFloat; -use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply}; +use crate::distributions::utils::{ + BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply, +}; use crate::distributions::Distribution; -#[cfg(feature = "simd_support")] -use crate::distributions::Standard; +#[cfg(feature = "simd_support")] use crate::distributions::Standard; use crate::{Rng, RngCore}; #[cfg(not(feature = "std"))] @@ -122,8 +123,7 @@ use crate::distributions::utils::Float; #[cfg(feature = "simd_support")] use core::simd::*; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// Sample values uniformly between two bounds. /// @@ -177,7 +177,10 @@ use serde::{Serialize, Deserialize}; #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] -#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] +#[cfg_attr( + feature = "serde1", + serde(bound(deserialize = "X::Sampler: Deserialize<'de>")) +)] pub struct Uniform(X::Sampler); impl Uniform { @@ -293,10 +296,10 @@ pub trait UniformSampler: Sized { /// some types more optimal implementations for single usage may be provided /// via this method. /// Results may not be identical. - fn sample_single_inclusive(low: B1, high: B2, rng: &mut R) - -> Self::X - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized + fn sample_single_inclusive(low: B1, high: B2, rng: &mut R) -> Self::X + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, { let uniform: Self = UniformSampler::new_inclusive(low, high); uniform.sample(rng) @@ -315,7 +318,6 @@ impl From> for Uniform { } } - /// Helper trait similar to [`Borrow`] but implemented /// only for SampleUniform and references to SampleUniform in /// order to resolve ambiguity issues. @@ -380,12 +382,10 @@ impl SampleRange for RangeInclusive { } } - //////////////////////////////////////////////////////////////////////////////// // What follows are all back-ends. - /// The back-end implementing [`UniformSampler`] for integer types. /// /// Unless you are implementing [`UniformSampler`] for your own type, this type @@ -518,14 +518,19 @@ macro_rules! uniform_int_impl { } #[inline] - fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Self::X + fn sample_single_inclusive( + low_b: B1, high_b: B2, rng: &mut R, + ) -> Self::X where B1: SampleBorrow + Sized, B2: SampleBorrow + Sized, { let low = *low_b.borrow(); let high = *high_b.borrow(); - assert!(low <= high, "UniformSampler::sample_single_inclusive: low > high"); + assert!( + low <= high, + "UniformSampler::sample_single_inclusive: low > high" + ); let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; // If the above resulted in wrap-around to 0, the range is $ty::MIN..=$ty::MAX, // and any integer will do. @@ -875,7 +880,8 @@ macro_rules! uniform_float_impl { fn sample(&self, rng: &mut R) -> Self::X { // Generate a value in the range [1, 2) - let value1_2 = (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + let value1_2 = (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)) + .into_float_with_exponent(0); // Get a value in the range [0, 1) in order to avoid // overflowing into infinity when multiplying with scale @@ -910,12 +916,15 @@ macro_rules! uniform_float_impl { "UniformSampler::sample_single: low >= high" ); let mut scale = high - low; - assert!(scale.all_finite(), "UniformSampler::sample_single: range overflow"); + assert!( + scale.all_finite(), + "UniformSampler::sample_single: range overflow" + ); loop { // Generate a value in the range [1, 2) - let value1_2 = - (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + let value1_2 = (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)) + .into_float_with_exponent(0); // Get a value in the range [0, 1) in order to avoid // overflowing into infinity when multiplying with scale @@ -990,7 +999,6 @@ uniform_float_impl! { f64x4, u64x4, f64, u64, 64 - 52 } #[cfg(feature = "simd_support")] uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } - /// The back-end implementing [`UniformSampler`] for `Duration`. /// /// Unless you are implementing [`UniformSampler`] for your own types, this type @@ -1132,24 +1140,43 @@ mod tests { #[cfg(feature = "serde1")] fn test_serialization_uniform_duration() { let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)); - let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); - assert_eq!( - distr.offset, de_distr.offset - ); + let de_distr: UniformDuration = + bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); + assert_eq!(distr.offset, de_distr.offset); match (distr.mode, de_distr.mode) { - (UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => { + ( + UniformDurationMode::Small { + secs: a_secs, + nanos: a_nanos, + }, + UniformDurationMode::Small { secs, nanos }, + ) => { assert_eq!(a_secs, secs); assert_eq!(a_nanos.0.low, nanos.0.low); assert_eq!(a_nanos.0.range, nanos.0.range); assert_eq!(a_nanos.0.z, nanos.0.z); } - (UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => { + ( + UniformDurationMode::Medium { nanos: a_nanos }, + UniformDurationMode::Medium { nanos }, + ) => { assert_eq!(a_nanos.0.low, nanos.0.low); assert_eq!(a_nanos.0.range, nanos.0.range); assert_eq!(a_nanos.0.z, nanos.0.z); } - (UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => { + ( + UniformDurationMode::Large { + max_secs: a_max_secs, + max_nanos: a_max_nanos, + secs: a_secs, + }, + UniformDurationMode::Large { + max_secs, + max_nanos, + secs, + }, + ) => { assert_eq!(a_max_secs, max_secs); assert_eq!(a_max_nanos, max_nanos); @@ -1157,22 +1184,24 @@ mod tests { assert_eq!(a_secs.0.range, secs.0.range); assert_eq!(a_secs.0.z, secs.0.z); } - _ => panic!("`UniformDurationMode` was not serialized/deserialized correctly") + _ => panic!("`UniformDurationMode` was not serialized/deserialized correctly"), } } #[test] #[cfg(feature = "serde1")] fn test_uniform_serialization() { - let unit_box: Uniform = Uniform::new(-1, 1); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + let unit_box: Uniform = Uniform::new(-1, 1); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); assert_eq!(unit_box.0.low, de_unit_box.0.low); assert_eq!(unit_box.0.range, de_unit_box.0.range); assert_eq!(unit_box.0.z, de_unit_box.0.z); let unit_box: Uniform = Uniform::new(-1., 1.); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); assert_eq!(unit_box.0.low, de_unit_box.0.low); assert_eq!(unit_box.0.scale, de_unit_box.0.scale); @@ -1337,8 +1366,9 @@ mod tests { assert!(low_scalar <= v && v < high_scalar); let v = rng.sample(my_incl_uniform).extract(lane); assert!(low_scalar <= v && v <= high_scalar); - let v = <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut rng).extract(lane); + let v = + <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) + .extract(lane); assert!(low_scalar <= v && v < high_scalar); } @@ -1349,9 +1379,15 @@ mod tests { assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); - assert_eq!(<$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut zero_rng) - .extract(lane), low_scalar); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut zero_rng + ) + .extract(lane), + low_scalar + ); assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); @@ -1364,9 +1400,13 @@ mod tests { (-1i64 << $bits_shifted) as u64, ); assert!( - <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut lowering_max_rng) - .extract(lane) < high_scalar + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut lowering_max_rng + ) + .extract(lane) + < high_scalar ); } } @@ -1477,7 +1517,6 @@ mod tests { } } - #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_durations() { @@ -1629,6 +1668,9 @@ mod tests { assert_eq!(Uniform::new(1.0, 2.0), Uniform::new(1.0, 2.0)); // To cover UniformInt - assert_eq!(Uniform::new(1 as u32, 2 as u32), Uniform::new(1 as u32, 2 as u32)); + assert_eq!( + Uniform::new(1 as u32, 2 as u32), + Uniform::new(1 as u32, 2 as u32) + ); } } diff --git a/src/distributions/utils.rs b/src/distributions/utils.rs index bddb0a4a599..06cd6698b4a 100644 --- a/src/distributions/utils.rs +++ b/src/distributions/utils.rs @@ -10,7 +10,6 @@ #[cfg(feature = "simd_support")] use core::simd::*; - pub(crate) trait WideningMultiply { type Output; @@ -371,7 +370,6 @@ macro_rules! scalar_float_impl { scalar_float_impl!(f32, u32); scalar_float_impl!(f64, u64); - #[cfg(feature = "simd_support")] macro_rules! simd_impl { ($fty:ident, $uty:ident) => { diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs index 846b9df9c28..0ef76dd19d6 100644 --- a/src/distributions/weighted.rs +++ b/src/distributions/weighted.rs @@ -11,16 +11,16 @@ //! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and //! [`crate::distributions::WeightedError`] instead. -pub use super::{WeightedIndex, WeightedError}; +pub use super::{WeightedError, WeightedIndex}; #[allow(missing_docs)] #[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] pub mod alias_method { // This module exists to provide a deprecation warning which minimises // compile errors, but still fails to compile if ever used. - use core::marker::PhantomData; - use alloc::vec::Vec; use super::WeightedError; + use alloc::vec::Vec; + use core::marker::PhantomData; #[derive(Debug)] pub struct WeightedIndex { diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 2af2446d7ba..8a3525e4846 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -17,8 +17,7 @@ use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. use alloc::vec::Vec; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling of discrete items /// @@ -262,7 +261,7 @@ mod test { } #[test] - fn test_accepting_nan(){ + fn test_accepting_nan() { assert_eq!( WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), WeightedError::InvalidWeight, @@ -285,7 +284,6 @@ mod test { ) } - #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weightedindex() { diff --git a/src/lib.rs b/src/lib.rs index 755b5ba6e9c..3028831e5ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,13 +51,10 @@ #![no_std] #![cfg_attr(feature = "simd_support", feature(stdsimd, portable_simd))] #![cfg_attr(doc_cfg, feature(doc_cfg))] -#![allow( - clippy::float_cmp, - clippy::neg_cmp_op_on_partial_ord, -)] +#![allow(clippy::float_cmp, clippy::neg_cmp_op_on_partial_ord)] -#[cfg(feature = "std")] extern crate std; #[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "std")] extern crate std; #[allow(unused)] macro_rules! trace { ($($x:tt)*) => ( diff --git a/src/prelude.rs b/src/prelude.rs index 51c457e3f9e..9c37ec1901f 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -14,7 +14,7 @@ //! //! ``` //! use rand::prelude::*; -//! # let mut r = StdRng::from_rng(thread_rng()).unwrap(); +//! # let mut r = StdRng::from_rng(thread_rng()); //! # let _: f32 = r.gen(); //! ``` @@ -23,7 +23,8 @@ #[doc(no_inline)] pub use crate::rngs::SmallRng; #[cfg(feature = "std_rng")] -#[doc(no_inline)] pub use crate::rngs::StdRng; +#[doc(no_inline)] +pub use crate::rngs::StdRng; #[doc(no_inline)] #[cfg(all(feature = "std", feature = "std_rng"))] pub use crate::rngs::ThreadRng; diff --git a/src/rng.rs b/src/rng.rs index c9f3a5f72e5..b19e4bb6efa 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -9,11 +9,11 @@ //! [`Rng`] trait -use rand_core::{Error, RngCore}; use crate::distributions::uniform::{SampleRange, SampleUniform}; use crate::distributions::{self, Distribution, Standard}; use core::num::Wrapping; use core::{mem, slice}; +use rand_core::RngCore; /// An automatically-implemented extension trait on [`RngCore`] providing high-level /// generic methods for sampling values and other convenience methods. @@ -127,7 +127,7 @@ pub trait Rng: RngCore { fn gen_range(&mut self, range: R) -> T where T: SampleUniform, - R: SampleRange + R: SampleRange, { assert!(!range.is_empty(), "cannot sample empty range"); range.sample_single(self) @@ -214,35 +214,7 @@ pub trait Rng: RngCore { /// [`fill_bytes`]: RngCore::fill_bytes /// [`try_fill`]: Rng::try_fill fn fill(&mut self, dest: &mut T) { - dest.try_fill(self).unwrap_or_else(|_| panic!("Rng::fill failed")) - } - - /// Fill any type implementing [`Fill`] with random data - /// - /// The distribution is expected to be uniform with portable results, but - /// this cannot be guaranteed for third-party implementations. - /// - /// This is identical to [`fill`] except that it forwards errors. - /// - /// # Example - /// - /// ``` - /// # use rand::Error; - /// use rand::{thread_rng, Rng}; - /// - /// # fn try_inner() -> Result<(), Error> { - /// let mut arr = [0u64; 4]; - /// thread_rng().try_fill(&mut arr[..])?; - /// # Ok(()) - /// # } - /// - /// # try_inner().unwrap() - /// ``` - /// - /// [`try_fill_bytes`]: RngCore::try_fill_bytes - /// [`fill`]: Rng::fill - fn try_fill(&mut self, dest: &mut T) -> Result<(), Error> { - dest.try_fill(self) + dest.fill(self) } /// Return a bool with a probability `p` of being true. @@ -311,18 +283,17 @@ impl Rng for R {} /// [Chapter on Portability](https://rust-random.github.io/book/portability.html)). pub trait Fill { /// Fill self with random data - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error>; + fn fill(&mut self, rng: &mut R); } macro_rules! impl_fill_each { () => {}; ($t:ty) => { impl Fill for [$t] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { for elt in self.iter_mut() { *elt = rng.gen(); } - Ok(()) } } }; @@ -335,8 +306,8 @@ macro_rules! impl_fill_each { impl_fill_each!(bool, char, f32, f64,); impl Fill for [u8] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - rng.try_fill_bytes(self) + fn fill(&mut self, rng: &mut R) { + rng.fill_bytes(self) } } @@ -345,37 +316,35 @@ macro_rules! impl_fill { ($t:ty) => { impl Fill for [$t] { #[inline(never)] // in micro benchmarks, this improves performance - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { + rng.fill_bytes(unsafe { slice::from_raw_parts_mut(self.as_mut_ptr() as *mut u8, self.len() * mem::size_of::<$t>() ) - })?; + }); for x in self { *x = x.to_le(); } } - Ok(()) } } impl Fill for [Wrapping<$t>] { #[inline(never)] - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { + rng.fill_bytes(unsafe { slice::from_raw_parts_mut(self.as_mut_ptr() as *mut u8, self.len() * mem::size_of::<$t>() ) - })?; + }); for x in self { - *x = Wrapping(x.0.to_le()); + *x = Wrapping(x.0.to_le()); } } - Ok(()) } } }; @@ -393,16 +362,17 @@ impl_fill!(i8, i16, i32, i64, isize, i128,); impl Fill for [T; N] where [T]: Fill { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - self[..].try_fill(rng) + fn fill(&mut self, rng: &mut R) { + let dst: &mut [T] = self; + Fill::fill(dst, rng) } } #[cfg(test)] mod test { use super::*; - use crate::test::rng; use crate::rngs::mock::StepRng; + use crate::test::rng; #[cfg(feature = "alloc")] use alloc::boxed::Box; #[test] diff --git a/src/rngs/adapter/mod.rs b/src/rngs/adapter/mod.rs index bd1d2943233..112e2c862fe 100644 --- a/src/rngs/adapter/mod.rs +++ b/src/rngs/adapter/mod.rs @@ -11,6 +11,5 @@ mod read; mod reseeding; -#[allow(deprecated)] -pub use self::read::{ReadError, ReadRng}; +#[allow(deprecated)] pub use self::read::{ReadError, ReadRng}; pub use self::reseeding::ReseedingRng; diff --git a/src/rngs/adapter/read.rs b/src/rngs/adapter/read.rs index 25a9ca7fca4..fb2dbe12e6b 100644 --- a/src/rngs/adapter/read.rs +++ b/src/rngs/adapter/read.rs @@ -14,8 +14,7 @@ use std::fmt; use std::io::Read; -use rand_core::{impls, Error, RngCore}; - +use rand_core::{impls, RngCore}; /// An RNG that reads random bytes straight from any type supporting /// [`std::io::Read`], for example files. @@ -35,7 +34,7 @@ use rand_core::{impls, Error, RngCore}; /// [`OsRng`]: crate::rngs::OsRng /// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Debug)] -#[deprecated(since="0.8.4", note="removal due to lack of usage")] +#[deprecated(since = "0.8.4", note = "removal due to lack of usage")] pub struct ReadRng { reader: R, } @@ -57,28 +56,19 @@ impl RngCore for ReadRng { } fn fill_bytes(&mut self, dest: &mut [u8]) { - self.try_fill_bytes(dest).unwrap_or_else(|err| { - panic!( - "reading random bytes from Read implementation failed; error: {}", - err - ) - }); - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { if dest.is_empty() { - return Ok(()); + return; } // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. self.reader .read_exact(dest) - .map_err(|e| Error::new(ReadError(e))) + .expect("reading random bytes from Read implementation failed") } } /// `ReadRng` error type #[derive(Debug)] -#[deprecated(since="0.8.4")] +#[deprecated(since = "0.8.4")] pub struct ReadError(std::io::Error); impl fmt::Display for ReadError { @@ -93,11 +83,8 @@ impl std::error::Error for ReadError { } } - #[cfg(test)] mod test { - use std::println; - use super::ReadRng; use crate::RngCore; @@ -137,14 +124,13 @@ mod test { } #[test] + #[should_panic] fn test_reader_rng_insufficient_bytes() { let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; let mut w = [0u8; 9]; let mut rng = ReadRng::new(&v[..]); - let result = rng.try_fill_bytes(&mut w); - assert!(result.is_err()); - println!("Error: {}", result.unwrap_err()); + rng.fill_bytes(&mut w); } } diff --git a/src/rngs/adapter/reseeding.rs b/src/rngs/adapter/reseeding.rs index 5ab453c928e..350434e4f16 100644 --- a/src/rngs/adapter/reseeding.rs +++ b/src/rngs/adapter/reseeding.rs @@ -12,7 +12,7 @@ use core::mem::size_of; -use rand_core::block::{BlockRng, BlockRngCore}; +use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// A wrapper around any PRNG that implements [`BlockRngCore`], that adds the @@ -103,16 +103,15 @@ where } /// Reseed the internal PRNG. - pub fn reseed(&mut self) -> Result<(), Error> { - self.0.core.reseed() + pub fn reseed(&mut self) { + self.0.core.reseed(); } } // TODO: this should be implemented for any type where the inner type // implements RngCore, but we can't specify that because ReseedingCore is private impl RngCore for ReseedingRng -where - R: BlockRngCore + SeedableRng, +where R: BlockRngCore + SeedableRng { #[inline(always)] fn next_u32(&mut self) -> u32 { @@ -127,10 +126,6 @@ where fn fill_bytes(&mut self, dest: &mut [u8]) { self.0.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } } impl Clone for ReseedingRng @@ -147,9 +142,12 @@ where impl CryptoRng for ReseedingRng where - R: BlockRngCore + SeedableRng + CryptoRng, + R: BlockRngCore + SeedableRng + CryptoBlockRng, Rsdr: RngCore + CryptoRng, { + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { + self.0.crypto_fill_bytes(dest) + } } #[derive(Debug)] @@ -215,11 +213,10 @@ where } /// Reseed the internal PRNG. - fn reseed(&mut self) -> Result<(), Error> { - R::from_rng(&mut self.reseeder).map(|result| { - self.bytes_until_reseed = self.threshold; - self.inner = result - }) + fn reseed(&mut self) { + let res = R::from_rng(&mut self.reseeder); + self.bytes_until_reseed = self.threshold; + self.inner = res; } fn is_forked(&self, global_fork_counter: usize) -> bool { @@ -249,10 +246,7 @@ where let num_bytes = results.as_ref().len() * size_of::<::Item>(); - if let Err(e) = self.reseed() { - warn!("Reseeding RNG failed: {}", e); - let _ = e; - } + self.reseed(); self.fork_counter = global_fork_counter; self.bytes_until_reseed = self.threshold - num_bytes as i64; @@ -276,14 +270,13 @@ where } } -impl CryptoRng for ReseedingCore +impl CryptoBlockRng for ReseedingCore where - R: BlockRngCore + SeedableRng + CryptoRng, + R: BlockRngCore + SeedableRng + CryptoBlockRng, Rsdr: RngCore + CryptoRng, { } - #[cfg(all(unix, not(target_os = "emscripten")))] mod fork { use core::sync::atomic::{AtomicUsize, Ordering}; @@ -317,11 +310,9 @@ mod fork { static REGISTER: Once = Once::new(); REGISTER.call_once(|| { // Bump the counter before and after forking (see #1169): - let ret = unsafe { libc::pthread_atfork( - Some(fork_handler), - Some(fork_handler), - Some(fork_handler), - ) }; + let ret = unsafe { + libc::pthread_atfork(Some(fork_handler), Some(fork_handler), Some(fork_handler)) + }; if ret != 0 { panic!("libc::pthread_atfork failed with code {}", ret); } @@ -337,7 +328,6 @@ mod fork { pub fn register_fork_handler() {} } - #[cfg(feature = "std_rng")] #[cfg(test)] mod test { @@ -349,7 +339,7 @@ mod test { #[test] fn test_reseeding() { let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); + let rng = Core::from_rng(&mut zero); let thresh = 1; // reseed every time the buffer is exhausted let mut reseeding = ReseedingRng::new(rng, thresh, zero); @@ -371,7 +361,7 @@ mod test { #![allow(clippy::redundant_clone)] let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); + let rng = Core::from_rng(&mut zero); let mut rng1 = ReseedingRng::new(rng, 32 * 4, zero); let first: u32 = rng1.gen(); diff --git a/src/rngs/mock.rs b/src/rngs/mock.rs index a1745a490dd..649e76edb17 100644 --- a/src/rngs/mock.rs +++ b/src/rngs/mock.rs @@ -8,10 +8,9 @@ //! Mock random number generator -use rand_core::{impls, Error, RngCore}; +use rand_core::{impls, RngCore}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A simple implementation of `RngCore` for testing purposes. /// @@ -62,12 +61,6 @@ impl RngCore for StepRng { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest); } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[cfg(test)] @@ -82,6 +75,5 @@ mod tests { bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap(); assert_eq!(some_rng.v, de_some_rng.v); assert_eq!(some_rng.a, de_some_rng.a); - } } diff --git a/src/rngs/mod.rs b/src/rngs/mod.rs index ac3c2c595da..05b396bef8b 100644 --- a/src/rngs/mod.rs +++ b/src/rngs/mod.rs @@ -97,23 +97,26 @@ //! [`rng` tag]: https://crates.io/keywords/rng #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -#[cfg(feature = "std")] pub mod adapter; +#[cfg(feature = "std")] +pub mod adapter; pub mod mock; // Public so we don't export `StepRng` directly, making it a bit // more clear it is intended for testing. -#[cfg(all(feature = "small_rng", target_pointer_width = "64"))] -mod xoshiro256plusplus; +#[cfg(feature = "small_rng")] mod small; #[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))] mod xoshiro128plusplus; -#[cfg(feature = "small_rng")] mod small; +#[cfg(all(feature = "small_rng", target_pointer_width = "64"))] +mod xoshiro256plusplus; #[cfg(feature = "std_rng")] mod std; #[cfg(all(feature = "std", feature = "std_rng"))] pub(crate) mod thread; #[cfg(feature = "small_rng")] pub use self::small::SmallRng; #[cfg(feature = "std_rng")] pub use self::std::StdRng; -#[cfg(all(feature = "std", feature = "std_rng"))] pub use self::thread::ThreadRng; +#[cfg(all(feature = "std", feature = "std_rng"))] +pub use self::thread::ThreadRng; #[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] -#[cfg(feature = "getrandom")] pub use rand_core::OsRng; +#[cfg(feature = "getrandom")] +pub use rand_core::OsRng; diff --git a/src/rngs/small.rs b/src/rngs/small.rs index a3261757847..b4e094d7fd6 100644 --- a/src/rngs/small.rs +++ b/src/rngs/small.rs @@ -8,7 +8,7 @@ //! A small fast RNG -use rand_core::{Error, RngCore, SeedableRng}; +use rand_core::{RngCore, SeedableRng}; #[cfg(target_pointer_width = "64")] type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus; @@ -68,7 +68,7 @@ type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; /// // Create small, cheap to initialize and fast RNGs with random seeds. /// // One can generally assume this won't fail. /// let rngs: Vec = (0..10) -/// .map(|_| SmallRng::from_rng(&mut thread_rng).unwrap()) +/// .map(|_| SmallRng::from_rng(&mut thread_rng)) /// .collect(); /// ``` /// @@ -95,11 +95,6 @@ impl RngCore for SmallRng { fn fill_bytes(&mut self, dest: &mut [u8]) { self.0.fill_bytes(dest); } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } } impl SeedableRng for SmallRng { @@ -111,8 +106,8 @@ impl SeedableRng for SmallRng { } #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(SmallRng) + fn from_rng(rng: R) -> Self { + SmallRng(Rng::from_rng(rng)) } #[inline(always)] diff --git a/src/rngs/std.rs b/src/rngs/std.rs index cdae8fab01c..bcd617e4670 100644 --- a/src/rngs/std.rs +++ b/src/rngs/std.rs @@ -8,7 +8,7 @@ //! The standard RNG -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use crate::{CryptoRng, RngCore, SeedableRng}; pub(crate) use rand_chacha::ChaCha12Core as Core; @@ -19,7 +19,7 @@ use rand_chacha::ChaCha12Rng as Rng; /// (meaning a cryptographically secure PRNG). /// /// The current algorithm used is the ChaCha block cipher with 12 rounds. Please -/// see this relevant [rand issue] for the discussion. This may change as new +/// see this relevant [rand issue] for the discussion. This may change as new /// evidence of cipher security and performance becomes available. /// /// The algorithm is deterministic but should not be considered reproducible @@ -48,11 +48,6 @@ impl RngCore for StdRng { fn fill_bytes(&mut self, dest: &mut [u8]) { self.0.fill_bytes(dest); } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } } impl SeedableRng for StdRng { @@ -64,14 +59,13 @@ impl SeedableRng for StdRng { } #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(StdRng) + fn from_rng(rng: R) -> Self { + StdRng(Rng::from_rng(rng)) } } impl CryptoRng for StdRng {} - #[cfg(test)] mod test { use crate::rngs::StdRng; @@ -90,7 +84,7 @@ mod test { let mut rng0 = StdRng::from_seed(seed); let x0 = rng0.next_u64(); - let mut rng1 = StdRng::from_rng(rng0).unwrap(); + let mut rng1 = StdRng::from_rng(rng0); let x1 = rng1.next_u64(); assert_eq!([x0, x1], target); diff --git a/src/rngs/thread.rs b/src/rngs/thread.rs index 78cecde5755..a8e9c6da37f 100644 --- a/src/rngs/thread.rs +++ b/src/rngs/thread.rs @@ -9,9 +9,9 @@ //! Thread-local random number generator use core::cell::UnsafeCell; +use std::fmt; use std::rc::Rc; use std::thread_local; -use std::fmt; use super::std::Core; use crate::rngs::adapter::ReseedingRng; @@ -32,7 +32,6 @@ use crate::{CryptoRng, Error, RngCore, SeedableRng}; // `ThreadRng` internally, which is nonsensical anyway. We should also never run // `ThreadRng` in destructors of its implementation, which is also nonsensical. - // Number of generated bytes after which to reseed `ThreadRng`. // According to benchmarks, reseeding has a noticeable impact with thresholds // of 32 kB and less. We choose 64 kB to avoid significant overhead. @@ -81,8 +80,7 @@ thread_local!( // We require Rc<..> to avoid premature freeing when thread_rng is used // within thread-local destructors. See #968. static THREAD_RNG_KEY: Rc>> = { - let r = Core::from_rng(OsRng).unwrap_or_else(|err| - panic!("could not initialize thread_rng: {}", err)); + let r = Core::from_rng(OsRng); let rng = ReseedingRng::new(r, THREAD_RNG_RESEED_THRESHOLD, OsRng); @@ -142,18 +140,17 @@ impl RngCore for ThreadRng { let rng = unsafe { &mut *self.rng.get() }; rng.fill_bytes(dest) } +} - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { +impl CryptoRng for ThreadRng { + fn crypto_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { // SAFETY: We must make sure to stop using `rng` before anyone else // creates another mutable reference let rng = unsafe { &mut *self.rng.get() }; - rng.try_fill_bytes(dest) + rng.crypto_fill_bytes(dest) } } -impl CryptoRng for ThreadRng {} - - #[cfg(test)] mod test { #[test] @@ -168,6 +165,9 @@ mod test { fn test_debug_output() { // We don't care about the exact output here, but it must not include // private CSPRNG state or the cache stored by BlockRng! - assert_eq!(std::format!("{:?}", crate::thread_rng()), "ThreadRng { .. }"); + assert_eq!( + std::format!("{:?}", crate::thread_rng()), + "ThreadRng { .. }" + ); } } diff --git a/src/rngs/xoshiro128plusplus.rs b/src/rngs/xoshiro128plusplus.rs index ece98fafd6a..441cb3928fd 100644 --- a/src/rngs/xoshiro128plusplus.rs +++ b/src/rngs/xoshiro128plusplus.rs @@ -6,10 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; -use rand_core::impls::{next_u64_via_u32, fill_bytes_via_next}; +use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32}; use rand_core::le::read_u32_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A xoshiro128++ random number generator. /// @@ -20,7 +20,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro128plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Xoshiro128PlusPlus { s: [u32; 4], } @@ -89,12 +89,6 @@ impl RngCore for Xoshiro128PlusPlus { fn fill_bytes(&mut self, dest: &mut [u8]) { fill_bytes_via_next(self, dest); } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[cfg(test)] @@ -103,13 +97,13 @@ mod tests { #[test] fn reference() { - let mut rng = Xoshiro128PlusPlus::from_seed( - [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); + let mut rng = + Xoshiro128PlusPlus::from_seed([1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro128plusplus.c let expected = [ - 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, - 3867114732, 1355841295, 495546011, 621204420, + 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, 3867114732, 1355841295, + 495546011, 621204420, ]; for &e in &expected { assert_eq!(rng.next_u32(), e); diff --git a/src/rngs/xoshiro256plusplus.rs b/src/rngs/xoshiro256plusplus.rs index 8ffb18b8033..c9044cf436e 100644 --- a/src/rngs/xoshiro256plusplus.rs +++ b/src/rngs/xoshiro256plusplus.rs @@ -6,10 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; use rand_core::impls::fill_bytes_via_next; use rand_core::le::read_u64_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A xoshiro256++ random number generator. /// @@ -20,7 +20,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro256plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Xoshiro256PlusPlus { s: [u64; 4], } @@ -91,12 +91,6 @@ impl RngCore for Xoshiro256PlusPlus { fn fill_bytes(&mut self, dest: &mut [u8]) { fill_bytes_via_next(self, dest); } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[cfg(test)] @@ -105,15 +99,23 @@ mod tests { #[test] fn reference() { - let mut rng = Xoshiro256PlusPlus::from_seed( - [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, - 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); + let mut rng = Xoshiro256PlusPlus::from_seed([ + 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, + 0, 0, 0, + ]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro256plusplus.c let expected = [ - 41943041, 58720359, 3588806011781223, 3591011842654386, - 9228616714210784205, 9973669472204895162, 14011001112246962877, - 12406186145184390807, 15849039046786891736, 10450023813501588000, + 41943041, + 58720359, + 3588806011781223, + 3591011842654386, + 9228616714210784205, + 9973669472204895162, + 14011001112246962877, + 12406186145184390807, + 15849039046786891736, + 10450023813501588000, ]; for &e in &expected { assert_eq!(rng.next_u64(), e); diff --git a/src/seq/index.rs b/src/seq/index.rs index 7682facd51e..45418eeea5a 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -16,14 +16,15 @@ use alloc::collections::BTreeSet; #[cfg(feature = "std")] use std::collections::HashSet; -#[cfg(feature = "std")] -use crate::distributions::WeightedError; +#[cfg(feature = "std")] use crate::distributions::WeightedError; #[cfg(feature = "alloc")] -use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}}; +use crate::{ + distributions::{uniform::SampleUniform, Distribution, Uniform}, + Rng, +}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; /// A vector of indices. /// @@ -88,8 +89,8 @@ impl IndexVec { } impl IntoIterator for IndexVec { - type Item = usize; type IntoIter = IndexVecIntoIter; + type Item = usize; /// Convert into an iterator over the indices as a sequence of `usize` values #[inline] @@ -196,7 +197,6 @@ impl Iterator for IndexVecIntoIter { impl ExactSizeIterator for IndexVecIntoIter {} - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in random order (fully shuffled). /// @@ -290,7 +290,6 @@ where } } - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in an arbitrary order (there is no guarantee of shuffling or /// ordering). The weights are to be provided by the input function `weights`, @@ -331,8 +330,8 @@ where } impl Ord for Element { fn cmp(&self, other: &Self) -> core::cmp::Ordering { - // partial_cmp will always produce a value, - // because we check that the weights are not nan + // partial_cmp will always produce a value, + // because we check that the weights are not nan self.partial_cmp(other).unwrap() } } @@ -361,8 +360,7 @@ where // keys. Do this by using `select_nth_unstable` to put the elements with // the *smallest* keys at the beginning of the list in `O(n)` time, which // provides equivalent information about the elements with the *greatest* keys. - let (_, mid, greater) - = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); + let (_, mid, greater) = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); let mut result: Vec = Vec::with_capacity(amount.as_usize()); result.push(mid.index); @@ -436,8 +434,9 @@ where R: Rng + ?Sized { IndexVec::from(indices) } -trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform - + core::hash::Hash + core::ops::AddAssign { +trait UInt: + Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + core::hash::Hash + core::ops::AddAssign +{ fn zero() -> Self; fn one() -> Self; fn as_usize(self) -> usize; @@ -516,15 +515,18 @@ mod test { #[cfg(feature = "serde1")] fn test_serialization_index_vec() { let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]); - let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); + let de_some_index_vec: IndexVec = + bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); match (some_index_vec, de_some_index_vec) { (IndexVec::U32(a), IndexVec::U32(b)) => { assert_eq!(a, b); - }, + } (IndexVec::USize(a), IndexVec::USize(b)) => { assert_eq!(a, b); - }, - _ => {panic!("failed to seralize/deserialize `IndexVec`")} + } + _ => { + panic!("failed to seralize/deserialize `IndexVec`") + } } } @@ -602,7 +604,7 @@ mod test { for &i in &indices { assert!((i as usize) < len); } - }, + } IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), } } diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 420ef253d5e..73430f419a0 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -24,7 +24,6 @@ //! `usize` indices are sampled as a `u32` where possible (also providing a //! small performance boost in some cases). - #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod index; @@ -218,7 +217,6 @@ pub trait SliceRandom { /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); /// ``` /// [`choose_multiple`]: SliceRandom::choose_multiple - // // Note: this is feature-gated on std due to usage of f64::powf. // If necessary, we may use alloc+libm as an alternative (see PR #1089). #[cfg(feature = "std")] @@ -392,7 +390,7 @@ pub trait IteratorRandom: Iterator + Sized { let (lower, _) = self.size_hint(); if lower >= 2 { let highest_selected = (0..lower) - .filter(|ix| gen_index(rng, consumed+ix+1) == 0) + .filter(|ix| gen_index(rng, consumed + ix + 1) == 0) .last(); consumed += lower; @@ -407,10 +405,10 @@ pub trait IteratorRandom: Iterator + Sized { let elem = self.nth(next); if elem.is_none() { - return result + return result; } - if gen_index(rng, consumed+1) == 0 { + if gen_index(rng, consumed + 1) == 0 { result = elem; } consumed += 1; @@ -495,7 +493,6 @@ pub trait IteratorRandom: Iterator + Sized { } } - impl SliceRandom for [T] { type Item = T; @@ -621,7 +618,6 @@ impl SliceRandom for [T] { impl IteratorRandom for I where I: Iterator + Sized {} - /// An iterator over multiple slice elements. /// /// This struct is created by @@ -658,7 +654,6 @@ impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator } } - // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where // possible, primarily in order to produce the same output on 32-bit and 64-bit // platforms. @@ -671,7 +666,6 @@ fn gen_index(rng: &mut R, ubound: usize) -> usize { } } - #[cfg(test)] mod test { use super::*; @@ -932,33 +926,48 @@ mod test { } let reference = test_iter(0..9); - assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); + assert_eq!( + test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), + reference + ); #[cfg(feature = "alloc")] assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }), reference); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), + reference + ); } #[test] @@ -1260,9 +1269,13 @@ mod test { // Case 2: All of the weights are 0 let choices = [('a', 0), ('b', 0), ('c', 0)]; - assert_eq!(choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap().count(), 2); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .count(), + 2 + ); // Case 3: Negative weights let choices = [('a', -1), ('b', 1), ('c', 1)]; @@ -1275,9 +1288,13 @@ mod test { // Case 4: Empty list let choices = []; - assert_eq!(choices - .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) - .unwrap().count(), 0); + assert_eq!( + choices + .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) + .unwrap() + .count(), + 0 + ); // Case 5: NaN weights let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)];