Skip to content

Commit

Permalink
Fix checkpoint callback & Trainer.test(_) issue for TPUs (#6654)
Browse files Browse the repository at this point in the history
* Fix checkpoint callback issue for TPUs

* update changelog

* add barrier

* apply code suggestions

* update trainer test

* remove spaces

* fix tpu tests

* Apply suggestions from code review

* add comment

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
kaushikb11 and Borda committed Mar 30, 2021
1 parent 014a6b7 commit a1829bc
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))


## [1.2.5] - 2021-03-23
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -96,13 +95,15 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:

# replace trainer save_checkpoint to use `xm.save`
trainer.save_checkpoint = self.save_checkpoint
self.barrier()
self.barrier("pre-run-stage")

results = trainer.train_or_test_or_predict()

self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

self.barrier("end-process")

def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
Expand All @@ -113,12 +114,11 @@ def model_to_device(self) -> None:
self._model.to(xm.xla_device())

def barrier(self, name: Optional[str] = None) -> None:
if torch_distrib.is_initialized():
rendezvous(f"pl.Trainer.{name}")
rendezvous(name)

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

if self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.enums import LightningEnum
Expand Down Expand Up @@ -942,7 +942,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
)
return {}

self.training_type_plugin.barrier()
# only one process running at this point for TPUs, as spawn isn't triggered yet
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down
17 changes: 15 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,14 @@ def test_reduce(rank):
xmp.spawn(test_reduce, nprocs=8, start_method='fork')


@pytest.mark.parametrize("clip_val", [0, 10])
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@RunIf(tpu=True)
@pl_multi_process_test
@pytest.mark.parametrize("clip_val", [10])
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
TODO: Fix (test fails with parametrize)
"""
tutils.reset_seed()
trainer_options = dict(
Expand All @@ -375,3 +376,15 @@ def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
mock_clip_grad_norm.assert_called()
else:
mock_clip_grad_norm.assert_not_called()


@RunIf(tpu=True)
@pl_multi_process_test
def test_if_test_works_with_checkpoint_false(tmpdir):
"""Ensure that model trains properly when `checkpoint_callback` is set to False."""

# Train a model on TPU
model = BoringModel()
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

0 comments on commit a1829bc

Please sign in to comment.