diff --git a/base/partr.jl b/base/partr.jl index 8c95e3668ee74a..6cf33cc0df5f72 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -19,9 +19,95 @@ const heap_d = UInt32(8) const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)] const heaps_lock = [SpinLock(), SpinLock()] +""" + cong(max::UInt32) + +Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0. +""" +cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check + + +""" + jl_rand_ptls(max::UInt32) + +Return a random UInt32 in the range `0:max-1` using the thread-local RNG +state. Max must be greater than 0. +""" +function jl_rand_ptls(max::UInt32) + ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls()) + rngseed = Base.unsafe_load(ptls, 2) + val, seed = rand_uniform_max_int32(max, rngseed) + Base.unsafe_store!(ptls, seed, 2) + return val % UInt32 +end + +# This implementation is based on OpenSSLs implementation of rand_uniform +# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 +# Comments are vendored from their implemantation as well. +# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143. + +# Essentially it boils down to incrementally generating a fixed point +# number on the interval [0, 1) and multiplying this number by the upper +# range limit. Once it is certain what the fractional part contributes to +# the integral part of the product, the algorithm has produced a definitive +# result. +""" + rand_uniform_max_int32(max::UInt32, seed::UInt64) + +Return a random UInt32 in the range `0:max-1` using the given seed. +Max must be greater than 0. +""" +function rand_uniform_max_int32(max::UInt32, seed::UInt64) + if max == UInt32(1) + return UInt32(0), seed + end -cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1) +# We are generating a fixed point number on the interval [0, 1). +# Multiplying this by the range gives us a number on [0, upper). +# The high word of the multiplication result represents the integral +# part we want. The lower word is the fractional part. We can early exit if +# if the fractional part is small enough that no carry from the next lower +# word can cause an overflow and carry into the integer part. This +# happens when the fractional part is bounded by 2^32 - upper which +# can be simplified to just -upper (as an unsigned integer). + seed = UInt64(69069) * seed + UInt64(362437) + prod = (UInt64(max)) * (seed % UInt32) # 64 bit product + i = unsafe_trunc(UInt32, prod >> 32) # integral part + f = unsafe_trunc(UInt32, (prod & 0xffffffff)) # fractional part + if (f <= (UInt32(1) + ~max)) # likely + return unsafe_trunc(UInt32, i), seed + end +# We're in the position where the carry from the next word *might* cause +# a carry to the integral part. The process here is to generate the next +# word, multiply it by the range and add that to the current word. If +# it overflows, the carry propagates to the integer part (return i+1). +# If it can no longer overflow regardless of further lower order bits, +# we are done (return i). If there is still a chance of overflow, we +# repeat the process with the next lower word. +# +# Each *bit* of randomness has a probability of one half of terminating +# this process, so each each word beyond the first has a probability +# of 2^-32 of not terminating the process. That is, we're extremely +# likely to stop very rapidly. + for _ in 1:10 + seed = UInt64(69069) * seed + UInt64(362437) + prod = (UInt64(max)) * (seed % UInt32) + f2 = unsafe_trunc(UInt32,prod >> 32) # extra fractional part + f *= f2 % UInt32 + if f < f2 + return i + UInt32(1), seed + end + if (f != 0xffffffff) #unlikely + return i, seed + end + f = prod & 0xffffffff % UInt32 + end +# If we get here, we've consumed 32 * max_followup_iterations + 32 bits +# with no firm decision, this gives a bias with probability < 2^-(32*n), +# which is likely acceptable. + return i, seed +end function multiq_sift_up(heap::taskheap, idx::Int32) while idx > Int32(1) diff --git a/src/julia_internal.h b/src/julia_internal.h index ab81cf18623a59..c235f916f5deca 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1306,6 +1306,9 @@ JL_DLLEXPORT size_t jl_maxrss(void); // congruential random number generator // for a small amount of thread-local randomness +//TODO: utilize https://github.com/openssl/openssl/blob/master/crypto/rand/rand_uniform.c#L13-L99 +// for better performance, it does however require making users expect a 32bit random number. + STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT { if (max < 2)