Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Share the training step output data via ClosureResult #9349

Merged
merged 42 commits into from
Sep 10, 2021

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Sep 7, 2021

What does this PR do?

Remove the need to use deepcopy.

Fixes #8821

Follow-up PRs

  • Treat everything returned during manual_optimization as an extra.
  • Warn when extras are returned but training_epoch_end is not implemented.
  • Split out the result dataclasses, as the behavior would be different in automatic compared to manual:
@dataclass
class OutputResult:
    all attributes optional

    defines utility methods to avoid duplication

class ManualResult(OutputResult):
    def from_taining_step_output(): ...
class ClosureResult(OutputResult):
    def from_taining_step_output(): ...

Notes:

# manual optimization `training_step` possible return formats:
(1): None # most common (doesn't skip)
(2): a_loss # Tensor. what does this do?
(3): {'loss': ...}  # same as (2)
(4): {'loss': ..., 'anything': ...}  # includes an extra
(5): {'loss': None, 'anything': ...}  #  you dont want to return a loss but need the extra
(6): {'loss': ..., 'hiddens': ...}
(7): {'loss': None, 'hiddens': ...}

Does your PR introduce any breaking changes? If yes, please list them.

None

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@carmocca carmocca added bug Something isn't working refactor labels Sep 7, 2021
@carmocca carmocca added this to the v1.4.x milestone Sep 7, 2021
@carmocca carmocca self-assigned this Sep 7, 2021
@carmocca carmocca removed this from the v1.4.x milestone Sep 7, 2021
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/batch/training_batch_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/batch/training_batch_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/batch/training_batch_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
tests/loops/test_closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/optimizer/optimizer_loop.py Outdated Show resolved Hide resolved
@carmocca carmocca changed the base branch from master to refactor/enforce-closure-execution September 8, 2021 00:57
@carmocca carmocca added this to the v1.5 milestone Sep 8, 2021
@carmocca carmocca enabled auto-merge (squash) September 8, 2021 18:13
@carmocca carmocca mentioned this pull request Sep 8, 2021
12 tasks
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great !

pytorch_lightning/loops/batch/manual.py Show resolved Hide resolved
pytorch_lightning/loops/optimizer/optimizer_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/utilities.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Sep 9, 2021

Codecov Report

Merging #9349 (65ab8d9) into master (c963bf6) will decrease coverage by 0%.
The diff coverage is 97%.

@@          Coverage Diff           @@
##           master   #9349   +/-   ##
======================================
- Coverage      93%     93%   -0%     
======================================
  Files         179     179           
  Lines       14927   14915   -12     
======================================
- Hits        13868   13854   -14     
- Misses       1059    1061    +2     

pytorch_lightning/loops/batch/manual.py Show resolved Hide resolved
pytorch_lightning/loops/closure.py Show resolved Hide resolved
pytorch_lightning/loops/utilities.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label Sep 10, 2021
@mergify mergify bot removed the has conflicts label Sep 10, 2021
@carmocca carmocca merged commit e0f2e04 into master Sep 10, 2021
@carmocca carmocca deleted the bugfix/closure-result branch September 10, 2021 11:40
@leezu
Copy link
Contributor

leezu commented Sep 30, 2021

@carmocca this PR fixes a bug that was also backported/introduced in the 1.4.6 release (#9239). Is it possible to backport this PR as well to 1.4?

@carmocca
Copy link
Contributor Author

We didn't cherry pick this because it builds on top of several refactors that landed before and aren't in the bug-fix branch.

A custom fix for it in the bug-fix branch could be explored but it's not a high priority right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: 0 High priority task ready PRs ready to be merged refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PyTorch Lightning 1.4.1 crashes during training
6 participants