-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a loss scale optimizer #851
Conversation
This is the big missing piece we need for feature parity when running mixed precision training compared to tf.keras.
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #851 +/- ##
==========================================
+ Coverage 75.99% 76.09% +0.09%
==========================================
Files 328 329 +1
Lines 31099 31269 +170
Branches 6051 6083 +32
==========================================
+ Hits 23635 23793 +158
- Misses 5866 5874 +8
- Partials 1598 1602 +4
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
A few points of awkwardness/discussion:
|
Looks like there is some sort of device placement issue for [Edit: now working on all backends] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Addressed the initial round, though I may play with a test that deliberately triggers the underflow in trainer and asserts that variable updates appear. I don't think that should be too hard? But we will see. |
return loss * self.loss_scale_factor | ||
return loss | ||
|
||
def stateless_scale_loss(self, optimizer_variables, loss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this should take optimizer_variables
-- it only reads the value of one variable. I'm also not sure it should exist at all: you could just use a stateless scope when you call it. No strong opinion though. What are the trade offs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I like the stateless scope. That keeps the weirdness isolated to jax.
Overall the jax trainer is definitely a little wonky (pushing state through aux
for jax.value_and_grad
is kinda bleh), but the real grossness is the jax implementation of the loss scale optimizer itself. Basically all jax control flow cannot be written then same was as tf/torch because a stateless scope is insufficient, you still have to return all state from each control flow callback. Not something to fix this PR, and I'm not sure we can really do anything about it, but it does break the idea of a backend agnostic graph of ops
.
@keras_core_export( | ||
[ | ||
"keras_core.optimizers.LossScaleOptimizer", | ||
"keras_core.mixed_precision.LossScaleOptimizer", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this export path for backward compat? It's a bit awkward to have an optimizer in the mixed precision namespace. If it doesn't break too many people, I'd just drop it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, IIUC, this was the only name it was ever exposed as, so if we want the backward compat we need this alias.
Added a test end to end test only looking at variables updates across |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! LGTM
This is the big missing piece we need for feature parity when running mixed precision training compared to tf.keras.
Fixes #571