Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpoint callback & Trainer.test(_) issue for TPUs #6654

Merged
merged 11 commits into from
Mar 25, 2021
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))


- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))


- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))


- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))


Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -109,13 +108,15 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:

# replace trainer save_checkpoint to use `xm.save`
trainer.save_checkpoint = self.save_checkpoint
self.barrier()
self.barrier("pre-run-stage")

results = trainer.run_stage()

self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

self.barrier("end-process")
Borda marked this conversation as resolved.
Show resolved Hide resolved

def __save_end_of_training_weights(self, model: LightningModule) -> None:
# when training ends on these platforms dump weights to get out of the main process
if on_colab_kaggle():
Expand All @@ -126,11 +127,11 @@ def model_to_device(self) -> None:
self._model.to(xm.xla_device())

def barrier(self, name: Optional[str] = None) -> None:
if torch_distrib.is_initialized():
rendezvous(f"pl.Trainer.{name}")
rendezvous(name)

def transfer_distrib_spawn_state_on_fit_end(self, results):
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
Comment on lines +133 to +134
Copy link
Contributor

@ananthsub ananthsub Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if there are multiple checkpoint callbacks attached? should we save once per path?

@awaelchli @carmocca this is gonna be amplified if people are tracking multiple versions of "best model paths" at the same time in an example like this

trainer = Trainer(...., callbacks=[checkpoint1, checkpoint2])
trainer.fit(module)
trainer.test()  <--- what checkpoint path path is used for running this?

should this raise an error due to ambiguity?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd rather use the first best path and log the path used when running test


if self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -983,7 +983,9 @@ def __load_ckpt_weights(
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)

self.training_type_plugin.barrier()
# only one process running at this point for TPUs, as spawn isn't triggered yet
if not self._device_type == DeviceType.TPU:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down
15 changes: 14 additions & 1 deletion tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,14 @@ def test_reduce(rank):
xmp.spawn(test_reduce, nprocs=8, start_method='fork')


@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(tpu=True)
@pl_multi_process_test
@pytest.mark.parametrize("clip_val", [10])
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
TODO: Fix (test fails with parametrize)
"""
tutils.reset_seed()
trainer_options = dict(
Expand All @@ -383,3 +384,15 @@ def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
mock_clip_grad_norm.assert_called()
else:
mock_clip_grad_norm.assert_not_called()


@RunIf(tpu=True)
@pl_multi_process_test
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def test_if_test_works_with_checkpoint_false(tmpdir):
"""Ensure that model trains properly when `checkpoint_callback` is set to False."""

# Train a model on TPU
model = BoringModel()
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"