[checkpoint] Resolve 2 different checkpoint loading paths across fit
vs validate
/test
/predict
#9405
Labels
checkpointing
Related to checkpointing
feature
Is an improvement or enhancement
let's do it!
approved to implement
refactor
Proposed refactoring or deprecation
Consolidate checkpoint loading code across
fit
andvalidate
/test
/predict
Motivation
We have 2 different code paths for checkpoint loading
Trainer.fit
and the constructor argumentresume_from_checkpoint
ckpt_path
passed toTrainer.validate/test/predict
Offering multiple code paths here risks divergence. Lightning must ensure a consistent experience for checkpoint loading across these different entry points.
Background
These are the paths today.
Trainer.fit:
resume_from_checkpoint
: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L143resume_from_checkpoint
: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L383_run
we callTrainer._restore_modules_and_callbacks
: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L960-L962Trainer.validate/test/predict:
ckpt_path
is an argument to the functionTrainer._run
-- https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L950-L951
-- https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L1006-L1007
Notably:
restore_model
whereastrainer.validate/test/predict
callsrestore_model_weights
in the checkpoint connector. https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L116-L150Pitch
Unify in CheckpointConnector
CheckpointConnector.restore_model
is a superset ofCheckpointConnector.restore_model_weights
which suggests we don't need both.Unify in trainer.py
The overall sequence in
_run
can be shared:We have 5 different properties exposed for the checkpoint path to resume from (excluding HPC stuff):
trainer.resume_from_checkpoint
:trainer.validated_ckpt_path
: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L671-L674trainer.tested_ckpt_path
: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L760-L762trainer.predicted_ckpt_path
: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L843-L845trainer._ckpt_path
: https://github.com/PyTorchLightning/pytorch-lightning/blob/6e124e7207f6459cb43f540cfb5a1c6cc9b00f7a/pytorch_lightning/trainer/properties.py#L626-L633-- This is also inconsistent: the attributes initialized in validate/test/predict are public, while
_ckpt_path
is private. Why?It's unclear what the lifecycle of these properties should be. Do successive calls to validate/test/predict end up relying on this?
Proposal:
Trainer._run
to avoid our dependency on these propertiesProposal: deprecate
resume_from_checkpoint
from the Trainer constructor, and add a new argumentckpt_path
toTrainer.fit
. This provides API consistency withvalidate
/test
/predict
Additional context
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
The text was updated successfully, but these errors were encountered: