diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f27f9c4c61476..2517888ad5ed7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -517,6 +517,10 @@ def _update_best_and_save( self.best_k_models.pop(self.kth_best_model_path) del_list.append(delpath) + # do not save non, for replace then by +/- inf + if torch.isnan(current): + current = {"min": torch.tensor(float('inf')), "max": torch.tensor(-float('inf'))}[self.mode] + self.best_k_models[filepath] = current if len(self.best_k_models) == k: # monitor dict has reached k elements diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b3b8204166ecc..ee988bb8f4b60 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -15,7 +15,7 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from tests.base import EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -457,3 +457,28 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ) for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): assert w0.eq(w1).all() + + +@pytest.mark.parametrize('mode', ['min', 'max']) +def test_checkpointing_with_nan_as_first(tmpdir, mode): + os.environ['PL_DEV_DEBUG'] = '1' + monitor = [float('nan')] + monitor += [5, 7, 8] if mode == 'max' else [8, 7, 5] + + class CurrentModel(BoringModel): + def validation_epoch_end(self, outputs): + val_loss = monitor[self.current_epoch] + self.log('abc', val_loss) + + model = CurrentModel() + + trainer = Trainer( + checkpoint_callback=ModelCheckpoint(monitor='abc', mode=mode, save_top_k=1, filepath=tmpdir), + default_root_dir=tmpdir, + val_check_interval=1.0, + max_epochs=len(monitor), + ) + trainer.fit(model) + + # check that last one is also the best one + assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1