From 8266b141bae79300bc2fd1d353985c2871f43aac Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 19 May 2021 15:14:13 -0700 Subject: [PATCH] [feat] Support time-based checkpointing during training (#7515) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 4 ++ .../callbacks/model_checkpoint.py | 59 +++++++++++++++---- tests/checkpointing/test_model_checkpoint.py | 53 +++++++++++++++-- 3 files changed, 101 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 741a657740fdc..f5512ff4e5a01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow + - Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521)) +- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515)) + + - Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 905a246f3931e..34717e09a9681 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -21,7 +21,9 @@ import logging import os import re +import time from copy import deepcopy +from datetime import timedelta from pathlib import Path from typing import Any, Callable, Dict, Optional, Union @@ -101,12 +103,17 @@ class ModelCheckpoint(Callback): is saved (``model.save(filepath)``). every_n_train_steps: Number of training steps between checkpoints. If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training - To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative. - This must be mutually exclusive with ``every_n_val_epochs``. + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``. + train_time_interval: Checkpoints are monitored at the specified time interval. + For all practical purposes, this cannot be smaller than the amount + of time it takes to process a single training batch. This is not + guaranteed to execute at the exact time specified, but should be close. + This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``. every_n_val_epochs: Number of validation epochs between checkpoints. If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. - This must be mutually exclusive with ``every_n_train_steps``. + This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N @@ -129,6 +136,9 @@ class ModelCheckpoint(Callback): For example, you can change the default last checkpoint name by doing ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"`` + If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, + then you should create multiple ``ModelCheckpoint`` callbacks. + Raises: MisconfigurationException: If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``, @@ -190,6 +200,7 @@ def __init__( mode: str = "min", auto_insert_metric_name: bool = True, every_n_train_steps: Optional[int] = None, + train_time_interval: Optional[timedelta] = None, every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, ): @@ -201,6 +212,7 @@ def __init__( self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._last_global_step_saved = -1 + self._last_time_checked: Optional[float] = None self.current_score = None self.best_k_models = {} self.kth_best_model_path = "" @@ -210,7 +222,7 @@ def __init__( self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) - self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period) self.__validate_init_configuration() self._save_function = None @@ -221,6 +233,9 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn self.__resolve_ckpt_dir(trainer) self._save_function = trainer.save_checkpoint + def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + self._last_time_checked = time.monotonic() + def on_train_batch_end( self, trainer: 'pl.Trainer', @@ -235,8 +250,22 @@ def on_train_batch_end( return step = trainer.global_step skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) - if skip_batch: + + train_time_interval = self._train_time_interval + skip_time = True + now = time.monotonic() + if train_time_interval: + prev_time_check = self._last_time_checked + skip_time = (prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()) + # in case we have time differences across ranks + # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs + skip_time = trainer.training_type_plugin.broadcast(skip_time) + + if skip_batch and skip_time: return + if not skip_time: + self._last_time_checked = now + self.save_checkpoint(trainer) def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: @@ -322,12 +351,17 @@ def __validate_init_configuration(self) -> None: raise MisconfigurationException( f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' ) - if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: + + every_n_train_steps_triggered = self._every_n_train_steps >= 1 + every_n_val_epochs_triggered = self._every_n_val_epochs >= 1 + train_time_interval_triggered = self._train_time_interval is not None + if (every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1): raise MisconfigurationException( - f'Invalid values for every_n_train_steps={self._every_n_train_steps}' - ' and every_n_val_epochs={self._every_n_val_epochs}.' - ' Both cannot be enabled at the same time.' + f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " + f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} " + "should be mutually exclusive." ) + if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -379,12 +413,13 @@ def __init_monitor_mode(self, mode: str) -> None: self.kth_value, self.mode = mode_dict[mode] def __init_triggers( - self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], + train_time_interval: Optional[timedelta], period: Optional[int] ) -> None: # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_val_epochs is set - if every_n_train_steps is None and every_n_val_epochs is None: + if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None: self._every_n_val_epochs = 1 self._every_n_train_steps = 0 log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") @@ -392,6 +427,8 @@ def __init_triggers( self._every_n_val_epochs = every_n_val_epochs or 0 self._every_n_train_steps = every_n_train_steps or 0 + self._train_time_interval: Optional[timedelta] = train_time_interval + # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: rank_zero_deprecation( diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index dedea751173f0..3d7a35917e095 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -16,7 +16,9 @@ import os import pickle import re +import time from argparse import Namespace +from datetime import timedelta from logging import INFO from pathlib import Path from typing import Union @@ -564,16 +566,24 @@ def test_invalid_every_n_train_steps(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) -def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): +def test_invalid_trigger_combination(tmpdir): """ - Test that a MisconfigurationException is raised if both - every_n_val_epochs and every_n_train_steps are enabled together. + Test that a MisconfigurationException is raised if more than one of + every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together. """ - with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): + with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'): ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'): + ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2) + with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'): + ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2) + # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) + ModelCheckpoint( + dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1) + ) def test_none_every_n_train_steps_val_epochs(tmpdir): @@ -718,6 +728,41 @@ def test_ckpt_every_n_train_steps(tmpdir): assert set(os.listdir(tmpdir)) == set(expected) +@mock.patch("pytorch_lightning.callbacks.model_checkpoint.time") +def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None: + """Tests that the checkpoints are saved at the specified time interval.""" + seconds_per_batch = 7 + start_time = time.monotonic() + batches_per_epoch = 64 + num_epochs = 2 + max_batches = batches_per_epoch * num_epochs + 1 + mock_datetime.monotonic.side_effect = [start_time + seconds_per_batch * i for i in range(max_batches)] + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + min_epochs=num_epochs, + max_epochs=num_epochs, + progress_bar_refresh_rate=0, + callbacks=[ + ModelCheckpoint( + filename="{epoch}-{step}", + dirpath=tmpdir, + train_time_interval=timedelta(minutes=1), + save_top_k=-1, + save_last=False, + ) + ], + logger=False, + ) + + trainer.fit(model) + # Each batch takes 7 sec and we checkpoint every minute. There are 64 + # batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes, + # so we should have 14 checkpoints. + assert len(os.listdir(tmpdir)) == 14 + + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods()