From 04c75b173762bfd40e2c64efeb127432a04e038a Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Tue, 12 Sep 2017 08:25:15 +0200 Subject: [PATCH] random: introduce `State` to formalize hooking into rand machinery --- base/random/RNGs.jl | 172 ++++++++++--------- base/random/generation.jl | 338 +++++++++++++------------------------- base/random/misc.jl | 3 +- base/random/random.jl | 106 ++++++++++++ test/random.jl | 13 +- 5 files changed, 324 insertions(+), 308 deletions(-) diff --git a/base/random/RNGs.jl b/base/random/RNGs.jl index 0514ace843499..11a0b56ee7340 100644 --- a/base/random/RNGs.jl +++ b/base/random/RNGs.jl @@ -2,8 +2,9 @@ ## RandomDevice -const BoolBitIntegerType = Union{Type{Bool},Base.BitIntegerType} -const BoolBitIntegerArray = Union{Array{Bool},Base.BitIntegerArray} + +StateTypes(U::Union) = Union{map(T->StateType{T}, Base.uniontypes(U))...} +const StateBoolBitInteger = StateTypes(Union{Bool, Base.BitInteger}) if Sys.iswindows() struct RandomDevice <: AbstractRNG @@ -12,15 +13,9 @@ if Sys.iswindows() RandomDevice() = new(Vector{UInt128}(1)) end - function rand(rd::RandomDevice, T::BoolBitIntegerType) + function rand(rd::RandomDevice, st::StateBoolBitInteger) rand!(rd, rd.buffer) - @inbounds return rd.buffer[1] % T - end - - function rand!(rd::RandomDevice, A::BoolBitIntegerArray) - ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32), - A, sizeof(A)) - A + @inbounds return rd.buffer[1] % st[] end else # !windows struct RandomDevice <: AbstractRNG @@ -31,10 +26,22 @@ else # !windows new(open(unlimited ? "/dev/urandom" : "/dev/random"), unlimited) end - rand(rd::RandomDevice, T::BoolBitIntegerType) = read( rd.file, T) - rand!(rd::RandomDevice, A::BoolBitIntegerArray) = read!(rd.file, A) + rand(rd::RandomDevice, st::StateBoolBitInteger) = read( rd.file, st[]) end # os-test +# NOTE: this can't be put in within the if-else block above +for T in (Bool, Base.BitInteger_types...) + if Sys.iswindows() + @eval function rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T}) + ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32), + A, sizeof(A)) + A + end + else + @eval rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T}) = read!(rd.file, A) + end +end + """ RandomDevice() @@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng ### generation of floats -rand(r::RandomDevice, I::FloatInterval) = rand_generic(r, I) +rand(r::RandomDevice, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[]) ## MersenneTwister @@ -229,30 +236,30 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r) #### floats -rand(r::MersenneTwister, I::FloatInterval_64) = (reserve_1(r); rand_inbounds(r, I)) +rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval_64}) = (reserve_1(r); rand_inbounds(r, st[])) -rand(r::MersenneTwister, I::FloatInterval) = rand_generic(r, I) +rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[]) #### integers rand(r::MersenneTwister, - ::Type{T}) where {T<:Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}} = - rand_ui52_raw(r) % T + T::StateTypes(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) = + rand_ui52_raw(r) % T[] -function rand(r::MersenneTwister, ::Type{UInt64}) +function rand(r::MersenneTwister, ::StateType{UInt64}) reserve(r, 2) rand_ui52_raw_inbounds(r) << 32 ⊻ rand_ui52_raw_inbounds(r) end -function rand(r::MersenneTwister, ::Type{UInt128}) +function rand(r::MersenneTwister, ::StateType{UInt128}) reserve(r, 3) xor(rand_ui52_raw_inbounds(r) % UInt128 << 96, rand_ui52_raw_inbounds(r) % UInt128 << 48, rand_ui52_raw_inbounds(r)) end -rand(r::MersenneTwister, ::Type{Int64}) = reinterpret(Int64, rand(r, UInt64)) -rand(r::MersenneTwister, ::Type{Int128}) = reinterpret(Int128, rand(r, UInt128)) +rand(r::MersenneTwister, ::StateType{Int64}) = reinterpret(Int64, rand(r, UInt64)) +rand(r::MersenneTwister, ::StateType{Int128}) = reinterpret(Int128, rand(r, UInt128)) #### arrays of floats @@ -278,7 +285,8 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6 A end -rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A) +rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::StateTrivial{<:FloatInterval_64}) = + rand_AbstractArray_Float64!(r, A, length(A), I[]) fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) = dsfmt_fill_array_close_open!(s, A, n) @@ -286,8 +294,8 @@ fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) = fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) = dsfmt_fill_array_close1_open2!(s, A, n) -function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A), - I::FloatInterval_64=CloseOpen()) +function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int, + I::FloatInterval_64) # depending on the alignment of A, the data written by fill_array! may have # to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for # reproducibility purposes; @@ -317,65 +325,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A), A end +rand!(r::MersenneTwister, A::Array{Float64}, st::StateTrivial{<:FloatInterval_64}) = + _rand!(r, A, length(A), st[]) + mask128(u::UInt128, ::Type{Float16}) = (u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00 mask128(u::UInt128, ::Type{Float32}) = (u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000 -function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}, - ::Close1Open2_64) - T = eltype(A) - n = length(A) - n128 = n * sizeof(T) ÷ 16 - Base.@gc_preserve A rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128), - 2*n128, Close1Open2()) - # FIXME: This code is completely invalid!!! - A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128) - @inbounds for i in 1:n128 - u = A128[i] - u ⊻= u << 26 - # at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+" - # the bit xor, are: - # [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1] - # the bits needing to be random are - # [1:10, 17:26, 33:42, 49:58] (for Float16) - # [1:23, 33:55] (for Float32) - # this is obviously satisfied on the 32 low bits side, and on the high side, - # the entropy comes from bits 33:52 of A128[i] and then from bits 27:32 - # (which are discarded on the low side) - # this is similar for the 64 high bits of u - A128[i] = mask128(u, T) - end - for i in 16*n128÷sizeof(T)+1:n - @inbounds A[i] = rand(r, T) + oneunit(T) +for T in (Float16, Float32) + @eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{Close1Open2{$T}}) + n = length(A) + n128 = n * sizeof($T) ÷ 16 + Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128), + 2*n128, Close1Open2()) + # FIXME: This code is completely invalid!!! + A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128) + @inbounds for i in 1:n128 + u = A128[i] + u ⊻= u << 26 + # at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+" + # the bit xor, are: + # [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1] + # the bits needing to be random are + # [1:10, 17:26, 33:42, 49:58] (for Float16) + # [1:23, 33:55] (for Float32) + # this is obviously satisfied on the 32 low bits side, and on the high side, + # the entropy comes from bits 33:52 of A128[i] and then from bits 27:32 + # (which are discarded on the low side) + # this is similar for the 64 high bits of u + A128[i] = mask128(u, $T) + end + for i in 16*n128÷sizeof($T)+1:n + @inbounds A[i] = rand(r, $T) + oneunit($T) + end + A end - A -end -function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}, ::CloseOpen_64) - rand!(r, A, Close1Open2()) - I32 = one(Float32) - for i in eachindex(A) - @inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16 + @eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{CloseOpen{$T}}) + rand!(r, A, Close1Open2($T)) + I32 = one(Float32) + for i in eachindex(A) + @inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16 + end + A end - A end -rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}) = - rand!(r, A, CloseOpen()) - #### arrays of integers -function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A)) - if n > length(A) - throw(BoundsError(A,n)) - end +function rand!(r::MersenneTwister, A::Array{UInt128}, ::StateType{UInt128}) + n::Int=length(A) # FIXME: This code is completely invalid!!! Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n) i = n while true - rand!(r, Af, 2i, Close1Open2()) + _rand!(r, Af, 2i, Close1Open2()) n < 5 && break i = 0 @inbounds while n-i >= 5 @@ -396,17 +402,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A)) A end -# A::Array{UInt128} will match the specialized method above -function rand!(r::MersenneTwister, A::Base.BitIntegerArray) - n = length(A) - T = eltype(A) - n128 = n * sizeof(T) ÷ 16 - # FIXME: This code is completely invalid!!! - rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)) - for i = 16*n128÷sizeof(T)+1:n - @inbounds A[i] = rand(r, T) +for T in Base.BitInteger_types + T === UInt128 && continue + @eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateType{$T}) + n = length(A) + n128 = n * sizeof($T) ÷ 16 + # FIXME: This code is completely invalid!!! + rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)) + for i = 16*n128÷sizeof($T)+1:n + @inbounds A[i] = rand(r, $T) + end + A end - A end #### from a range @@ -418,7 +425,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer end end -function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInteger64,Bool} +function rand(rng::MersenneTwister, + st::StateTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool} + r = st[] isempty(r) && throw(ArgumentError("range must be non-empty")) m = last(r) % UInt64 - first(r) % UInt64 bw = (64 - leading_zeros(m)) % UInt # bit-width @@ -428,7 +437,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte (x + first(r) % UInt64) % T end -function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt128} +function rand(rng::MersenneTwister, + st::StateTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128} + r = st[] isempty(r) && throw(ArgumentError("range must be non-empty")) m = (last(r)-first(r)) % UInt128 bw = (128 - leading_zeros(m)) % UInt # bit-width @@ -439,6 +450,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1 x % T + first(r) end +for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise + @eval State(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) = + StateTrivial(r) +end + ### randjump diff --git a/base/random/generation.jl b/base/random/generation.jl index 413b714478701..f86d62a96bff2 100644 --- a/base/random/generation.jl +++ b/base/random/generation.jl @@ -4,15 +4,10 @@ ## from types: rand(::Type, [dims...]) -### GLOBAL_RNG fallback for all types - -rand(::Type{T}) where {T} = rand(GLOBAL_RNG, T) - ### random floats -# CloseOpen(T) is the fallback for an AbstractFloat T -rand(r::AbstractRNG=GLOBAL_RNG, ::Type{T}=Float64) where {T<:AbstractFloat} = - rand(r, CloseOpen(T)) +State(rng::AbstractRNG, ::Type{T}, n::Repetition) where {T<:AbstractFloat} = + State(rng, CloseOpen(T), n) # generic random generation function which can be used by RNG implementors # it is not defined as a fallback rand method as this could create ambiguities @@ -34,13 +29,13 @@ rand_generic(r::AbstractRNG, ::CloseOpen_64) = rand(r, Close1Open2()) - 1.0 const bits_in_Limb = sizeof(Limb) << 3 const Limb_high_bit = one(Limb) << (bits_in_Limb-1) -struct BigFloatRandGenerator +struct StateBigFloat{I<:FloatInterval{BigFloat}} <: State prec::Int nlimbs::Int limbs::Vector{Limb} shift::UInt - function BigFloatRandGenerator(prec::Int=precision(BigFloat)) + function StateBigFloat{I}(prec::Int) where I<:FloatInterval{BigFloat} nlimbs = (prec-1) ÷ bits_in_Limb + 1 limbs = Vector{Limb}(nlimbs) shift = nlimbs * bits_in_Limb - prec @@ -48,28 +43,31 @@ struct BigFloatRandGenerator end end -function _rand(rng::AbstractRNG, gen::BigFloatRandGenerator) +State(::AbstractRNG, I::FloatInterval{BigFloat}, ::Repetition) = + StateBigFloat{typeof(I)}(precision(BigFloat)) + +function _rand(rng::AbstractRNG, st::StateBigFloat) z = BigFloat() - limbs = gen.limbs + limbs = st.limbs rand!(rng, limbs) @inbounds begin - limbs[1] <<= gen.shift + limbs[1] <<= st.shift randbool = iszero(limbs[end] & Limb_high_bit) limbs[end] |= Limb_high_bit end z.sign = 1 - Base.@gc_preserve limbs unsafe_copy!(z.d, pointer(limbs), gen.nlimbs) + Base.@gc_preserve limbs unsafe_copy!(z.d, pointer(limbs), st.nlimbs) (z, randbool) end -function rand(rng::AbstractRNG, gen::BigFloatRandGenerator, ::Close1Open2{BigFloat}) - z = _rand(rng, gen)[1] +function _rand(rng::AbstractRNG, st::StateBigFloat, ::Close1Open2{BigFloat}) + z = _rand(rng, st)[1] z.exp = 1 z end -function rand(rng::AbstractRNG, gen::BigFloatRandGenerator, ::CloseOpen{BigFloat}) - z, randbool = _rand(rng, gen) +function _rand(rng::AbstractRNG, st::StateBigFloat, ::CloseOpen{BigFloat}) + z, randbool = _rand(rng, st) z.exp = 0 randbool && ccall((:mpfr_sub_d, :libmpfr), Int32, @@ -80,15 +78,15 @@ end # alternative, with 1 bit less of precision # TODO: make an API for requesting full or not-full precision -function rand(rng::AbstractRNG, gen::BigFloatRandGenerator, ::CloseOpen{BigFloat}, ::Void) - z = rand(rng, Close1Open2(BigFloat), gen) +function _rand(rng::AbstractRNG, st::StateBigFloat, ::CloseOpen{BigFloat}, ::Void) + z = _rand(rng, st, Close1Open2(BigFloat)) ccall((:mpfr_sub_ui, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Culong, Int32), z, z, 1, Base.MPFR.ROUNDING_MODE[]) z end -rand_generic(rng::AbstractRNG, I::FloatInterval{BigFloat}) = - rand(rng, BigFloatRandGenerator(), I) +rand(rng::AbstractRNG, st::StateBigFloat{T}) where {T<:FloatInterval{BigFloat}} = + _rand(rng, st, T()) ### random integers @@ -100,76 +98,21 @@ rand_ui52(r::AbstractRNG) = rand_ui52_raw(r) & 0x000fffffffffffff ### random complex numbers -rand(r::AbstractRNG, ::Type{Complex{T}}) where {T<:Real} = complex(rand(r, T), rand(r, T)) +rand(r::AbstractRNG, ::StateType{Complex{T}}) where {T<:Real} = + complex(rand(r, T), rand(r, T)) ### random characters # returns a random valid Unicode scalar value (i.e. 0 - 0xd7ff, 0xe000 - # 0x10ffff) -function rand(r::AbstractRNG, ::Type{Char}) +function rand(r::AbstractRNG, ::StateType{Char}) c = rand(r, 0x00000000:0x0010f7ff) (c < 0xd800) ? Char(c) : Char(c+0x800) end -### arrays of random numbers - -function rand!(r::AbstractRNG, A::AbstractArray{T}, ::Type{X}=T) where {T,X} - for i in eachindex(A) - @inbounds A[i] = rand(r, X) - end - A -end - -rand!(A::AbstractArray, ::Type{X}) where {X} = rand!(GLOBAL_RNG, A, X) -# NOTE: if the second parameter above is defaulted to eltype(A) and the -# method below is removed, then some specialized methods (e.g. for -# rand!(::Array{Float64})) will fail to be called -rand!(A::AbstractArray) = rand!(GLOBAL_RNG, A) - - -rand(r::AbstractRNG, dims::Dims) = rand(r, Float64, dims) -rand( dims::Dims) = rand(GLOBAL_RNG, dims) -rand(r::AbstractRNG, dims::Integer...) = rand(r, Dims(dims)) -rand( dims::Integer...) = rand(Dims(dims)) - -rand(r::AbstractRNG, ::Type{T}, dims::Dims) where {T} = rand!(r, Array{T}(dims)) -rand( ::Type{T}, dims::Dims) where {T} = rand(GLOBAL_RNG, T, dims) - -rand(r::AbstractRNG, ::Type{T}, d::Integer, dims::Integer...) where {T} = - rand(r, T, Dims((d, dims...))) - -rand( ::Type{T}, d::Integer, dims::Integer...) where {T} = - rand(T, Dims((d, dims...))) -# note: the above methods would trigger an ambiguity warning if d was not separated out: -# rand(r, ()) would match both this method and rand(r, dims::Dims) -# moreover, a call like rand(r, NotImplementedType()) would be an infinite loop - -#### arrays of floats - -rand!(r::AbstractRNG, A::AbstractArray, ::Type{T}) where {T<:AbstractFloat} = - rand!(r, A, CloseOpen{T}()) - -function rand!(r::AbstractRNG, A::AbstractArray, I::FloatInterval) - for i in eachindex(A) - @inbounds A[i] = rand(r, I) - end - A -end - -function rand!(rng::AbstractRNG, A::AbstractArray, I::FloatInterval{BigFloat}) - gen = BigFloatRandGenerator() - for i in eachindex(A) - @inbounds A[i] = rand(rng, gen, I) - end - A -end - -rand!(A::AbstractArray, I::FloatInterval) = rand!(GLOBAL_RNG, A, I) ## Generate random integer within a range -abstract type RangeGenerator end - -### RangeGenerator for BitInteger +### BitInteger # remainder function according to Knuth, where rem_knuth(a, 0) = a rem_knuth(a::UInt, b::UInt) = a % (b + (b == 0)) + a * (b == 0) @@ -186,232 +129,181 @@ maxmultiplemix(k::UInt64) = k >> 32 != 0 ? maxmultiple(k) : (div(0x0000000100000000, k + (k == 0))*k - oneunit(k))::UInt64 -struct RangeGeneratorInt{T<:Integer,U<:Unsigned} <: RangeGenerator +struct StateRangeInt{T<:Integer,U<:Unsigned} <: State a::T # first element of the range k::U # range length or zero for full range u::U # rejection threshold end # generators with 32, 128 bits entropy -RangeGeneratorInt(a::T, k::U) where {T,U<:Union{UInt32,UInt128}} = - RangeGeneratorInt{T,U}(a, k, maxmultiple(k)) +StateRangeInt(a::T, k::U) where {T,U<:Union{UInt32,UInt128}} = + StateRangeInt{T,U}(a, k, maxmultiple(k)) # mixed 32/64 bits entropy generator -RangeGeneratorInt(a::T, k::UInt64) where {T} = - RangeGeneratorInt{T,UInt64}(a, k, maxmultiplemix(k)) +StateRangeInt(a::T, k::UInt64) where {T} = + StateRangeInt{T,UInt64}(a, k, maxmultiplemix(k)) -function RangeGenerator(r::UnitRange{T}) where T<:Unsigned +function State(::AbstractRNG, r::UnitRange{T}, ::Repetition) where T<:Unsigned isempty(r) && throw(ArgumentError("range must be non-empty")) - RangeGeneratorInt(first(r), last(r) - first(r) + oneunit(T)) + StateRangeInt(first(r), last(r) - first(r) + oneunit(T)) end for (T, U) in [(UInt8, UInt32), (UInt16, UInt32), (Int8, UInt32), (Int16, UInt32), (Int32, UInt32), (Int64, UInt64), (Int128, UInt128), (Bool, UInt32)] - @eval RangeGenerator(r::UnitRange{$T}) = begin + @eval State(::AbstractRNG, r::UnitRange{$T}, ::Repetition) = begin isempty(r) && throw(ArgumentError("range must be non-empty")) # overflow ok: - RangeGeneratorInt(first(r), convert($U, unsigned(last(r) - first(r)) + one($U))) + StateRangeInt(first(r), convert($U, unsigned(last(r) - first(r)) + one($U))) end end -### RangeGenerator for BigInt - -struct RangeGeneratorBigInt <: RangeGenerator - a::BigInt # first - m::BigInt # range length - 1 - nlimbs::Int # number of limbs in generated BigInt's (z ∈ [0, m]) - nlimbsmax::Int # max number of limbs for z+a - mask::Limb # applied to the highest limb -end - - -function RangeGenerator(r::UnitRange{BigInt}) - m = last(r) - first(r) - m < 0 && throw(ArgumentError("range must be non-empty")) - nd = ndigits(m, 2) - nlimbs, highbits = divrem(nd, 8*sizeof(Limb)) - highbits > 0 && (nlimbs += 1) - mask = highbits == 0 ? ~zero(Limb) : one(Limb)<> 32 == 0 + if (st.k - 1) >> 32 == 0 x = rand(rng, UInt32) - while x > g.u + while x > st.u x = rand(rng, UInt32) end else x = rand(rng, UInt64) - while x > g.u + while x > st.u x = rand(rng, UInt64) end end - return reinterpret(T, reinterpret(UInt64, g.a) + rem_knuth(x, g.k)) + return reinterpret(T, reinterpret(UInt64, st.a) + rem_knuth(x, st.k)) end -function rand(rng::AbstractRNG, g::RangeGeneratorInt{T,U}) where {T<:Integer,U<:Unsigned} +function rand(rng::AbstractRNG, st::StateRangeInt{T,U}) where {T<:Integer,U<:Unsigned} x = rand(rng, U) - while x > g.u + while x > st.u x = rand(rng, U) end - (unsigned(g.a) + rem_knuth(x, g.k)) % T + (unsigned(st.a) + rem_knuth(x, st.k)) % T end -function rand(rng::AbstractRNG, g::RangeGeneratorBigInt) - x = MPZ.realloc2(g.nlimbsmax*8*sizeof(Limb)) - limbs = unsafe_wrap(Array, x.d, g.nlimbs) +### BigInt + +struct StateBigInt <: State + a::BigInt # first + m::BigInt # range length - 1 + nlimbs::Int # number of limbs in generated BigInt's (z ∈ [0, m]) + nlimbsmax::Int # max number of limbs for z+a + mask::Limb # applied to the highest limb +end + +function State(::AbstractRNG, r::UnitRange{BigInt}, ::Repetition) + m = last(r) - first(r) + m < 0 && throw(ArgumentError("range must be non-empty")) + nd = ndigits(m, 2) + nlimbs, highbits = divrem(nd, 8*sizeof(Limb)) + highbits > 0 && (nlimbs += 1) + mask = highbits == 0 ? ~zero(Limb) : one(Limb)<= 6) - x.size = g.nlimbs + x.size = st.nlimbs while x.size > 0 @inbounds limbs[x.size] != 0 && break x.size -= 1 end - MPZ.add!(x, g.a) + MPZ.add!(x, st.a) end -#### arrays -function rand!(rng::AbstractRNG, A::AbstractArray, g::RangeGenerator) - for i in eachindex(A) - @inbounds A[i] = rand(rng, g) - end - return A -end - -### random values from UnitRange - -rand(rng::AbstractRNG, r::UnitRange{<:Integer}) = rand(rng, RangeGenerator(r)) +## random values from AbstractArray -rand!(rng::AbstractRNG, A::AbstractArray, r::UnitRange{<:Integer}) = - rand!(rng, A, RangeGenerator(r)) +State(rng::AbstractRNG, r::AbstractArray, n::Repetition) = + StateSimple(r, State(rng, 1:length(r), n)) -## random values from AbstractArray +rand(rng::AbstractRNG, st::StateSimple{<:AbstractArray,<:State}) = + @inbounds return st[][rand(rng, st.state)] -rand(rng::AbstractRNG, r::AbstractArray) = @inbounds return r[rand(rng, 1:length(r))] -rand( r::AbstractArray) = rand(GLOBAL_RNG, r) -### arrays +## random values from Dict, Set, IntSet -function rand!(rng::AbstractRNG, A::AbstractArray, r::AbstractArray) - g = RangeGenerator(1:(length(r))) - for i in eachindex(A) - @inbounds A[i] = r[rand(rng, g)] +for x in (1, Inf) # eval because of ambiguity otherwise + for T in (Dict, Set, IntSet) + @eval State(::AbstractRNG, t::$T, ::Val{$x}) = StateTrivial(t) end - return A end -rand!(A::AbstractArray, r::AbstractArray) = rand!(GLOBAL_RNG, A, r) - -rand(rng::AbstractRNG, r::AbstractArray{T}, dims::Dims) where {T} = - rand!(rng, Array{T}(dims), r) -rand( r::AbstractArray, dims::Dims) = rand(GLOBAL_RNG, r, dims) -rand(rng::AbstractRNG, r::AbstractArray, dims::Integer...) = rand(rng, r, Dims(dims)) -rand( r::AbstractArray, dims::Integer...) = rand(GLOBAL_RNG, r, Dims(dims)) - - -## random values from Dict, Set, IntSet - -function rand(r::AbstractRNG, t::Dict) - isempty(t) && throw(ArgumentError("collection must be non-empty")) - rg = RangeGenerator(1:length(t.slots)) +function rand(rng::AbstractRNG, st::StateTrivial{<:Dict}) + isempty(st[]) && throw(ArgumentError("collection must be non-empty")) + rst = State(rng, 1:length(st[].slots)) while true - i = rand(r, rg) - Base.isslotfilled(t, i) && @inbounds return (t.keys[i] => t.vals[i]) + i = rand(rng, rst) + Base.isslotfilled(st[], i) && @inbounds return (st[].keys[i] => st[].vals[i]) end end -rand(r::AbstractRNG, s::Set) = rand(r, s.dict).first +rand(rng::AbstractRNG, st::StateTrivial{<:Set}) = rand(rng, st[].dict).first -function rand(r::AbstractRNG, s::IntSet) - isempty(s) && throw(ArgumentError("collection must be non-empty")) - # s can be empty while s.bits is not, so we cannot rely on the - # length check in RangeGenerator below - rg = RangeGenerator(1:length(s.bits)) +function rand(rng::AbstractRNG, st::StateTrivial{IntSet}) + isempty(st[]) && throw(ArgumentError("collection must be non-empty")) + # st[] can be empty while st[].bits is not, so we cannot rely on the + # length check in State below + rst = State(rng, 1:length(st[].bits)) while true - n = rand(r, rg) - @inbounds b = s.bits[n] + n = rand(rng, rst) + @inbounds b = st[].bits[n] b && return n end end -function nth(iter, n::Integer)::eltype(iter) - for (i, x) in enumerate(iter) - i == n && return x - end -end -nth(iter::AbstractArray, n::Integer) = iter[n] +## random values from Associative/AbstractSet -rand(r::AbstractRNG, s::Union{Associative,AbstractSet}) = nth(s, rand(r, 1:length(s))) +# avoid linear complexity for repeated calls +State(rng::AbstractRNG, t::Union{Associative,AbstractSet}, n::Repetition) = + State(rng, collect(t), n) -rand(s::Union{Associative,AbstractSet}) = rand(GLOBAL_RNG, s) +# when generating only one element, avoid the call to collect +State(::AbstractRNG, t::Union{Associative,AbstractSet}, ::Val{1}) = + StateTrivial(t) -### arrays - -function rand!(r::AbstractRNG, A::AbstractArray, s::Union{Dict,Set,IntSet}) - for i in eachindex(A) - @inbounds A[i] = rand(r, s) +function nth(iter, n::Integer)::eltype(iter) + for (i, x) in enumerate(iter) + i == n && return x end - A end -# avoid linear complexity for repeated calls with generic containers -rand!(r::AbstractRNG, A::AbstractArray, s::Union{Associative,AbstractSet}) = - rand!(r, A, collect(s)) - -rand!(A::AbstractArray, s::Union{Associative,AbstractSet}) = rand!(GLOBAL_RNG, A, s) +rand(rng::AbstractRNG, st::StateTrivial{<:Union{Associative,AbstractSet}}) = + nth(st[], rand(rng, 1:length(st[]))) -rand(r::AbstractRNG, s::Associative{K,V}, dims::Dims) where {K,V} = - rand!(r, Array{Pair{K,V}}(dims), s) -rand(r::AbstractRNG, s::AbstractSet{T}, dims::Dims) where {T} = rand!(r, Array{T}(dims), s) -rand(r::AbstractRNG, s::Union{Associative,AbstractSet}, dims::Integer...) = - rand(r, s, Dims(dims)) -rand(s::Union{Associative,AbstractSet}, dims::Integer...) = rand(GLOBAL_RNG, s, Dims(dims)) -rand(s::Union{Associative,AbstractSet}, dims::Dims) = rand(GLOBAL_RNG, s, dims) +## random characters from a string +# we use collect(str), which is most of the time more efficient than specialized methods +# (except maybe for very small arrays) +State(rng::AbstractRNG, str::AbstractString, n::Repetition) = State(rng, collect(str), n) -## random characters from a string +# when generating only one char from a string, the specialized method below +# is usually more efficient +State(::AbstractRNG, str::AbstractString, ::Val{1}) = StateTrivial(str) isvalid_unsafe(s::String, i) = !Base.is_valid_continuation(Base.@gc_preserve s unsafe_load(pointer(s), i)) isvalid_unsafe(s::AbstractString, i) = isvalid(s, i) _endof(s::String) = sizeof(s) _endof(s::AbstractString) = endof(s) -function rand(rng::AbstractRNG, s::AbstractString)::Char - g = RangeGenerator(1:_endof(s)) +function rand(rng::AbstractRNG, st::StateTrivial{<:AbstractString})::Char + str = st[] + st_pos = State(rng, 1:_endof(str)) while true - pos = rand(rng, g) - isvalid_unsafe(s, pos) && return s[pos] + pos = rand(rng, st_pos) + isvalid_unsafe(str, pos) && return str[pos] end end - -rand(s::AbstractString) = rand(GLOBAL_RNG, s) - -### arrays - -# we use collect(str), which is most of the time more efficient than specialized methods -# (except maybe for very small arrays) -rand!(rng::AbstractRNG, A::AbstractArray, str::AbstractString) = rand!(rng, A, collect(str)) -rand!(A::AbstractArray, str::AbstractString) = rand!(GLOBAL_RNG, A, str) -rand(rng::AbstractRNG, str::AbstractString, dims::Dims) = - rand!(rng, Array{eltype(str)}(dims), str) - -rand(rng::AbstractRNG, str::AbstractString, d::Integer, dims::Integer...) = - rand(rng, str, Dims((d, dims...))) - -rand(str::AbstractString, dims::Dims) = rand(GLOBAL_RNG, str, dims) -rand(str::AbstractString, d::Integer, dims::Integer...) = rand(GLOBAL_RNG, str, d, dims...) diff --git a/base/random/misc.jl b/base/random/misc.jl index c749211aed81b..ad402075e5176 100644 --- a/base/random/misc.jl +++ b/base/random/misc.jl @@ -145,8 +145,7 @@ randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p) "Return a random `Int` (masked with `mask`) in ``[0, n)``, when `n <= 2^52`." @inline function rand_lt(r::AbstractRNG, n::Int, mask::Int=nextpow2(n)-1) - # this duplicates the functionality of RangeGenerator objects, - # to optimize this special case + # this duplicates the functionality of rand(1:n), to optimize this special case while true x = (rand_ui52_raw(r) % Int) & mask x < n && return x diff --git a/base/random/random.jl b/base/random/random.jl index 9d3cc99bc69fe..177831dffadd5 100644 --- a/base/random/random.jl +++ b/base/random/random.jl @@ -20,8 +20,12 @@ export srand, GLOBAL_RNG, randjump +## general definitions + abstract type AbstractRNG end +### floats + abstract type FloatInterval{T<:AbstractFloat} end struct CloseOpen{ T<:AbstractFloat} <: FloatInterval{T} end # interval [0,1) @@ -34,8 +38,110 @@ const Close1Open2_64 = Close1Open2{Float64} CloseOpen( ::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen{T}() Close1Open2(::Type{T}=Float64) where {T<:AbstractFloat} = Close1Open2{T}() +Base.eltype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T + const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}} +### State + +abstract type State end + +# temporarily for BaseBenchmarks +RangeGenerator(x) = State(GLOBAL_RNG, x) + +# In some cases, when only 1 random value is to be generated, +# the optimal sampler can be different than if multiple values +# have to be generated. Hence a `Repetition` parameter is used +# to choose the best one depending on the need. +const Repetition = Union{Val{1},Val{Inf}} + +# these default fall-back for all RNGs would be nice, +# but generate difficult-to-solve ambiguities +# State(::AbstractRNG, X, ::Val{Inf}) = State(X) +# State(::AbstractRNG, ::Type{X}, ::Val{Inf}) where {X} = State(X) + +State(rng::AbstractRNG, st::State, ::Repetition) = + throw(ArgumentError("State for this object is not defined")) + +# default shortcut for the general case +State(rng::AbstractRNG, X) = State(rng, X, Val(Inf)) +State(rng::AbstractRNG, ::Type{X}) where {X} = State(rng, X, Val(Inf)) + +#### pre-defined useful State subtypes + +# default fall-back for types +struct StateType{T} <: State end + +State(::AbstractRNG, ::Type{T}, ::Repetition) where {T} = StateType{T}() + +Base.getindex(st::StateType{T}) where {T} = T + +# default fall-back for values +struct StateTrivial{T} <: State + self::T +end + +State(::AbstractRNG, X, ::Repetition) = StateTrivial(X) + +Base.getindex(st::StateTrivial) = st.self + +struct StateSimple{T,S} <: State + self::T + state::S +end + +Base.getindex(st::StateSimple) = st.self + + +### machinery for generation with State + +#### scalars + +rand(rng::AbstractRNG, X) = rand(rng, State(rng, X, Val(1))) +rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{X}=Float64) where {X} = + rand(rng, State(rng, X, Val(1))) + +rand(X) = rand(GLOBAL_RNG, X) +rand(::Type{X}) where {X} = rand(GLOBAL_RNG, X) + +#### arrays + +rand!(A::AbstractArray{T}, X) where {T} = rand!(GLOBAL_RNG, A, X) +rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(GLOBAL_RNG, A, X) + +rand!(rng::AbstractRNG, A::AbstractArray{T}, X) where {T} = rand!(rng, A, State(rng, X)) +rand!(rng::AbstractRNG, A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(rng, A, State(rng, X)) + +function rand!(rng::AbstractRNG, A::AbstractArray{T}, st::State) where T + for i in eachindex(A) + @inbounds A[i] = rand(rng, st) + end + A +end + +rand(r::AbstractRNG, dims::Dims) = rand(r, Float64, dims) +rand( dims::Dims) = rand(GLOBAL_RNG, dims) +rand(r::AbstractRNG, dims::Integer...) = rand(r, Dims(dims)) +rand( dims::Integer...) = rand(Dims(dims)) + +rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{eltype(X)}(dims), X) +rand( X, dims::Dims) = rand(GLOBAL_RNG, X, dims) + +rand(r::AbstractRNG, X, d::Integer, dims::Integer...) = rand(r, X, Dims((d, dims...))) +rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...))) +# note: the above methods would trigger an ambiguity warning if d was not separated out: +# rand(r, ()) would match both this method and rand(r, dims::Dims) +# moreover, a call like rand(r, NotImplementedType()) would be an infinite loop + +rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{eltype(X)}(dims), X) +rand( ::Type{X}, dims::Dims) where {X} = rand(GLOBAL_RNG, X, dims) + +rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r, X, Dims((d, dims...))) +rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X, Dims((d, dims...))) + + +## __init__ & include + function __init__() try srand() diff --git a/test/random.jl b/test/random.jl index 5851567a64aca..d9e3f2d574acb 100644 --- a/test/random.jl +++ b/test/random.jl @@ -34,7 +34,7 @@ let A = zeros(2, 2) 0.9103565379264364 0.17732884646626457] end let A = zeros(2, 2) - @test_throws BoundsError rand!(MersenneTwister(0), A, 5) + @test_throws ArgumentError rand!(MersenneTwister(0), A, 5) @test rand(MersenneTwister(0), Int64, 1) == [4439861565447045202] end let A = zeros(Int64, 2, 2) @@ -42,9 +42,6 @@ let A = zeros(Int64, 2, 2) @test A == [858542123778948672 5715075217119798169; 8690327730555225005 8435109092665372532] end -let A = zeros(UInt128, 2, 2) - @test_throws BoundsError rand!(MersenneTwister(0), A, 5) -end # rand from AbstractArray let mt = MersenneTwister() @@ -601,7 +598,7 @@ let b = ['0':'9';'A':'Z';'a':'z'] end # this shouldn't crash (#22403) -@test_throws MethodError rand!(Union{UInt,Int}[1, 2, 3]) +@test_throws ArgumentError rand!(Union{UInt,Int}[1, 2, 3]) @testset "$RNG() & srand(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice) m = RNG() @@ -630,3 +627,9 @@ end srand(m, seed) @test a == [rand(m) for _=1:100] end + +struct RandomStruct23964 end +@testset "error message when rand not defined for a type" begin + @test_throws ArgumentError rand(nothing) + @test_throws ArgumentError rand(RandomStruct23964()) +end