From 95a7cb42c4d7577d775a53ec8d95f9e92849c1e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Jun 2021 21:08:46 +0200 Subject: [PATCH] add missing function --- .../trainer/connectors/checkpoint_connector.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 0cbc4386aeaab1..128c5501b79da4 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -165,6 +165,15 @@ def restore_model(self) -> None: # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: + """ Restore only the model weights. """ + checkpoint = self._loaded_checkpoint + if checkpoint_path is not None: + checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + + self.trainer.lightning_module.on_load_checkpoint(checkpoint) + self.trainer.training_type_plugin.load_model_state_dict(checkpoint) + def restore_training_state(self) -> None: """ Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,