Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

on_after_backward should always run after backward #7924

Closed
carmocca opened this issue Jun 10, 2021 · 4 comments · Fixed by #8048
Closed

on_after_backward should always run after backward #7924

carmocca opened this issue Jun 10, 2021 · 4 comments · Fixed by #8048
Assignees
Labels
bug Something isn't working design Includes a design discussion good first issue Good for newcomers help wanted Open to be worked on
Milestone

Comments

@carmocca
Copy link
Contributor

🐛 Bug

If accumulate_grad_batches is enabled, we don't call on_after_backward until we step the optimizers

https://github.com/PyTorchLightning/pytorch-lightning/blob/d209b689796719d1ab4fcc8e1c26b8b57cd348c4/pytorch_lightning/trainer/training_loop.py#L757-L763

This means on_after_backward is acting like on_before_optimizer_step.

So we should add that and always run on_after_backward after backward.

@carmocca carmocca added bug Something isn't working help wanted Open to be worked on good first issue Good for newcomers design Includes a design discussion labels Jun 10, 2021
@carmocca carmocca added this to the v1.4 milestone Jun 10, 2021
@ddrevicky
Copy link
Contributor

I'll take a look at this.

@carmocca
Copy link
Contributor Author

Awesome! Ask away if you need any help

@ddrevicky
Copy link
Contributor

ddrevicky commented Jun 15, 2021

Hey @carmocca, would something as simple as this work?

# backward pass
if result is not None:
   with self.trainer.profiler.profile("backward"):
       self.backward(result, optimizer, opt_idx)

   self.on_after_backward(batch_idx, result.loss)
   
   # hook - call this hook only
   # when gradients have finished to accumulate
   if not self.should_accumulate():
       self.on_before_optimizer_step(batch_idx, result.loss)  # TODO: result.loss corresponds only to the last batch

   # check if loss or model weights are nan
   if self.trainer.terminate_on_nan:
       self._check_finite(result.loss)

A couple of questions for you:

  • Is this the right place to put it? Couldn't find anything more appropriate.

  • What should be passed to the on_before_optimizer_step hook? result.loss corresponds to only the last batch's loss, is that enough for the user? Or would they want the averaged loss across the batches being accumulated over? (I'm not sure about the use case for this hook)

  • pl_module.trainer.call_hook("on_after_backward") is being called in some of the precision plugins in the pre_optimizer_step if pl_module.automatic_optimization == False. As I understand it, that's because those plugins do not make the optimizer step in the pre_optimizer_step (that's why they return False as make_optimizer_step value.

  • Should on_before_optimizer_step be called in this case as well? I don't understand why the on_backward is being called explicitly there since it is already called in the lambda_closure which should be just the training_step_and_backward method. Seems to me like it would get called twice for each step. Unless if pl_module.automatic_optimization == False means that on_after_backward is not called in the closure. See for example: https://github.com/PyTorchLightning/pytorch-lightning/blob/971908a1aa644cf243719071c879421323b88888/pytorch_lightning/plugins/precision/native_amp.py#L86-L92

Also @awaelchli, I notice there is a training loop refactor in progress. Would this change affect your work or vice versa?

@carmocca
Copy link
Contributor Author

Is this the right place to put it? Couldn't find anything more appropriate.

I would put it right before this line:

https://github.com/PyTorchLightning/pytorch-lightning/blob/cdcc483e9b7a79de3e5a7ac9c1e9dfd12ab77f4f/pytorch_lightning/loops/training_batch_loop.py#L187

What should be passed to the on_before_optimizer_step hook?

Maybe just the batch idx, optimizer, and optimizer idx. Could also include some of the other arguments of optimizer_step but I don't think that's necessary right now.

result.loss corresponds to only the last batch's loss, is that enough for the user? Or would they want the averaged loss across the batches being accumulated over? (I'm not sure about the use case for this hook)

Note that for a hook with these attributes, we already have on_after_backward. If necessary, the users can check should_accumulate() themselves to mimic the old behaviour.

As I understand it, that's because those plugins do not make the optimizer step in the pre_optimizer_step (that's why they return False as make_optimizer_step value.

Yes. those need explicit calls.

Unless if pl_module.automatic_optimization == False means that on_after_backward is not called in the closure

That is the case, see:

https://github.com/PyTorchLightning/pytorch-lightning/blob/cdcc483e9b7a79de3e5a7ac9c1e9dfd12ab77f4f/pytorch_lightning/loops/training_batch_loop.py#L561

Regardless, maybe we shouldn't call it in that case as it won't get called when using manual optimization without AMP.
In those cases, it should be replaced with before_optimizer_step.

Would this change affect your work or vice versa?

Has already landed (the large training changes at least), so more like it impacts yours. But fixing the conflicts shouldn't be hard.

@edenlightning edenlightning modified the milestones: v1.4, v1.3.x Jul 1, 2021
@Borda Borda modified the milestones: v1.3.x, v1.4 Jul 6, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.3.x Jul 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working design Includes a design discussion good first issue Good for newcomers help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants