Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DPP + SyncBN #6838

Merged
merged 2 commits into from
Apr 6, 2021
Merged

Fix DPP + SyncBN #6838

merged 2 commits into from
Apr 6, 2021

Conversation

BloodAxe
Copy link
Contributor

@BloodAxe BloodAxe commented Apr 5, 2021

What does this PR do?

This PR ensures that model is already on correct GPU before applying SyncBN conversion. According to PyTorch docs, SyncBN conversion should be applied to model that is already on target GPU. Not following this requirement will lead to hard-tracing errors of some BN layers kept on CPU.
Changing the order of model_to_device call seems to fix this issue without any side-effects to the rest of the code.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

Ensure that model is already on correct GPU before applying SyncBN conversion
@codecov
Copy link

codecov bot commented Apr 5, 2021

Codecov Report

Merging #6838 (9b4a71d) into master (22a266d) will decrease coverage by 5%.
The diff coverage is 50%.

@@           Coverage Diff           @@
##           master   #6838    +/-   ##
=======================================
- Coverage      91%     87%    -5%     
=======================================
  Files         192     192            
  Lines       12191   12191            
=======================================
- Hits        11145   10595   -550     
- Misses       1046    1596   +550     

@awaelchli awaelchli added the distributed Generic distributed-related topic label Apr 5, 2021
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like that's the correct thing to do.
can you apply the same to ddp_spawn.py?

@awaelchli awaelchli added the bug Something isn't working label Apr 5, 2021
@BloodAxe
Copy link
Contributor Author

BloodAxe commented Apr 5, 2021

Greetings! Thanks for quick response. Added same changes to ddp_spawn in my last commit.

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for this catch !

# move the model to the correct device
self.model_to_device()

if self.sync_batchnorm:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.sync_batchnorm:
 # `sync BatchNorm` requires a model on `cuda` device
if self.sync_batchnorm:

# move the model to the correct device
self.model_to_device()

if self.sync_batchnorm:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.sync_batchnorm:
# `sync BatchNorm` requires a model on `cuda` device
if self.sync_batchnorm:

@tchaton tchaton merged commit eafec7d into Lightning-AI:master Apr 6, 2021
@kaushikb11 kaushikb11 mentioned this pull request Apr 6, 2021
kaushikb11 pushed a commit to kaushikb11/pytorch-lightning that referenced this pull request Apr 6, 2021
* Fix DPP + SyncBN

Ensure that model is already on correct GPU before applying SyncBN conversion

* Fix order of SyncBN for ddp_spawn
lexierule pushed a commit that referenced this pull request Apr 7, 2021
* update readme by v1.2.x (#6728)

* [bugfix] Add support for omegaconf and tpu (#6741)

* fix_hydra

* update changelog

Co-authored-by: Your Name <you@example.com>

* [docs] Update Bolts link (#6743)

* Update Bolts link

* Update Bolts link

* formt

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update logic for checking TPUs availability (#6767)

* Update logic for checking TPUs availability

* fix flake8

* add fix

* resolve bug (#6781)

* Fix validation progress counter with check_val_every_n_epoch > 1 (#5952)

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Remove extinct parameters from lightning_module.rst (#6801)

Fixes  #6800

* Update TPU docs for installation (#6794)

* fix boolean check on iterable dataset when len not defined (#6828)

* fix iterable dataset len check

* update predict and validate

* add validate to test

* add changelog

* add predict

* Sanitize `None` params during pruning (#6836)

* sanitize none params during pruning

* amend

* Fix `unfreeze_and_add_param_group` expects `modules` rather than `module` (#6822)

* Enforce an epoch scheduler interval when using SWA (#6588)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>

* Fix DPP + SyncBN (#6838)

* Fix DPP + SyncBN

Ensure that model is already on correct GPU before applying SyncBN conversion

* Fix order of SyncBN for ddp_spawn

* [Fix] TPU Training Type Plugin (#6816)

* Fix support for symlink save_dir in TensorBoardLogger (#6730)

* Add test for symlink support and initial fix

* Respond to comment and add docstring

* Update CHANGELOG.md

* Simplify

* Update pytorch_lightning/utilities/cloud_io.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Make `LightningLocalFileSystem` protected

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Fixed missing arguments in `lr_find` call (#6784)

There seem to be 3 arguments missing in the `lr_find` call in the tunining.py file.

* Update Changelog & version

* Fix TPU tests for checkpoint

Skip advanced profiler for torch > 1.8

Skip pytorch profiler for torch > 1.8

Fix save checkpoint logic for TPUs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: Yuan-Hang Zhang <sailordiary@users.noreply.github.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Elizaveta Logacheva <elimohl@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Karthik Prasad <prasadkr@uci.edu>
Co-authored-by: Sadiq Jaffer <sadiq@toao.com>
Co-authored-by: Michael Baumgartner <m.baumgartner@ymail.com>
Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Tharindu Hasthika <tharindubathigama@gmail.com>
lexierule pushed a commit that referenced this pull request Apr 7, 2021
* Fix DPP + SyncBN

Ensure that model is already on correct GPU before applying SyncBN conversion

* Fix order of SyncBN for ddp_spawn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants