Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split restore_training_state into logical parts [2 / 2] #7900

Merged
merged 9 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 58 additions & 23 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:

# restore training state
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint, self._load_optimizer_states)
self.restore_training_state()

self.resume_end()
return True
Expand All @@ -135,36 +135,48 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])

def restore_training_state(self, checkpoint, load_optimizer_states: bool = True):
def restore_training_state(self) -> None:
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
optimizer states and learning rate scheduler states.
"""
if not self._loaded_checkpoint:
return

# validation
if load_optimizer_states and ('optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint):
raise KeyError(
'Trying to restore training state but checkpoint contains only the model.'
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)
# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)

self.restore_callbacks()

# restore progress (loops etc.)
self.restore_progress()

if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
self.restore_optimizers_and_schedulers()

def restore_callbacks(self) -> None:
""" Restores all callbacks from the pre-loaded checkpoint. """
if not self._loaded_checkpoint:
return

if any([key in self._loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
self.trainer.on_load_checkpoint(self._loaded_checkpoint)

self.trainer.precision_plugin.on_load_checkpoint(checkpoint)

# restore callback states
self.trainer.on_load_checkpoint(checkpoint)
def restore_progress(self) -> None:
"""
Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step
and current epoch.
"""
if not self._loaded_checkpoint:
return

self.trainer.train_loop.global_step = checkpoint['global_step']
self.trainer.train_loop.current_epoch = checkpoint['epoch']
self.trainer.train_loop.global_step = self._loaded_checkpoint['global_step']
self.trainer.train_loop.current_epoch = self._loaded_checkpoint['epoch']

# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
Expand All @@ -186,11 +198,27 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
" consider using an end of epoch checkpoint."
)

if not load_optimizer_states:
def restore_optimizers_and_schedulers(self) -> None:
""" Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return

# validation
if "optimizer_states" not in self._loaded_checkpoint or "lr_schedulers" not in self._loaded_checkpoint:
raise KeyError(
"Trying to restore training state but checkpoint contains only the model."
" This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`."
)
self.restore_optimizers()
self.restore_lr_schedulers()

def restore_optimizers(self) -> None:
""" Restores the optimizer states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return

# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
optimizer_states = self._loaded_checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

Expand All @@ -202,8 +230,13 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True)
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)

def restore_lr_schedulers(self) -> None:
""" Restores the learning rate scheduler states from the pre-loaded checkpoint. """
if not self._load_optimizer_states or not self._loaded_checkpoint:
return

# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
lr_schedulers = self._loaded_checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

Expand Down Expand Up @@ -346,7 +379,9 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint)
self._loaded_checkpoint = checkpoint
self.restore_training_state()
self._loaded_checkpoint = dict()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# call hpc specific hook
model.on_hpc_load(checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_model_checkpoint_only_weights(tmpdir):

# assert restoring train state fails
with pytest.raises(KeyError, match="checkpoint contains only the model"):
trainer.checkpoint_connector.restore_training_state(checkpoint)
trainer.checkpoint_connector.restore(new_weights_path)


def test_model_freeze_unfreeze():
Expand Down