Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Ensure we set the eval/train flag correctly on accelerator model #6877

Merged
merged 10 commits into from
Apr 8, 2021

Conversation

SeanNaren
Copy link
Contributor

@SeanNaren SeanNaren commented Apr 7, 2021

What does this PR do?

Fixes #6876.

This PR is ready, with more information underneath.

We also require facebookresearch/fairscale#587 to be merged and included in next release, but since we do not rely on the upstream release yet, this fix isn't going to change any FairScale related tests on CI. This will probably also block #6152 as we move from our own fork to the latest FairScale release.

In terms of testing, once #6152 is merged to deprecate the current Pipe implementation, this fix will come to fruition as the FairScale version will be reliant on the upstream release.

Overall however, I notice in many places we call train eval on just the lightning module, not the accelerator wrapped model. In the case of custom implementations like SDP, this is an issue. I swapped this to using self.model which should recursively set the lightning module when it is wrapped.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@SeanNaren SeanNaren added this to the 1.2.x milestone Apr 7, 2021
@SeanNaren SeanNaren changed the title Fix/sharded eval [Fix] Ensure we set the eval/train flag correctly on accelerator model Apr 7, 2021
@SeanNaren SeanNaren self-assigned this Apr 7, 2021
@SeanNaren SeanNaren added the bug Something isn't working label Apr 7, 2021
@SeanNaren SeanNaren requested a review from a team April 7, 2021 20:41
Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch!!! let's add comments here for when to use self.model vs when to directly use self.lightning_module

@SeanNaren
Copy link
Contributor Author

great catch!!! let's add comments here for when to use self.model vs when to directly use self.lightning_module

Thanks @ananthsub!

The difference from the names can be quite confusing, and in most cases the choice is clearer; if you want to access lightning_module internals/functions, use self.trainer.lightning_module. And eventually for FSDP or where you'd like to access the wrapped model, you can use self.trainer.model or self.model in the trainer :)

@SeanNaren
Copy link
Contributor Author

Ah great point, forgot lightning_module.eval() and lightning_module.train() are hooks. What's advised here? I could change the default to self.trainer.model.eval() and self.trainer.model.train() within the hooks or I could add this additional logic outside the hook. cc @awaelchli @tchaton

@tchaton
Copy link
Contributor

tchaton commented Apr 8, 2021

Ah great point, forgot lightning_module.eval() and lightning_module.train() are hooks. What's advised here? I could change the default to self.trainer.model.eval() and self.trainer.model.train() within the hooks or I could add this additional logic outside the hook. cc @awaelchli @tchaton

Sounds good to me !

@SeanNaren
Copy link
Contributor Author

Thanks @tchaton just a slight mistake I should be referring to the module hooks on_test_model_train on_test_model_eval on_predict_model_eval that will change.

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good <3

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@SeanNaren SeanNaren added the _Will label Apr 8, 2021
@carmocca
Copy link
Contributor

carmocca commented Apr 8, 2021

@SeanNaren the test is failing

@SeanNaren
Copy link
Contributor Author

The test showed that we call the predict hook twice, I've remedied this and it now follows the same behaviour as run_evaluation and run_train

@codecov
Copy link

codecov bot commented Apr 8, 2021

Codecov Report

Merging #6877 (b8d473f) into master (19e67d1) will decrease coverage by 6%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #6877    +/-   ##
=======================================
- Coverage      91%     86%    -6%     
=======================================
  Files         193     193            
  Lines       12299   12539   +240     
=======================================
- Hits        11252   10779   -473     
- Misses       1047    1760   +713     

@lexierule lexierule disabled auto-merge April 8, 2021 18:04
@lexierule lexierule merged commit 742c48e into master Apr 8, 2021
@lexierule lexierule deleted the fix/sharded_eval branch April 8, 2021 18:04
@SeanNaren SeanNaren mentioned this pull request Apr 12, 2021
SeanNaren pushed a commit that referenced this pull request Apr 13, 2021
#6877)

* Ensure we move the model to eval mode before running evaluation

* Ensure we set the flag appropriately across all stages

* Add test, move hooks logic

* Apply same fix to the validate loop

* Update pytorch_lightning/trainer/trainer.py

* Fix function name

* Fix order, add predict

* Shorten the name

* Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function

* Use hook, remove double call

(cherry picked from commit 742c48e)
SeanNaren pushed a commit that referenced this pull request Apr 13, 2021
#6877)

* Ensure we move the model to eval mode before running evaluation

* Ensure we set the flag appropriately across all stages

* Add test, move hooks logic

* Apply same fix to the validate loop

* Update pytorch_lightning/trainer/trainer.py

* Fix function name

* Fix order, add predict

* Shorten the name

* Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function

* Use hook, remove double call

(cherry picked from commit 742c48e)
lexierule pushed a commit that referenced this pull request Apr 14, 2021
#6877)

* Ensure we move the model to eval mode before running evaluation

* Ensure we set the flag appropriately across all stages

* Add test, move hooks logic

* Apply same fix to the validate loop

* Update pytorch_lightning/trainer/trainer.py

* Fix function name

* Fix order, add predict

* Shorten the name

* Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function

* Use hook, remove double call

(cherry picked from commit 742c48e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Latest FairScale + Sharded Training crashes using default trainer parameters
7 participants