You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been comparing some Flux models on the CPU and GPU and was surprised to find they were not giving similar predictions. It took me quite some time to nail this down to the problem with dropout given in the comment title and demonstrated below. Maybe I'm missing something, but it seems to me this is serious impediment to reproducibility using Flux.
using Flux
import Random.seed!
seed!(123);
data = [(rand(Float32, 5, 1), rand(Float32, 5)), ]
model = Flux.Chain(Flux.Dense(5, 2, identity),
Flux.Dropout(0.5),
Flux.Dense(2, 1))
data_gpu =gpu(data);
model_gpu =gpu(model)
loss(x, y) = Flux.mse(model(x), y)
loss_gpu(x, y) = Flux.mse(model_gpu(x), y)
optimiser = Flux.ADAM()
# cpu training:seed!(123);
Flux.train!(loss, Flux.params(model), data, optimiser)
rand() # 0.6739586945680673# gpu training: seed!(123);
Flux.train!(loss_gpu, Flux.params(model_gpu), data_gpu, optimiser)
rand() # 0.13672511011651545 <----------------------- should be the same as a `rand()` above
If one removes Dropout from the chain, the rand() calls return the same value, as expected.
julia>versioninfo()
Julia Version 1.5.2
Commit 539f3ce943 (2020-09-2323:17 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel Xeon Processor (Skylake, IBRS)
WORD_SIZE:64
LIBM: libopenlibm
LLVM: libLLVM-9.0.1 (ORCJIT, skylake-avx512)
(jl_4sVBIY) pkg> st
Status `/tmp/jl_4sVBIY/Project.toml`
[587475ba] Flux v0.11.1
The text was updated successfully, but these errors were encountered:
I don't recall an issue with the CPU's global RNG behaving differently depending on whether the model is on the GPU or not. @ablaom you could try redefining some inner method of e.g. MersenneTwister to have it trace where randomness is requested.
I have been comparing some Flux models on the CPU and GPU and was surprised to find they were not giving similar predictions. It took me quite some time to nail this down to the problem with dropout given in the comment title and demonstrated below. Maybe I'm missing something, but it seems to me this is serious impediment to reproducibility using Flux.
If one removes
Dropout
from the chain, therand()
calls return the same value, as expected.The text was updated successfully, but these errors were encountered: