Skip to content

Commit

Permalink
[bugfix] Add support for omegaconf and tpu (#6741)
Browse files Browse the repository at this point in the history
* fix_hydra

* update changelog

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
tchaton and Your Name committed Mar 30, 2021
1 parent 583fcf2 commit bb92754
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))


- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))

## [1.2.4] - 2021-03-16

### Changed
Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.apply_func import apply_to_collection

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand All @@ -37,6 +38,10 @@
else:
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
from omegaconf import DictConfig, ListConfig


class TPUSpawnPlugin(DDPSpawnPlugin):

Expand Down Expand Up @@ -304,4 +309,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
filepath: write-target file's path
"""
# Todo: TypeError: 'mappingproxy' object does not support item assignment
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

0 comments on commit bb92754

Please sign in to comment.