Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monitor on training_epoch_end with ModelCheckpoint #5084

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -511,9 +511,13 @@ def _validate_monitor_key(self, trainer):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. "
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
)
raise MisconfigurationException(m)
if not trainer.checkpoint_connector._one_training_epoch_completed:
m += "Running first epoch, a MisconfigurationException will be raise next epoch"
rank_zero_warn(m, UserWarning)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
m += f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import APEX_AVAILABLE, OMEGACONF_AVAILABLE, AMPType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if APEX_AVAILABLE:
from apex import amp
Expand All @@ -39,7 +39,8 @@ def __init__(self, trainer):
self.trainer = trainer

# used to validate checkpointing logic
self.has_trained = False
self._has_trained = False
self._one_training_epoch_completed = False

def restore_weights(self, model: LightningModule):
"""
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -47,6 +46,7 @@
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand All @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -494,7 +494,8 @@ def train(self):
# set stage for logging
self.logger_connector.set_stage("train")

self.checkpoint_connector.has_trained = False
self.checkpoint_connector._has_trained = False
self.checkpoint_connector._one_training_epoch_completed = False

# enable train mode
model = self.get_model()
Expand Down Expand Up @@ -526,6 +527,8 @@ def train(self):
# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')

self.checkpoint_connector._one_training_epoch_completed = True

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def on_train_end(self):

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save and self.trainer.checkpoint_connector.has_trained:
if should_save and self.trainer.checkpoint_connector._has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]

if is_last and any(c.save_last for c in checkpoint_callbacks):
Expand Down Expand Up @@ -599,7 +599,7 @@ def run_training_epoch(self):
# update LR schedulers
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.trainer.checkpoint_connector.has_trained = True
self.trainer.checkpoint_connector._has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import os
import warnings
from functools import wraps
from typing import Any, Optional, Union

import torch

from pytorch_lightning import _logger as log
from typing import Union, Optional, Any

if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import group
from torch.distributed import ReduceOp, group
else:
class ReduceOp:
SUM = None
Expand Down Expand Up @@ -145,7 +145,7 @@ def sync_ddp(
if group is None:
group = torch.distributed.group.WORLD

if reduce_op is None:
if reduce_op is None or reduce_op == "sum":
carmocca marked this conversation as resolved.
Show resolved Hide resolved
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
Expand Down
45 changes: 39 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch.utils.data import DataLoader, Dataset, random_split

import pytorch_lightning as pl
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
from tests.base import BoringModel, RandomDataset


class LogInTwoMethods(BoringModel):
Expand Down Expand Up @@ -767,7 +768,7 @@ def validation_step(self, batch, batch_idx):
return {"val_loss": loss}

def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert not trainer.checkpoint_connector._has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0

Expand Down Expand Up @@ -815,7 +816,7 @@ def get_model():
assert_trainer_init(trainer)

trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.checkpoint_connector._has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)
Expand All @@ -841,12 +842,12 @@ def get_model():
assert_trainer_init(trainer)

trainer.test(model)
assert not trainer.checkpoint_connector.has_trained
assert not trainer.checkpoint_connector._has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs

trainer.fit(model)
assert not trainer.checkpoint_connector.has_trained
assert not trainer.checkpoint_connector._has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs
assert_checkpoint_log_dir(idx)
Expand Down Expand Up @@ -1020,3 +1021,35 @@ def __init__(self, hparams):
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type


def test_model_checkpoint_with_training_epoch_end(tmpdir):

"""
This test assert ModelCheckpoint a warming is issued when monitor metric is used in training_epoch_end
"""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
class TestedModel(BoringModel):

def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('train_loss', loss)
return {"loss": loss}

tchaton marked this conversation as resolved.
Show resolved Hide resolved
def training_epoch_end(self, outputs) -> None:
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log('epoch_end_train_loss', avg_loss)
self.log('gb_step', self.global_step)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

model = TestedModel()

chk = ModelCheckpoint(monitor='epoch_end_train_loss', save_top_k=-1)
with pytest.warns(UserWarning, match="Running first epoch, a MisconfigurationException"):
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=3,
progress_bar_refresh_rate=1,
callbacks=[chk],
)
trainer.current_epoch = 2
trainer.fit(model)
tchaton marked this conversation as resolved.
Show resolved Hide resolved