Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update checkpoint docs #1016

Merged
merged 14 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 63 additions & 47 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import shutil
import glob
import logging as log
import warnings

Expand All @@ -13,17 +13,19 @@ class ModelCheckpoint(Callback):
Save the model after every epoch.

Args:
filepath: path to save the model file.
dirpath: path to save the model file.
Can contain named formatting options to be auto-filled.

Example::

# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if such model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt

monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
Expand All @@ -32,43 +34,54 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
period: Interval (number of epochs) between checkpoints.
prefix: String name for particular model

Example::
Example:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path')
checkpoint_callback = ModelCheckpoint('my_path')
Trainer(checkpoint_callback=checkpoint_callback)
"""

def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
#: checkpoint extension
EXTENSION = '.ckpt'

def __init__(
self,
dirpath: str,
monitor: str = 'val_loss',
verbose: bool = False,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = 'auto',
period: int = 1,
prefix: str = ''
):
super().__init__()
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)

self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.dirpath = dirpath
os.makedirs(dirpath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
Expand All @@ -80,6 +93,14 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
self.best = 0
self.save_function = None

# this create unique prefix if the give already exists
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
version_cnt = 0
while self.prefix in existing_names:
self.prefix = f'{prefix}-v{version_cnt}'
version_cnt += 1

mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
Expand All @@ -95,29 +116,34 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,

self.monitor_op, self.kth_value, self.mode = mode_dict[mode]

def _del_model(self, filepath):
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _del_model(self, filepath: str) -> None:
# shutil.rmtree(filepath)
os.remove(filepath)

def _save_model(self, filepath):
def _save_model(self, filepath: str) -> None:
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
os.makedirs(self.dirpath, exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current):
def check_monitor_top_k(self, current: float) -> bool:
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_validation_end(self, trainer, pl_module):
def _get_available_filepath(self, current: float, epoch: int) -> str:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
return filepath

def on_validation_end(self, trainer, pl_module) -> None:
# only run on main process
if trainer.proc_rank != 0:
return
Expand All @@ -131,35 +157,27 @@ def on_validation_end(self, trainer, pl_module):
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
current = logs.get(self.monitor)
filepath = self._get_available_filepath(current, epoch)

if self.save_top_k != -1:
current = logs.get(self.monitor)

if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
warnings.warn(f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)

else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
log.info('Epoch %05d: saving model to %s', epoch, filepath)
self._save_model(filepath)

def _do_check_save(self, filepath, current, epoch):
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
Expand All @@ -178,8 +196,6 @@ def _do_check_save(self, filepath, current, epoch):
self.best = _op(self.best_k_models.values())

if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
self._save_model(filepath)
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def configure_checkpoint_callback(self):

self.ckpt_path = ckpt_path
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path
dirpath=ckpt_path
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
Expand All @@ -62,7 +62,7 @@ def configure_checkpoint_callback(self):
self.checkpoint_callback.save_function = self.save_checkpoint

# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath
self.weights_save_path = self.checkpoint_callback.dirpath

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def on_train_end(self):
self.amp_level = amp_level
self.precision = precision

assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if self.precision == 16 and num_tpu_cores is None:
use_amp = True
Expand Down
4 changes: 2 additions & 2 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):

# test model loading
pretrained_model = load_model(trainer.logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
path_expt=trainer_options.get('default_save_path'))

# test new model accuracy
Expand Down Expand Up @@ -70,7 +70,7 @@ def run_model_test(trainer_options, model, on_gpu=True):
assert result == 1, 'amp + ddp model failed to complete'

# test model loading
pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath)
pretrained_model = load_model(logger, trainer.checkpoint_callback.dirpath)

# test new model accuracy
test_loaders = model.test_dataloader()
Expand Down
13 changes: 6 additions & 7 deletions tests/test_restore_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import logging as log
import os

Expand Down Expand Up @@ -52,7 +53,7 @@ def test_running_test_pretrained_model_ddp(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)

# run test set
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_running_test_pretrained_model(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(
logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
)

new_trainer = Trainer(**trainer_options)
Expand Down Expand Up @@ -132,9 +133,7 @@ def test_load_model_from_checkpoint(tmpdir):
assert result == 1, 'training failed to complete'

# load last checkpoint
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
if not os.path.isfile(last_checkpoint):
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)

# test that hparams loaded correctly
Expand Down Expand Up @@ -186,7 +185,7 @@ def test_running_test_pretrained_model_dp(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)

new_trainer = Trainer(**trainer_options)
Expand Down Expand Up @@ -346,7 +345,7 @@ def test_load_model_with_missing_hparams(tmpdir):

model = LightningTestModelWithoutHyperparametersArg()
trainer.fit(model)
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]

# try to load a checkpoint that has hparams but model is missing hparams arg
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):
Expand Down
Loading