From 3aef4e456332e7f4f0322b569f012805fb0d05ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:58:59 +0200 Subject: [PATCH] all --- pl_examples/bug_report_model.py | 125 ++++++++++++------ pl_examples/model_resume.py | 33 +++++ .../plugins/training_type/deepspeed.py | 2 +- .../connectors/checkpoint_connector.py | 84 +++++++----- pytorch_lightning/trainer/trainer.py | 18 ++- tests/callbacks/test_finetuning_callback.py | 2 +- tests/trainer/test_trainer.py | 5 +- 7 files changed, 190 insertions(+), 79 deletions(-) create mode 100644 pl_examples/model_resume.py diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index abb65ba86fd93..8e2d4009c6cea 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,67 +1,106 @@ +import logging import os +from typing import Any, Dict import torch -from torch.utils.data import DataLoader, Dataset +import torch.nn as nn +import torch.optim as optim +from torch.optim import AdamW +from torch.utils.data import DataLoader -from pytorch_lightning import LightningModule, Trainer +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -class RandomDataset(Dataset): +class ToyModel(nn.Module): - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) - - def __getitem__(self, index): - return self.data[index] + def __init__(self): + super().__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) - def __len__(self): - return self.len + def forward(self, x): + return self.net2(self.relu(self.net1(x))) -class BoringModel(LightningModule): +class ToyTask(pl.LightningModule): def __init__(self): super().__init__() - self.layer = torch.nn.Linear(32, 2) + self.loss_fn = nn.MSELoss() + + def setup(self, stage: str): + if stage == "test": + return + self.setup_model_and_optimizer() + print("setup called") + + def setup_model_and_optimizer(self): + self.model = ToyModel() + self.optimizer = AdamW( + self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False + ) def forward(self, x): - return self.layer(x) + return self.model(x) def training_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} + targets = self.forward(batch["model_input"]) + loss = self.loss_fn(targets, batch["label"]) - def validation_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("valid_loss", loss) + # Log loss results per train step and per epoch + self.log("loss", loss) - def test_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("test_loss", loss) + # Tell Lightning to minimize loss + return loss def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - - -def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, + return self.optimizer + + # def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # self.setup_model_and_optimizer() + + +if __name__ == "__main__": + task = ToyTask() + + dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)] + + train_dataloader = DataLoader(dataset, batch_size=None) + val_dataloader = DataLoader(dataset, batch_size=None) + + model_checkpoint = ModelCheckpoint( + save_last=True, + every_n_val_epochs=1, ) - trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data) - trainer.test(model, test_dataloaders=test_data) + trainer = pl.Trainer( + gpus=2, + precision=16, + max_epochs=3, + progress_bar_refresh_rate=100, + log_gpu_memory=None, + reload_dataloaders_every_epoch=True, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + callbacks=[model_checkpoint], + ) + + results = trainer.fit(task, train_dataloader) -if __name__ == '__main__': - run() + print(model_checkpoint.last_model_path) + + trainer = pl.Trainer( + gpus=2, + precision=16, + max_epochs=4, + reload_dataloaders_every_epoch=True, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + callbacks=[model_checkpoint], + resume_from_checkpoint=model_checkpoint.last_model_path, + ) + trainer.fit(task, train_dataloader) diff --git a/pl_examples/model_resume.py b/pl_examples/model_resume.py new file mode 100644 index 0000000000000..f56e9750105d8 --- /dev/null +++ b/pl_examples/model_resume.py @@ -0,0 +1,33 @@ +import torch +from torch.utils.data import DataLoader + +import pytorch_lightning as pl +from pl_examples.bug_report_model import ToyTask +from pytorch_lightning.callbacks import ModelCheckpoint + +if __name__ == "__main__": + task = ToyTask() + + dataset = [{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)] + + train_dataloader = DataLoader(dataset, batch_size=None) + val_dataloader = DataLoader(dataset, batch_size=None) + + model_checkpoint = ModelCheckpoint( + save_last=True, + every_n_val_epochs=1, + ) + + trainer = pl.Trainer( + gpus=2, + precision=16, + max_epochs=4, + reload_dataloaders_every_epoch=True, + limit_train_batches=10, + limit_val_batches=10, + limit_test_batches=10, + callbacks=[model_checkpoint], + resume_from_checkpoint= + "/home/adrian/repositories/pytorch-lightning/lightning_logs/version_82/checkpoints/last.ckpt", + ) + trainer.fit(task, train_dataloader) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index dc688de65cd34..5ccbad575ae15 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -247,7 +247,7 @@ def _load_config(self, config): config = json.load(f) return config - def pre_dispatch(self): + def pre_dispatch(self) -> None: self.init_deepspeed() self.barrier() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f203b28c09048..128c5501b79da 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -14,6 +14,7 @@ import os import re +from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, Optional, Union @@ -29,7 +30,6 @@ rank_zero_warn, ) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -53,6 +53,13 @@ def hpc_resume_path(self) -> Optional[str]: if max_version is not None: return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" + def resume_from_checkpoint(self, path: Union[str, Path], **kwargs) -> None: + """ + Signals the Trainer to resume from the given path the next time Trainer.fit/validate/test/predict is called. + """ + self.resume_checkpoint_path = path + # TODO: decide what to resume + def resume_start(self) -> None: """ Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: @@ -94,9 +101,18 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: + # TODO: decice if we should use it or not (e.g., in Trainer.fit over self._run()) + @contextmanager + def restore_ctx(self): + try: + self.resume_start() + yield + finally: + self.resume_end() + + def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: """ - Attempt to restore model/training states from a 'PyTorch-Lightning checkpoint' file + Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: 1. from HPC weights if found @@ -104,39 +120,50 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool: 3. don't restore All restored states are listed in return value description of `dump_checkpoint`. + + Args: + checkpoint_path: Path to a PyTorch Lightning checkpoint file. """ - self.resume_checkpoint_path = checkpoint_path or self.resume_checkpoint_path + self.resume_checkpoint_path = checkpoint_path self.resume_start() - model = self.trainer.lightning_module - self.restore_model_state(model, self._loaded_checkpoint) + # restore module states + self.restore_datamodule() + self.restore_model() - if self.trainer._device_type == DeviceType.GPU: - model.cuda(self.trainer.root_gpu) + # restore callback states + self.restore_callbacks() # restore training state - if self._loaded_checkpoint: - self.restore_training_state(self._loaded_checkpoint) - + self.restore_training_state() self.resume_end() - return True - def restore_model_state(self, model: LightningModule, checkpoint) -> None: + def restore_datamodule(self) -> None: + """ Calls hooks on the datamodule to give it a chance to restore its state from the checkpoint. """ + datamodule = self.trainer.datamodule + if datamodule is not None: + datamodule.on_load_checkpoint(self._loaded_checkpoint) + + def restore_model(self) -> None: """ - Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object + Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give + the LightningModule a chance to modify the contents, then finally the model gets updated with + the loaded weights. """ - if not checkpoint: + if not self._loaded_checkpoint: return - # restore datamodule states - if self.trainer.datamodule is not None: - self.trainer.datamodule.on_load_checkpoint(checkpoint) + model = self.trainer.lightning_module # hook: give user access to checkpoint if needed. - model.on_load_checkpoint(checkpoint) + model.on_load_checkpoint(self._loaded_checkpoint) + + # call hpc specific hook + if self.hpc_resume_path is not None: + self.trainer.lightning_module.on_hpc_load(self._loaded_checkpoint) # restore model state_dict - self.trainer.training_type_plugin.load_model_state_dict(checkpoint) + self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None: """ Restore only the model weights. """ @@ -147,19 +174,16 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) - def restore_training_state(self, checkpoint: Dict[str, Any]) -> None: + def restore_training_state(self) -> None: """ Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. """ - if not checkpoint: + if not self._loaded_checkpoint: return # restore precision plugin (scaler etc.) - self.trainer.precision_plugin.on_load_checkpoint(checkpoint) - - self.restore_callbacks() - + self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) # restore progress (loops etc.) self.restore_progress() @@ -229,10 +253,8 @@ def restore_optimizers(self) -> None: return # restore the optimizers - optimizer_states = self._loaded_checkpoint['optimizer_states'] - for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): - optimizer.load_state_dict(opt_state) - + self.trainer.training_type_plugin.load_optimizer_state_dict(self._loaded_checkpoint) + for optimizer in self.trainer.optimizers: # move optimizer to GPU 1 weight at a time # avoids OOM if self.trainer.root_gpu is not None: @@ -257,6 +279,7 @@ def restore_lr_schedulers(self) -> None: def hpc_load(self, checkpoint_path: str): """ Attempts to restore the full training and model state from a HPC checkpoint file. + .. deprecated:: `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. Use `CheckpointConnector.restore` instead. @@ -364,6 +387,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: lr_schedulers.append(scheduler['scheduler'].state_dict()) checkpoint['lr_schedulers'] = lr_schedulers + # dump amp scaling self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a6f93d9b4263d..85889e2ddc9ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -456,6 +456,9 @@ def fit( model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule ) + self.checkpoint_connector.resume_start() + + # with self.checkpoint_connector.restore_ctx(): self._run(model) assert self.state.stopped @@ -732,7 +735,14 @@ def _run(self, model: LightningModule) -> Optional[Union[_EVALUATE_OUTPUT, _PRED self.call_hook("on_before_accelerator_backend_setup", model) self.accelerator.connect(model) self.accelerator.setup_environment() - self._call_setup_hook(model) # allow user to setup lightning_module in accelerator environment + self._call_setup_hook(model) # allow user to setup lightning_module in accelerator + + # restore modules after setup + self.checkpoint_connector.restore_datamodule() + self.checkpoint_connector.restore_model() + # restore callback states + self.checkpoint_connector.restore_callbacks() + self._call_configure_sharded_model(model) # allow user to setup in model sharded environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module @@ -806,6 +816,9 @@ def _pre_dispatch(self): self.logger.log_graph(self.lightning_module) self.logger.save() + # restore optimizers, etc. + self.checkpoint_connector.restore_training_state() + def _post_dispatch(self): self.accelerator.post_dispatch(self) self.accelerator.teardown() @@ -849,8 +862,7 @@ def _pre_training_routine(self): if self.is_global_zero and self.weights_summary is not None and not self.testing: ref_model.summarize(mode=self.weights_summary) - # restore training and model before hpc is called - self.checkpoint_connector.restore() + self.checkpoint_connector.resume_end() # on pretrain routine end self.on_pretrain_routine_end() diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 53d34c4645bef..fe8915e6e8443 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -275,7 +275,7 @@ def configure_optimizers(self): model = FreezeModel() cb = OnEpochLayerFinetuning() trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb]) - with pytest.raises(IndexError, match="index 6 is out of range"): + with pytest.raises(ValueError, match="loaded state dict has a different number of parameter groups"): trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d353c0941d3f6..c04191a57bfa8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -428,7 +428,10 @@ def test_model_checkpoint_only_weights(tmpdir): # assert restoring train state fails with pytest.raises(KeyError, match="checkpoint contains only the model"): - trainer.checkpoint_connector.restore(new_weights_path) + trainer.checkpoint_connector.resume_from_checkpoint(new_weights_path) + trainer.checkpoint_connector.resume_start() + trainer.checkpoint_connector.restore_training_state() + trainer.checkpoint_connector.resume_end() def test_model_freeze_unfreeze():