Skip to content

Commit

Permalink
Random: better handling of the "global seed" (using TLS) (#51526)
Browse files Browse the repository at this point in the history
We maintain a "global seed" for this feature of `@testset`:

> Before the execution of the body of a @testset, there is an implicit
call to Random.seed!(seed)
where seed is the current seed of the global RNG. Moreover, after the
execution of the body, the state of the global RNG is restored to what
it was before the @testset. This is meant to ease reproducibility in
case of failure, and to allow seamless re-arrangements of @testsets
regardless of their side-effect on the global RNG state.

But since we don't use `MersenneTwister` as the "global RNG" anymore, we
need to maintain a separate "global seed" object. So far we literally
used a global object `Random.GLOBAL_SEED` storing the original seed, but
it's not robust when multi-tasking is involved: e.g.
```julia
seed!(0)
x = rand()
seed!(0)
@sync begin
    @async @testset "A" begin
        seed!(1) # reset GLOBAL_SEED to V2
        sleep(2)
    end # reset GLOBAL_SEED to its original value V1
    sleep(0.5)
    @async @testset "B" begin
        # here seed!(2) above has already been called
        # so @testset B recorded value V2 as the "original" value of GLOBAL_SEED
        seed!(2)
        sleep(2)
        # here @testset A already finished
    end # reset GLOBAL_SEED to the wrong original value V2
end
@testset "main task" begin
    # async tests didn't mutate this task's global seed
    @test x == rand() # fails!
end
```

So we store here a "global seed" in `task_local_storage()`, which is set
when `seed!()` is invoked without an explicit RNG, and defaults to
`Random.GLOBAL_SEED`, which is set only once when `Random` is loaded.
And instead of actually storing a seed, we store a copy of the RNG
state.

This is still not ideal, in that at the beginning of `@testset "A"` or
`@testset "B"`, we can't do `@test x == rand()`, because these are in
separate tasks, so the global seed defaults to `Random.GLOBAL_SEED`, and
not to the global seed of the parent's task; there might be a nice way
to handle that, but at least different tasks don't corrupt each-other's
seeds.
  • Loading branch information
rfourquet authored Oct 7, 2023
1 parent fee7551 commit 0296599
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 40 deletions.
29 changes: 17 additions & 12 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,23 +387,28 @@ const GLOBAL_RNG = default_rng()
# the following feature of `@testset`:
# > Before the execution of the body of a `@testset`, there is an implicit
# > call to `Random.seed!(seed)` where `seed` is the current seed of the global RNG.
# But the global RNG is now TaskLocalRNG() and doesn't store its seed; in order to not break `@testset`, we now
# store the seed used in a call like `seed!(seed)` *without* an explicit RNG in `GLOBAL_SEED`; the wording of the
# feature above was sufficiently unprecise (e.g. what exactly is the "global RNG"?) that this solution seems fine
GLOBAL_SEED = 0
# only the Test module is allowed to use this function!
set_global_seed!(seed) = global GLOBAL_SEED = seed

# seed the "global" RNG
# But the global RNG is now `TaskLocalRNG()` and doesn't store its seed; in order to not break `@testset`,
# in a call like `seed!(seed)` *without* an explicit RNG, we now store the state of `TaskLocalRNG()` in
# `task_local_storage()`

# GLOBAL_SEED is used as a fall-back when no tls seed is found
# only `Random.__init__` is allowed to set it
const GLOBAL_SEED = Xoshiro(0, 0, 0, 0, 0)

get_tls_seed() = get!(() -> copy(GLOBAL_SEED), task_local_storage(),
:__RANDOM_GLOBAL_RNG_SEED_uBlmfA8ZS__)::Xoshiro

# seed the default RNG
function seed!(seed=nothing)
# the seed is not left as `nothing`, as storing `nothing` as the global seed wouldn't lead to reproducible streams
seed = @something seed rand(RandomDevice(), UInt128)
set_global_seed!(seed)
seed!(default_rng(), seed)
copy!(get_tls_seed(), default_rng())
default_rng()
end

function __init__()
seed!()
# do not call no-arg `seed!()` to not update `task_local_storage()` unnecessarily at startup
seed!(default_rng())
copy!(GLOBAL_SEED, TaskLocalRNG())
ccall(:jl_gc_init_finalizer_rng_state, Cvoid, ())
end

Expand Down
6 changes: 3 additions & 3 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,13 +718,13 @@ end

for seed=seeds
Random.seed!(seed)
@test Random.GLOBAL_SEED === seed
@test Random.get_tls_seed() == default_rng()
end

for ii = 1:8
iseven(ii) ? Random.seed!(nothing) : Random.seed!()
push!(seeds, Random.GLOBAL_SEED)
@test Random.GLOBAL_SEED isa UInt128 # could change, but must not be nothing
push!(seeds, copy(Random.get_tls_seed()))
@test Random.get_tls_seed() isa Xoshiro # could change, but must not be nothing
end
@test allunique(seeds)
end
Expand Down
28 changes: 12 additions & 16 deletions stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1589,11 +1589,11 @@ function testset_beginend_call(args, tests, source)
# we reproduce the logic of guardseed, but this function
# cannot be used as it changes slightly the semantic of @testset,
# by wrapping the body in a function
local oldrng = copy(default_rng())
local oldseed = Random.GLOBAL_SEED
local default_rng_orig = copy(default_rng())
local tls_seed_orig = copy(Random.get_tls_seed())
try
# default RNG is re-seeded with its own seed to ease reproduce a failed test
Random.seed!(Random.GLOBAL_SEED)
# default RNG is reset to its state from last `seed!()` to ease reproduce a failed test
copy!(Random.default_rng(), tls_seed_orig)
let
$(esc(tests))
end
Expand All @@ -1608,8 +1608,8 @@ function testset_beginend_call(args, tests, source)
record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source))))
end
finally
copy!(default_rng(), oldrng)
Random.set_global_seed!(oldseed)
copy!(default_rng(), default_rng_orig)
copy!(Random.get_tls_seed(), tls_seed_orig)
pop_testset()
ret = finish(ts)
end
Expand Down Expand Up @@ -1674,10 +1674,7 @@ function testset_forloop(args, testloop, source)
finish_errored = true
push!(arr, finish(ts))
finish_errored = false

# it's 1000 times faster to copy from tmprng rather than calling Random.seed!
copy!(default_rng(), tmprng)

copy!(default_rng(), tls_seed_orig)
end
ts = if ($testsettype === $DefaultTestSet) && $(isa(source, LineNumberNode))
$(testsettype)($desc; source=$(QuoteNode(source.file)), $options...)
Expand All @@ -1703,10 +1700,9 @@ function testset_forloop(args, testloop, source)
local first_iteration = true
local ts
local finish_errored = false
local oldrng = copy(default_rng())
local oldseed = Random.GLOBAL_SEED
Random.seed!(Random.GLOBAL_SEED)
local tmprng = copy(default_rng())
local default_rng_orig = copy(default_rng())
local tls_seed_orig = copy(Random.get_tls_seed())
copy!(Random.default_rng(), tls_seed_orig)
try
let
$(Expr(:for, Expr(:block, [esc(v) for v in loopvars]...), blk))
Expand All @@ -1717,8 +1713,8 @@ function testset_forloop(args, testloop, source)
pop_testset()
push!(arr, finish(ts))
end
copy!(default_rng(), oldrng)
Random.set_global_seed!(oldseed)
copy!(default_rng(), default_rng_orig)
copy!(Random.get_tls_seed(), tls_seed_orig)
end
arr
end
Expand Down
70 changes: 61 additions & 9 deletions stdlib/Test/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ end
# i.e. it behaves as if it was wrapped in a `guardseed(GLOBAL_SEED)` block
seed = rand(UInt128)
Random.seed!(seed)
seeded_state = copy(Random.default_rng())
a = rand()
@testset begin
# global RNG must re-seeded at the beginning of @testset
Expand All @@ -1043,31 +1044,82 @@ end
# the @testset's above must have no consequence for rand() below
b = rand()
Random.seed!(seed)
@test Random.default_rng() == seeded_state
@test a == rand()
@test b == rand()

# Even when seed!() is called within a testset A, subsequent testsets
# should start with the same "global RNG state" as what A started with,
# such that the test `refvalue == rand(Int)` below succeeds.
# Currently, this means that Random.GLOBAL_SEED has to be restored,
# Currently, this means that `Random.get_tls_seed()` has to be restored,
# in addition to the state of Random.default_rng().
GLOBAL_SEED_orig = Random.GLOBAL_SEED
tls_seed_orig = copy(Random.get_tls_seed())
local refvalue
@testset "GLOBAL_SEED is also preserved (setup)" begin
@test GLOBAL_SEED_orig == Random.GLOBAL_SEED
@testset "TLS seed is also preserved (setup)" begin
@test tls_seed_orig == Random.get_tls_seed()
refvalue = rand(Int)
Random.seed!()
@test GLOBAL_SEED_orig != Random.GLOBAL_SEED
@test tls_seed_orig != Random.get_tls_seed()
end
@test GLOBAL_SEED_orig == Random.GLOBAL_SEED
@testset "GLOBAL_SEED is also preserved (forloop)" for _=1:3
@test tls_seed_orig == Random.get_tls_seed()
@testset "TLS seed is also preserved (forloop)" for _=1:3
@test refvalue == rand(Int)
Random.seed!()
end
@test GLOBAL_SEED_orig == Random.GLOBAL_SEED
@testset "GLOBAL_SEED is also preserved (beginend)" begin
@test tls_seed_orig == Random.get_tls_seed()
@testset "TLS seed is also preserved (beginend)" begin
@test refvalue == rand(Int)
end

# @testset below is not compatible with e.g. v1.9, but it still fails there (at "main task")
# when deleting lines using get_tls_seed() or GLOBAL_SEED
@testset "TLS seed and concurrency" begin
# Even with multi-tasking, the TLS seed must stay consistent: the default_rng() state
# is reset to the "global seed" at the beginning, and the "global seed" is reset to what
# it was at the end of the testset; make sure that distinct tasks don't see the mutation
# of this "global seed" (iow, it's task-local)
seed = rand(UInt128)
Random.seed!(seed)
seeded_state = copy(Random.default_rng())
a = rand()

ch = Channel{Nothing}()
@sync begin
@async begin
@testset "task 1" begin
# tick 1
# this task didn't call seed! explicitly (yet), so its TaskLocalRNG() should have been
# reset to `Random.GLOBAL_SEED` at the beginning of `@testset`
@test Random.GLOBAL_SEED == Random.default_rng()
Random.seed!()
put!(ch, nothing) # tick 1 -> tick 2
take!(ch) # tick 3
end
put!(ch, nothing) # tick 3 -> tick 4
end
@async begin
take!(ch) # tick 2
# @testset below will record the current TLS "seed" and reset default_rng() to
# this value;
# it must not be affected by the fact that "task 1" called `seed!()` first
@test Random.get_tls_seed() == Random.GLOBAL_SEED

@testset "task 2" begin
@test Random.GLOBAL_SEED == Random.default_rng()
Random.seed!()
put!(ch, nothing) # tick 2 -> tick 3
take!(ch) # tick 4
end
# when `@testset` of task 2 finishes, which is after `@testset` from task 1,
# it resets `get_tls_seed()` to what it was before starting:
@test Random.get_tls_seed() == Random.GLOBAL_SEED
end
end
@testset "main task" begin
@test Random.default_rng() == seeded_state
@test a == rand()
end
end
end

@testset "InterruptExceptions #21043" begin
Expand Down

0 comments on commit 0296599

Please sign in to comment.