-
-
Notifications
You must be signed in to change notification settings - Fork 607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
train!
using Metal and stateful optimizers fails
#2310
Comments
We should be doing calculations in Float32, so either some intermediate ones aren't or Float64 gradients are somehow being passed to the optimizers. Hence I've moved this to Optimisers.jl. I would recommend first checking the gradients to make sure all are Float32. If you've confirmed that, then it should be straightforward to create a MWE using a single array and Optimisers.jl (no Flux), which should make it easier for us to investigate. |
@ToucheSir Thanks for you reply. > Flux.Optimise.Nesterov()
Nesterov(0.001, 0.9, IdDict{Any, Any}()) with > Optimisers.Nesterov()
Nesterov{Float32}(0.001f0, 0.9f0) fixes the problem. It seems in Optimisers, the hyper-parameters are templated, while in Flux they are explicitly typed as Float64. So it's a Flux bug, but I guess you're going to move to Optimisers at some point? This is also contradicted in Flux documentation:
|
We already did. If you use Flux per the tutorials, everything should run on Optimisers.jl under the hood by default. https://fluxml.ai/Flux.jl/stable/training/optimisers/#man-optimisers touches on this. You should never have to touch |
Great! One should, however, use explicitly
Also, using Flux, Optimisers
Adam() warns You may want to consider making this more clear, that one should explicitly use Thanks! |
You shouldn't have to use it explicitly, the translation layer should take care of it. In this case it appears the translation layer isn't converting hyperparams from Float64 -> Float32, so this remains an issue. It should be fixed by FluxML/Optimisers.jl#151, so leaving this open until that's released. |
Using the "Metal" GPU backend and
Flux.Optimise.train!
with anything else thanFlux.Optimise.Descent
fails (triedNesterov
andAdam
). Seems like stateful optimizers implicitly work withFloat64
. I'm not sure if this is a bug or expected behaviour due to the experimental nature of Metal backend.Thanks!
Minimal example to reproduce:
Backtrace when using `Nesterov`
Backtrace when using `Adam`
Julia version: 1.9.2
Flux version: 0.14.2
Metal version: 0.5.0
The text was updated successfully, but these errors were encountered: