From 9392bbe347226fb6686d5ddb77717b0ab5142902 Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Sat, 3 Oct 2020 16:15:37 +0200 Subject: [PATCH] fix bug with rand(::MersenneTwister, ::UInt128) (#37808) Generation of UInt64 and UInt128 share the same cache, but the routine handling generation of UInt128 was not fully aknowledging the sharing. This leads to situations like: ``` julia> m = MersenneTwister(0); rand(m, UInt64); rand(m, UInt128) 0x79ed9db9ec79a6a019c5f638a776ab3c julia> rand(m, UInt64) 0x19c5f638a776ab3c ``` These values aren't independent enough! --- stdlib/Random/src/RNGs.jl | 7 +++---- stdlib/Random/test/runtests.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/stdlib/Random/src/RNGs.jl b/stdlib/Random/src/RNGs.jl index 281acad533dad..e9b7f152ef4fe 100644 --- a/stdlib/Random/src/RNGs.jl +++ b/stdlib/Random/src/RNGs.jl @@ -232,12 +232,11 @@ function mt_pop!(r::MersenneTwister, ::Type{T}) where T<:BitInteger (x128 >> (i128 * (sizeof(T) << 3))) % T end -# not necessary, but very slightly more efficient function mt_pop!(r::MersenneTwister, ::Type{T}) where {T<:Union{Int128,UInt128}} reserve1(r, T) - @inbounds res = r.ints[r.idxI >> 4] - r.idxI -= 16 - res % T + idx = r.idxI >> 4 + r.idxI = idx << 4 - 16 + @inbounds r.ints[idx] % T end diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index b78e3ae4b8a1f..2aeea0f623877 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -801,3 +801,35 @@ end @testset "RNGs broadcast as scalars: T" for T in (MersenneTwister, RandomDevice) @test length.(rand.(T(), 1:3)) == 1:3 end + +@testset "generated scalar integers do not overlap" begin + m = MersenneTwister() + xs = reinterpret(UInt64, m.ints) + x = rand(m, UInt128) # m.idxI % 16 == 0 + @test x % UInt64 == xs[end-1] + x = rand(m, UInt64) + @test x == xs[end-2] + x = rand(m, UInt64) + @test x == xs[end-3] + x = rand(m, UInt64) + @test x == xs[end-4] + x = rand(m, UInt128) # m.idxI % 16 == 8 + @test (x >> 64) % UInt64 == xs[end-6] + @test x % UInt64 == xs[end-7] + x = rand(m, UInt64) + @test x == xs[end-8] # should not be == xs[end-7] + + s = Set{UInt64}() + n = 0 + for _=1:2000 + x = rand(m, rand((UInt64, UInt128, Int64, Int128))) + if sizeof(x) == 8 + push!(s, x % UInt64) + n += 1 + else + push!(s, x % UInt64, (x >> 64) % UInt64) + n += 2 + end + end + @test length(s) == n +end