Skip to content

Commit

Permalink
[Random] Add s4 field to Xoshiro type (#51332)
Browse files Browse the repository at this point in the history
This PR adds an optional field to the existing `Xoshiro` struct to be
able to faithfully copy the task-local RNG state.

Fixes #51255
Redo of #51271

Background context: #49110 added an additional state to the task-local
RNG. However, before this PR `copy(default_rng())` did not include this
extra state, causing subtle errors in `Test` where `copy(default_rng())`
is assumed to contain the full task-local RNG state.

(cherry picked from commit 41b41ab)
  • Loading branch information
nhz2 authored and nalimilan committed Nov 5, 2023
1 parent a00e2d4 commit 51ff820
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,37 @@ mutable struct Xoshiro <: AbstractRNG
s1::UInt64
s2::UInt64
s3::UInt64
s4::UInt64 # internal splitmix state

Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = new(s0, s1, s2, s3)
Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer, s4::Integer) = new(s0, s1, s2, s3, s4)
Xoshiro(s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64) = new(s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3)
Xoshiro(seed=nothing) = seed!(new(), seed)
end

function setstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = Xoshiro(UInt64(s0), UInt64(s1), UInt64(s2), UInt64(s3))

function setstate!(
x::Xoshiro,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64, # internal splitmix state
)
x.s0 = s0
x.s1 = s1
x.s2 = s2
x.s3 = s3
x.s4 = s4
x
end

copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3)
copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3, rng.s4)

function copy!(dst::Xoshiro, src::Xoshiro)
dst.s0, dst.s1, dst.s2, dst.s3 = src.s0, src.s1, src.s2, src.s3
dst.s0, dst.s1, dst.s2, dst.s3, dst.s4 = src.s0, src.s1, src.s2, src.s3, src.s4
dst
end

function ==(a::Xoshiro, b::Xoshiro)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 && a.s4 == b.s4
end

rng_native_52(::Xoshiro) = UInt64
Expand Down Expand Up @@ -116,7 +125,7 @@ rng_native_52(::TaskLocalRNG) = UInt64
function setstate!(
x::TaskLocalRNG,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state
s4::UInt64, # internal splitmix state
)
t = current_task()
t.rngState0 = s0
Expand Down Expand Up @@ -148,14 +157,20 @@ end
function seed!(rng::Union{TaskLocalRNG,Xoshiro})
# as we get good randomness from RandomDevice, we can skip hashing
rd = RandomDevice()
setstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64))
s0 = rand(rd, UInt64)
s1 = rand(rd, UInt64)
s2 = rand(rd, UInt64)
s3 = rand(rd, UInt64)
s4 = 1s0 + 3s1 + 5s2 + 7s3
setstate!(rng, s0, s1, s2, s3, s4)
end

function seed!(rng::Union{TaskLocalRNG,Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
c = SHA.SHA2_256_CTX()
SHA.update!(c, reinterpret(UInt8, seed))
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
setstate!(rng, s0, s1, s2, s3)
s4 = 1s0 + 3s1 + 5s2 + 7s3
setstate!(rng, s0, s1, s2, s3, s4)
end

seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))
Expand All @@ -178,24 +193,30 @@ end

function copy(rng::TaskLocalRNG)
t = current_task()
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3)
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
end

function copy!(dst::TaskLocalRNG, src::Xoshiro)
t = current_task()
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
setstate!(dst, src.s0, src.s1, src.s2, src.s3, src.s4)
return dst
end

function copy!(dst::Xoshiro, src::TaskLocalRNG)
t = current_task()
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
return dst
end

function ==(a::Xoshiro, b::TaskLocalRNG)
t = current_task()
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3
(
a.s0 == t.rngState0 &&
a.s1 == t.rngState1 &&
a.s2 == t.rngState2 &&
a.s3 == t.rngState3 &&
a.s4 == t.rngState4
)
end

==(a::TaskLocalRNG, b::Xoshiro) = b == a
Expand Down

0 comments on commit 51ff820

Please sign in to comment.