Skip to content

Commit

Permalink
update ckpt_connector
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Feb 7, 2022
1 parent 44ec05e commit dd79460
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 12 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")
Expand Down Expand Up @@ -314,6 +314,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
optimizer_states = checkpoint["optimizer_states"]
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
optimizer_to_device(optimizer, self.root_device)

def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual training step.
Expand Down
11 changes: 0 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,6 @@ def restore_optimizers(self) -> None:

# restore the optimizers
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for param, state in optimizer.state.items():
if isinstance(state, dict):
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
elif isinstance(state, torch.Tensor):
optimizer.state[param] = state.cuda(self.trainer.root_gpu)

def restore_lr_schedulers(self) -> None:
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""
Expand Down

0 comments on commit dd79460

Please sign in to comment.