From 8a5f6fde348d35c13ee8b43be3b68c8c9bc3a3d8 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 10 May 2022 11:58:35 +0200 Subject: [PATCH 01/14] add BaseModelCheckpoint --- pytorch_lightning/callbacks/model_checkpoint.py | 6 +++++- .../trainer/connectors/callback_connector.py | 5 +++-- pytorch_lightning/trainer/trainer.py | 9 +++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 33468216ab85e..7a8e592dcc31c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -45,7 +45,11 @@ warning_cache = WarningCache() -class ModelCheckpoint(Callback): +class BaseModelCheckpoint(Callback): + pass + + +class ModelCheckpoint(BaseModelCheckpoint): r""" Save the model periodically by monitoring a quantity. Every metric logged with :meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 7514e5c85eef7..6daafe885cb62 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -26,6 +26,7 @@ RichProgressBar, TQDMProgressBar, ) +from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.enums import ModelSummaryMode @@ -286,8 +287,8 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are ModelCheckpoints if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)] - not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)] + checkpoints = [c for c in callbacks if isinstance(c, BaseModelCheckpoint)] + not_checkpoints = [c for c in callbacks if not isinstance(c, BaseModelCheckpoint)] return not_checkpoints + checkpoints diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f72c2a8d08df2..58dc51defe1ea 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -36,7 +36,8 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks import Callback, EarlyStopping, ProgressBarBase +from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.optimizer import LightningOptimizer @@ -2309,17 +2310,17 @@ def prediction_writer_callbacks(self) -> List[BasePredictionWriter]: return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] @property - def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + def checkpoint_callback(self) -> Optional[BaseModelCheckpoint]: """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + def checkpoint_callbacks(self) -> List[BaseModelCheckpoint]: """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" - return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + return [c for c in self.callbacks if isinstance(c, BaseModelCheckpoint)] @property def progress_bar_callback(self) -> Optional[ProgressBarBase]: From 7583e28da751bc5fbd392cd488cedaa487951a45 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 10 May 2022 12:04:04 +0200 Subject: [PATCH 02/14] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bd36750c3dd1..b4636eb191b86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938)) -- +- Added `BaseModelCheckpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024)) ### Changed From 7b1aff36b6c592a32243e1218161a3341ee21791 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 11 May 2022 09:25:24 +0200 Subject: [PATCH 03/14] doc + other places --- docs/source/common/checkpointing_expert.rst | 7 +++++++ pytorch_lightning/callbacks/model_checkpoint.py | 6 +++++- pytorch_lightning/loggers/logger.py | 6 +++--- pytorch_lightning/loggers/neptune.py | 6 +++--- pytorch_lightning/loggers/wandb.py | 6 +++--- pytorch_lightning/trainer/connectors/callback_connector.py | 4 ++-- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/docs/source/common/checkpointing_expert.rst b/docs/source/common/checkpointing_expert.rst index c1859d60ecf52..2f775ed6a6844 100644 --- a/docs/source/common/checkpointing_expert.rst +++ b/docs/source/common/checkpointing_expert.rst @@ -87,3 +87,10 @@ Custom Checkpoint IO Plugin .. note:: Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable. + + +********************************* +Writing your own Checkpoint class +********************************* + +We provide ``BaseModelCheckpoint`` class, for easier subclassing. Users may want to subclass it in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7a8e592dcc31c..0b615c6403dd2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -46,7 +46,11 @@ class BaseModelCheckpoint(Callback): - pass + r""" + This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing + custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that + the trainer recognizes the custom class as a checkpointing callback + """ class ModelCheckpoint(BaseModelCheckpoint): diff --git a/pytorch_lightning/loggers/logger.py b/pytorch_lightning/loggers/logger.py index 80c37f03e02d9..9b629a233e62c 100644 --- a/pytorch_lightning/loggers/logger.py +++ b/pytorch_lightning/loggers/logger.py @@ -25,7 +25,7 @@ import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only @@ -86,7 +86,7 @@ def __init__( else: self._agg_default_func = np.mean - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: """Called after model checkpoint callback saves a new checkpoint. Args: @@ -236,7 +236,7 @@ def __init__(self, logger_iterable: Iterable[Logger]): def __getitem__(self, index: int) -> Logger: return list(self._logger_iterable)[index] - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 7df7f5599ce31..4d503c4b33b9c 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -30,7 +30,7 @@ import torch from pytorch_lightning import __version__ -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params @@ -533,7 +533,7 @@ def log_model_summary(self, model, max_depth=-1): ) @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. Args: @@ -580,7 +580,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpo ) @staticmethod - def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str: + def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> str: """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" if not model_path.startswith(expected_model_path): diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index f3582eeeaa1e4..d7a9507e59fc7 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -23,7 +23,7 @@ import torch.nn as nn -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10 @@ -461,7 +461,7 @@ def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) @@ -474,7 +474,7 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: # get checkpoints to be saved with associated score checkpoints = { checkpoint_callback.last_model_path: checkpoint_callback.current_score, diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 6daafe885cb62..db8195db94348 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -277,14 +277,14 @@ def _attach_model_callbacks(self) -> None: @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: - """Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of + """Moves all BaseModelCheckpoint callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. Args: callbacks: A list of callbacks. Return: - A new list in which the last elements are ModelCheckpoints if there were any present in the + A new list in which the last elements are BaseModelCheckpoints if there were any present in the input. """ checkpoints = [c for c in callbacks if isinstance(c, BaseModelCheckpoint)] From 898feaf5e471a88d29a6d029129a355d0413f5b6 Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Mon, 16 May 2022 09:20:04 +0200 Subject: [PATCH 04/14] Update pytorch_lightning/callbacks/model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0b615c6403dd2..2a3b3722a9fc3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -49,7 +49,7 @@ class BaseModelCheckpoint(Callback): r""" This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that - the trainer recognizes the custom class as a checkpointing callback + the trainer recognizes the custom class as a checkpointing callback. """ From 144a4c69385c432dce898177c793a5175db031bf Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 16 May 2022 09:23:47 +0200 Subject: [PATCH 05/14] move docs --- docs/source/common/checkpointing_expert.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/common/checkpointing_expert.rst b/docs/source/common/checkpointing_expert.rst index 2f775ed6a6844..f6d0c5d82ebc9 100644 --- a/docs/source/common/checkpointing_expert.rst +++ b/docs/source/common/checkpointing_expert.rst @@ -8,6 +8,13 @@ Checkpointing (expert) TODO: I don't understand this... +********************************* +Writing your own Checkpoint class +********************************* + +We provide ``BaseModelCheckpoint`` class, for easier subclassing. Users may want to subclass it in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. + + *********************** Customize Checkpointing *********************** @@ -87,10 +94,3 @@ Custom Checkpoint IO Plugin .. note:: Some ``TrainingTypePlugins`` like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable. - - -********************************* -Writing your own Checkpoint class -********************************* - -We provide ``BaseModelCheckpoint`` class, for easier subclassing. Users may want to subclass it in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. From a7ea8dc78a1460cfe85ab995b3d75c69f33c5f31 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 25 May 2022 09:03:34 +0200 Subject: [PATCH 06/14] checkpoint --- CHANGELOG.md | 2 +- docs/source/common/checkpointing_expert.rst | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- pytorch_lightning/loggers/logger.py | 6 +++--- pytorch_lightning/loggers/neptune.py | 6 +++--- pytorch_lightning/loggers/wandb.py | 6 +++--- .../trainer/connectors/callback_connector.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 8 ++++---- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 476c0dc6b6385..4447f294dd2f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938)) -- Added `BaseModelCheckpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024)) +- Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024)) ### Changed diff --git a/docs/source/common/checkpointing_expert.rst b/docs/source/common/checkpointing_expert.rst index f6d0c5d82ebc9..3e963ee8c5700 100644 --- a/docs/source/common/checkpointing_expert.rst +++ b/docs/source/common/checkpointing_expert.rst @@ -12,7 +12,7 @@ TODO: I don't understand this... Writing your own Checkpoint class ********************************* -We provide ``BaseModelCheckpoint`` class, for easier subclassing. Users may want to subclass it in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. +We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. *********************** diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a3b3722a9fc3..ab4ed9ba10b3a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -45,7 +45,7 @@ warning_cache = WarningCache() -class BaseModelCheckpoint(Callback): +class Checkpoint(Callback): r""" This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that @@ -53,7 +53,7 @@ class BaseModelCheckpoint(Callback): """ -class ModelCheckpoint(BaseModelCheckpoint): +class ModelCheckpoint(Checkpoint): r""" Save the model periodically by monitoring a quantity. Every metric logged with :meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in diff --git a/pytorch_lightning/loggers/logger.py b/pytorch_lightning/loggers/logger.py index 9b629a233e62c..582d9bd5d56b6 100644 --- a/pytorch_lightning/loggers/logger.py +++ b/pytorch_lightning/loggers/logger.py @@ -25,7 +25,7 @@ import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import Checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only @@ -86,7 +86,7 @@ def __init__( else: self._agg_default_func = np.mean - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: """Called after model checkpoint callback saves a new checkpoint. Args: @@ -236,7 +236,7 @@ def __init__(self, logger_iterable: Iterable[Logger]): def __getitem__(self, index: int) -> Logger: return list(self._logger_iterable)[index] - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 4d503c4b33b9c..d98107798395c 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -30,7 +30,7 @@ import torch from pytorch_lightning import __version__ -from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params @@ -533,7 +533,7 @@ def log_model_summary(self, model, max_depth=-1): ) @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. Args: @@ -580,7 +580,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelChe ) @staticmethod - def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> str: + def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str: """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" if not model_path.startswith(expected_model_path): diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index d7a9507e59fc7..e0f40b979ffe9 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -23,7 +23,7 @@ import torch.nn as nn -from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10 @@ -461,7 +461,7 @@ def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) @@ -474,7 +474,7 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[BaseModelCheckpoint]") -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # get checkpoints to be saved with associated score checkpoints = { checkpoint_callback.last_model_path: checkpoint_callback.current_score, diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index cc4858863fa5d..327a8fc129939 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -26,7 +26,7 @@ RichProgressBar, TQDMProgressBar, ) -from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import Checkpoint from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.enums import ModelSummaryMode @@ -278,8 +278,8 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the last elements are BaseModelCheckpoints if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, BaseModelCheckpoint)] - not_checkpoints = [c for c in callbacks if not isinstance(c, BaseModelCheckpoint)] + checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] + not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3c965103e3ca9..3f7e99a8b0272 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -37,7 +37,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, HPUAccelerator, IPUAccelerator, TPUAccelerator from pytorch_lightning.callbacks import Callback, EarlyStopping, ProgressBarBase -from pytorch_lightning.callbacks.model_checkpoint import BaseModelCheckpoint +from pytorch_lightning.callbacks.model_checkpoint import Checkpoint from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.optimizer import LightningOptimizer @@ -2301,17 +2301,17 @@ def prediction_writer_callbacks(self) -> List[BasePredictionWriter]: return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] @property - def checkpoint_callback(self) -> Optional[BaseModelCheckpoint]: + def checkpoint_callback(self) -> Optional[Checkpoint]: """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[BaseModelCheckpoint]: + def checkpoint_callbacks(self) -> List[Checkpoint]: """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" - return [c for c in self.callbacks if isinstance(c, BaseModelCheckpoint)] + return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property def progress_bar_callback(self) -> Optional[ProgressBarBase]: From a366491daa4913533aa241d71899bc493080e5fc Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Wed, 25 May 2022 13:25:47 +0200 Subject: [PATCH 07/14] Update pytorch_lightning/trainer/connectors/callback_connector.py Co-authored-by: ananthsub --- pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index fd7c1d28e6ee2..84b96d4832cb9 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -267,7 +267,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: callbacks: A list of callbacks. Return: - A new list in which the last elements are BaseModelCheckpoints if there were any present in the + A new list in which the last elements are Checkpoint if there were any present in the input. """ checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] From 800043cb933313f41418165029f6146ec5a6ab8d Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Wed, 25 May 2022 13:26:00 +0200 Subject: [PATCH 08/14] Update pytorch_lightning/trainer/connectors/callback_connector.py Co-authored-by: ananthsub --- pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 84b96d4832cb9..006cae107990e 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -260,7 +260,7 @@ def _attach_model_callbacks(self) -> None: @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: - """Moves all BaseModelCheckpoint callbacks to the end of the list. The sequential order within the group of + """Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. Args: From 9a1b4922227a386485426d238a544245db8fad7f Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 22 Jun 2022 16:45:25 +0200 Subject: [PATCH 09/14] guarding with hasattr --- .gitignore | 2 +- .../cli/pl-app-template/core/callbacks.py | 6 +++- src/pytorch_lightning/callbacks/__init__.py | 2 ++ src/pytorch_lightning/callbacks/checkpoint.py | 9 +++++ .../callbacks/fault_tolerance.py | 4 +-- .../callbacks/model_checkpoint.py | 10 +----- src/pytorch_lightning/loggers/logger.py | 2 +- src/pytorch_lightning/loggers/neptune.py | 30 +++++++++------- src/pytorch_lightning/loggers/wandb.py | 36 ++++++++++++------- .../strategies/launchers/spawn.py | 8 +++-- .../strategies/launchers/xla_spawn.py | 6 +++- .../trainer/connectors/callback_connector.py | 2 +- src/pytorch_lightning/trainer/trainer.py | 12 ++++--- 13 files changed, 82 insertions(+), 47 deletions(-) create mode 100644 src/pytorch_lightning/callbacks/checkpoint.py diff --git a/.gitignore b/.gitignore index 1454933657f8b..53308a2f67041 100644 --- a/.gitignore +++ b/.gitignore @@ -139,7 +139,7 @@ ENV/ Datasets/ mnist/ MNIST/ -legacy/checkpoints/ +tests/legacy/checkpoints/ *.gz *ubyte diff --git a/src/lightning_app/cli/pl-app-template/core/callbacks.py b/src/lightning_app/cli/pl-app-template/core/callbacks.py index 93992c552f781..f0f053da6e10d 100644 --- a/src/lightning_app/cli/pl-app-template/core/callbacks.py +++ b/src/lightning_app/cli/pl-app-template/core/callbacks.py @@ -291,7 +291,11 @@ def setup( self._collect_logger_metadata(trainer) def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if trainer.checkpoint_callback and trainer.checkpoint_callback.dirpath is not None: + if ( + trainer.checkpoint_callback + and hasattr(trainer.checkpoint_callback, "dirpath") + and trainer.checkpoint_callback.dirpath is not None + ): self.work.checkpoint_dir = Path(trainer.checkpoint_callback.dirpath) def _collect_logger_metadata(self, trainer: "pl.Trainer") -> None: diff --git a/src/pytorch_lightning/callbacks/__init__.py b/src/pytorch_lightning/callbacks/__init__.py index 6e37b84ce204a..b3d2035f33496 100644 --- a/src/pytorch_lightning/callbacks/__init__.py +++ b/src/pytorch_lightning/callbacks/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks.checkpoint import Checkpoint from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning @@ -32,6 +33,7 @@ "BackboneFinetuning", "BaseFinetuning", "Callback", + "Checkpoint", "DeviceStatsMonitor", "EarlyStopping", "GradientAccumulationScheduler", diff --git a/src/pytorch_lightning/callbacks/checkpoint.py b/src/pytorch_lightning/callbacks/checkpoint.py new file mode 100644 index 0000000000000..91aab12290f70 --- /dev/null +++ b/src/pytorch_lightning/callbacks/checkpoint.py @@ -0,0 +1,9 @@ +from pytorch_lightning.callbacks.callback import Callback + + +class Checkpoint(Callback): + r""" + This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing + custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that + the trainer recognizes the custom class as a checkpointing callback. + """ diff --git a/src/pytorch_lightning/callbacks/fault_tolerance.py b/src/pytorch_lightning/callbacks/fault_tolerance.py index 59b8d31f46506..9d04fc86b62ce 100644 --- a/src/pytorch_lightning/callbacks/fault_tolerance.py +++ b/src/pytorch_lightning/callbacks/fault_tolerance.py @@ -21,11 +21,11 @@ from typing import Any import pytorch_lightning as pl -from pytorch_lightning import Callback +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.types import _PATH -class _FaultToleranceCheckpoint(Callback): +class _FaultToleranceCheckpoint(Checkpoint): """Used to save a fault-tolerance checkpoint on exception.""" FILE_EXTENSION = ".ckpt" diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index ceb49b058637e..bb6d0a9a9b0b6 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -34,7 +34,7 @@ from torch import Tensor import pytorch_lightning as pl -from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.logger import _name, _version @@ -46,14 +46,6 @@ warning_cache = WarningCache() -class Checkpoint(Callback): - r""" - This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing - custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that - the trainer recognizes the custom class as a checkpointing callback. - """ - - class ModelCheckpoint(Checkpoint): r""" Save the model periodically by monitoring a quantity. Every metric logged with diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 411bb28c54c88..d32a7d36c7ae1 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -25,7 +25,7 @@ import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import Checkpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only diff --git a/src/pytorch_lightning/loggers/neptune.py b/src/pytorch_lightning/loggers/neptune.py index 8a7d4f6fa0afc..44ae3f0f5bfdc 100644 --- a/src/pytorch_lightning/loggers/neptune.py +++ b/src/pytorch_lightning/loggers/neptune.py @@ -31,7 +31,7 @@ from torch import Tensor from pytorch_lightning import __version__ -from pytorch_lightning.callbacks.model_checkpoint import Checkpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params @@ -547,19 +547,20 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]" checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") # save last model - if checkpoint_callback.last_model_path: + if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) file_names.add(model_last_name) self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) # save best k models - for key in checkpoint_callback.best_k_models.keys(): - model_name = self._get_full_model_name(key, checkpoint_callback) - file_names.add(model_name) - self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) + if hasattr(checkpoint_callback, "best_k_models"): + for key in checkpoint_callback.best_k_models.keys(): + model_name = self._get_full_model_name(key, checkpoint_callback) + file_names.add(model_name) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) # log best model path and checkpoint - if checkpoint_callback.best_model_path: + if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path: self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) @@ -575,7 +576,7 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]" del self.run[f"{checkpoints_namespace}/{file_to_drop}"] # log best model score - if checkpoint_callback.best_model_score: + if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score: self.run[self._construct_path_with_prefix("model/best_model_score")] = ( checkpoint_callback.best_model_score.cpu().detach().numpy() ) @@ -583,11 +584,14 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]" @staticmethod def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str: """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" - expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" - if not model_path.startswith(expected_model_path): - raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") - # Remove extension from filepath - filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + if hasattr(checkpoint_callback, "dirpath"): + expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" + if not model_path.startswith(expected_model_path): + raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") + # Remove extension from filepath + filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + else: + filepath = model_path return filepath diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 162575a808ccd..88439cd9435db 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -23,7 +23,7 @@ import torch.nn as nn -from pytorch_lightning.callbacks.model_checkpoint import Checkpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10 @@ -463,7 +463,12 @@ def version(self) -> Optional[str]: def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + if ( + self._log_model == "all" + or self._log_model is True + and hasattr(checkpoint_callback, "save_top_k") + and checkpoint_callback.save_top_k == -1 + ): self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback @@ -476,23 +481,31 @@ def finalize(self, status: str) -> None: def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # get checkpoints to be saved with associated score - checkpoints = { - checkpoint_callback.last_model_path: checkpoint_callback.current_score, - checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, - **checkpoint_callback.best_k_models, - } - checkpoints = sorted((Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()) + checkpoints = dict() + if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"): + checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest") + + if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"): + checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best") + + if hasattr(checkpoint_callback, "best_k_models"): + for key, value in checkpoint_callback.best_k_models.items(): + checkpoints[key] = (value, "best_k") + + checkpoints = sorted( + (Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file() + ) checkpoints = [ c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0] ] # log iteratively all new checkpoints - for t, p, s in checkpoints: + for t, p, s, tag in checkpoints: metadata = ( { "score": s, "original_filename": Path(p).name, - "ModelCheckpoint": { + checkpoint_callback.__class__.__name__: { k: getattr(checkpoint_callback, k) for k in [ "monitor", @@ -511,7 +524,6 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoi ) artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name="model.ckpt") - aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] - self.experiment.log_artifact(artifact, aliases=aliases) + self.experiment.log_artifact(artifact, aliases=[tag]) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t diff --git a/src/pytorch_lightning/strategies/launchers/spawn.py b/src/pytorch_lightning/strategies/launchers/spawn.py index 6af2688e47419..d94909b778a83 100644 --- a/src/pytorch_lightning/strategies/launchers/spawn.py +++ b/src/pytorch_lightning/strategies/launchers/spawn.py @@ -109,7 +109,7 @@ def _wrapping_function( def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer - if trainer.checkpoint_callback: + if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"): trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path) # TODO: pass also best score @@ -131,7 +131,11 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the DDP spawn environment.") checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() diff --git a/src/pytorch_lightning/strategies/launchers/xla_spawn.py b/src/pytorch_lightning/strategies/launchers/xla_spawn.py index b3e1bf3465203..13c948577ca5b 100644 --- a/src/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/src/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -115,7 +115,11 @@ def _wrapping_function( def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the TPU spawn environment.") checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index 20358d19f177b..83881905beeb1 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -19,6 +19,7 @@ from pytorch_lightning.callbacks import ( Callback, + Checkpoint, GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, @@ -26,7 +27,6 @@ RichProgressBar, TQDMProgressBar, ) -from pytorch_lightning.callbacks.model_checkpoint import Checkpoint from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index fed164aa595c8..920545160e801 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1397,7 +1397,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.' ) - if not self.checkpoint_callback.best_model_path: + if hasattr(self.checkpoint_callback, "best_model_path") and not self.checkpoint_callback.best_model_path: if self.fast_dev_run: raise MisconfigurationException( f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.' @@ -1407,11 +1407,15 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights - ckpt_path = self.checkpoint_callback.best_model_path + ckpt_path = ( + self.checkpoint_callback.best_model_path + if hasattr(self.checkpoint_callback, "best_model_path") + else None + ) if ckpt_path == "last": - candidates = [ft.ckpt_path for ft in ft_checkpoints] + [ - cb.last_model_path for cb in self.checkpoint_callbacks + candidates = [ft.ckpt_path if hasattr(ft, "ckpt_path") else None for ft in ft_checkpoints] + [ + cb.last_model_path if hasattr(cb, "last_model_path") else None for cb in self.checkpoint_callbacks ] candidates_fs = {path: get_filesystem(path) for path in candidates if path} candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)} From 131992490ad9d82313df3933a4fe94f8a30eaea7 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 22 Jun 2022 17:09:32 +0200 Subject: [PATCH 10/14] remove unrelated changelog line --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6847da633be5a..74be5f19547d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,9 +69,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024)) -- Added all DDP params to be exposed through hpu parallel strategy ([#13067](https://github.com/PyTorchLightning/pytorch-lightning/pull/13067)) - - - Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795)) From c3248fe2879b819f2263a018c8b6f48c12b30efa Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Fri, 24 Jun 2022 11:31:29 +0200 Subject: [PATCH 11/14] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/callbacks/checkpoint.py | 2 +- src/pytorch_lightning/trainer/trainer.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/pytorch_lightning/callbacks/checkpoint.py b/src/pytorch_lightning/callbacks/checkpoint.py index 91aab12290f70..5dde66f1ac6f5 100644 --- a/src/pytorch_lightning/callbacks/checkpoint.py +++ b/src/pytorch_lightning/callbacks/checkpoint.py @@ -3,7 +3,7 @@ class Checkpoint(Callback): r""" - This is the base class for Model checkpointing. Expert users may want to subclass it in case of writing + This is the base class for model checkpointing. Expert users may want to subclass it in case of writing custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that the trainer recognizes the custom class as a checkpointing callback. """ diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 920545160e801..b2e4df1e976fb 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1407,15 +1407,11 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights - ckpt_path = ( - self.checkpoint_callback.best_model_path - if hasattr(self.checkpoint_callback, "best_model_path") - else None - ) + ckpt_path = getattr(self.checkpoint_callback, "best_model_path", None) if ckpt_path == "last": - candidates = [ft.ckpt_path if hasattr(ft, "ckpt_path") else None for ft in ft_checkpoints] + [ - cb.last_model_path if hasattr(cb, "last_model_path") else None for cb in self.checkpoint_callbacks + candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [ + getattr(cb, "last_model_path", None) for cb in self.checkpoint_callbacks ] candidates_fs = {path: get_filesystem(path) for path in candidates if path} candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)} From 506218061e13bca7b733d6fc62cccd282f01ec90 Mon Sep 17 00:00:00 2001 From: otaj Date: Fri, 24 Jun 2022 11:37:00 +0200 Subject: [PATCH 12/14] drop changing in app part of the repo --- src/lightning_app/cli/pl-app-template/core/callbacks.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/lightning_app/cli/pl-app-template/core/callbacks.py b/src/lightning_app/cli/pl-app-template/core/callbacks.py index f0f053da6e10d..93992c552f781 100644 --- a/src/lightning_app/cli/pl-app-template/core/callbacks.py +++ b/src/lightning_app/cli/pl-app-template/core/callbacks.py @@ -291,11 +291,7 @@ def setup( self._collect_logger_metadata(trainer) def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if ( - trainer.checkpoint_callback - and hasattr(trainer.checkpoint_callback, "dirpath") - and trainer.checkpoint_callback.dirpath is not None - ): + if trainer.checkpoint_callback and trainer.checkpoint_callback.dirpath is not None: self.work.checkpoint_dir = Path(trainer.checkpoint_callback.dirpath) def _collect_logger_metadata(self, trainer: "pl.Trainer") -> None: From c627027636b1fcaa0022c84813177a3547b3cca2 Mon Sep 17 00:00:00 2001 From: otaj Date: Fri, 24 Jun 2022 11:40:09 +0200 Subject: [PATCH 13/14] correct class in docstring --- src/pytorch_lightning/callbacks/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/checkpoint.py b/src/pytorch_lightning/callbacks/checkpoint.py index 5dde66f1ac6f5..405f29876c6fc 100644 --- a/src/pytorch_lightning/callbacks/checkpoint.py +++ b/src/pytorch_lightning/callbacks/checkpoint.py @@ -4,6 +4,6 @@ class Checkpoint(Callback): r""" This is the base class for model checkpointing. Expert users may want to subclass it in case of writing - custom :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback, so that + custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that the trainer recognizes the custom class as a checkpointing callback. """ From ff24752fb696b2525d325d48c45b88f57313b0b5 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 28 Jun 2022 14:05:38 +0200 Subject: [PATCH 14/14] move todo docs --- docs/source-pytorch/common/checkpointing_expert.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index 3e963ee8c5700..c4a948a34cb9d 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -6,8 +6,6 @@ Checkpointing (expert) ###################### -TODO: I don't understand this... - ********************************* Writing your own Checkpoint class ********************************* @@ -30,6 +28,8 @@ and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` met what's saved in the checkpoint. +TODO: I don't understand this... + ****************************** Built-in Checkpoint IO Plugins ******************************