Skip to content

Commit

Permalink
add missing function
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 11, 2021
1 parent e0e98e0 commit 95a7cb4
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ def restore_model(self) -> None:
# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
""" Restore only the model weights. """
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self) -> None:
"""
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
Expand Down

0 comments on commit 95a7cb4

Please sign in to comment.