Skip to content
This repository has been archived by the owner on Sep 28, 2022. It is now read-only.

Commit

Permalink
Accelerator model state dict (Lightning-AI#7474)
Browse files Browse the repository at this point in the history
* Fix some test errors
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* checkpoint consolidation

* Update ddp_spawn.py

* Update test_metric_result_integration.py

* Update test_results.py

* Update utils.py

* Update utils.py

* Update test_all_gather_grad.py

* Update test_all_gather_grad.py

* Update test_results.py

* Revert "Update test_results.py"

This reverts commit 9d4a2b8.

* Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate"

This reverts commit c5053da, reversing
changes made to 0d23d75.

* Revert "Update test_all_gather_grad.py"

This reverts commit 0d23d75.

* Revert "Update utils.py"

This reverts commit 70fe5da.

* Revert "Update utils.py"

This reverts commit a9aae99.

* Revert "Update test_results.py"

This reverts commit ea74906.

* Revert "Update test_metric_result_integration.py"

This reverts commit bf70e43.

* Revert "Update ddp_spawn.py"

This reverts commit f172101.

* Revert "checkpoint consolidation"

This reverts commit 536c132.

* Revert "Revert "checkpoint consolidation""

This reverts commit 3a9fde9.

* Revert "Revert "Revert "checkpoint consolidation"""

This reverts commit 7a369f4.

* Revert "Revert "Update ddp_spawn.py""

This reverts commit 8222dc9.

* Revert "Revert "Update test_metric_result_integration.py""

This reverts commit 6c095b2.

* Revert "Revert "Update test_results.py""

This reverts commit 250d0aa.

* Revert "Revert "Update utils.py""

This reverts commit 8651d54.

* Revert "Revert "Update test_all_gather_grad.py""

This reverts commit dcdcd29.

* modify distributed environment to make test pass

* modify model state dict to training type plugin

* remove changes

* add changelog

* fixing isort for pre-commit failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: SeanNaren <sean@grid.ai>
  • Loading branch information
3 people committed May 11, 2021
1 parent a1a655d commit 8538c1f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))


- Changed `model.state_dict()` in `CheckpointConnector` to allow `training_type_plugin` to customize the model's `state_dict()` ([7474](https://github.com/PyTorchLightning/pytorch-lightning/pull/7474))


### Deprecated


Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,12 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""
return getattr(self.training_type_plugin, 'optimizer_state', lambda x: x.state_dict())(optimizer)

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
"""
Returns state of model. Allows for syncing/collating model state from processes in custom plugins.
"""
return self.training_type_plugin.lightning_module_state_dict()

def on_save(self, checkpoint: Dict[str, Union[Any, Tensor]]) -> Dict[str, Union[Any, Tensor]]:
return self.training_type_plugin.on_save(checkpoint)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TypeVar, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -241,6 +242,11 @@ def update_global_step(self, total_batch_idx: int, current_global_step: int) ->
"""
return current_global_step + 1

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
"""Returns model state."""
model = self.lightning_module
return model.state_dict()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
'state_dict': self.trainer.accelerator.lightning_module_state_dict(),
}

if not weights_only:
Expand Down

0 comments on commit 8538c1f

Please sign in to comment.