From 5bb1838e453e10853489113be57b1846bec0770b Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 25 Mar 2021 16:07:37 +0530 Subject: [PATCH] Fix checkpoint callback & Trainer.test(_) issue for TPUs (#6654) * Fix checkpoint callback issue for TPUs * update changelog * add barrier * apply code suggestions * update trainer test * remove spaces * fix tpu tests * Apply suggestions from code review * add comment Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ .../plugins/training_type/tpu_spawn.py | 12 ++++++------ pytorch_lightning/trainer/trainer.py | 6 ++++-- tests/models/test_tpu.py | 17 +++++++++++++++-- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 524e57ac48e03..6669050a56298 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) +- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) +- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654)) ## [1.2.5] - 2021-03-23 diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1e951329b22cc..09603f9a22bc2 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch -import torch.distributed as torch_distrib import torch.multiprocessing as mp from pytorch_lightning.core.lightning import LightningModule @@ -96,13 +95,15 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: # replace trainer save_checkpoint to use `xm.save` trainer.save_checkpoint = self.save_checkpoint - self.barrier() + self.barrier("pre-run-stage") results = trainer.train_or_test_or_predict() self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) + self.barrier("end-process") + def __save_end_of_training_weights(self, model: LightningModule) -> None: # when training ends on these platforms dump weights to get out of the main process if on_colab_kaggle(): @@ -113,12 +114,11 @@ def model_to_device(self) -> None: self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: - if torch_distrib.is_initialized(): - rendezvous(f"pl.Trainer.{name}") + rendezvous(name) def transfer_distrib_spawn_state_on_fit_end(self, results): - # TODO: is there a better way than accessing callback through model -> trainer -> callback? - best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path + checkpoint_callback = self.lightning_module.trainer.checkpoint_callback + best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None if self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f378ee830d261..2d5e2504a319f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.enums import LightningEnum @@ -942,7 +942,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): ) return {} - self.training_type_plugin.barrier() + # only one process running at this point for TPUs, as spawn isn't triggered yet + if not self._device_type == DeviceType.TPU: + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 0554d924e6e9f..fbda891f0065f 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -349,13 +349,14 @@ def test_reduce(rank): xmp.spawn(test_reduce, nprocs=8, start_method='fork') -@pytest.mark.parametrize("clip_val", [0, 10]) -@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@RunIf(tpu=True) @pl_multi_process_test +@pytest.mark.parametrize("clip_val", [10]) @mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_") def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): """ Ensure that clip gradients is only called if the value is greater than 0. + TODO: Fix (test fails with parametrize) """ tutils.reset_seed() trainer_options = dict( @@ -375,3 +376,15 @@ def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): mock_clip_grad_norm.assert_called() else: mock_clip_grad_norm.assert_not_called() + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_if_test_works_with_checkpoint_false(tmpdir): + """Ensure that model trains properly when `checkpoint_callback` is set to False.""" + + # Train a model on TPU + model = BoringModel() + trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"