Skip to content

Commit

Permalink
MersenneTwister: more efficient integer generation with caching
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Dec 23, 2017
1 parent 058716e commit 4f78032
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 66 deletions.
17 changes: 15 additions & 2 deletions base/int.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,38 @@
# they are also used elsewhere where Int128/UInt128 support is separated out,
# such as in hashing2.jl

const BitSigned64_types = (Int8, Int16, Int32, Int64)
const BitUnsigned64_types = (UInt8, UInt16, UInt32, UInt64)
const BitSigned32_types = (Int8, Int16, Int32)
const BitUnsigned32_types = (UInt8, UInt16, UInt32)
const BitInteger32_types = (BitSigned32_types..., BitUnsigned32_types...)

const BitSigned64_types = (BitSigned32_types..., Int64)
const BitUnsigned64_types = (BitUnsigned32_types..., UInt64)
const BitInteger64_types = (BitSigned64_types..., BitUnsigned64_types...)

const BitSigned_types = (BitSigned64_types..., Int128)
const BitUnsigned_types = (BitUnsigned64_types..., UInt128)
const BitInteger_types = (BitSigned_types..., BitUnsigned_types...)

const BitSignedSmall_types = Int === Int64 ? ( Int8, Int16, Int32) : ( Int8, Int16)
const BitUnsignedSmall_types = Int === Int64 ? (UInt8, UInt16, UInt32) : (UInt8, UInt16)
const BitIntegerSmall_types = (BitSignedSmall_types..., BitUnsignedSmall_types...)

const BitSigned32 = Union{BitSigned32_types...}
const BitUnsigned32 = Union{BitUnsigned32_types...}
const BitInteger32 = Union{BitInteger32_types...}

const BitSigned64 = Union{BitSigned64_types...}
const BitUnsigned64 = Union{BitUnsigned64_types...}
const BitInteger64 = Union{BitInteger64_types...}

const BitSigned = Union{BitSigned_types...}
const BitUnsigned = Union{BitUnsigned_types...}
const BitInteger = Union{BitInteger_types...}

const BitSignedSmall = Union{BitSignedSmall_types...}
const BitUnsignedSmall = Union{BitUnsignedSmall_types...}
const BitIntegerSmall = Union{BitIntegerSmall_types...}

const BitSigned64T = Union{Type{Int8}, Type{Int16}, Type{Int32}, Type{Int64}}
const BitUnsigned64T = Union{Type{UInt8}, Type{UInt16}, Type{UInt32}, Type{UInt64}}

Expand Down
136 changes: 94 additions & 42 deletions base/random/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,33 @@ srand(rng::RandomDevice) = rng

## MersenneTwister

const MTCacheLength = dsfmt_get_min_array_size()
const MT_CACHE_F = dsfmt_get_min_array_size()
const MT_CACHE_I = 501 << 4

mutable struct MersenneTwister <: AbstractRNG
seed::Vector{UInt32}
state::DSFMT_state
vals::Vector{Float64}
idx::Int

function MersenneTwister(seed, state, vals, idx)
length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength ||
throw(DomainError((length(vals), idx),
"`length(vals)` and `idx` must be consistent with $MTCacheLength"))
new(seed, state, vals, idx)
ints::Vector{UInt128}
idxF::Int
idxI::Int

function MersenneTwister(seed, state, vals, ints, idxF, idxI)
length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
throw(DomainError((length(vals), idxF),
"`length(vals)` and `idxF` must be consistent with $MT_CACHE_F"))
length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
throw(DomainError((length(ints), idxI),
"`length(ints)` and `idxI` must be consistent with $MT_CACHE_I"))
new(seed, state, vals, ints, idxF, idxI)
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength)
MersenneTwister(seed, state,
Vector{Float64}(uninitialized, MT_CACHE_F),
Vector{UInt128}(uninitialized, MT_CACHE_I >> 4),
MT_CACHE_F, 0)

"""
MersenneTwister(seed)
Expand Down Expand Up @@ -120,27 +129,36 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
copy!(dst.state, src.state)
copyto!(dst.vals, src.vals)
dst.idx = src.idx
copyto!(dst.ints, src.ints)
dst.idxF = src.idxF
dst.idxI = src.idxI
dst
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), src.idx)
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI)


==(r1::MersenneTwister, r2::MersenneTwister) =
r1.seed == r2.seed && r1.state == r2.state && isequal(r1.vals, r2.vals) &&
r1.idx == r2.idx
r1.seed == r2.seed && r1.state == r2.state &&
isequal(r1.vals, r2.vals) &&
isequal(r1.ints, r2.ints) &&
r1.idxF == r2.idxF && r1.idxI == r2.idxI

hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.idx))
hash(r::MersenneTwister, h::UInt) =
foldr(hash, h, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI))


### low level API

mt_avail(r::MersenneTwister) = MTCacheLength - r.idx
mt_empty(r::MersenneTwister) = r.idx == MTCacheLength
mt_setfull!(r::MersenneTwister) = r.idx = 0
mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]
#### floats

mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF
mt_empty(r::MersenneTwister) = r.idxF == MT_CACHE_F
mt_setfull!(r::MersenneTwister) = r.idxF = 0
mt_setempty!(r::MersenneTwister) = r.idxF = MT_CACHE_F
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idxF+=1]

function gen_rand(r::MersenneTwister)
@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
Expand All @@ -149,9 +167,56 @@ end

reserve_1(r::MersenneTwister) = (mt_empty(r) && gen_rand(r); nothing)
# `reserve` allows one to call `rand_inbounds` n times
# precondition: n <= MTCacheLength
# precondition: n <= MT_CACHE_F
reserve(r::MersenneTwister, n::Int) = (mt_avail(r) < n && gen_rand(r); nothing)

#### ints

logsizeof(::Type{<:Union{Bool,Int8,UInt8}}) = 0
logsizeof(::Type{<:Union{Int16,UInt16}}) = 1
logsizeof(::Type{<:Union{Int32,UInt32}}) = 2
logsizeof(::Type{<:Union{Int64,UInt64}}) = 3
logsizeof(::Type{<:Union{Int128,UInt128}}) = 4

idxmask(::Type{<:Union{Bool,Int8,UInt8}}) = 15
idxmask(::Type{<:Union{Int16,UInt16}}) = 7
idxmask(::Type{<:Union{Int32,UInt32}}) = 3
idxmask(::Type{<:Union{Int64,UInt64}}) = 1
idxmask(::Type{<:Union{Int128,UInt128}}) = 0


mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} =
r.idxI >> logsizeof(T)

function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger})
rand!(r, r.ints)
r.idxI = MT_CACHE_I
end

mt_setempty!(r::MersenneTwister, ::Type{<:BitInteger}) = r.idxI = 0

function reserve1(r::MersenneTwister, ::Type{T}) where T<:BitInteger
r.idxI < sizeof(T) && mt_setfull!(r, T)
nothing
end

function mt_pop!(r::MersenneTwister, ::Type{T}) where T<:BitInteger
reserve1(r, T)
r.idxI -= sizeof(T)
i = r.idxI
@inbounds x128 = r.ints[1 + i >> 4]
i128 = (i >> logsizeof(T)) & idxmask(T) # 0-based "indice" in x128
(x128 >> (i128 * (sizeof(T) << 3))) % T
end

# not necessary, but very slightly more efficient
function mt_pop!(r::MersenneTwister, ::Type{T}) where {T<:Union{Int128,UInt128}}
reserve1(r, T)
@inbounds res = r.ints[r.idxI >> 4]
r.idxI -= 16
res % T
end


### seeding

Expand Down Expand Up @@ -193,6 +258,9 @@ function srand(r::MersenneTwister, seed::Vector{UInt32})
copyto!(resize!(r.seed, length(seed)), seed)
dsfmt_init_by_array(r.state, r.seed)
mt_setempty!(r)
fill!(r.vals, 0.0) # not strictly necessary, but why not, makes comparing two MT easier
mt_setempty!(r, UInt128)
fill!(r.ints, 0)
return r
end

Expand Down Expand Up @@ -243,24 +311,8 @@ rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) =

#### integers

rand(r::MersenneTwister,
T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand(r, UInt52Raw()) % T[]

function rand(r::MersenneTwister, ::SamplerType{UInt64})
reserve(r, 2)
rand_inbounds(r, UInt52Raw()) << 32 rand_inbounds(r, UInt52Raw())
end

function rand(r::MersenneTwister, ::SamplerType{UInt128})
reserve(r, 3)
xor(rand_inbounds(r, UInt52Raw(UInt128)) << 96,
rand_inbounds(r, UInt52Raw(UInt128)) << 48,
rand_inbounds(r, UInt52Raw(UInt128)))
end

rand(r::MersenneTwister, ::SamplerType{Int64}) = rand(r, UInt64) % Int64
rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128
rand(r::MersenneTwister, T::SamplerUnion(BitInteger)) = mt_pop!(r, T[])
rand(r::MersenneTwister, ::SamplerType{Bool}) = rand(r, UInt8) % Bool

#### arrays of floats

Expand Down Expand Up @@ -315,13 +367,13 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
mt_avail(r) == 0 && gen_rand(r)
# from now on, at most one call to gen_rand(r) will be necessary
m = min(n, mt_avail(r))
@gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idx+1), m)
@gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idxF+1), m)
if m == n
r.idx += m
r.idxF += m
else # m < n
gen_rand(r)
@gc_preserve r unsafe_copyto!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m)
r.idx = n-m
r.idxF = n-m
end
if I isa CloseOpen
for i=1:n
Expand Down Expand Up @@ -470,7 +522,7 @@ end

#### from a range

for T in (Bool, BitInteger_types...) # eval because of ambiguity otherwise
for T in BitInteger_types # eval because of ambiguity otherwise
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
SamplerRangeFast(r)
end
Expand Down
54 changes: 41 additions & 13 deletions base/random/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# Note that the 1) is automated when the sampler is not intended to carry information,
# i.e. the default fall-backs SamplerType and SamplerTrivial are used.


## from types: rand(::Type, [dims...])

### random floats
Expand Down Expand Up @@ -101,6 +100,8 @@ rand(rng::AbstractRNG, sp::SamplerBigFloat{T}) where {T<:FloatInterval{BigFloat}

### random integers

#### UniformBits

rand(r::AbstractRNG, ::SamplerTrivial{UInt10Raw{UInt16}}) = rand(r, UInt16)
rand(r::AbstractRNG, ::SamplerTrivial{UInt23Raw{UInt32}}) = rand(r, UInt32)

Expand All @@ -111,7 +112,7 @@ _rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Ope
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)

rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) =
rand(r, UInt52Raw(UInt128)) << 52 rand_inbounds(r, UInt52Raw(UInt128))
rand(r, UInt52Raw(UInt128)) << 52 rand(r, UInt52Raw(UInt128))

rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
Expand All @@ -121,6 +122,32 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())
rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} =
rand(r, uint_default(sp[])) % T

#### BitInteger

# rand_generic methods are intended to help RNG implementors with common operations
# we don't call them simply `rand` as this can easily contribute to create
# amibuities with user-side methods (forcing the user to resort to @eval)

rand_generic(r::AbstractRNG, T::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}) =
rand(r, UInt52Raw()) % T[]

rand_generic(r::AbstractRNG, ::Type{UInt64}) =
rand(r, UInt52Raw()) << 32 rand(r, UInt52Raw())

rand_generic(r::AbstractRNG, ::Type{UInt128}) = _rand128(r, rng_native_52(r))

_rand128(r::AbstractRNG, ::Type{UInt64}) =
((rand(r, UInt64) % UInt128) << 64) rand(r, UInt64)

function _rand128(r::AbstractRNG, ::Type{Float64})
xor(rand(r, UInt52Raw(UInt128)) << 96,
rand(r, UInt52Raw(UInt128)) << 48,
rand(r, UInt52Raw(UInt128)))
end

rand_generic(r::AbstractRNG, ::Type{Int128}) = rand(r, UInt128) % Int128
rand_generic(r::AbstractRNG, ::Type{Int64}) = rand(r, UInt64) % Int64

### random complex numbers

rand(r::AbstractRNG, ::SamplerType{Complex{T}}) where {T<:Real} =
Expand Down Expand Up @@ -149,33 +176,34 @@ end

#### helper functions

uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
uint_sup(::Type{<:Base.BitInteger32}) = UInt32
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128

#### Fast

struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler
a::T # first element of the range
bw::UInt # bit width
m::U # range length - 1
mask::U # mask generated values before threshold rejection
end

SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:BitInteger =
SamplerRangeFast(r, uint_sup(T))

function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = (last(r) - first(r)) % U
m = (last(r) - first(r)) % unsigned(T) % U # % unsigned(T) to not propagate sign bit
bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width
mask = (1 % U << bw) - (1 % U)
SamplerRangeFast{U,T}(first(r), bw, m, mask)
end

function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
x = rand(rng, LessThan(m, Masked(mask, uniform(UInt32))))
# below, we don't use UInt32, to get reproducible values, whether Int is Int64 or Int32
x = rand(rng, LessThan(m, Masked(mask, UInt52Raw(UInt32))))
(x + a % UInt32) % T
end

Expand Down Expand Up @@ -215,21 +243,21 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
div(sup, k + (k == 0))*k - one(k)

struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler
a::T # first element of the range
bw::Int # bit width
k::U # range length or zero for full range
u::U # rejection threshold
end


SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:BitInteger =
SamplerRangeInt(r, uint_sup(T))

function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
isempty(r) && throw(ArgumentError("range must be non-empty"))
a = first(r)
m = (last(r) - first(r)) % U
m = (last(r) - first(r)) % unsigned(T) % U
k = m + one(U)
bw = (sizeof(U) << 3 - leading_zeros(m)) % Int
mult = if U === UInt32
Expand All @@ -247,11 +275,11 @@ function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
end

Sampler(::AbstractRNG, r::AbstractUnitRange{T},
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
::Repetition) where {T<:BitInteger} = SamplerRangeInt(r)

rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T

rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:BitInteger} =
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, UInt52Raw(UInt32))), sp.k)) % T

# this function uses 52 bit entropy for small ranges of length <= 2^52
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger
Expand Down
Loading

0 comments on commit 4f78032

Please sign in to comment.