Skip to content

Commit

Permalink
update checkpoint docs (Lightning-AI#1016)
Browse files Browse the repository at this point in the history
* update checkpoint docs

* fix tests

* fix tests

* formatting

* typing

* filename

* fix tests

* fixing tests

* fixing tests

* fixing tests

* unique name

* fixing

* fixing

* Update model_checkpoint.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
2 people authored and tullie committed Apr 3, 2020
1 parent 1ec182b commit ea25a3c
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 78 deletions.
110 changes: 63 additions & 47 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import os
import shutil
import glob
import logging as log
import warnings

Expand All @@ -20,17 +20,19 @@ class ModelCheckpoint(Callback):
Save the model after every epoch.
Args:
filepath: path to save the model file.
dirpath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if such model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
Expand All @@ -39,43 +41,54 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
period: Interval (number of epochs) between checkpoints.
prefix: String name for particular model
Example::
Example:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path')
checkpoint_callback = ModelCheckpoint('my_path')
Trainer(checkpoint_callback=checkpoint_callback)
"""

def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
#: checkpoint extension
EXTENSION = '.ckpt'

def __init__(
self,
dirpath: str,
monitor: str = 'val_loss',
verbose: bool = False,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = 'auto',
period: int = 1,
prefix: str = ''
):
super().__init__()
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)

self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.dirpath = dirpath
os.makedirs(dirpath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
Expand All @@ -87,6 +100,14 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
self.best = 0
self.save_function = None

# this create unique prefix if the give already exists
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
version_cnt = 0
while self.prefix in existing_names:
self.prefix = f'{prefix}-v{version_cnt}'
version_cnt += 1

mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
Expand All @@ -102,29 +123,34 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,

self.monitor_op, self.kth_value, self.mode = mode_dict[mode]

def _del_model(self, filepath):
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _del_model(self, filepath: str) -> None:
# shutil.rmtree(filepath)
os.remove(filepath)

def _save_model(self, filepath):
def _save_model(self, filepath: str) -> None:
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
os.makedirs(self.dirpath, exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current):
def check_monitor_top_k(self, current: float) -> bool:
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self, trainer, pl_module):
def _get_available_filepath(self, current: float, epoch: int) -> str:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
return filepath

def on_validation_end(self, trainer, pl_module) -> None:
# only run on main process
if trainer.proc_rank != 0:
return
Expand All @@ -138,35 +164,27 @@ def on_validation_end(self, trainer, pl_module):
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
current = logs.get(self.monitor)
filepath = self._get_available_filepath(current, epoch)

if self.save_top_k != -1:
current = logs.get(self.monitor)

if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
warnings.warn(f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)

else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
log.info('Epoch %05d: saving model to %s', epoch, filepath)
self._save_model(filepath)

def _do_check_save(self, filepath, current, epoch):
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
Expand All @@ -185,8 +203,6 @@ def _do_check_save(self, filepath, current, epoch):
self.best = _op(self.best_k_models.values())

if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
self._save_model(filepath)
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def configure_checkpoint_callback(self):

self.ckpt_path = ckpt_path
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path
dirpath=ckpt_path
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
Expand All @@ -62,7 +62,7 @@ def configure_checkpoint_callback(self):
self.checkpoint_callback.save_function = self.save_checkpoint

# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath
self.weights_save_path = self.checkpoint_callback.dirpath

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def on_train_end(self):
self.amp_level = amp_level
self.precision = precision

assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if self.precision == 16 and num_tpu_cores is None:
use_amp = True
Expand Down
4 changes: 2 additions & 2 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):

# test model loading
pretrained_model = load_model(trainer.logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
path_expt=trainer_options.get('default_save_path'))

# test new model accuracy
Expand Down Expand Up @@ -70,7 +70,7 @@ def run_model_test(trainer_options, model, on_gpu=True):
assert result == 1, 'amp + ddp model failed to complete'

# test model loading
pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath)
pretrained_model = load_model(logger, trainer.checkpoint_callback.dirpath)

# test new model accuracy
test_loaders = model.test_dataloader()
Expand Down
13 changes: 6 additions & 7 deletions tests/test_restore_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import logging as log
import os

Expand Down Expand Up @@ -52,7 +53,7 @@ def test_running_test_pretrained_model_ddp(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)

# run test set
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_running_test_pretrained_model(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(
logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
)

new_trainer = Trainer(**trainer_options)
Expand Down Expand Up @@ -132,9 +133,7 @@ def test_load_model_from_checkpoint(tmpdir):
assert result == 1, 'training failed to complete'

# load last checkpoint
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
if not os.path.isfile(last_checkpoint):
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)

# test that hparams loaded correctly
Expand Down Expand Up @@ -186,7 +185,7 @@ def test_running_test_pretrained_model_dp(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)

new_trainer = Trainer(**trainer_options)
Expand Down Expand Up @@ -346,7 +345,7 @@ def test_load_model_with_missing_hparams(tmpdir):

model = LightningTestModelWithoutHyperparametersArg()
trainer.fit(model)
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]

# try to load a checkpoint that has hparams but model is missing hparams arg
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):
Expand Down
Loading

0 comments on commit ea25a3c

Please sign in to comment.