diff --git a/base/random.jl b/base/random.jl index c883e9977c294..6b80fa50b1d2c 100644 --- a/base/random.jl +++ b/base/random.jl @@ -37,7 +37,10 @@ function gen_rand(r::MersenneTwister) mt_setfull!(r) end -@inline gen_rand_maybe(r::MersenneTwister) = mt_empty(r) && gen_rand(r) +@inline reserve_1(r::MersenneTwister) = mt_empty(r) && gen_rand(r) +# `reserve` allows to call `rand_inbounds` n times +# precondition: n <= MTCacheLength +@inline reserve(r::MersenneTwister, n::Int) = mt_avail(r) < n && gen_rand(r) abstract FloatInterval type CloseOpen <: FloatInterval end @@ -49,13 +52,11 @@ type Close1Open2 <: FloatInterval end @inline rand_inbounds(r::MersenneTwister) = rand_inbounds(r, CloseOpen) # produce Float64 values -@inline rand{I<:FloatInterval}(r::MersenneTwister, ::Type{I}) = (gen_rand_maybe(r); rand_inbounds(r, I)) +@inline rand{I<:FloatInterval}(r::MersenneTwister, ::Type{I}) = (reserve_1(r); rand_inbounds(r, I)) -# this is similar to `dsfmt_genrand_uint32` from dSFMT.h: -@inline rand_ui32(r::MersenneTwister) = reinterpret(UInt64, rand(r, Close1Open2)) % UInt32 - -@inline rand_ui52_raw(r::MersenneTwister) = reinterpret(UInt64, rand(r, Close1Open2)) -@inline rand_ui2x52_raw(r::MersenneTwister) = (((rand_ui52_raw(r) % UInt128) << 64) | rand_ui52_raw(r)) +@inline rand_ui52_raw_inbounds(r::MersenneTwister) = reinterpret(UInt64, rand_inbounds(r, Close1Open2)) +@inline rand_ui52_raw(r::MersenneTwister) = (reserve_1(r); rand_ui52_raw_inbounds(r)) +@inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r) function srand(r::MersenneTwister, seed::Vector{UInt32}) r.seed = seed @@ -165,15 +166,20 @@ rand{T<:Union(Float16, Float32)}(r::MersenneTwister, ::Type{T}) = convert(T, ran ## random integers (MersenneTwister) -rand(r::MersenneTwister, ::Type{UInt8}) = rand(r, UInt32) % UInt8 -rand(r::MersenneTwister, ::Type{UInt16}) = rand(r, UInt32) % UInt16 -rand(r::MersenneTwister, ::Type{UInt32}) = rand_ui32(r) -rand(r::MersenneTwister, ::Type{UInt64}) = uint64(rand(r, UInt32)) <<32 | rand(r, UInt32) -rand(r::MersenneTwister, ::Type{UInt128}) = uint128(rand(r, UInt64))<<64 | rand(r, UInt64) +rand{T<:Union(Int8, UInt8, Int16, UInt16, Int32, UInt32)}(r::MersenneTwister, ::Type{T}) = rand_ui52_raw(r) % T + +function rand(r::MersenneTwister, ::Type{UInt64}) + reserve(r, 2) + rand_ui52_raw_inbounds(r) << 32 $ rand_ui52_raw_inbounds(r) +end + +function rand(r::MersenneTwister, ::Type{UInt128}) + reserve(r, 3) + rand_ui52_raw_inbounds(r) % UInt128 << 96 $ + rand_ui52_raw_inbounds(r) % UInt128 << 48 $ + rand_ui52_raw_inbounds(r) +end -rand(r::MersenneTwister, ::Type{Int8}) = rand(r, UInt32) % Int8 -rand(r::MersenneTwister, ::Type{Int16}) = rand(r, UInt32) % Int16 -rand(r::MersenneTwister, ::Type{Int32}) = reinterpret(Int32, rand(r, UInt32)) rand(r::MersenneTwister, ::Type{Int64}) = reinterpret(Int64, rand(r, UInt64)) rand(r::MersenneTwister, ::Type{Int128}) = reinterpret(Int128, rand(r, UInt128)) diff --git a/test/random.jl b/test/random.jl index a017578ef966a..ecd66c1a618b5 100644 --- a/test/random.jl +++ b/test/random.jl @@ -25,7 +25,7 @@ A = zeros(2, 2) rand!(MersenneTwister(0), A) @test A == [0.8236475079774124 0.16456579813368521; 0.9103565379264364 0.17732884646626457] -@test rand(MersenneTwister(0), Int64, 1) == [172014471070449746] +@test rand(MersenneTwister(0), Int64, 1) == [4439861565447045202] A = zeros(Int64, 2, 2) rand!(MersenneTwister(0), A) @test A == [858542123778948672 5715075217119798169; @@ -215,7 +215,6 @@ let mt = MersenneTwister(0) 0x066d8695ebf85f833427c93416193e1f, 0x48fab49cc9fcee1c920d6dae629af446, 0x4b54632b4619f4eca22675166784d229][i] - end srand(mt,0) @@ -227,7 +226,7 @@ let mt = MersenneTwister(0) @test A[end] == Any[21,0x7b,17385,0x3086,-1574090021,0xadcb4460,6797283068698303107,0x4e91c9c4d4f5f759, -3482609696641744459568613291754091152,float16(0.03125),0.68733835f0][i] - @test B[end] == Any[49,0x65,-3725,0x719d,814246081,0xdf61843a,-1603010949539670188,0x5e4ca1658810985d, + @test B[end] == Any[49,0x65,-3725,0x719d,814246081,0xdf61843a,-3010919637398300844,0x61b367cf8810985d, -33032345278809823492812856023466859769,float16(0.9346),0.5929704f0][i] end