Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a loss scale optimizer #851

Merged
merged 6 commits into from
Sep 9, 2023
Merged

Add a loss scale optimizer #851

merged 6 commits into from
Sep 9, 2023

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Sep 7, 2023

This is the big missing piece we need for feature parity when running mixed precision training compared to tf.keras.

Fixes #571

This is the big missing piece we need for feature parity when running
mixed precision training compared to tf.keras.
@codecov
Copy link

codecov bot commented Sep 7, 2023

Codecov Report

Patch coverage: 86.30% and project coverage change: +0.09% 🎉

Comparison is base (ab45558) 75.99% compared to head (fd58cfb) 76.09%.
Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #851      +/-   ##
==========================================
+ Coverage   75.99%   76.09%   +0.09%     
==========================================
  Files         328      329       +1     
  Lines       31099    31269     +170     
  Branches     6051     6083      +32     
==========================================
+ Hits        23635    23793     +158     
- Misses       5866     5874       +8     
- Partials     1598     1602       +4     
Flag Coverage Δ
keras_core 75.99% <86.30%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/tensorflow/optimizer.py 90.56% <ø> (-0.35%) ⬇️
keras_core/callbacks/tensorboard.py 83.84% <ø> (+0.11%) ⬆️
keras_core/ops/core.py 74.09% <0.00%> (-4.05%) ⬇️
keras_core/backend/torch/trainer.py 89.56% <50.00%> (-0.35%) ⬇️
keras_core/backend/tensorflow/trainer.py 78.53% <66.66%> (-0.14%) ⬇️
keras_core/optimizers/base_optimizer.py 74.76% <90.00%> (+0.56%) ⬆️
keras_core/optimizers/loss_scale_optimizer.py 94.28% <94.28%> (ø)
keras_core/backend/jax/numpy.py 97.69% <100.00%> (+0.01%) ⬆️
keras_core/backend/jax/trainer.py 96.08% <100.00%> (+0.11%) ⬆️
keras_core/optimizers/__init__.py 92.10% <100.00%> (+0.21%) ⬆️
... and 1 more

... and 5 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mattdangerw
Copy link
Member Author

A few points of awkwardness/discussion:

  1. ops.cond needs to be stateless for jax. Autoscaling has two cond branches with variables updates in all branches. To do this I overrode stateless_apply separately and had to do a lot of StatelessScopes. This feels very verbose and awkward, but I wasn't able to think of a great way around it. Suggestions welcome!
  2. We want learning_rate to proxy the inner optimizer learning rate. I overrode the learning_rate property to do so. But the base optimizer still has to be created with a learning rate, so I ended up just passing a zero valued variable which never gets used. This feels awkward and a bit confusing.
  3. Because the loss scale optimizer needs a variable to scale the loss, and we don't sync jax state except on epoch boundaries, I made a scale_loss and stateless_scale_loss. This is probably fine, just feels like a lot of code.

@mattdangerw mattdangerw requested a review from fchollet September 7, 2023 17:03
@mattdangerw
Copy link
Member Author

mattdangerw commented Sep 7, 2023

Looks like there is some sort of device placement issue for tensorflow GPU that isn't picked up in our CPU only tests? Will poke around. Confirmed torch/jax are working.

[Edit: now working on all backends]

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras_core/backend/tensorflow/optimizer.py Show resolved Hide resolved
keras_core/optimizers/loss_scale_optimizer.py Outdated Show resolved Hide resolved
keras_core/trainers/trainer.py Outdated Show resolved Hide resolved
keras_core/trainers/trainer.py Show resolved Hide resolved
@mattdangerw
Copy link
Member Author

Addressed the initial round, though I may play with a test that deliberately triggers the underflow in trainer and asserts that variable updates appear. I don't think that should be too hard? But we will see.

return loss * self.loss_scale_factor
return loss

def stateless_scale_loss(self, optimizer_variables, loss):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this should take optimizer_variables -- it only reads the value of one variable. I'm also not sure it should exist at all: you could just use a stateless scope when you call it. No strong opinion though. What are the trade offs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I like the stateless scope. That keeps the weirdness isolated to jax.

Overall the jax trainer is definitely a little wonky (pushing state through aux for jax.value_and_grad is kinda bleh), but the real grossness is the jax implementation of the loss scale optimizer itself. Basically all jax control flow cannot be written then same was as tf/torch because a stateless scope is insufficient, you still have to return all state from each control flow callback. Not something to fix this PR, and I'm not sure we can really do anything about it, but it does break the idea of a backend agnostic graph of ops.

keras_core/optimizers/loss_scale_optimizer_test.py Outdated Show resolved Hide resolved
@keras_core_export(
[
"keras_core.optimizers.LossScaleOptimizer",
"keras_core.mixed_precision.LossScaleOptimizer",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this export path for backward compat? It's a bit awkward to have an optimizer in the mixed precision namespace. If it doesn't break too many people, I'd just drop it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, IIUC, this was the only name it was ever exposed as, so if we want the backward compat we need this alias.

keras_core/backend/tensorflow/optimizer.py Show resolved Hide resolved
@mattdangerw
Copy link
Member Author

Added a test end to end test only looking at variables updates across fit(), which will hopefully keep us from accidentally breaking this.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! LGTM

@fchollet fchollet merged commit d0b53fd into keras-team:main Sep 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port the loss scale optimizer to keras-core
2 participants