Skip to content

Commit

Permalink
rename State -> Sample
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Nov 18, 2017
1 parent 103e98c commit 1a6807a
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 97 deletions.
51 changes: 26 additions & 25 deletions base/random/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

## RandomDevice

# StateTypes(Union{X,Y,...}) = Union{StateType{X},StateType{Y},...}
StateTypes(U::Union) = Union{map(T->StateType{T}, Base.uniontypes(U))...}
const StateBoolBitInteger = StateTypes(Union{Bool, Base.BitInteger})
# SampleTypes(Union{X,Y,...}) = Union{SampleType{X},SampleType{Y},...}
SampleTypes(U::Union) = Union{map(T->SampleType{T}, Base.uniontypes(U))...}
const SampleBoolBitInteger = SampleTypes(Union{Bool, Base.BitInteger})

if Sys.iswindows()
struct RandomDevice <: AbstractRNG
Expand All @@ -13,7 +13,7 @@ if Sys.iswindows()
RandomDevice() = new(Vector{UInt128}(1))
end

function rand(rd::RandomDevice, st::StateBoolBitInteger)
function rand(rd::RandomDevice, st::SampleBoolBitInteger)
rand!(rd, rd.buffer)
@inbounds return rd.buffer[1] % st[]
end
Expand All @@ -26,19 +26,19 @@ else # !windows
new(open(unlimited ? "/dev/urandom" : "/dev/random"), unlimited)
end

rand(rd::RandomDevice, st::StateBoolBitInteger) = read( rd.file, st[])
rand(rd::RandomDevice, st::SampleBoolBitInteger) = read( rd.file, st[])
end # os-test

# NOTE: this can't be put in within the if-else block above
for T in (Bool, Base.BitInteger_types...)
if Sys.iswindows()
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T})
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::SampleType{$T})
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
A, sizeof(A))
A
end
else
@eval rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T}) = read!(rd.file, A)
@eval rand!(rd::RandomDevice, A::Array{$T}, ::SampleType{$T}) = read!(rd.file, A)
end
end

Expand All @@ -56,7 +56,7 @@ srand(rng::RandomDevice) = rng

### generation of floats

rand(r::RandomDevice, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[])
rand(r::RandomDevice, st::SampleTrivial{<:FloatInterval}) = rand_generic(r, st[])


## MersenneTwister
Expand Down Expand Up @@ -236,30 +236,31 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)

#### floats

rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval_64}) = (reserve_1(r); rand_inbounds(r, st[]))
rand(r::MersenneTwister, st::SampleTrivial{<:FloatInterval_64}) =
(reserve_1(r); rand_inbounds(r, st[]))

rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[])
rand(r::MersenneTwister, st::SampleTrivial{<:FloatInterval}) = rand_generic(r, st[])

#### integers

rand(r::MersenneTwister,
T::StateTypes(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
T::SampleTypes(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand_ui52_raw(r) % T[]

function rand(r::MersenneTwister, ::StateType{UInt64})
function rand(r::MersenneTwister, ::SampleType{UInt64})
reserve(r, 2)
rand_ui52_raw_inbounds(r) << 32 rand_ui52_raw_inbounds(r)
end

function rand(r::MersenneTwister, ::StateType{UInt128})
function rand(r::MersenneTwister, ::SampleType{UInt128})
reserve(r, 3)
xor(rand_ui52_raw_inbounds(r) % UInt128 << 96,
rand_ui52_raw_inbounds(r) % UInt128 << 48,
rand_ui52_raw_inbounds(r))
end

rand(r::MersenneTwister, ::StateType{Int64}) = reinterpret(Int64, rand(r, UInt64))
rand(r::MersenneTwister, ::StateType{Int128}) = reinterpret(Int128, rand(r, UInt128))
rand(r::MersenneTwister, ::SampleType{Int64}) = reinterpret(Int64, rand(r, UInt64))
rand(r::MersenneTwister, ::SampleType{Int128}) = reinterpret(Int128, rand(r, UInt128))

#### arrays of floats

Expand All @@ -285,7 +286,7 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
A
end

rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::StateTrivial{<:FloatInterval_64}) =
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::SampleTrivial{<:FloatInterval_64}) =
rand_AbstractArray_Float64!(r, A, length(A), I[])

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
Expand Down Expand Up @@ -325,7 +326,7 @@ function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
A
end

rand!(r::MersenneTwister, A::Array{Float64}, st::StateTrivial{<:FloatInterval_64}) =
rand!(r::MersenneTwister, A::Array{Float64}, st::SampleTrivial{<:FloatInterval_64}) =
_rand!(r, A, length(A), st[])

mask128(u::UInt128, ::Type{Float16}) =
Expand All @@ -335,7 +336,7 @@ mask128(u::UInt128, ::Type{Float32}) =
(u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000

for T in (Float16, Float32)
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{Close1Open2{$T}})
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SampleTrivial{Close1Open2{$T}})
n = length(A)
n128 = n * sizeof($T) ÷ 16
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
Expand Down Expand Up @@ -363,7 +364,7 @@ for T in (Float16, Float32)
A
end

@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{CloseOpen{$T}})
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SampleTrivial{CloseOpen{$T}})
rand!(r, A, Close1Open2($T))
I32 = one(Float32)
for i in eachindex(A)
Expand All @@ -375,7 +376,7 @@ end

#### arrays of integers

function rand!(r::MersenneTwister, A::Array{UInt128}, ::StateType{UInt128})
function rand!(r::MersenneTwister, A::Array{UInt128}, ::SampleType{UInt128})
n::Int=length(A)
# FIXME: This code is completely invalid!!!
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
Expand Down Expand Up @@ -404,7 +405,7 @@ end

for T in Base.BitInteger_types
T === UInt128 && continue
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateType{$T})
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SampleType{$T})
n = length(A)
n128 = n * sizeof($T) ÷ 16
# FIXME: This code is completely invalid!!!
Expand All @@ -426,7 +427,7 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
end

function rand(rng::MersenneTwister,
st::StateTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
st::SampleTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
r = st[]
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = last(r) % UInt64 - first(r) % UInt64
Expand All @@ -438,7 +439,7 @@ function rand(rng::MersenneTwister,
end

function rand(rng::MersenneTwister,
st::StateTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
st::SampleTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
r = st[]
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = (last(r)-first(r)) % UInt128
Expand All @@ -451,8 +452,8 @@ function rand(rng::MersenneTwister,
end

for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise
@eval State(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
StateTrivial(r)
@eval Sample(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
SampleTrivial(r)
end


Expand Down
Loading

0 comments on commit 1a6807a

Please sign in to comment.