Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monitor on training_epoch_end with ModelCheckpoint #5084

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -232,7 +232,8 @@ def save_checkpoint(self, trainer, pl_module):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)
if not self._validate_monitor_key(trainer):
return

# track epoch when ckpt was last checked
self.last_global_step_saved = global_step
Expand Down Expand Up @@ -501,17 +502,20 @@ def _add_backward_monitor_support(self, trainer):
if self.save_top_k is None and self.monitor is not None:
self.save_top_k = 1

def _validate_monitor_key(self, trainer):
def _validate_monitor_key(self, trainer) -> bool:
metrics = trainer.logger_connector.callback_metrics

# validate metric
if self.monitor is not None and not self._is_valid_monitor_key(metrics):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
)
raise MisconfigurationException(m)
if not trainer.checkpoint_connector._one_training_epoch_completed:
return False
else:
raise MisconfigurationException(
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
)
return True

def _get_metric_interpolated_filepath_name(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if APEX_AVAILABLE:
from apex import amp
Expand All @@ -41,7 +41,8 @@ def __init__(self, trainer):
self.trainer = trainer

# used to validate checkpointing logic
self.has_trained = False
self._has_trained = False
self._one_training_epoch_completed = False

def restore_weights(self, model: LightningModule):
"""
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -47,6 +46,7 @@
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand All @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -494,7 +494,8 @@ def train(self):
# set stage for logging
self.logger_connector.set_stage("train")

self.checkpoint_connector.has_trained = False
self.checkpoint_connector._has_trained = False
self.checkpoint_connector._one_training_epoch_completed = False

# enable train mode
model = self.get_model()
Expand Down Expand Up @@ -526,6 +527,8 @@ def train(self):
# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')

self.checkpoint_connector._one_training_epoch_completed = True

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def on_train_end(self):

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save and self.trainer.checkpoint_connector.has_trained:
if should_save and self.trainer.checkpoint_connector._has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]

if is_last and any(c.save_last for c in checkpoint_callbacks):
Expand Down Expand Up @@ -597,7 +597,7 @@ def run_training_epoch(self):
# update LR schedulers
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.trainer.checkpoint_connector.has_trained = True
self.trainer.checkpoint_connector._has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import os
import warnings
from functools import wraps
from typing import Any, Optional, Union

import torch

from pytorch_lightning import _logger as log
from typing import Union, Optional, Any

if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import group
from torch.distributed import ReduceOp, group
else:
class ReduceOp:
SUM = None
Expand Down
50 changes: 43 additions & 7 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch.utils.data import DataLoader, Dataset, random_split

import pytorch_lightning as pl
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
from tests.base import BoringModel, RandomDataset


class LogInTwoMethods(BoringModel):
Expand Down Expand Up @@ -702,7 +703,7 @@ def validation_epoch_end(self, *_):
...

def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert not trainer.checkpoint_connector._has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0

Expand Down Expand Up @@ -739,7 +740,7 @@ def assert_checkpoint_log_dir(idx):

model = ExtendedBoringModel()
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.checkpoint_connector._has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)
Expand All @@ -759,12 +760,12 @@ def assert_checkpoint_log_dir(idx):

model = ExtendedBoringModel()
trainer.test(model)
assert not trainer.checkpoint_connector.has_trained
assert not trainer.checkpoint_connector._has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs

trainer.fit(model)
assert not trainer.checkpoint_connector.has_trained
assert not trainer.checkpoint_connector._has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs
assert_checkpoint_log_dir(idx)
Expand Down Expand Up @@ -940,6 +941,41 @@ def __init__(self, hparams):
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


def test_model_checkpoint_with_training_epoch_end(tmpdir):
"""
This test ensures ModelCheckpoint issues a warning when the monitor is logged on training_epoch_end
"""
class TestedModel(BoringModel):

def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('train_loss', loss)
return {"loss": loss}

tchaton marked this conversation as resolved.
Show resolved Hide resolved
def training_epoch_end(self, outputs) -> None:
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log('epoch_end_train_loss', avg_loss)

model = TestedModel()

chk = ModelCheckpoint(dirpath=tmpdir, monitor='epoch_end_train_loss', save_top_k=-1)
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=4,
progress_bar_refresh_rate=1,
callbacks=[chk],
)
trainer.current_epoch = 2
trainer.fit(model)

chks = os.listdir(tmpdir)
assert 'epoch=4.ckpt' not in chks
assert 'epoch=3.ckpt' not in chks
assert 'epoch=2.ckpt' not in chks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is missing a pytest.warns check with the warning thrown 😄

Also it might be better to test

os.listdir(tmpdir) == ['epoch=0.ckpt', ...]

instead of testing that the rest don't exist

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning was removed. Also, I am thinking it might be better not to merge this PR as it is a wacky solution and work a on train_sanity check instead. @carmocca @rohitgr7 What are your thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's close this one for now.




@pytest.mark.parametrize('max_epochs', [3, 4])
@pytest.mark.parametrize(
'save_top_k, expected',
Expand Down Expand Up @@ -976,4 +1012,4 @@ def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, ex
assert set(ckpt_files) == set(expected)

epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))