You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With the new-style training, I think this should basically just work.
m16 = f16(m32) makes a low-precision copy of the model, you can use that to compute the gradient g16, and then update!(opt_state, m32, g16) will apply this change to the original model.
Although not all operation support Float16, e.g. I'm not sure about convolutions. Maybe there are other un-anticipated problems.
It would be super-nice to have an example of this, e.g. a model zoo page which uses it.
Motivation and description
Just wondering if there is a way to do mixed precision training in Flux?
Possible Implementation
No response
The text was updated successfully, but these errors were encountered: