Skip to content

Commit

Permalink
proper checkpoint implementation (#1043)
Browse files Browse the repository at this point in the history
* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* enabled early stopping/checkpooiunt even  without val step

* name formatting

* version

* testing

* add test

* fix test

* Update model_checkpoint.py

* doctests

* pylint

* tests

* debug

* debug

* enabled early stopping/checkpooiunt even  without val step

* fix MNIST download (#1044)

* fix MNIST download

* simple

* name formatting

* version

* testing

* add test

* fix test

* doctests

* tests

* debug

* debug

* rebased 1041

* rebased 1041

* tests

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

* rebased 1041

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
williamFalcon and Borda authored Mar 5, 2020
1 parent 165b9fb commit bcb45d9
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 194 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))

### Changed

Expand Down
172 changes: 97 additions & 75 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,36 @@
r"""
Model Checkpoint
==============
Save the model as often as requested.
"""

import os
import glob
import shutil
import logging as log
import warnings
import re

import numpy as np

from .base import Callback
from pytorch_lightning.callbacks.base import Callback


class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
dirpath: path to save the model file.
filepath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# no path
ModelCheckpoint()
# saves like /my/path/epoch_0.ckpt
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
# save any arbitrary metrics like and val_loss, etc in name
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
Expand All @@ -43,54 +39,51 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode: one of {auto, min, max}.
mode (str): one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only: if True, then only the model's weights will be
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period: Interval (number of epochs) between checkpoints.
prefix: String name for particular model
period (int): Interval (number of epochs) between checkpoints.
Example:
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint('my_path')
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
"""
#: checkpoint extension
EXTENSION = '.ckpt'

def __init__(
self,
dirpath: str,
monitor: str = 'val_loss',
verbose: bool = False,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = 'auto',
period: int = 1,
prefix: str = ''
):

def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
warnings.warn(
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)

self.monitor = monitor
self.verbose = verbose
self.dirpath = dirpath
os.makedirs(dirpath, exist_ok=True)
if os.path.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
self.dirpath, self.filename = os.path.split(filepath)

os.makedirs(self.dirpath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
Expand All @@ -102,14 +95,6 @@ def __init__(
self.best = 0
self.save_function = None

# this create unique prefix if the give already exists
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
version_cnt = 0
while self.prefix in existing_names:
self.prefix = f'{prefix}-v{version_cnt}'
version_cnt += 1

mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
Expand All @@ -125,39 +110,65 @@ def __init__(

self.monitor_op, self.kth_value, self.mode = mode_dict[mode]

def _del_model(self, filepath: str) -> None:
# shutil.rmtree(filepath)
def _del_model(self, filepath):
os.remove(filepath)

def _save_model(self, filepath: str) -> None:
def _save_model(self, filepath):
# make paths
os.makedirs(self.dirpath, exist_ok=True)
os.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError("Method `.save_function()` not set")
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current: float) -> bool:
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def _get_available_filepath(self, current: float, epoch: int) -> str:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according define template.
Examples
--------
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
"""
# check if user passed in keys to the string
groups = re.findall(r'(\{.*?)[:\}]', self.filename)

if len(groups) == 0:
# default name
filename = f'{self.prefix}_ckpt_epoch_{epoch}'
else:
metrics['epoch'] = epoch
filename = self.filename
for tmp in groups:
name = tmp[1:]
filename = filename.replace(tmp, name + '={' + name)
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
str_ver = f'_v{ver}' if ver is not None else ''
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath

def on_validation_end(self, trainer, pl_module) -> None:
# only run on main process
if trainer.proc_rank != 0:
return

logs = trainer.callback_metrics
def on_validation_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1

Expand All @@ -166,27 +177,36 @@ def on_validation_end(self, trainer, pl_module) -> None:
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
current = logs.get(self.monitor)
filepath = self._get_available_filepath(current, epoch)

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

if self.save_top_k != -1:
current = metrics.get(self.monitor)

if current is None:
warnings.warn(f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')

else:
if self.verbose > 0:
log.info('Epoch %05d: saving model to %s', epoch, filepath)
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)

def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
def _do_check_save(self, filepath, current, epoch):
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
Expand All @@ -205,6 +225,8 @@ def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
self.best = _op(self.best_k_models.values())

if self.verbose > 0:
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)
14 changes: 1 addition & 13 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,7 @@ def __init__(self, *args, **kwargs):
#: True if using amp
self.use_amp = False

@property
def hparams(self) -> Namespace:
if not hasattr(self, '_hparams'):
return Namespace()
assert isinstance(self._hparams, dict)
return Namespace(**self._hparams)

@hparams.setter
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
"""Set the model hyper-parameters."""
if isinstance(params, Namespace):
params = vars(params)
self._hparams = params
self.hparams = None

def print(self, *args, **kwargs):
r"""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str,
# in case converting from namespace
if isinstance(params, Namespace):
params = vars(params)

if params is None:
params = {}

return params

@abstractmethod
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,15 @@ def configure_checkpoint_callback(self):
else:
ckpt_path = os.path.join(self.default_save_path, "checkpoints")

# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overriden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'

self.ckpt_path = ckpt_path
os.makedirs(ckpt_path, exist_ok=True)
self.checkpoint_callback = ModelCheckpoint(
dirpath=ckpt_path
filepath=ckpt_path,
monitor=monitor_key
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class TrainerEvaluationLoopMixin(ABC):
process_output: ...
training_tqdm_dict: ...
proc_rank: int
checkpoint_callback: ...
current_epoch: int
callback_metrics: ...
test_dataloaders: DataLoader
Expand Down Expand Up @@ -377,11 +376,6 @@ def run_evaluation(self, test_mode: bool = False):
# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
# model checkpointing
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,9 +1132,6 @@ def run_pretrain_routine(self, model: LightningModule):
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")

# set up checkpoint callback
self.configure_checkpoint_callback()

# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()

Expand All @@ -1151,6 +1148,9 @@ def run_pretrain_routine(self, model: LightningModule):
# if cluster resets state, the model will update with the saved weights
self.model = model

# set up checkpoint callback
self.configure_checkpoint_callback()

# restore training and model before hpc call
self.restore_weights(model)

Expand Down
Loading

0 comments on commit bcb45d9

Please sign in to comment.