Skip to content

Commit

Permalink
[feat] Support time-based checkpointing during training (#7515)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people authored May 19, 2021
1 parent 485554c commit 8266b14
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow


- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))


- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))


- Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603))


Expand Down
59 changes: 48 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import logging
import os
import re
import time
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

Expand Down Expand Up @@ -101,12 +103,17 @@ class ModelCheckpoint(Callback):
is saved (``model.save(filepath)``).
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
This must be mutually exclusive with ``every_n_val_epochs``.
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``.
train_time_interval: Checkpoints are monitored at the specified time interval.
For all practical purposes, this cannot be smaller than the amount
of time it takes to process a single training batch. This is not
guaranteed to execute at the exact time specified, but should be close.
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``.
every_n_val_epochs: Number of validation epochs between checkpoints.
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps``.
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
Expand All @@ -129,6 +136,9 @@ class ModelCheckpoint(Callback):
For example, you can change the default last checkpoint name by doing
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
then you should create multiple ``ModelCheckpoint`` callbacks.
Raises:
MisconfigurationException:
If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``,
Expand Down Expand Up @@ -190,6 +200,7 @@ def __init__(
mode: str = "min",
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_val_epochs: Optional[int] = None,
period: Optional[int] = None,
):
Expand All @@ -201,6 +212,7 @@ def __init__(
self.save_weights_only = save_weights_only
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self._last_time_checked: Optional[float] = None
self.current_score = None
self.best_k_models = {}
self.kth_best_model_path = ""
Expand All @@ -210,7 +222,7 @@ def __init__(

self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period)
self.__validate_init_configuration()
self._save_function = None

Expand All @@ -221,6 +233,9 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
self.__resolve_ckpt_dir(trainer)
self._save_function = trainer.save_checkpoint

def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
self._last_time_checked = time.monotonic()

def on_train_batch_end(
self,
trainer: 'pl.Trainer',
Expand All @@ -235,8 +250,22 @@ def on_train_batch_end(
return
step = trainer.global_step
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
if skip_batch:

train_time_interval = self._train_time_interval
skip_time = True
now = time.monotonic()
if train_time_interval:
prev_time_check = self._last_time_checked
skip_time = (prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds())
# in case we have time differences across ranks
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
skip_time = trainer.training_type_plugin.broadcast(skip_time)

if skip_batch and skip_time:
return
if not skip_time:
self._last_time_checked = now

self.save_checkpoint(trainer)

def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
Expand Down Expand Up @@ -322,12 +351,17 @@ def __validate_init_configuration(self) -> None:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
)
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:

every_n_train_steps_triggered = self._every_n_train_steps >= 1
every_n_val_epochs_triggered = self._every_n_val_epochs >= 1
train_time_interval_triggered = self._train_time_interval is not None
if (every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1):
raise MisconfigurationException(
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
' and every_n_val_epochs={self._every_n_val_epochs}.'
' Both cannot be enabled at the same time.'
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} "
"should be mutually exclusive."
)

if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in (None, -1, 0):
Expand Down Expand Up @@ -379,19 +413,22 @@ def __init_monitor_mode(self, mode: str) -> None:
self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int],
train_time_interval: Optional[timedelta], period: Optional[int]
) -> None:

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
if every_n_train_steps is None and every_n_val_epochs is None:
if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None:
self._every_n_val_epochs = 1
self._every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
else:
self._every_n_val_epochs = every_n_val_epochs or 0
self._every_n_train_steps = every_n_train_steps or 0

self._train_time_interval: Optional[timedelta] = train_time_interval

# period takes precedence over every_n_val_epochs for backwards compatibility
if period is not None:
rank_zero_deprecation(
Expand Down
53 changes: 49 additions & 4 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import os
import pickle
import re
import time
from argparse import Namespace
from datetime import timedelta
from logging import INFO
from pathlib import Path
from typing import Union
Expand Down Expand Up @@ -564,16 +566,24 @@ def test_invalid_every_n_train_steps(tmpdir):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)


def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
def test_invalid_trigger_combination(tmpdir):
"""
Test that a MisconfigurationException is raised if both
every_n_val_epochs and every_n_train_steps are enabled together.
Test that a MisconfigurationException is raised if more than one of
every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together.
"""
with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'):
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2)
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2)

# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
ModelCheckpoint(
dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1)
)


def test_none_every_n_train_steps_val_epochs(tmpdir):
Expand Down Expand Up @@ -718,6 +728,41 @@ def test_ckpt_every_n_train_steps(tmpdir):
assert set(os.listdir(tmpdir)) == set(expected)


@mock.patch("pytorch_lightning.callbacks.model_checkpoint.time")
def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None:
"""Tests that the checkpoints are saved at the specified time interval."""
seconds_per_batch = 7
start_time = time.monotonic()
batches_per_epoch = 64
num_epochs = 2
max_batches = batches_per_epoch * num_epochs + 1
mock_datetime.monotonic.side_effect = [start_time + seconds_per_batch * i for i in range(max_batches)]

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
min_epochs=num_epochs,
max_epochs=num_epochs,
progress_bar_refresh_rate=0,
callbacks=[
ModelCheckpoint(
filename="{epoch}-{step}",
dirpath=tmpdir,
train_time_interval=timedelta(minutes=1),
save_top_k=-1,
save_last=False,
)
],
logger=False,
)

trainer.fit(model)
# Each batch takes 7 sec and we checkpoint every minute. There are 64
# batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes,
# so we should have 14 checkpoints.
assert len(os.listdir(tmpdir)) == 14


def test_model_checkpoint_topk_zero(tmpdir):
""" Test that no checkpoints are saved when save_top_k=0. """
model = LogInTwoMethods()
Expand Down

0 comments on commit 8266b14

Please sign in to comment.