Skip to content

Commit

Permalink
all
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jun 12, 2021
1 parent 0f17119 commit 3aef4e4
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 79 deletions.
125 changes: 82 additions & 43 deletions pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,106 @@
import logging
import os
from typing import Any, Dict

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


class RandomDataset(Dataset):
class ToyModel(nn.Module):

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
return self.data[index]
def __init__(self):
super().__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def __len__(self):
return self.len
def forward(self, x):
return self.net2(self.relu(self.net1(x)))


class BoringModel(LightningModule):
class ToyTask(pl.LightningModule):

def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.loss_fn = nn.MSELoss()

def setup(self, stage: str):
if stage == "test":
return
self.setup_model_and_optimizer()
print("setup called")

def setup_model_and_optimizer(self):
self.model = ToyModel()
self.optimizer = AdamW(
self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False
)

def forward(self, x):
return self.layer(x)
return self.model(x)

def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
targets = self.forward(batch["model_input"])
loss = self.loss_fn(targets, batch["label"])

def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
# Log loss results per train step and per epoch
self.log("loss", loss)

def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
# Tell Lightning to minimize loss
return loss

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
return self.optimizer

# def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# self.setup_model_and_optimizer()


if __name__ == "__main__":
task = ToyTask()

dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)]

train_dataloader = DataLoader(dataset, batch_size=None)
val_dataloader = DataLoader(dataset, batch_size=None)

model_checkpoint = ModelCheckpoint(
save_last=True,
every_n_val_epochs=1,
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
trainer.test(model, test_dataloaders=test_data)

trainer = pl.Trainer(
gpus=2,
precision=16,
max_epochs=3,
progress_bar_refresh_rate=100,
log_gpu_memory=None,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
)

results = trainer.fit(task, train_dataloader)

if __name__ == '__main__':
run()
print(model_checkpoint.last_model_path)

trainer = pl.Trainer(
gpus=2,
precision=16,
max_epochs=4,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
resume_from_checkpoint=model_checkpoint.last_model_path,
)
trainer.fit(task, train_dataloader)
33 changes: 33 additions & 0 deletions pl_examples/model_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pl_examples.bug_report_model import ToyTask
from pytorch_lightning.callbacks import ModelCheckpoint

if __name__ == "__main__":
task = ToyTask()

dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)]

train_dataloader = DataLoader(dataset, batch_size=None)
val_dataloader = DataLoader(dataset, batch_size=None)

model_checkpoint = ModelCheckpoint(
save_last=True,
every_n_val_epochs=1,
)

trainer = pl.Trainer(
gpus=2,
precision=16,
max_epochs=4,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
resume_from_checkpoint=
"/home/adrian/repositories/pytorch-lightning/lightning_logs/version_82/checkpoints/last.ckpt",
)
trainer.fit(task, train_dataloader)
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _load_config(self, config):
config = json.load(f)
return config

def pre_dispatch(self):
def pre_dispatch(self) -> None:
self.init_deepspeed()
self.barrier()

Expand Down
84 changes: 54 additions & 30 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import re
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand All @@ -29,7 +30,6 @@
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

Expand All @@ -53,6 +53,13 @@ def hpc_resume_path(self) -> Optional[str]:
if max_version is not None:
return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt"

def resume_from_checkpoint(self, path: Union[str, Path], **kwargs) -> None:
"""
Signals the Trainer to resume from the given path the next time Trainer.fit/validate/test/predict is called.
"""
self.resume_checkpoint_path = path
# TODO: decide what to resume

def resume_start(self) -> None:
"""
Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
Expand Down Expand Up @@ -94,49 +101,69 @@ def resume_end(self) -> None:
# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")

def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:
# TODO: decice if we should use it or not (e.g., in Trainer.fit over self._run())
@contextmanager
def restore_ctx(self):
try:
self.resume_start()
yield
finally:
self.resume_end()

def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None:
"""
Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file
Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file
through file-read and state-restore, in this priority:
1. from HPC weights if found
2. from `resume_from_checkpoint` file if provided
3. don't restore
All restored states are listed in return value description of `dump_checkpoint`.
Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
"""
self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path
self.resume_checkpoint_path = checkpoint_path
self.resume_start()
model = self.trainer.lightning_module

self.restore_model_state(model, self._loaded_checkpoint)
# restore module states
self.restore_datamodule()
self.restore_model()

if self.trainer._device_type == DeviceType.GPU:
model.cuda(self.trainer.root_gpu)
# restore callback states
self.restore_callbacks()

# restore training state
if self._loaded_checkpoint:
self.restore_training_state(self._loaded_checkpoint)

self.restore_training_state()
self.resume_end()
return True

def restore_model_state(self, model: LightningModule, checkpoint) -> None:
def restore_datamodule(self) -> None:
""" Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """
datamodule = self.trainer.datamodule
if datamodule is not None:
datamodule.on_load_checkpoint(self._loaded_checkpoint)

def restore_model(self) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give
the LightningModule a chance to modify the contents, then finally the model gets updated with
the loaded weights.
"""
if not checkpoint:
if not self._loaded_checkpoint:
return

# restore datamodule states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
model = self.trainer.lightning_module

# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)
model.on_load_checkpoint(self._loaded_checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
self.trainer.lightning_module.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
""" Restore only the model weights. """
Expand All @@ -147,19 +174,16 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) ->
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
def restore_training_state(self) -> None:
"""
Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress,
optimizer states and learning rate scheduler states.
"""
if not checkpoint:
if not self._loaded_checkpoint:
return

# restore precision plugin (scaler etc.)
self.trainer.precision_plugin.on_load_checkpoint(checkpoint)

self.restore_callbacks()

self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint)
# restore progress (loops etc.)
self.restore_progress()

Expand Down Expand Up @@ -229,10 +253,8 @@ def restore_optimizers(self) -> None:
return

# restore the optimizers
optimizer_states = self._loaded_checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
Expand All @@ -257,6 +279,7 @@ def restore_lr_schedulers(self) -> None:
def hpc_load(self, checkpoint_path: str):
"""
Attempts to restore the full training and model state from a HPC checkpoint file.
.. deprecated::
`CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6.
Use `CheckpointConnector.restore` instead.
Expand Down Expand Up @@ -364,6 +387,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers

# dump amp scaling
self.trainer.precision_plugin.on_save_checkpoint(checkpoint)

# dump hyper-parameters
Expand Down
Loading

0 comments on commit 3aef4e4

Please sign in to comment.