-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support for native amp #1561
support for native amp #1561
Conversation
since which version is amp in pytorch native? |
1.6. but we don’t need to explicitly check. we can test the properties as i did |
saved_state = scaler.state_dict() |
pytorch_lightning/core/hooks.py
Outdated
@@ -138,11 +138,20 @@ def backward(self, use_amp, loss, optimizer): | |||
else: | |||
loss.backward() | |||
|
|||
.. note:: with PyTorch 1.6+ + precision=16 + multiple optimizers, set .backward(retrain_graph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need this note.
The example is misleading, I guess. The retain_graph=True
bit has nothing to do with Amp, it's only present because both losses interleave outputs from multiple models. Both backward passes use the same model graphs, so the first backward() must not tear them down. retain_graph=True
would be necessary with or without Amp. That's unclear and maybe I should either change the example snippet so retain_graph=True
is not needed, or add a comment clarifying that retain_graph=True
there is not Amp-related.
pytorch_lightning/core/hooks.py
Outdated
return | ||
|
||
if self.trainer.use_native_amp: | ||
# don't forget to retain graph on backward with multiple optimizers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also remove this comment, see https://github.com/PyTorchLightning/pytorch-lightning/pull/1561/files#r413230111.
This pull request is now in conflict... :( |
@Borda these tests are failing bc amp is not installed... did we remove amp? |
@@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool): | |||
if on_gpu: | |||
model.cuda(self.root_gpu) | |||
|
|||
# restore amp scaling | |||
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcarilli sanity check this loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good if you fix the saving https://github.com/PyTorchLightning/pytorch-lightning/pull/1561/files#r413418705
Like saving, loading should occur either at the very beginning of an iteration (before any training-related scaler
calls for that iteration) or at the end of an iteration, after scaler.update()
. It doesn't make a lot of sense to load state dicts at the end of an iteration, but if the saved state originated from a scaler.state_dict()
call at the end of, say, iteration 1000 (i.e. after iteration 1000's call to scaler.update()
), then it's ok to call load_state_dict
at the beginning of iteration 1001 to resume.
@@ -316,6 +320,10 @@ def dump_checkpoint(self): | |||
|
|||
checkpoint['state_dict'] = model.state_dict() | |||
|
|||
# restore native amp scaling | |||
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcarilli sanity check this saving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state_dict
is a method, as for modules and optimizers, so checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
is what you want.
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict
would stash the bound-method object itself :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also you should make sure state_dict()
is retrieved either at the very beginning of an iteration (before any scaler
method calls) or at the very end (after scaler.update()
), and that the model and optimizer state dicts are saved at that same spot.
I can't tell from these lines alone if the calling code occurs at a spot that obeys those criteria.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i thought it was a property haha, but i guess it's consistent with the other state_dict() calls haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol i see. it's consistent with the rest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing to consider is that with torch.cuda.amp
, it's permissible to
- load a checkpoint from a model + optimizer not trained with Amp, and resume training with Amp enabled, or
- load a checkpoint from a model + optimizer trained with Amp, and resume training without Amp.
I think your if
criteria are flexible enough that both those cases can happen naturally with the appropriate user args but I'm not sure just from looking at it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this code works.
Case 1: Train with amp, load amp
works fine
case 2: Train amp, load and not use amp
in this case, lightning loads the amp state but amp is disabled so user doesn't use it at all
case 3: train regular, resume regular
works fine
case 4: train regular, resume with amp
in this case the checkpoint has no amp state and model starts normal but on amp.
@@ -316,6 +320,10 @@ def dump_checkpoint(self): | |||
|
|||
checkpoint['state_dict'] = model.state_dict() | |||
|
|||
# restore native amp scaling | |||
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint: | |||
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
probably, unfortunately, it happened here with Horovoed #1529 (comment) |
This pull request is now in conflict... :( |
Codecov Report
@@ Coverage Diff @@
## master #1561 +/- ##
======================================
- Coverage 89% 88% -0%
======================================
Files 68 68
Lines 3913 3955 +42
======================================
+ Hits 3473 3496 +23
- Misses 440 459 +19 |
Saving was introduced in Lightning-AI#1561.
Saving was introduced in #1561.
Fixes #1336
Fixes #1337
@mcarilli mind taking a look?
Issue 1
We have a slight issue with the DP API...
@ethanwharris suggested a way around this which we have in the PR
Issue 2
How do we save the state of the scaling factor to resume training?
@mcarilli