Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod: fixed early stopping and added metrics aggregation #3775

Merged
merged 42 commits into from
Nov 5, 2020

Conversation

tgaddair
Copy link
Contributor

@tgaddair tgaddair commented Oct 1, 2020

What does this PR do?

Fixes #3381 by introducing Horovod metrics aggregation alongside support for DDP. Now, whenever a metric needs to be aggregated across workers, we will check if DDP is initialized, then Horovod if it is available, before returning the original value.

This means that sync_ddp has been renamed sync_dist, and a new function sync_dist_if_available has been added to check each framework individually.

Because this change relies un functionality introduced in Horovod v0.20.0, the minimum supported Horovod version has been bumped up accordingly.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@tgaddair tgaddair changed the title Hvd early stop Horovod: fixed early stopping and added metrics aggregation Oct 1, 2020
@mergify mergify bot requested a review from a team October 1, 2020 20:09
@Borda Borda added the bug Something isn't working label Oct 2, 2020
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test for this case...

Return:
reduced value
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so at firistround try PT distrib and later Horovod?

Copy link
Contributor Author

@tgaddair tgaddair Oct 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, this component does not have access to the Trainer state the tells us which distributed backend we're using. A better design would be to inject this somehow, but it's not immediately clear how we could do this, since the EvalResult is created by the LightningModule.

So for now, we need to check each framework individually.

pytorch_lightning/metrics/converters.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team October 2, 2020 08:13
@mergify
Copy link
Contributor

mergify bot commented Oct 4, 2020

This pull request is now in conflict... :(

@mergify
Copy link
Contributor

mergify bot commented Oct 6, 2020

This pull request is now in conflict... :(

@Borda Borda added this to the 1.0 milestone Oct 6, 2020
@pep8speaks
Copy link

pep8speaks commented Oct 6, 2020

Hello @tgaddair! Thanks for updating this PR.

Line 218:21: W503 line break before binary operator
Line 219:21: W503 line break before binary operator

Comment last updated at 2020-11-04 22:08:41 UTC

@williamFalcon
Copy link
Contributor

@tgaddair mind keeping all the horovod code in the accelerator?
then add the matching function for the other accelerators that need it.

ie: we shouldn't call:

if ddp:
x()
if horovod:
y()

it should be
accelerator.x()

Copy link
Contributor

@williamFalcon williamFalcon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the horovod accelerator as described so we can keep the code clean and simple...
thanks!

@mergify mergify bot requested a review from a team October 6, 2020 23:01
@mergify mergify bot requested a review from a team October 6, 2020 23:04
@mergify mergify bot requested a review from a team October 6, 2020 23:05
@mergify mergify bot requested a review from a team October 6, 2020 23:05
@edenlightning edenlightning modified the milestones: 1.0, 0.10.0 Oct 6, 2020
@Borda Borda added the priority: 0 High priority task label Oct 7, 2020
@mergify
Copy link
Contributor

mergify bot commented Oct 7, 2020

This pull request is now in conflict... :(

@teddykoker
Copy link
Contributor

Great! If I understand correctly, each accelerator backend will now (optionally) implement gather_all_tensors which can then be passed into metrics via the dist_sync_fn flag?

@tgaddair
Copy link
Contributor Author

tgaddair commented Nov 4, 2020

@williamFalcon @edenlightning @teddykoker this is ready to go.

@williamFalcon
Copy link
Contributor

@SeanNaren @tchaton mind prioritizing this to land? thanks!

Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks great ! Some minor changes and one question about your first test.

pytorch_lightning/utilities/distributed.py Show resolved Hide resolved
pytorch_lightning/metrics/metric.py Outdated Show resolved Hide resolved
tests/metrics/test_ddp.py Outdated Show resolved Hide resolved
tests/models/test_horovod.py Show resolved Hide resolved
pytorch_lightning/accelerators/accelerator.py Show resolved Hide resolved
pytorch_lightning/metrics/metric.py Outdated Show resolved Hide resolved
tests/models/test_horovod.py Outdated Show resolved Hide resolved
SeanNaren and others added 4 commits November 4, 2020 21:12
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
@Borda Borda added distributed Generic distributed-related topic Metrics labels Nov 4, 2020
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@SeanNaren
Copy link
Contributor

pinging @teddykoker to have a look since it touches the metric package (albeit not too intrusive)

Copy link
Contributor

@teddykoker teddykoker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! At some point we'll need a way of providing the sync function to metrics automatically

@williamFalcon williamFalcon merged commit 51cc7a8 into Lightning-AI:master Nov 5, 2020
SeanNaren pushed a commit that referenced this pull request Nov 10, 2020
* Fixed early stopping for Horovod

* Refactored to sync_dist_if_available

* Bump min Horovod version to support hvd.is_initialized

* Changelog

* Added back change for Horovod

* Removed redundant checks for initialization

* Implement metrics gathering for Horovod

* Added test for EvalResult

* Renamed ddp_sync_on_step -> dist_sync_on_step

* Added metric test for Horovod

* Added option pass callable allgather function to metric base class

* Added dist_sync_fn

* Fixed calls to private _sync_dist

* Fixed Horovod test

* Added sync_tensor to the distributed backend

* Skip Windows

* Insert test path

* Removed redundant import

* Updated drone

* Unset HOROVOD_GPU_ALLREDUCE

* Unset

* No cache dir

* No uninstall

* Unset variables

* Uninstall Horovod during initialization

* Replaced more references to ddp_sync_on_step

* Fixed imports

* Fixed attribute

* Added back default

* Lint

* Added back docstring

* Made gather_all_tensors default

* Added whitespace

* Update tests/models/test_horovod.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update CHANGELOG.md

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 51cc7a8)
SeanNaren pushed a commit that referenced this pull request Nov 11, 2020
* Fixed early stopping for Horovod

* Refactored to sync_dist_if_available

* Bump min Horovod version to support hvd.is_initialized

* Changelog

* Added back change for Horovod

* Removed redundant checks for initialization

* Implement metrics gathering for Horovod

* Added test for EvalResult

* Renamed ddp_sync_on_step -> dist_sync_on_step

* Added metric test for Horovod

* Added option pass callable allgather function to metric base class

* Added dist_sync_fn

* Fixed calls to private _sync_dist

* Fixed Horovod test

* Added sync_tensor to the distributed backend

* Skip Windows

* Insert test path

* Removed redundant import

* Updated drone

* Unset HOROVOD_GPU_ALLREDUCE

* Unset

* No cache dir

* No uninstall

* Unset variables

* Uninstall Horovod during initialization

* Replaced more references to ddp_sync_on_step

* Fixed imports

* Fixed attribute

* Added back default

* Lint

* Added back docstring

* Made gather_all_tensors default

* Added whitespace

* Update tests/models/test_horovod.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update CHANGELOG.md

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 51cc7a8)
rohitgr7 pushed a commit that referenced this pull request Nov 21, 2020
* Fixed early stopping for Horovod

* Refactored to sync_dist_if_available

* Bump min Horovod version to support hvd.is_initialized

* Changelog

* Added back change for Horovod

* Removed redundant checks for initialization

* Implement metrics gathering for Horovod

* Added test for EvalResult

* Renamed ddp_sync_on_step -> dist_sync_on_step

* Added metric test for Horovod

* Added option pass callable allgather function to metric base class

* Added dist_sync_fn

* Fixed calls to private _sync_dist

* Fixed Horovod test

* Added sync_tensor to the distributed backend

* Skip Windows

* Insert test path

* Removed redundant import

* Updated drone

* Unset HOROVOD_GPU_ALLREDUCE

* Unset

* No cache dir

* No uninstall

* Unset variables

* Uninstall Horovod during initialization

* Replaced more references to ddp_sync_on_step

* Fixed imports

* Fixed attribute

* Added back default

* Lint

* Added back docstring

* Made gather_all_tensors default

* Added whitespace

* Update tests/models/test_horovod.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/metric.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update CHANGELOG.md

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic priority: 0 High priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Early stopping fails on horovod with cannot unpack non-iterable NoneType object
8 participants