From 72d0433387572c1d91081f30869ef9a8d51b4dd7 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 14 Oct 2021 15:07:44 +1300 Subject: [PATCH 01/59] Save StochasticWeightAveraging callback data in checkpoints --- .../callbacks/stochastic_weight_avg.py | 59 +++++++++++++- tests/callbacks/test_stochastic_weight_avg.py | 76 ++++++++++++++++++- 2 files changed, 129 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index bde9c1b5c2407..01f2563f4c9be 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,7 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import nn @@ -115,6 +115,8 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") + self.momenta = None + self.n_averaged = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs @@ -123,6 +125,8 @@ def __init__( self._device = device self._model_contains_batch_norm = None self._average_model = None + self._initialized = False + self._swa_scheduler = None @property def swa_start(self) -> int: @@ -162,7 +166,10 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): trainer.fit_loop.max_epochs += 1 def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - if trainer.current_epoch == self.swa_start: + resuming_after_start = trainer.current_epoch > self.swa_start and not self._initialized + if trainer.current_epoch == self.swa_start or resuming_after_start: + self._initialized = True + # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) @@ -198,7 +205,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo else: trainer.lr_schedulers.append(default_scheduler_cfg) - self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) + if self.n_averaged is None: + self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) if self.swa_start <= trainer.current_epoch <= self.swa_end: self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) @@ -280,3 +288,48 @@ def avg_fn( ) -> torch.FloatTensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) + + def on_save_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + ) -> dict: + checkpoint_data = { + "momenta": self.momenta, + "n_averaged": self.n_averaged, + "swa_lrs": self._swa_lrs, + "annealing_epochs": self._annealing_epochs, + "annealing_strategy": self._annealing_strategy, + "average_model_parameters": self._get_average_model_parameters(), + } + return checkpoint_data + + def on_load_checkpoint( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] + ) -> None: + if callback_state: + self.momenta = callback_state["momenta"] + self.n_averaged = callback_state["n_averaged"] + self._swa_lrs = callback_state["swa_lrs"] + self._annealing_strategy = callback_state["annealing_strategy"] + self._annealing_epochs = callback_state["annealing_epochs"] + self._load_average_model_parameters(callback_state["average_model_parameters"]) + else: + rank_zero_warn( + f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state." + ) + + def _get_average_model_parameters(self) -> Any: + if self._average_model is None: + return None + parameters = [] + for p_swa in self._average_model.parameters(): + parameters.append(p_swa.detach()) + return parameters + + def _load_average_model_parameters(self, parameter_state: Any): + if self._average_model is None: + return + for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state): + device = p_swa.device + p_swa_ = p_swa.detach() + p_checkpoint_ = p_checkpoint.detach().to(device) + p_swa_.copy_(p_checkpoint_) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 82c4b257b9ef4..c861332e08907 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os +from pathlib import Path from unittest import mock import pytest @@ -31,7 +33,9 @@ class SwaTestModel(BoringModel): - def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False): + def __init__( + self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_after_epoch=None + ): super().__init__() layers = [nn.Linear(32, 32)] if batchnorm: @@ -40,6 +44,8 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat self.layer = nn.Sequential(*layers) self.interval = interval self.iterable_dataset = iterable_dataset + self.crash_after_epoch = crash_after_epoch + self._epoch_count = 0 def training_step(self, batch, batch_idx): output = self.forward(batch) @@ -63,8 +69,23 @@ def configure_optimizers(self): }, } + def training_epoch_end(self, _): + if not self.crash_after_epoch: + return + self._epoch_count += 1 + if self._epoch_count >= self.crash_after_epoch: + raise RuntimeError("Crash test") + class SwaTestCallback(StochasticWeightAveraging): + def __init__(self, *args, **kwargs): + if "resuming_from_epoch" in kwargs: + self.resuming_from_epoch = kwargs["resuming_from_epoch"] + del kwargs["resuming_from_epoch"] + else: + self.resuming_from_epoch = 0 + super().__init__(*args, **kwargs) + update_parameters_calls: int = 0 transfer_weights_calls: int = 0 @@ -102,10 +123,16 @@ def on_train_end(self, trainer, pl_module): if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin): # check backward call count. the batchnorm update epoch should not backward - assert trainer.accelerator.backward.call_count == trainer.max_epochs * trainer.limit_train_batches + assert trainer.accelerator.backward.call_count == ( + (trainer.max_epochs - self.resuming_from_epoch) * trainer.limit_train_batches + ) # check call counts - assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1) + if self.resuming_from_epoch >= self._swa_epoch_start: + expected_update_calls = trainer.max_epochs - self.resuming_from_epoch + else: + expected_update_calls = trainer.max_epochs - (self._swa_epoch_start - 1) + assert self.update_parameters_calls == expected_update_calls assert self.transfer_weights_calls == 1 @@ -273,3 +300,46 @@ def on_train_epoch_start(self): ) trainer.fit(model) assert model.on_train_epoch_start_called + + +def test_swa_resume_training_from_checkpoint(tmpdir): + model = SwaTestModel(crash_after_epoch=3) + swa_start = 2 + max_epochs = 5 + swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) + + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=max_epochs, + limit_train_batches=5, + limit_val_batches=0, + callbacks=[swa_callback], + accumulate_grad_batches=2, + num_processes=1, + ) + + with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward), pytest.raises(RuntimeError): + trainer.fit(model) + + checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" + checkpoint_files = os.listdir(checkpoint_dir) + assert len(checkpoint_files) == 1 + checkpoint_path = checkpoint_dir / checkpoint_files[0] + + model = SwaTestModel() + swa_callback = SwaTestCallback(resuming_from_epoch=2, swa_epoch_start=swa_start, swa_lrs=0.1) + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=max_epochs, + limit_train_batches=5, + limit_val_batches=0, + callbacks=[swa_callback], + accumulate_grad_batches=2, + num_processes=1, + resume_from_checkpoint=checkpoint_path, + ) + + with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): + trainer.fit(model) From 3d2bf65feb2d586e203184a899d259ad5cd8d4c5 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Fri, 15 Oct 2021 10:50:52 +1300 Subject: [PATCH 02/59] Add option to use SWA parameters during validation --- .../callbacks/stochastic_weight_avg.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 01f2563f4c9be..0b75cffa34bc5 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -40,6 +40,7 @@ def __init__( annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, device: Optional[Union[torch.device, str]] = torch.device("cpu"), + swa_validation: bool = False, ): r""" @@ -93,6 +94,9 @@ def __init__( When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) + swa_validation: if True, then the averaged model weights are used during validation + (default: ``False``) + """ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." @@ -122,9 +126,11 @@ def __init__( self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn + self._swa_validation = swa_validation self._device = device self._model_contains_batch_norm = None self._average_model = None + self._temp_model = None self._initialized = False self._swa_scheduler = None @@ -144,6 +150,9 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: # copy the model before moving it to accelerator device. with pl_module._prevent_trainer_and_dataloaders_deepcopy(): self._average_model = deepcopy(pl_module) + if self._swa_validation: + # Also create a model for temporarily copying weights to during validation + self._temp_model = deepcopy(pl_module) def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers @@ -172,6 +181,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) + if self._temp_model: + self._temp_model = self._temp_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] if self._swa_lrs is None: @@ -243,6 +254,18 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): # Last SWA epoch. Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end): + # Take a temporary copy of the model parameters + self.transfer_weights(pl_module, self._temp_model) + # Update the model with the averaged parameters + self.transfer_weights(self._average_model, pl_module) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end): + # Copy original model parameters back + self.transfer_weights(self._temp_model, pl_module) + @staticmethod def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): From 16962737bb2c782acb7fdbf305d0572ffd40e040 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Fri, 15 Oct 2021 11:22:22 +1300 Subject: [PATCH 03/59] Allow restoring SWA parameters to a model from a checkpoint --- .../callbacks/stochastic_weight_avg.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 0b75cffa34bc5..2d8684f0dd831 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,7 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, IO, List, Optional, Type, Union import torch from torch import nn @@ -26,6 +26,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -340,6 +341,52 @@ def on_load_checkpoint( f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state." ) + @classmethod + def restore_average_parameters_from_checkpoint( + cls, + pl_module: "pl.LightningModule", + checkpoint_path: Union[str, IO], + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + ) -> bool: + r""" + Set model weights to the SWA averaged weights saved in a checkpoint. + + Arguments: + pl_module: The module to set weights on + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object + map_location: + If your checkpoint saved a GPU model and you now load on CPUs + or a different number of GPUs, use this to map to the new setup. + The behaviour is the same as in :func:`torch.load`. + + Return: + A `bool` indicating whether averaged weights were loaded. If `False`, this means the checkpoint is + from an epoch before the SWA epoch start. + """ + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + if not callback_states: + raise ValueError("callback states are not present in the checkpoint") + + state_key = cls.__qualname__ # Default state key defined in Callback base class + state = callback_states.get(state_key) + if not state: + raise ValueError(f"no {state_key} state found in the checkpoint") + state = deepcopy(state) + average_model_parameters = state["average_model_parameters"] + + if not average_model_parameters: + return False + + for p_model, p_swa in zip(pl_module.parameters(), average_model_parameters): + device = p_model.device + p_swa_ = p_swa.detach().to(device) + p_model.detach().copy_(p_swa_) + return True + def _get_average_model_parameters(self) -> Any: if self._average_model is None: return None From c8db9d8d1f405e9a4b9d20d0c634a3d33c4a2c6c Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 18 Oct 2021 15:08:45 +1300 Subject: [PATCH 04/59] Refactor SWA batch norm moment update to work with validation --- .../callbacks/stochastic_weight_avg.py | 79 +++++++++---------- tests/callbacks/test_stochastic_weight_avg.py | 43 +++++++--- 2 files changed, 67 insertions(+), 55 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 2d8684f0dd831..c03cec13b7abe 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -120,7 +120,6 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") - self.momenta = None self.n_averaged = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs @@ -134,6 +133,7 @@ def __init__( self._temp_model = None self._initialized = False self._swa_scheduler = None + self._batch_norm_moments = None @property def swa_start(self) -> int: @@ -171,12 +171,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) self._max_epochs = trainer.max_epochs - if self._model_contains_batch_norm: - # virtually increase max_epochs to perform batch norm update on latest epoch. - trainer.fit_loop.max_epochs += 1 def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - resuming_after_start = trainer.current_epoch > self.swa_start and not self._initialized + resuming_after_start = (not self._initialized) and (self.swa_start < trainer.current_epoch <= self.swa_end) if trainer.current_epoch == self.swa_start or resuming_after_start: self._initialized = True @@ -223,37 +220,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.swa_start <= trainer.current_epoch <= self.swa_end: self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) - # Note: No > here in case the callback is saved with the model and training continues - if trainer.current_epoch == self.swa_end + 1: - - # Transfer weights from average model to pl_module - self.transfer_weights(self._average_model, pl_module) - - # Reset BatchNorm for update - self.reset_batch_norm_and_save_state(pl_module) - - # There is no need to perform either backward or optimizer.step as we are - # performing only one pass over the train data-loader to compute activation statistics - # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. - trainer.num_training_batches += 1 - trainer.fit_loop._skip_backward = True - self._accumulate_grad_batches = trainer.accumulate_grad_batches - - trainer.accumulate_grad_batches = trainer.num_training_batches - - def on_train_epoch_end(self, trainer: "pl.Trainer", *args): - trainer.fit_loop._skip_backward = False - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: - # BatchNorm epoch update. Reset state - trainer.accumulate_grad_batches = self._accumulate_grad_batches - trainer.num_training_batches -= 1 - trainer.fit_loop.max_epochs -= 1 - self.reset_momenta() - elif trainer.current_epoch == self.swa_end: + if trainer.current_epoch == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) + if self._model_contains_batch_norm: + self._update_batch_norm_moments(trainer, pl_module, store_moments=False) def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end): @@ -261,37 +233,60 @@ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod self.transfer_weights(pl_module, self._temp_model) # Update the model with the averaged parameters self.transfer_weights(self._average_model, pl_module) + if self._model_contains_batch_norm: + self._update_batch_norm_moments(trainer, pl_module) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end): # Copy original model parameters back self.transfer_weights(self._temp_model, pl_module) + if self._model_contains_batch_norm: + self._restore_batch_norm_moments() @staticmethod def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) - def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" - self.momenta = {} + def _update_batch_norm_moments( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", store_moments: bool = True + ): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166.""" + prev_momenta = {} + self._batch_norm_moments = {} + + was_training = pl_module.training + pl_module.train() + for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue + prev_momenta[module] = module.momentum + if store_moments: + self._batch_norm_moments[module] = (module.running_mean, module.running_var) module.running_mean = torch.zeros_like( module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype ) module.running_var = torch.ones_like( module.running_var, device=pl_module.device, dtype=module.running_var.dtype ) - self.momenta[module] = module.momentum module.momentum = None module.num_batches_tracked *= 0 - def reset_momenta(self): - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" - for bn_module in self.momenta: - bn_module.momentum = self.momenta[bn_module] + # Recompute mean and variance for all batch norm layers by doing a full pass over the training data + for batch, _ in trainer.data_connector.train_data_fetcher: + batch = batch.to(pl_module.device) + pl_module(batch) + + # Reset model state + for bn_module, momenta in prev_momenta.items(): + bn_module.momentum = momenta + pl_module.train(was_training) + + def _restore_batch_norm_moments(self): + for bn_module, (mean, variance) in self._batch_norm_moments.items(): + bn_module.running_mean = mean + bn_module.running_var = variance @staticmethod def update_parameters( @@ -317,7 +312,6 @@ def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: checkpoint_data = { - "momenta": self.momenta, "n_averaged": self.n_averaged, "swa_lrs": self._swa_lrs, "annealing_epochs": self._annealing_epochs, @@ -330,7 +324,6 @@ def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] ) -> None: if callback_state: - self.momenta = callback_state["momenta"] self.n_averaged = callback_state["n_averaged"] self._swa_lrs = callback_state["swa_lrs"] self._annealing_strategy = callback_state["annealing_strategy"] diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index c861332e08907..34764fe33a068 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -52,6 +52,11 @@ def training_step(self, batch, batch_idx): loss = self.loss(batch, output) return {"loss": loss} + def validation_step(self, batch, batch_idx): + output = self.forward(batch) + loss = self.loss(batch, output) + return {"x": loss} + def train_dataloader(self): dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset @@ -59,6 +64,9 @@ def train_dataloader(self): return DataLoader(dset, batch_size=2) + def val_dataloader(self): + return self.train_dataloader() + def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return { @@ -86,6 +94,7 @@ def __init__(self, *args, **kwargs): self.resuming_from_epoch = 0 super().__init__(*args, **kwargs) + validation_calls: int = 0 update_parameters_calls: int = 0 transfer_weights_calls: int = 0 @@ -93,13 +102,16 @@ def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 return StochasticWeightAveraging.update_parameters(*args, **kwargs) + def on_validation_start(self, *args, **kwargs): + self.validation_calls += 1 + return super().on_validation_start(*args, **kwargs) + def transfer_weights(self, *args, **kwargs): self.transfer_weights_calls += 1 return StochasticWeightAveraging.transfer_weights(*args, **kwargs) def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) - assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) assert trainer.lr_schedulers[0]["interval"] == "epoch" @@ -116,11 +128,6 @@ def on_train_epoch_end(self, trainer, *args): def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) - # make sure these are correctly set again - assert not trainer.fit_loop._skip_backward - assert trainer.accumulate_grad_batches == 2 - assert trainer.num_training_batches == 5 - if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin): # check backward call count. the batchnorm update epoch should not backward assert trainer.accelerator.backward.call_count == ( @@ -133,16 +140,27 @@ def on_train_end(self, trainer, pl_module): else: expected_update_calls = trainer.max_epochs - (self._swa_epoch_start - 1) assert self.update_parameters_calls == expected_update_calls - assert self.transfer_weights_calls == 1 + if self._swa_validation: + # 3 weight transfers are needed per SWA validation step + assert self.transfer_weights_calls == (self.validation_calls - self._swa_epoch_start) * 3 + 1 + else: + assert self.transfer_weights_calls == 1 def train_with_swa( - tmpdir, batchnorm=True, strategy=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False + tmpdir, + batchnorm=True, + strategy=None, + gpus=None, + num_processes=1, + interval="epoch", + iterable_dataset=False, + validation=False, ): model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset) swa_start = 2 max_epochs = 5 - swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) + swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=validation) assert swa_callback.update_parameters_calls == 0 assert swa_callback.transfer_weights_calls == 0 @@ -151,7 +169,7 @@ def train_with_swa( enable_progress_bar=False, max_epochs=max_epochs, limit_train_batches=5, - limit_val_batches=0, + limit_val_batches=1.0 if validation else 0.0, callbacks=[swa_callback], accumulate_grad_batches=2, strategy=strategy, @@ -188,8 +206,9 @@ def test_swa_callback_1_gpu(tmpdir): @pytest.mark.parametrize("batchnorm", (True, False)) @pytest.mark.parametrize("iterable_dataset", (True, False)) -def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool): - train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset) +@pytest.mark.parametrize("validation", (True, False)) +def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool, validation: bool): + train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset, validation=validation) @pytest.mark.parametrize("interval", ("epoch", "step")) From 004959b4fd06cf5f4bea53714898d9303293fe07 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 19 Oct 2021 14:04:45 +1300 Subject: [PATCH 05/59] Add test for loading a model from a checkpoint with SWA parameters --- .../callbacks/stochastic_weight_avg.py | 15 ++++-- tests/callbacks/test_stochastic_weight_avg.py | 46 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index c03cec13b7abe..9202e028f843c 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -255,6 +255,11 @@ def _update_batch_norm_moments( prev_momenta = {} self._batch_norm_moments = {} + train_data_fetcher = trainer.data_connector.train_data_fetcher + if train_data_fetcher is None: + # Training data not yet connected, could be in a validation sanity check + return + was_training = pl_module.training pl_module.train() @@ -274,7 +279,7 @@ def _update_batch_norm_moments( module.num_batches_tracked *= 0 # Recompute mean and variance for all batch norm layers by doing a full pass over the training data - for batch, _ in trainer.data_connector.train_data_fetcher: + for batch, _ in train_data_fetcher: batch = batch.to(pl_module.device) pl_module(batch) @@ -316,7 +321,7 @@ def on_save_checkpoint( "swa_lrs": self._swa_lrs, "annealing_epochs": self._annealing_epochs, "annealing_strategy": self._annealing_strategy, - "average_model_parameters": self._get_average_model_parameters(), + "average_model_parameters": self._get_average_model_parameters(trainer), } return checkpoint_data @@ -380,8 +385,10 @@ def restore_average_parameters_from_checkpoint( p_model.detach().copy_(p_swa_) return True - def _get_average_model_parameters(self) -> Any: - if self._average_model is None: + def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any: + if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): + # If we're not within the SWA epochs then when loading checkpoint data we would want + # to use parameters from the underlying model rather than the SWA parameters. return None parameters = [] for p_swa in self._average_model.parameters(): diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 34764fe33a068..cfe3a15ebd1d3 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -24,7 +24,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.callbacks import StochasticWeightAveraging +from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -46,6 +46,7 @@ def __init__( self.iterable_dataset = iterable_dataset self.crash_after_epoch = crash_after_epoch self._epoch_count = 0 + self.save_hyperparameters() def training_step(self, batch, batch_idx): output = self.forward(batch) @@ -55,6 +56,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): output = self.forward(batch) loss = self.loss(batch, output) + self.log("val_loss", loss) return {"x": loss} def train_dataloader(self): @@ -142,7 +144,7 @@ def on_train_end(self, trainer, pl_module): assert self.update_parameters_calls == expected_update_calls if self._swa_validation: # 3 weight transfers are needed per SWA validation step - assert self.transfer_weights_calls == (self.validation_calls - self._swa_epoch_start) * 3 + 1 + assert self.transfer_weights_calls == (self.validation_calls - self.swa_start) * 3 + 1 else: assert self.transfer_weights_calls == 1 @@ -169,7 +171,8 @@ def train_with_swa( enable_progress_bar=False, max_epochs=max_epochs, limit_train_batches=5, - limit_val_batches=1.0 if validation else 0.0, + limit_val_batches=5 if validation else 0, + num_sanity_val_steps=0, callbacks=[swa_callback], accumulate_grad_batches=2, strategy=strategy, @@ -362,3 +365,40 @@ def test_swa_resume_training_from_checkpoint(tmpdir): with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): trainer.fit(model) + + +@pytest.mark.parametrize("batchnorm", (True, False)) +@pytest.mark.parametrize("within_swa_epochs", (True, False)) +def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bool): + model = SwaTestModel(batchnorm=batchnorm) + if within_swa_epochs: + # Start at epoch 1 so we can guarantee the best checkpoint should be saved with SWA weights + swa_start = 1 + else: + # Start after the last epoch, so we never save a checkpoint with SWA parameters + swa_start = 6 + max_epochs = 5 + + swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=True) + checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=3, mode="min") + + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=max_epochs, + limit_train_batches=5, + limit_val_batches=5, + num_sanity_val_steps=0, + callbacks=[swa_callback, checkpoint_callback], + accumulate_grad_batches=2, + num_processes=1, + ) + + with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): + trainer.fit(model) + + checkpoint_path = checkpoint_callback.best_model_path + new_model = SwaTestModel.load_from_checkpoint(checkpoint_path) + parameters_loaded = SwaTestCallback.restore_average_parameters_from_checkpoint(new_model, checkpoint_path) + + assert parameters_loaded == within_swa_epochs From d76528b4365ab04b8755f62517aabde713be4e0f Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 19 Oct 2021 16:38:33 +1300 Subject: [PATCH 06/59] Recompute batch norm moments when updating parameters from a checkpoint --- .../callbacks/stochastic_weight_avg.py | 47 +++++++++++++++---- tests/callbacks/test_stochastic_weight_avg.py | 18 ++++++- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 9202e028f843c..8feb301720c9a 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -21,10 +21,13 @@ import torch from torch import nn from torch.optim.swa_utils import SWALR +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -251,15 +254,26 @@ def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.Lig def _update_batch_norm_moments( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", store_moments: bool = True ): - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166.""" - prev_momenta = {} self._batch_norm_moments = {} - train_data_fetcher = trainer.data_connector.train_data_fetcher - if train_data_fetcher is None: + train_dataloader = trainer.train_dataloader + if train_dataloader is None: # Training data not yet connected, could be in a validation sanity check return + self._update_module_batch_norm_moments( + train_dataloader, pl_module, self._batch_norm_moments if store_moments else None + ) + + @staticmethod + def _update_module_batch_norm_moments( + data_loader: Union[DataLoader, CombinedLoader], + pl_module: "pl.LightningModule", + moment_cache: Optional[Dict[nn.Module, Any]] = None, + ): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166.""" + prev_momenta = {} + was_training = pl_module.training pl_module.train() @@ -267,8 +281,8 @@ def _update_batch_norm_moments( if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue prev_momenta[module] = module.momentum - if store_moments: - self._batch_norm_moments[module] = (module.running_mean, module.running_var) + if moment_cache is not None: + moment_cache[module] = (module.running_mean, module.running_var) module.running_mean = torch.zeros_like( module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype ) @@ -279,7 +293,7 @@ def _update_batch_norm_moments( module.num_batches_tracked *= 0 # Recompute mean and variance for all batch norm layers by doing a full pass over the training data - for batch, _ in train_data_fetcher: + for batch in data_loader: batch = batch.to(pl_module.device) pl_module(batch) @@ -345,18 +359,24 @@ def restore_average_parameters_from_checkpoint( pl_module: "pl.LightningModule", checkpoint_path: Union[str, IO], map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + datamodule: Optional[LightningDataModule] = None, ) -> bool: r""" Set model weights to the SWA averaged weights saved in a checkpoint. Arguments: pl_module: The module to set weights on + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object - map_location: - If your checkpoint saved a GPU model and you now load on CPUs + + map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:`torch.load`. + datamodule: If the module uses batch normalization and does not implement the train_dataloder method, + a data module must be provided in order to allow recomputing the batch normalization parameters after + loading the SWA weights. + Return: A `bool` indicating whether averaged weights were loaded. If `False`, this means the checkpoint is from an epoch before the SWA epoch start. @@ -383,6 +403,15 @@ def restore_average_parameters_from_checkpoint( device = p_model.device p_swa_ = p_swa.detach().to(device) p_model.detach().copy_(p_swa_) + + if cls.pl_module_contains_batch_norm(pl_module): + if datamodule is not None: + train_dataloaders = datamodule.train_dataloader() + else: + train_dataloaders = pl_module.train_dataloader() + train_dataloaders = CombinedLoader(train_dataloaders, mode="max_size_cycle") + cls._update_module_batch_norm_moments(train_dataloaders, pl_module) + return True def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index cfe3a15ebd1d3..13123172615f2 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -25,6 +25,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -369,7 +370,8 @@ def test_swa_resume_training_from_checkpoint(tmpdir): @pytest.mark.parametrize("batchnorm", (True, False)) @pytest.mark.parametrize("within_swa_epochs", (True, False)) -def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bool): +@pytest.mark.parametrize("use_datamodule", (True, False)) +def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bool, use_datamodule: bool): model = SwaTestModel(batchnorm=batchnorm) if within_swa_epochs: # Start at epoch 1 so we can guarantee the best checkpoint should be saved with SWA weights @@ -397,8 +399,20 @@ def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bo with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): trainer.fit(model) + if use_datamodule: + + class TestDataModule(LightningDataModule): + def train_dataloader(self): + return model.train_dataloader() + + datamodule = TestDataModule() + else: + datamodule = None + checkpoint_path = checkpoint_callback.best_model_path new_model = SwaTestModel.load_from_checkpoint(checkpoint_path) - parameters_loaded = SwaTestCallback.restore_average_parameters_from_checkpoint(new_model, checkpoint_path) + parameters_loaded = SwaTestCallback.restore_average_parameters_from_checkpoint( + new_model, checkpoint_path, datamodule=datamodule + ) assert parameters_loaded == within_swa_epochs From 0ea22e0529ed64fbb7e7ca3444dd4c5d031be07e Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 21 Oct 2021 09:21:35 +1300 Subject: [PATCH 07/59] Handle when data batch is a list or tuple --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 8feb301720c9a..871cff28c40c5 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -294,6 +294,8 @@ def _update_module_batch_norm_moments( # Recompute mean and variance for all batch norm layers by doing a full pass over the training data for batch in data_loader: + if isinstance(batch, (list, tuple)): + batch = batch[0] batch = batch.to(pl_module.device) pl_module(batch) From 01ca2a7e39137d1563d4921d71bcbf95beaf2bc3 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 28 Oct 2021 09:12:25 +1300 Subject: [PATCH 08/59] Save SWA scheduler step count in checkpoints --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 6 ++++++ tests/callbacks/test_stochastic_weight_avg.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 871cff28c40c5..3e9ce8a8f5037 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -137,6 +137,7 @@ def __init__( self._initialized = False self._swa_scheduler = None self._batch_norm_moments = None + self._scheduler_step_count = None @property def swa_start(self) -> int: @@ -201,6 +202,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ) + if self._scheduler_step_count is not None: + # Restore scheduler step count from checkpoint + self._swa_scheduler._step_count = self._scheduler_step_count default_scheduler_cfg = _get_default_scheduler_config() assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 default_scheduler_cfg["scheduler"] = self._swa_scheduler @@ -337,6 +341,7 @@ def on_save_checkpoint( "swa_lrs": self._swa_lrs, "annealing_epochs": self._annealing_epochs, "annealing_strategy": self._annealing_strategy, + "scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count, "average_model_parameters": self._get_average_model_parameters(trainer), } return checkpoint_data @@ -349,6 +354,7 @@ def on_load_checkpoint( self._swa_lrs = callback_state["swa_lrs"] self._annealing_strategy = callback_state["annealing_strategy"] self._annealing_epochs = callback_state["annealing_epochs"] + self._scheduler_step_count = callback_state["scheduler_step_count"] self._load_average_model_parameters(callback_state["average_model_parameters"]) else: rank_zero_warn( diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 13123172615f2..1b87a9f106250 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -125,6 +125,9 @@ def on_train_epoch_end(self, trainer, *args): if self.swa_start <= trainer.current_epoch <= self.swa_end: swa_epoch = trainer.current_epoch - self.swa_start assert self.n_averaged == swa_epoch + 1 + assert self._swa_scheduler is not None + # Scheduler is stepped once on initialization and then at the end of each epoch + assert self._swa_scheduler._step_count == swa_epoch + 2 elif trainer.current_epoch > self.swa_end: assert self.n_averaged == self._max_epochs - self.swa_start From 08d655b66b5d13f1a32dbdc565023a14184cc10a Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 28 Oct 2021 10:44:41 +1300 Subject: [PATCH 09/59] Update SWA documentation and changelog --- CHANGELOG.md | 2 ++ docs/source/advanced/training_tricks.rst | 22 +++++++++++++++++++ .../callbacks/stochastic_weight_avg.py | 11 +++++++--- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37945300774d5..a0731eb2ffda6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,6 +193,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) +- Added support for using the Stochastic Weight Averaging (SWA) weights during validation and resuming from a checkpoint when using SWA ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938)) + ### Changed diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 28f81d98dcbd3..3a0eb834bf5d5 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -70,6 +70,28 @@ read `this post ` documentation + ---------- Auto scaling of batch size diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 3e9ce8a8f5037..aecd8a1a9cc56 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -372,6 +372,11 @@ def restore_average_parameters_from_checkpoint( r""" Set model weights to the SWA averaged weights saved in a checkpoint. + When loading a model that was trained using SWA from a checkpoint, + the loaded weights will not be the SWA averaged weights, so this method is required if you + wish to use SWA in conjunction with the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` + callback to select the best performing model during validation for example. + Arguments: pl_module: The module to set weights on @@ -381,13 +386,13 @@ def restore_average_parameters_from_checkpoint( or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:`torch.load`. - datamodule: If the module uses batch normalization and does not implement the train_dataloder method, + datamodule: If the module uses batch normalization and does not implement the ``train_dataloder`` method, a data module must be provided in order to allow recomputing the batch normalization parameters after loading the SWA weights. Return: - A `bool` indicating whether averaged weights were loaded. If `False`, this means the checkpoint is - from an epoch before the SWA epoch start. + Whether averaged weights were loaded. If ``False``, this means the checkpoint is + from an epoch before the SWA epoch start. """ if map_location is not None: checkpoint = pl_load(checkpoint_path, map_location=map_location) From 91ab35746875fa4a2c13de0d6a7057dcdd26e063 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 28 Oct 2021 11:42:05 +1300 Subject: [PATCH 10/59] Fix DeepSource code style issues --- tests/callbacks/test_stochastic_weight_avg.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 1b87a9f106250..5126d97528662 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -153,6 +153,26 @@ def on_train_end(self, trainer, pl_module): assert self.transfer_weights_calls == 1 +class SwaTestDataModule(LightningDataModule): + """Shim data module that just wraps a model.""" + + def __init__(self, model: LightningModule): + super().__init__() + self._model = model + + def train_dataloader(self): + return self._model.train_dataloader() + + def test_dataloader(self): + return self._model.test_dataloader() + + def predict_dataloader(self): + return self._model.predict_dataloader() + + def val_dataloader(self): + return self._model.val_dataloader() + + def train_with_swa( tmpdir, batchnorm=True, @@ -297,9 +317,10 @@ def test_swa_multiple_lrs(tmpdir): class TestModel(BoringModel): def __init__(self): - super(BoringModel, self).__init__() + super().__init__() self.layer1 = torch.nn.Linear(32, 32) self.layer2 = torch.nn.Linear(32, 2) + self.on_train_epoch_start_called = False def forward(self, x): x = self.layer1(x) @@ -402,15 +423,7 @@ def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bo with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): trainer.fit(model) - if use_datamodule: - - class TestDataModule(LightningDataModule): - def train_dataloader(self): - return model.train_dataloader() - - datamodule = TestDataModule() - else: - datamodule = None + datamodule = SwaTestDataModule(model) if use_datamodule else None checkpoint_path = checkpoint_callback.best_model_path new_model = SwaTestModel.load_from_checkpoint(checkpoint_path) From 22e5d515b6a7280d29f9b105d9a6d15e8c029096 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 10 Nov 2021 12:04:05 +1300 Subject: [PATCH 11/59] Revert SWA validation changes --- CHANGELOG.md | 2 +- docs/source/advanced/training_tricks.rst | 22 --- .../callbacks/stochastic_weight_avg.py | 187 ++++-------------- tests/callbacks/test_stochastic_weight_avg.py | 110 ++--------- 4 files changed, 54 insertions(+), 267 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0731eb2ffda6..9f532068f6eed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,7 +193,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) -- Added support for using the Stochastic Weight Averaging (SWA) weights during validation and resuming from a checkpoint when using SWA ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938)) +- Added support for resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938)) ### Changed diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 3a0eb834bf5d5..28f81d98dcbd3 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -70,28 +70,6 @@ read `this post ` documentation - ---------- Auto scaling of batch size diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index aecd8a1a9cc56..33b4b1211aa9e 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,20 +16,16 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Any, Callable, Dict, IO, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import nn from torch.optim.swa_utils import SWALR -from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config -from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -44,7 +40,6 @@ def __init__( annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, device: Optional[Union[torch.device, str]] = torch.device("cpu"), - swa_validation: bool = False, ): r""" @@ -98,9 +93,6 @@ def __init__( When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) - swa_validation: if True, then the averaged model weights are used during validation - (default: ``False``) - """ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." @@ -129,15 +121,13 @@ def __init__( self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn - self._swa_validation = swa_validation self._device = device self._model_contains_batch_norm = None self._average_model = None - self._temp_model = None self._initialized = False self._swa_scheduler = None - self._batch_norm_moments = None self._scheduler_step_count = None + self.momenta = None @property def swa_start(self) -> int: @@ -155,9 +145,6 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: # copy the model before moving it to accelerator device. with pl_module._prevent_trainer_and_dataloaders_deepcopy(): self._average_model = deepcopy(pl_module) - if self._swa_validation: - # Also create a model for temporarily copying weights to during validation - self._temp_model = deepcopy(pl_module) def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers @@ -175,6 +162,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) self._max_epochs = trainer.max_epochs + if self._model_contains_batch_norm: + # virtually increase max_epochs to perform batch norm update on latest epoch. + trainer.fit_loop.max_epochs += 1 def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): resuming_after_start = (not self._initialized) and (self.swa_start < trainer.current_epoch <= self.swa_end) @@ -183,8 +173,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) - if self._temp_model: - self._temp_model = self._temp_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] if self._swa_lrs is None: @@ -227,91 +215,62 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.swa_start <= trainer.current_epoch <= self.swa_end: self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - if trainer.current_epoch == self.swa_end: - # Last SWA epoch. Transfer weights from average model to pl_module + # Note: No > here in case the callback is saved with the model and training continues + if trainer.current_epoch == self.swa_end + 1: + # Transfer weights from average model to pl_module self.transfer_weights(self._average_model, pl_module) - if self._model_contains_batch_norm: - self._update_batch_norm_moments(trainer, pl_module, store_moments=False) - - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end): - # Take a temporary copy of the model parameters - self.transfer_weights(pl_module, self._temp_model) - # Update the model with the averaged parameters - self.transfer_weights(self._average_model, pl_module) - if self._model_contains_batch_norm: - self._update_batch_norm_moments(trainer, pl_module) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end): - # Copy original model parameters back - self.transfer_weights(self._temp_model, pl_module) - if self._model_contains_batch_norm: - self._restore_batch_norm_moments() + # Reset BatchNorm for update + self.reset_batch_norm_and_save_state(pl_module) - @staticmethod - def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): - for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): - dst_param.detach().copy_(src_param.to(dst_param.device)) + # There is no need to perform either backward or optimizer.step as we are + # performing only one pass over the train data-loader to compute activation statistics + # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward. + trainer.num_training_batches += 1 + trainer.fit_loop._skip_backward = True + self._accumulate_grad_batches = trainer.accumulate_grad_batches - def _update_batch_norm_moments( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", store_moments: bool = True - ): - self._batch_norm_moments = {} + trainer.accumulate_grad_batches = trainer.num_training_batches - train_dataloader = trainer.train_dataloader - if train_dataloader is None: - # Training data not yet connected, could be in a validation sanity check - return + def on_train_epoch_end(self, trainer: "pl.Trainer", *args): + trainer.fit_loop._skip_backward = False - self._update_module_batch_norm_moments( - train_dataloader, pl_module, self._batch_norm_moments if store_moments else None - ) + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1: + # BatchNorm epoch update. Reset state + trainer.accumulate_grad_batches = self._accumulate_grad_batches + trainer.num_training_batches -= 1 + trainer.fit_loop.max_epochs -= 1 + self.reset_momenta() + elif trainer.current_epoch == self.swa_end: + # Last SWA epoch. Transfer weights from average model to pl_module + self.transfer_weights(self._average_model, pl_module) @staticmethod - def _update_module_batch_norm_moments( - data_loader: Union[DataLoader, CombinedLoader], - pl_module: "pl.LightningModule", - moment_cache: Optional[Dict[nn.Module, Any]] = None, - ): - """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166.""" - prev_momenta = {} - - was_training = pl_module.training - pl_module.train() + def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): + for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): + dst_param.detach().copy_(src_param.to(dst_param.device)) + def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" + self.momenta = {} for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue - prev_momenta[module] = module.momentum - if moment_cache is not None: - moment_cache[module] = (module.running_mean, module.running_var) module.running_mean = torch.zeros_like( module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype ) module.running_var = torch.ones_like( module.running_var, device=pl_module.device, dtype=module.running_var.dtype ) + self.momenta[module] = module.momentum module.momentum = None module.num_batches_tracked *= 0 - # Recompute mean and variance for all batch norm layers by doing a full pass over the training data - for batch in data_loader: - if isinstance(batch, (list, tuple)): - batch = batch[0] - batch = batch.to(pl_module.device) - pl_module(batch) - - # Reset model state - for bn_module, momenta in prev_momenta.items(): - bn_module.momentum = momenta - pl_module.train(was_training) - - def _restore_batch_norm_moments(self): - for bn_module, (mean, variance) in self._batch_norm_moments.items(): - bn_module.running_mean = mean - bn_module.running_var = variance + def reset_momenta(self): + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" + for bn_module in self.momenta: + bn_module.momentum = self.momenta[bn_module] @staticmethod def update_parameters( @@ -361,72 +320,6 @@ def on_load_checkpoint( f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state." ) - @classmethod - def restore_average_parameters_from_checkpoint( - cls, - pl_module: "pl.LightningModule", - checkpoint_path: Union[str, IO], - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, - datamodule: Optional[LightningDataModule] = None, - ) -> bool: - r""" - Set model weights to the SWA averaged weights saved in a checkpoint. - - When loading a model that was trained using SWA from a checkpoint, - the loaded weights will not be the SWA averaged weights, so this method is required if you - wish to use SWA in conjunction with the :class:`~pytorch_lightning.callbacks.ModelCheckpoint` - callback to select the best performing model during validation for example. - - Arguments: - pl_module: The module to set weights on - - checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object - - map_location: If your checkpoint saved a GPU model and you now load on CPUs - or a different number of GPUs, use this to map to the new setup. - The behaviour is the same as in :func:`torch.load`. - - datamodule: If the module uses batch normalization and does not implement the ``train_dataloder`` method, - a data module must be provided in order to allow recomputing the batch normalization parameters after - loading the SWA weights. - - Return: - Whether averaged weights were loaded. If ``False``, this means the checkpoint is - from an epoch before the SWA epoch start. - """ - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") - if not callback_states: - raise ValueError("callback states are not present in the checkpoint") - - state_key = cls.__qualname__ # Default state key defined in Callback base class - state = callback_states.get(state_key) - if not state: - raise ValueError(f"no {state_key} state found in the checkpoint") - state = deepcopy(state) - average_model_parameters = state["average_model_parameters"] - - if not average_model_parameters: - return False - - for p_model, p_swa in zip(pl_module.parameters(), average_model_parameters): - device = p_model.device - p_swa_ = p_swa.detach().to(device) - p_model.detach().copy_(p_swa_) - - if cls.pl_module_contains_batch_norm(pl_module): - if datamodule is not None: - train_dataloaders = datamodule.train_dataloader() - else: - train_dataloaders = pl_module.train_dataloader() - train_dataloaders = CombinedLoader(train_dataloaders, mode="max_size_cycle") - cls._update_module_batch_norm_moments(train_dataloaders, pl_module) - - return True - def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any: if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): # If we're not within the SWA epochs then when loading checkpoint data we would want diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 5126d97528662..af78e4ab0388e 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -24,8 +24,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging -from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.plugins import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -54,12 +53,6 @@ def training_step(self, batch, batch_idx): loss = self.loss(batch, output) return {"loss": loss} - def validation_step(self, batch, batch_idx): - output = self.forward(batch) - loss = self.loss(batch, output) - self.log("val_loss", loss) - return {"x": loss} - def train_dataloader(self): dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset @@ -67,9 +60,6 @@ def train_dataloader(self): return DataLoader(dset, batch_size=2) - def val_dataloader(self): - return self.train_dataloader() - def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return { @@ -97,7 +87,6 @@ def __init__(self, *args, **kwargs): self.resuming_from_epoch = 0 super().__init__(*args, **kwargs) - validation_calls: int = 0 update_parameters_calls: int = 0 transfer_weights_calls: int = 0 @@ -105,16 +94,13 @@ def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 return StochasticWeightAveraging.update_parameters(*args, **kwargs) - def on_validation_start(self, *args, **kwargs): - self.validation_calls += 1 - return super().on_validation_start(*args, **kwargs) - def transfer_weights(self, *args, **kwargs): self.transfer_weights_calls += 1 return StochasticWeightAveraging.transfer_weights(*args, **kwargs) def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) + assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) assert trainer.lr_schedulers[0]["interval"] == "epoch" @@ -134,6 +120,11 @@ def on_train_epoch_end(self, trainer, *args): def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) + # make sure these are correctly set again + assert not trainer.fit_loop._skip_backward + assert trainer.accumulate_grad_batches == 2 + assert trainer.num_training_batches == 5 + if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin): # check backward call count. the batchnorm update epoch should not backward assert trainer.accelerator.backward.call_count == ( @@ -146,47 +137,16 @@ def on_train_end(self, trainer, pl_module): else: expected_update_calls = trainer.max_epochs - (self._swa_epoch_start - 1) assert self.update_parameters_calls == expected_update_calls - if self._swa_validation: - # 3 weight transfers are needed per SWA validation step - assert self.transfer_weights_calls == (self.validation_calls - self.swa_start) * 3 + 1 - else: - assert self.transfer_weights_calls == 1 - - -class SwaTestDataModule(LightningDataModule): - """Shim data module that just wraps a model.""" - - def __init__(self, model: LightningModule): - super().__init__() - self._model = model - - def train_dataloader(self): - return self._model.train_dataloader() - - def test_dataloader(self): - return self._model.test_dataloader() - - def predict_dataloader(self): - return self._model.predict_dataloader() - - def val_dataloader(self): - return self._model.val_dataloader() + assert self.transfer_weights_calls == 1 def train_with_swa( - tmpdir, - batchnorm=True, - strategy=None, - gpus=None, - num_processes=1, - interval="epoch", - iterable_dataset=False, - validation=False, + tmpdir, batchnorm=True, strategy=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False ): model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset) swa_start = 2 max_epochs = 5 - swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=validation) + swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) assert swa_callback.update_parameters_calls == 0 assert swa_callback.transfer_weights_calls == 0 @@ -195,8 +155,7 @@ def train_with_swa( enable_progress_bar=False, max_epochs=max_epochs, limit_train_batches=5, - limit_val_batches=5 if validation else 0, - num_sanity_val_steps=0, + limit_val_batches=0, callbacks=[swa_callback], accumulate_grad_batches=2, strategy=strategy, @@ -233,9 +192,8 @@ def test_swa_callback_1_gpu(tmpdir): @pytest.mark.parametrize("batchnorm", (True, False)) @pytest.mark.parametrize("iterable_dataset", (True, False)) -@pytest.mark.parametrize("validation", (True, False)) -def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool, validation: bool): - train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset, validation=validation) +def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool): + train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset) @pytest.mark.parametrize("interval", ("epoch", "step")) @@ -390,45 +348,3 @@ def test_swa_resume_training_from_checkpoint(tmpdir): with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): trainer.fit(model) - - -@pytest.mark.parametrize("batchnorm", (True, False)) -@pytest.mark.parametrize("within_swa_epochs", (True, False)) -@pytest.mark.parametrize("use_datamodule", (True, False)) -def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bool, use_datamodule: bool): - model = SwaTestModel(batchnorm=batchnorm) - if within_swa_epochs: - # Start at epoch 1 so we can guarantee the best checkpoint should be saved with SWA weights - swa_start = 1 - else: - # Start after the last epoch, so we never save a checkpoint with SWA parameters - swa_start = 6 - max_epochs = 5 - - swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=True) - checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=3, mode="min") - - trainer = Trainer( - default_root_dir=tmpdir, - enable_progress_bar=False, - max_epochs=max_epochs, - limit_train_batches=5, - limit_val_batches=5, - num_sanity_val_steps=0, - callbacks=[swa_callback, checkpoint_callback], - accumulate_grad_batches=2, - num_processes=1, - ) - - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): - trainer.fit(model) - - datamodule = SwaTestDataModule(model) if use_datamodule else None - - checkpoint_path = checkpoint_callback.best_model_path - new_model = SwaTestModel.load_from_checkpoint(checkpoint_path) - parameters_loaded = SwaTestCallback.restore_average_parameters_from_checkpoint( - new_model, checkpoint_path, datamodule=datamodule - ) - - assert parameters_loaded == within_swa_epochs From 11963f6999a10f2b45110105b3338230dbb702d3 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 10 Nov 2021 12:26:32 +1300 Subject: [PATCH 12/59] Fix resuming from epoch before SWA start and add extra test --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- tests/callbacks/test_stochastic_weight_avg.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 33b4b1211aa9e..fd6141d527754 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -331,7 +331,7 @@ def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any: return parameters def _load_average_model_parameters(self, parameter_state: Any): - if self._average_model is None: + if self._average_model is None or parameter_state is None: return for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state): device = p_swa.device diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index ad0b16072df5f..adb258e3642be 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -305,9 +305,10 @@ def on_train_epoch_start(self): assert model.on_train_epoch_start_called -def test_swa_resume_training_from_checkpoint(tmpdir): - model = SwaTestModel(crash_after_epoch=3) - swa_start = 2 +@pytest.mark.parametrize("crash_after_epoch", [2, 4]) +def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): + model = SwaTestModel(crash_after_epoch=crash_after_epoch) + swa_start = 3 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) @@ -331,7 +332,8 @@ def test_swa_resume_training_from_checkpoint(tmpdir): checkpoint_path = checkpoint_dir / checkpoint_files[0] model = SwaTestModel() - swa_callback = SwaTestCallback(resuming_from_epoch=2, swa_epoch_start=swa_start, swa_lrs=0.1) + restart_epoch = crash_after_epoch - 1 + swa_callback = SwaTestCallback(resuming_from_epoch=restart_epoch, swa_epoch_start=swa_start, swa_lrs=0.1) trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, From 226d8aa2b841b559bdc4fe2139110a9362df0087 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 10 Nov 2021 12:55:09 +1300 Subject: [PATCH 13/59] Don't save state derived from constructor parameters into checkpoints Users may want to restart with different parameters, and this is more consistent with how other callbacks behave. --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index fd6141d527754..ec40840ecd4a2 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -297,9 +297,6 @@ def on_save_checkpoint( ) -> dict: checkpoint_data = { "n_averaged": self.n_averaged, - "swa_lrs": self._swa_lrs, - "annealing_epochs": self._annealing_epochs, - "annealing_strategy": self._annealing_strategy, "scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count, "average_model_parameters": self._get_average_model_parameters(trainer), } @@ -310,9 +307,6 @@ def on_load_checkpoint( ) -> None: if callback_state: self.n_averaged = callback_state["n_averaged"] - self._swa_lrs = callback_state["swa_lrs"] - self._annealing_strategy = callback_state["annealing_strategy"] - self._annealing_epochs = callback_state["annealing_epochs"] self._scheduler_step_count = callback_state["scheduler_step_count"] self._load_average_model_parameters(callback_state["average_model_parameters"]) else: From 5d03d962937816b205e524de44eb7e00d9750db9 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 16 Nov 2021 10:09:11 +1300 Subject: [PATCH 14/59] Tidy ups from code review * Remove unnecessary detaches * Logic tidy ups * More type annotations --- .../callbacks/stochastic_weight_avg.py | 25 ++++++++----------- tests/callbacks/test_stochastic_weight_avg.py | 25 +++++++------------ 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index ec40840ecd4a2..a27b8a9af6d01 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -115,19 +115,19 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") - self.n_averaged = None + self.n_averaged: Optional[torch.Tensor] = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._model_contains_batch_norm = None - self._average_model = None + self._model_contains_batch_norm: Optional[bool] = None + self._average_model: Optional[pl.LightningModule] = None self._initialized = False - self._swa_scheduler = None - self._scheduler_step_count = None - self.momenta = None + self._swa_scheduler: Optional[SWALR] = None + self._scheduler_step_count: Optional[int] = None + self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @property def swa_start(self) -> int: @@ -167,8 +167,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): trainer.fit_loop.max_epochs += 1 def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): - resuming_after_start = (not self._initialized) and (self.swa_start < trainer.current_epoch <= self.swa_end) - if trainer.current_epoch == self.swa_start or resuming_after_start: + if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): self._initialized = True # move average model to request device. @@ -314,15 +313,12 @@ def on_load_checkpoint( f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state." ) - def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any: + def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[nn.Parameter]]: if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): # If we're not within the SWA epochs then when loading checkpoint data we would want # to use parameters from the underlying model rather than the SWA parameters. return None - parameters = [] - for p_swa in self._average_model.parameters(): - parameters.append(p_swa.detach()) - return parameters + return list(self._average_model.parameters()) def _load_average_model_parameters(self, parameter_state: Any): if self._average_model is None or parameter_state is None: @@ -330,5 +326,4 @@ def _load_average_model_parameters(self, parameter_state: Any): for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state): device = p_swa.device p_swa_ = p_swa.detach() - p_checkpoint_ = p_checkpoint.detach().to(device) - p_swa_.copy_(p_checkpoint_) + p_swa_.copy_(p_checkpoint.to(device)) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index adb258e3642be..3f041aafded31 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -14,6 +14,7 @@ import logging import os from pathlib import Path +from typing import Optional from unittest import mock import pytest @@ -78,16 +79,10 @@ def training_epoch_end(self, _): class SwaTestCallback(StochasticWeightAveraging): - def __init__(self, *args, **kwargs): - if "resuming_from_epoch" in kwargs: - self.resuming_from_epoch = kwargs["resuming_from_epoch"] - del kwargs["resuming_from_epoch"] - else: - self.resuming_from_epoch = 0 - super().__init__(*args, **kwargs) - update_parameters_calls: int = 0 transfer_weights_calls: int = 0 + # Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0 + first_epoch: Optional[int] = None def update_parameters(self, *args, **kwargs): self.update_parameters_calls += 1 @@ -99,6 +94,8 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) + if self.first_epoch is None: + self.first_epoch = trainer.current_epoch assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) @@ -127,15 +124,12 @@ def on_train_end(self, trainer, pl_module): if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin): # check backward call count. the batchnorm update epoch should not backward assert trainer.accelerator.backward.call_count == ( - (trainer.max_epochs - self.resuming_from_epoch) * trainer.limit_train_batches + (trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches ) # check call counts - if self.resuming_from_epoch >= self._swa_epoch_start: - expected_update_calls = trainer.max_epochs - self.resuming_from_epoch - else: - expected_update_calls = trainer.max_epochs - (self._swa_epoch_start - 1) - assert self.update_parameters_calls == expected_update_calls + first_swa_epoch = max(self.first_epoch, self.swa_start) + assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch assert self.transfer_weights_calls == 1 @@ -332,8 +326,7 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): checkpoint_path = checkpoint_dir / checkpoint_files[0] model = SwaTestModel() - restart_epoch = crash_after_epoch - 1 - swa_callback = SwaTestCallback(resuming_from_epoch=restart_epoch, swa_epoch_start=swa_start, swa_lrs=0.1) + swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, From 02a04da3cfa1009a681022938b8de2c970d3c6ae Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 16 Nov 2021 11:55:53 +1300 Subject: [PATCH 15/59] Fix handling of n_averaged checkpoint data with multiple processes --- .../callbacks/stochastic_weight_avg.py | 7 +++--- tests/callbacks/test_stochastic_weight_avg.py | 25 +++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index a27b8a9af6d01..69211b60585c4 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -127,6 +127,7 @@ def __init__( self._initialized = False self._swa_scheduler: Optional[SWALR] = None self._scheduler_step_count: Optional[int] = None + self._init_n_averaged = 0 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @property @@ -209,7 +210,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.lr_schedulers.append(default_scheduler_cfg) if self.n_averaged is None: - self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) + self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) if self.swa_start <= trainer.current_epoch <= self.swa_end: self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn) @@ -295,7 +296,7 @@ def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: checkpoint_data = { - "n_averaged": self.n_averaged, + "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count, "average_model_parameters": self._get_average_model_parameters(trainer), } @@ -305,7 +306,7 @@ def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] ) -> None: if callback_state: - self.n_averaged = callback_state["n_averaged"] + self._init_n_averaged = callback_state["n_averaged"] self._scheduler_step_count = callback_state["scheduler_step_count"] self._load_average_model_parameters(callback_state["average_model_parameters"]) else: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 3f041aafded31..bb49c26e2af06 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -299,13 +299,15 @@ def on_train_epoch_start(self): assert model.on_train_epoch_start_called -@pytest.mark.parametrize("crash_after_epoch", [2, 4]) -def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): +def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): model = SwaTestModel(crash_after_epoch=crash_after_epoch) swa_start = 3 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) + num_processes = 2 if ddp else 1 + strategy = "ddp_spawn" if ddp else None + trainer = Trainer( default_root_dir=tmpdir, enable_progress_bar=False, @@ -314,10 +316,12 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): limit_val_batches=0, callbacks=[swa_callback], accumulate_grad_batches=2, - num_processes=1, + num_processes=num_processes, + strategy=strategy, ) - with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward), pytest.raises(RuntimeError): + exception_type = torch.multiprocessing.ProcessRaisedException if ddp else RuntimeError + with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward), pytest.raises(exception_type): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" @@ -335,9 +339,20 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): limit_val_batches=0, callbacks=[swa_callback], accumulate_grad_batches=2, - num_processes=1, + num_processes=num_processes, + strategy=strategy, resume_from_checkpoint=checkpoint_path, ) with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): trainer.fit(model) + + +@pytest.mark.parametrize("crash_after_epoch", [2, 4]) +def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): + swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=crash_after_epoch) + + +@RunIf(skip_windows=True) +def test_swa_resume_training_from_checkpoint_ddp(tmpdir): + swa_resume_training_from_checkpoint(tmpdir, ddp=True) From 5763e05ed991fae55f7c6ff48eda0254cb5a88f4 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 09:16:32 +1300 Subject: [PATCH 16/59] Fix deprecation warning in test --- tests/callbacks/test_stochastic_weight_avg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index aa1c9915bbffb..c0d7f7f830aee 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -346,11 +346,10 @@ def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): accumulate_grad_batches=2, num_processes=num_processes, strategy=strategy, - resume_from_checkpoint=checkpoint_path, ) with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward): - trainer.fit(model) + trainer.fit(model, ckpt_path=checkpoint_path.as_posix()) @pytest.mark.parametrize("crash_after_epoch", [2, 4]) From d46be834ddb2eef6db6b36452dec6d094db862a4 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 09:22:46 +1300 Subject: [PATCH 17/59] Remove check for non-empty callback state in checkpoint This is already handled by the TrainerCallbackHookMixin --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 69211b60585c4..65b7dc05601a0 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -305,14 +305,9 @@ def on_save_checkpoint( def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] ) -> None: - if callback_state: - self._init_n_averaged = callback_state["n_averaged"] - self._scheduler_step_count = callback_state["scheduler_step_count"] - self._load_average_model_parameters(callback_state["average_model_parameters"]) - else: - rank_zero_warn( - f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state." - ) + self._init_n_averaged = callback_state["n_averaged"] + self._scheduler_step_count = callback_state["scheduler_step_count"] + self._load_average_model_parameters(callback_state["average_model_parameters"]) def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[nn.Parameter]]: if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): From e0fd0cb859ac0cfddc04afba4a46db40754dfce4 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 10:38:22 +1300 Subject: [PATCH 18/59] Raise MisconfigurationException when using SWA with sharded models --- .../callbacks/stochastic_weight_avg.py | 5 +++++ tests/callbacks/test_stochastic_weight_avg.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 65b7dc05601a0..505fa80fd6ea8 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,6 +24,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin +from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin, FullyShardedDataParallel from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -151,6 +153,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers + if isinstance(trainer.training_type_plugin, (DDPFullyShardedPlugin, FullyShardedDataParallel, DeepSpeedPlugin)): + raise MisconfigurationException("SWA does not currently support sharded models.") + if len(optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index c0d7f7f830aee..5707ab09b18fb 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -360,3 +360,19 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): @RunIf(skip_windows=True) def test_swa_resume_training_from_checkpoint_ddp(tmpdir): swa_resume_training_from_checkpoint(tmpdir, ddp=True) + + +@RunIf(min_gpus=1) +def test_misconfiguration_error_with_sharded_model(tmpdir): + model = SwaTestModel() + swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) + trainer = Trainer( + default_root_dir=tmpdir, + enable_progress_bar=False, + max_epochs=5, + callbacks=[swa_callback], + strategy="ddp_fully_sharded", + gpus=1, + ) + with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"): + trainer.fit(model) From 2a83f050f555a19e6939b6f4e31e779ef52088e0 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 11:07:58 +1300 Subject: [PATCH 19/59] Fix test failure with torch 1.7 With torch 1.7 subprocess errors are raised as plain Exceptions --- tests/callbacks/test_stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 5707ab09b18fb..785ee9f02bf2d 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -325,7 +325,7 @@ def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): strategy=strategy, ) - exception_type = torch.multiprocessing.ProcessRaisedException if ddp else RuntimeError + exception_type = Exception if ddp else RuntimeError with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward), pytest.raises(exception_type): trainer.fit(model) From 4a8d81c222980fe0d08566dc25af7b876d0e546d Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 11:16:26 +1300 Subject: [PATCH 20/59] Fix crash when fairscale isn't installed --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 505fa80fd6ea8..c39348cc0f3e8 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -25,11 +25,16 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin -from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin, FullyShardedDataParallel +from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +try: + from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedDataParallel +except ImportError: + FullyShardedDataParallel = None + _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -153,7 +158,10 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers - if isinstance(trainer.training_type_plugin, (DDPFullyShardedPlugin, FullyShardedDataParallel, DeepSpeedPlugin)): + sharded_plugins = [DDPFullyShardedPlugin, DeepSpeedPlugin] + if FullyShardedDataParallel: + sharded_plugins.append(FullyShardedDataParallel) + if isinstance(trainer.training_type_plugin, tuple(sharded_plugins)): raise MisconfigurationException("SWA does not currently support sharded models.") if len(optimizers) != 1: From dab0ef4a52a4db884b1f2fea1d681c1d6c53c842 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 13:16:18 +1300 Subject: [PATCH 21/59] Skip segfaulting test under pytorch < 1.8 --- tests/callbacks/test_stochastic_weight_avg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 785ee9f02bf2d..758026ca1c95e 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -357,8 +357,10 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=crash_after_epoch) -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, min_torch="1.8") def test_swa_resume_training_from_checkpoint_ddp(tmpdir): + # Requires PyTorch >= 1.8 to include this segfault fix: + # https://github.com/pytorch/pytorch/pull/50998 swa_resume_training_from_checkpoint(tmpdir, ddp=True) From a0d52c86ce3ef23c7d42f27e5c51e00542240189 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 30 Nov 2021 14:35:56 +1300 Subject: [PATCH 22/59] Changelog merge fix --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19f177c140112..61c07e0de9382 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,9 +193,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938)) -- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470)) - - - Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) From cdf47341e9252ec0dca0a1904dfcbf2fce5cc563 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 1 Dec 2021 08:21:41 +1300 Subject: [PATCH 23/59] Remove unnecessary intermediate variable --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index c39348cc0f3e8..1556aab4c83b9 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -308,12 +308,11 @@ def avg_fn( def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: - checkpoint_data = { + return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count, "average_model_parameters": self._get_average_model_parameters(trainer), } - return checkpoint_data def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] From ba5b8ab866b98389ad91e1afd1d9db018aefb3fe Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 1 Dec 2021 08:35:03 +1300 Subject: [PATCH 24/59] Fix checking for sharded plugins --- .../callbacks/stochastic_weight_avg.py | 21 +++++++-------- tests/callbacks/test_stochastic_weight_avg.py | 27 ++++++++++++++++--- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 1556aab4c83b9..cd8ae88e5c646 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,17 +24,16 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin -from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin +from pytorch_lightning.plugins.training_type import ( + DDPFullyShardedPlugin, + DDPShardedPlugin, + DDPSpawnShardedPlugin, + DeepSpeedPlugin, +) from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: - from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedDataParallel -except ImportError: - FullyShardedDataParallel = None - _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] @@ -158,10 +157,10 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers - sharded_plugins = [DDPFullyShardedPlugin, DeepSpeedPlugin] - if FullyShardedDataParallel: - sharded_plugins.append(FullyShardedDataParallel) - if isinstance(trainer.training_type_plugin, tuple(sharded_plugins)): + if isinstance( + trainer.training_type_plugin, + (DDPFullyShardedPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin, DeepSpeedPlugin), + ): raise MisconfigurationException("SWA does not currently support sharded models.") if len(optimizers) != 1: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 758026ca1c95e..355fdd8de434f 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -364,8 +364,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir): swa_resume_training_from_checkpoint(tmpdir, ddp=True) -@RunIf(min_gpus=1) -def test_misconfiguration_error_with_sharded_model(tmpdir): +def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None): model = SwaTestModel() swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) trainer = Trainer( @@ -373,8 +372,28 @@ def test_misconfiguration_error_with_sharded_model(tmpdir): enable_progress_bar=False, max_epochs=5, callbacks=[swa_callback], - strategy="ddp_fully_sharded", - gpus=1, + strategy=strategy, + gpus=gpus, ) with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"): trainer.fit(model) + + +@RunIf(fairscale_fully_sharded=True, min_gpus=1) +def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): + _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp", 1) + + +@RunIf(fairscale=True) +def test_misconfiguration_error_with_ddp_sharded(tmpdir): + _test_misconfiguration_error_with_sharded_model(tmpdir, "ddp_sharded") + + +@RunIf(fairscale=True) +def test_misconfiguration_error_with_ddp_spawn_sharded(tmpdir): + _test_misconfiguration_error_with_sharded_model(tmpdir, "ddp_sharded_spawn") + + +@RunIf(deepspeed=True) +def test_misconfiguration_error_with_deep_speed(tmpdir): + _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed") From d2bb0ad289ae9ab4dc76013d0880026f983c0581 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 1 Dec 2021 09:53:53 +1300 Subject: [PATCH 25/59] Don't raise an error for DDPSharded and DDPSpawnSharded with SWA --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 12 ++---------- tests/callbacks/test_stochastic_weight_avg.py | 10 ---------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index cd8ae88e5c646..962a438e4fe45 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,12 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.plugins.training_type import ( - DDPFullyShardedPlugin, - DDPShardedPlugin, - DDPSpawnShardedPlugin, - DeepSpeedPlugin, -) +from pytorch_lightning.plugins.training_type import DDPFullyShardedPlugin, DeepSpeedPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -157,10 +152,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers - if isinstance( - trainer.training_type_plugin, - (DDPFullyShardedPlugin, DDPShardedPlugin, DDPSpawnShardedPlugin, DeepSpeedPlugin), - ): + if isinstance(trainer.training_type_plugin, (DDPFullyShardedPlugin, DeepSpeedPlugin)): raise MisconfigurationException("SWA does not currently support sharded models.") if len(optimizers) != 1: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 355fdd8de434f..fec0ecc2ebd61 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -384,16 +384,6 @@ def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp", 1) -@RunIf(fairscale=True) -def test_misconfiguration_error_with_ddp_sharded(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "ddp_sharded") - - -@RunIf(fairscale=True) -def test_misconfiguration_error_with_ddp_spawn_sharded(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "ddp_sharded_spawn") - - @RunIf(deepspeed=True) def test_misconfiguration_error_with_deep_speed(tmpdir): _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed") From ffcf0110d79c9cf2b2fbc3daa741279593b6c819 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 6 Dec 2021 12:57:58 +1300 Subject: [PATCH 26/59] Fix incorrect multiple context manager syntax for Python < 3.9 Parenthesized context managers only work since 3.9 --- tests/callbacks/test_stochastic_weight_avg.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index a25deb3f6ed0e..a3503bdff0d83 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -326,10 +326,8 @@ def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): ) exception_type = Exception if ddp else RuntimeError - with ( - mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward), - pytest.raises(exception_type), - ): + backward_patch = mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) + with backward_patch, pytest.raises(exception_type): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" From 8e848dcdb53a59fa40f4e72c563750cf70d10de7 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 15 Dec 2021 15:26:22 +1300 Subject: [PATCH 27/59] Code review tidy up and fix CHANGELOG merge error --- CHANGELOG.md | 6 +++--- pytorch_lightning/callbacks/stochastic_weight_avg.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70e53878adff4..e0d5392a70c3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -239,9 +239,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed unnessesary `_move_optimizer_state` method overrides from `TPUSpawnPlugin` and `SingleTPUPlugin` ([#10849](https://github.com/PyTorchLightning/pytorch-lightning/pull/10849)) -- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938)) - - - Removed `model_sharded_context` method from `Accelerator` ([#10886](https://github.com/PyTorchLightning/pytorch-lightning/pull/10886)) @@ -273,6 +270,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991)) +- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938)) + + - diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 962a438e4fe45..d20aa47b0fcf9 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -316,10 +316,10 @@ def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[ if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): # If we're not within the SWA epochs then when loading checkpoint data we would want # to use parameters from the underlying model rather than the SWA parameters. - return None + return return list(self._average_model.parameters()) - def _load_average_model_parameters(self, parameter_state: Any): + def _load_average_model_parameters(self, parameter_state: Any) -> None: if self._average_model is None or parameter_state is None: return for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state): From 11757d5e889fe457c14d3f1f1dc55a035d24d6ff Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 16 Dec 2021 13:29:12 +1300 Subject: [PATCH 28/59] Add a warning with initializing SWA after start but without checkpoint data --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index d20aa47b0fcf9..493e8b1ffdf46 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -197,6 +197,14 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self._scheduler_step_count is not None: # Restore scheduler step count from checkpoint self._swa_scheduler._step_count = self._scheduler_step_count + elif trainer.current_epoch != self.swa_start: + # Log a warning if we're initializing after start without any checkpoint data, + # as behaviour will be different compared to having checkpoint data. + rank_zero_warn( + "SWA is initializing after swa_start without any checkpoint data. " + "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." + ) + default_scheduler_cfg = _get_default_scheduler_config() assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 default_scheduler_cfg["scheduler"] = self._swa_scheduler From fd59c41a26516a9ebce324dbfc5221fc2665ae84 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 22 Dec 2021 10:03:49 +1300 Subject: [PATCH 29/59] Fixes to account for changes merged from master --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 4 ++-- tests/callbacks/test_stochastic_weight_avg.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 493e8b1ffdf46..b9eb0d5cc8698 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,7 +24,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.plugins.training_type import DDPFullyShardedPlugin, DeepSpeedPlugin +from pytorch_lightning.plugins.training_type import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -152,7 +152,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers - if isinstance(trainer.training_type_plugin, (DDPFullyShardedPlugin, DeepSpeedPlugin)): + if isinstance(trainer.training_type_plugin, (DDPFullyShardedStrategy, DeepSpeedStrategy)): raise MisconfigurationException("SWA does not currently support sharded models.") if len(optimizers) != 1: diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 84a5b74088a3c..a4e9326e5c639 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -326,7 +326,7 @@ def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): ) exception_type = Exception if ddp else RuntimeError - backward_patch = mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward) + backward_patch = mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward) with backward_patch, pytest.raises(exception_type): trainer.fit(model) @@ -349,7 +349,7 @@ def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): strategy=strategy, ) - with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward): + with mock.patch.object(Strategy, "backward", wraps=trainer.training_type_plugin.backward): trainer.fit(model, ckpt_path=checkpoint_path.as_posix()) @@ -385,6 +385,6 @@ def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp", 1) -@RunIf(deepspeed=True) +@RunIf(deepspeed=True, min_gpus=1) def test_misconfiguration_error_with_deep_speed(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed") + _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed", 1) From b10261e03c7022d3106842259122af4407361f40 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 12 Jan 2022 14:47:36 +1300 Subject: [PATCH 30/59] Fix SWA scheduler not being stepped This was broken by 98ea79b8b07fd30ed718d28052d1bd8a586d3796 SWA only works with one optimizer, so always set opt_idx to zero --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index bf13925f405ed..c50b18a837236 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -208,6 +208,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo default_scheduler_cfg = _get_default_scheduler_config() assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1 default_scheduler_cfg["scheduler"] = self._swa_scheduler + default_scheduler_cfg["opt_idx"] = 0 if trainer.lr_schedulers: scheduler_cfg = trainer.lr_schedulers[0] From 5bc9bee00ee5f1a5457ebdb19ab639675dd93806 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jan 2022 01:51:11 +0000 Subject: [PATCH 31/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index c50b18a837236..1f97385566401 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -24,8 +24,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.core.optimizer import _get_default_scheduler_config +from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException From 9b5fbfcec1ccd89840ee68a4b586bfb5f31877f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 5 Feb 2022 06:44:26 +0100 Subject: [PATCH 32/59] mark test helper protected --- tests/callbacks/test_stochastic_weight_avg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 70c692ce4a600..a74564cf9fe2e 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -309,7 +309,7 @@ def on_train_epoch_start(self): assert model.on_train_epoch_start_called -def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): +def _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): model = SwaTestModel(crash_after_epoch=crash_after_epoch) swa_start = 3 max_epochs = 5 @@ -360,14 +360,14 @@ def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): @pytest.mark.parametrize("crash_after_epoch", [2, 4]) def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): - swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=crash_after_epoch) + _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=crash_after_epoch) @RunIf(skip_windows=True, min_torch="1.8") def test_swa_resume_training_from_checkpoint_ddp(tmpdir): # Requires PyTorch >= 1.8 to include this segfault fix: # https://github.com/pytorch/pytorch/pull/50998 - swa_resume_training_from_checkpoint(tmpdir, ddp=True) + _swa_resume_training_from_checkpoint(tmpdir, ddp=True) def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None): From 8e0c2554f0b08e648d4ed7a55120ee945dfd2abe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 5 Feb 2022 06:49:43 +0100 Subject: [PATCH 33/59] avoid warning for find_unused_parameters --- tests/callbacks/test_stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index a74564cf9fe2e..8190b59e398d7 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -316,7 +316,7 @@ def _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False) swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) num_processes = 2 if ddp else 1 - strategy = "ddp_spawn" if ddp else None + strategy = "ddp_spawn_find_unused_parameters_false" if ddp else None trainer = Trainer( default_root_dir=tmpdir, From c44279f2f9201d37060fdb8acf7013750291c623 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 10 Feb 2022 13:28:54 +1300 Subject: [PATCH 34/59] Use _LRScheduler.state_dict/load_state_dict instead of accessing private _step_count --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 56803e7722ada..33ae5b44065ec 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -127,7 +127,7 @@ def __init__( self._average_model: Optional[pl.LightningModule] = None self._initialized = False self._swa_scheduler: Optional[SWALR] = None - self._scheduler_step_count: Optional[int] = None + self._scheduler_state: Optional[Dict] = None self._init_n_averaged = 0 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @@ -191,9 +191,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ) - if self._scheduler_step_count is not None: - # Restore scheduler step count from checkpoint - self._swa_scheduler._step_count = self._scheduler_step_count + if self._scheduler_state is not None: + # Restore scheduler state from checkpoint + self._swa_scheduler.load_state_dict(self._scheduler_state) elif trainer.current_epoch != self.swa_start: # Log a warning if we're initializing after start without any checkpoint data, # as behaviour will be different compared to having checkpoint data. @@ -305,7 +305,7 @@ def on_save_checkpoint( ) -> dict: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), - "scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count, + "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), "average_model_parameters": self._get_average_model_parameters(trainer), } @@ -313,7 +313,7 @@ def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] ) -> None: self._init_n_averaged = callback_state["n_averaged"] - self._scheduler_step_count = callback_state["scheduler_step_count"] + self._scheduler_state = callback_state["scheduler_state"] self._load_average_model_parameters(callback_state["average_model_parameters"]) def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[nn.Parameter]]: From b3eee593e3ddd111694ef04411b866b4b8228cf5 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 10 Feb 2022 12:45:11 +1300 Subject: [PATCH 35/59] Add test to reproduce crash when resuming with SWA and a custom scheduler --- tests/callbacks/test_stochastic_weight_avg.py | 40 ++++++++++++++++--- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 8190b59e398d7..7f41dc075d95d 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -20,6 +20,7 @@ import pytest import torch from torch import nn +from torch.optim.lr_scheduler import LambdaLR from torch.optim.swa_utils import SWALR from torch.utils.data import DataLoader @@ -309,8 +310,7 @@ def on_train_epoch_start(self): assert model.on_train_epoch_start_called -def _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False): - model = SwaTestModel(crash_after_epoch=crash_after_epoch) +def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False): swa_start = 3 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) @@ -340,7 +340,6 @@ def _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False) assert len(checkpoint_files) == 1 checkpoint_path = checkpoint_dir / checkpoint_files[0] - model = SwaTestModel() swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) trainer = Trainer( default_root_dir=tmpdir, @@ -355,19 +354,48 @@ def _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False) ) with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward): - trainer.fit(model, ckpt_path=checkpoint_path.as_posix()) + trainer.fit(resume_model, ckpt_path=checkpoint_path.as_posix()) + + +class CustomSchedulerModel(SwaTestModel): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def lr_lambda(current_step: int): + return 0.1 + + scheduler = LambdaLR(optimizer, lr_lambda, -1) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.interval, + }, + } @pytest.mark.parametrize("crash_after_epoch", [2, 4]) def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): - _swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=crash_after_epoch) + model = SwaTestModel(crash_after_epoch=crash_after_epoch) + resume_model = SwaTestModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) + + +@pytest.mark.parametrize("crash_after_epoch", [2, 4]) +def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_after_epoch): + # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 + model = CustomSchedulerModel(crash_after_epoch=crash_after_epoch) + resume_model = CustomSchedulerModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) @RunIf(skip_windows=True, min_torch="1.8") def test_swa_resume_training_from_checkpoint_ddp(tmpdir): # Requires PyTorch >= 1.8 to include this segfault fix: # https://github.com/pytorch/pytorch/pull/50998 - _swa_resume_training_from_checkpoint(tmpdir, ddp=True) + model = SwaTestModel(crash_after_epoch=4) + resume_model = SwaTestModel() + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None): From 0107ff1205ed39509f0487ca5039899f64b36428 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 10 Feb 2022 14:40:48 +1300 Subject: [PATCH 36/59] Prevent trying to restore scheduler state into the wrong type of scheduler when using SWA --- .../callbacks/stochastic_weight_avg.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 33ae5b44065ec..abc67d0e23598 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -128,6 +128,7 @@ def __init__( self._initialized = False self._swa_scheduler: Optional[SWALR] = None self._scheduler_state: Optional[Dict] = None + self._scheduler_configs: Optional[List] = None self._init_n_averaged = 0 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @@ -168,6 +169,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): # virtually increase max_epochs to perform batch norm update on latest epoch. trainer.fit_loop.max_epochs += 1 + if self._scheduler_state is not None: + self._clear_schedulers(trainer) + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): self._initialized = True @@ -205,6 +209,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 + if self._scheduler_configs: + trainer.lr_scheduler_configs[:] = self._scheduler_configs + self._scheduler_configs = None + if trainer.lr_scheduler_configs: scheduler_cfg = trainer.lr_scheduler_configs[0] if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: @@ -315,6 +323,21 @@ def on_load_checkpoint( self._init_n_averaged = callback_state["n_averaged"] self._scheduler_state = callback_state["scheduler_state"] self._load_average_model_parameters(callback_state["average_model_parameters"]) + if self._scheduler_state is not None: + self._clear_schedulers(trainer) + + def _clear_schedulers(self, trainer: "pl.Trainer") -> None: + # If we have scheduler state saved, clear the scheduler configs so that we don't try to + # load state into the wrong type of schedulers when restoring scheduler checkpoint state. + # We'll configure the scheduler and re-load its state in on_train_epoch_start. + # Note that this is called from both on_load_checkpoint and on_fit_start, to handle when the + # training strategy's restore_checkpoint_after_setup is both True and False, and relies + # on the callback state being restored before the schedulers. + # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background. + if trainer.lr_scheduler_configs: + assert len(trainer.lr_scheduler_configs) == 1 + self._scheduler_configs = list(trainer.strategy.lr_scheduler_configs) + trainer.lr_scheduler_configs.clear() def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[nn.Parameter]]: if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): From 81ac19547407a56793bb4da3c2ef61bb16f7f911 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 14 Feb 2022 15:21:21 +1300 Subject: [PATCH 37/59] Add test case where trainer.strategy.restore_checkpoint_after_setup is True --- tests/callbacks/test_stochastic_weight_avg.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 7f41dc075d95d..cbf2ad3f85972 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -14,7 +14,7 @@ import logging import os from pathlib import Path -from typing import Optional +from typing import ContextManager, Optional from unittest import mock import pytest @@ -162,7 +162,7 @@ def train_with_swa( devices=devices, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward): + with _backward_patch(trainer): trainer.fit(model) # check the model is the expected @@ -310,7 +310,7 @@ def on_train_epoch_start(self): assert model.on_train_epoch_start_called -def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False): +def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False, restore_after_setup=False): swa_start = 3 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) @@ -331,8 +331,9 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False) ) exception_type = Exception if ddp else RuntimeError - backward_patch = mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) - with backward_patch, pytest.raises(exception_type): + backward_patch = _backward_patch(trainer) + restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) + with backward_patch, restore_patch, pytest.raises(exception_type): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" @@ -353,7 +354,9 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False) strategy=strategy, ) - with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward): + backward_patch = _backward_patch(trainer) + restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) + with restore_patch, backward_patch: trainer.fit(resume_model, ckpt_path=checkpoint_path.as_posix()) @@ -382,11 +385,12 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): @pytest.mark.parametrize("crash_after_epoch", [2, 4]) -def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_after_epoch): +@pytest.mark.parametrize("restore_after_setup", [False, True]) +def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_after_epoch, restore_after_setup): # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 model = CustomSchedulerModel(crash_after_epoch=crash_after_epoch) resume_model = CustomSchedulerModel() - _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, restore_after_setup=restore_after_setup) @RunIf(skip_windows=True, min_torch="1.8") @@ -421,3 +425,17 @@ def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): @RunIf(deepspeed=True, min_gpus=1) def test_misconfiguration_error_with_deep_speed(tmpdir): _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed", 1) + + +def _backward_patch(trainer: Trainer) -> ContextManager: + return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) + + +def _restore_after_setup_patch(trainer: Trainer, restore_after_setup: bool) -> ContextManager: + return mock.patch.object( + Strategy, + "restore_checkpoint_after_setup", + wraps=trainer.strategy.restore_checkpoint_after_setup, + new_callable=mock.PropertyMock, + return_value=restore_after_setup, + ) From 20393b1d7e6b3b674c56af56423e36357d8c0801 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 15 Feb 2022 04:00:34 +0100 Subject: [PATCH 38/59] Minor test refactoring --- tests/callbacks/test_stochastic_weight_avg.py | 55 ++++++------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index cbf2ad3f85972..2f04855fd749e 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -45,7 +45,6 @@ def __init__( self.interval = interval self.iterable_dataset = iterable_dataset self.crash_after_epoch = crash_after_epoch - self._epoch_count = 0 self.save_hyperparameters() def training_step(self, batch, batch_idx): @@ -54,10 +53,8 @@ def training_step(self, batch, batch_idx): return {"loss": loss} def train_dataloader(self): - dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset dset = dset_cls(32, 64) - return DataLoader(dset, batch_size=2) def configure_optimizers(self): @@ -73,8 +70,7 @@ def configure_optimizers(self): def training_epoch_end(self, _): if not self.crash_after_epoch: return - self._epoch_count += 1 - if self._epoch_count >= self.crash_after_epoch: + if self.trainer.current_epoch + 1 >= self.crash_after_epoch: raise RuntimeError("Crash test") @@ -312,52 +308,35 @@ def on_train_epoch_start(self): def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False, restore_after_setup=False): swa_start = 3 - max_epochs = 5 - swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) - - num_processes = 2 if ddp else 1 - strategy = "ddp_spawn_find_unused_parameters_false" if ddp else None - - trainer = Trainer( - default_root_dir=tmpdir, - enable_progress_bar=False, - max_epochs=max_epochs, - limit_train_batches=5, - limit_val_batches=0, - callbacks=[swa_callback], - accumulate_grad_batches=2, - num_processes=num_processes, - strategy=strategy, - ) + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 5, + "accelerator": "cpu", + "strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None, + "devices": 2 if ddp else 1, + "limit_train_batches": 5, + "limit_val_batches": 0, + "accumulate_grad_batches": 2, + "enable_progress_bar": False, + } + trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) - exception_type = Exception if ddp else RuntimeError backward_patch = _backward_patch(trainer) restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) - with backward_patch, restore_patch, pytest.raises(exception_type): + with backward_patch, restore_patch, pytest.raises(Exception if ddp else RuntimeError): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" checkpoint_files = os.listdir(checkpoint_dir) assert len(checkpoint_files) == 1 - checkpoint_path = checkpoint_dir / checkpoint_files[0] + ckpt_path = str(checkpoint_dir / checkpoint_files[0]) - swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) - trainer = Trainer( - default_root_dir=tmpdir, - enable_progress_bar=False, - max_epochs=max_epochs, - limit_train_batches=5, - limit_val_batches=0, - callbacks=[swa_callback], - accumulate_grad_batches=2, - num_processes=num_processes, - strategy=strategy, - ) + trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) backward_patch = _backward_patch(trainer) restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) with restore_patch, backward_patch: - trainer.fit(resume_model, ckpt_path=checkpoint_path.as_posix()) + trainer.fit(resume_model, ckpt_path=ckpt_path) class CustomSchedulerModel(SwaTestModel): From 14f9f20c364b7d3f251fae37a1078a669d9c4f68 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 15 Feb 2022 04:34:47 +0100 Subject: [PATCH 39/59] Fix test_swa_resume_training_from_checkpoint[2] --- pytorch_lightning/loops/fit_loop.py | 2 -- tests/callbacks/test_stochastic_weight_avg.py | 5 ++++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8cbe4c167a29d..ce266100062a3 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -141,8 +141,6 @@ def restarting(self, restarting: bool) -> None: self.epoch_progress.current.processed, ) finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) - if finished_before_on_train_end: - self.epoch_progress.current.completed = self.epoch_progress.current.processed restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 2f04855fd749e..b7781171f6419 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -90,7 +90,10 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) - if self.first_epoch is None: + if self.first_epoch is None and not trainer.fit_loop.restarting: + # since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will + # not update the model and just call the epoch-level hooks, for that reason, we check that we are not + # restarting before choosing the first epoch self.first_epoch = trainer.current_epoch assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end) if self.swa_start <= trainer.current_epoch: From c6771415af72ccce17ea5c98e53775bc1667c25f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 15 Feb 2022 04:42:27 +0100 Subject: [PATCH 40/59] Did not mean to remove this --- pytorch_lightning/loops/fit_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index ce266100062a3..8cbe4c167a29d 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -141,6 +141,8 @@ def restarting(self, restarting: bool) -> None: self.epoch_progress.current.processed, ) finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values) + if finished_before_on_train_end: + self.epoch_progress.current.completed = self.epoch_progress.current.processed restarting &= finished_before_on_train_end Loop.restarting.fset(self, restarting) # call the parent setter From 5cf5e1bfc9109c8c04221bc9ea4122b448d8c5e1 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 16 Feb 2022 11:43:14 +1300 Subject: [PATCH 41/59] Test tidy up from PR review comments --- tests/callbacks/test_stochastic_weight_avg.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index b7781171f6419..4e05a1063c6c9 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -34,7 +34,7 @@ class SwaTestModel(BoringModel): def __init__( - self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_after_epoch=None + self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None ): super().__init__() layers = [nn.Linear(32, 32)] @@ -44,10 +44,11 @@ def __init__( self.layer = nn.Sequential(*layers) self.interval = interval self.iterable_dataset = iterable_dataset - self.crash_after_epoch = crash_after_epoch - self.save_hyperparameters() + self.crash_on_epoch = crash_on_epoch def training_step(self, batch, batch_idx): + if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: + raise DummyError() output = self.forward(batch) loss = self.loss(batch, output) return {"loss": loss} @@ -67,12 +68,6 @@ def configure_optimizers(self): }, } - def training_epoch_end(self, _): - if not self.crash_after_epoch: - return - if self.trainer.current_epoch + 1 >= self.crash_after_epoch: - raise RuntimeError("Crash test") - class SwaTestCallback(StochasticWeightAveraging): update_parameters_calls: int = 0 @@ -132,6 +127,13 @@ def on_train_end(self, trainer, pl_module): assert self.transfer_weights_calls == 1 +class DummyError(Exception): + """Dummy error used to simulate a crash during training.""" + + def __init__(self): + super().__init__("Crash test") + + def train_with_swa( tmpdir, batchnorm=True, @@ -326,7 +328,7 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False, backward_patch = _backward_patch(trainer) restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) - with backward_patch, restore_patch, pytest.raises(Exception if ddp else RuntimeError): + with backward_patch, restore_patch, pytest.raises(Exception if ddp else DummyError): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" @@ -359,18 +361,18 @@ def lr_lambda(current_step: int): } -@pytest.mark.parametrize("crash_after_epoch", [2, 4]) -def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch): - model = SwaTestModel(crash_after_epoch=crash_after_epoch) +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) +def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch): + model = SwaTestModel(crash_on_epoch=crash_on_epoch) resume_model = SwaTestModel() _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) -@pytest.mark.parametrize("crash_after_epoch", [2, 4]) +@pytest.mark.parametrize("crash_on_epoch", [1, 3]) @pytest.mark.parametrize("restore_after_setup", [False, True]) -def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_after_epoch, restore_after_setup): +def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch, restore_after_setup): # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 - model = CustomSchedulerModel(crash_after_epoch=crash_after_epoch) + model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch) resume_model = CustomSchedulerModel() _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, restore_after_setup=restore_after_setup) @@ -379,7 +381,7 @@ def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_afte def test_swa_resume_training_from_checkpoint_ddp(tmpdir): # Requires PyTorch >= 1.8 to include this segfault fix: # https://github.com/pytorch/pytorch/pull/50998 - model = SwaTestModel(crash_after_epoch=4) + model = SwaTestModel(crash_on_epoch=3) resume_model = SwaTestModel() _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) From fe79d6c8c902389cbb9298af4c6e2d1869c8d50e Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 16 Feb 2022 11:59:00 +1300 Subject: [PATCH 42/59] Store most recent update epoch in the SWA checkpoint data --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index f72687dfabe6e..9a54632cbb2f9 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -130,6 +130,7 @@ def __init__( self._scheduler_state: Optional[Dict] = None self._scheduler_configs: Optional[List] = None self._init_n_averaged = 0 + self._latest_update_epoch = -1 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @property @@ -228,8 +229,11 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if self.n_averaged is None: self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) - if self.swa_start <= trainer.current_epoch <= self.swa_end: + if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( + trainer.current_epoch > self._latest_update_epoch + ): self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) + self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: @@ -314,6 +318,7 @@ def on_save_checkpoint( ) -> dict: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), + "latest_update_epoch": self._latest_update_epoch, "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), "average_model_parameters": self._get_average_model_parameters(trainer), } @@ -322,6 +327,7 @@ def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] ) -> None: self._init_n_averaged = callback_state["n_averaged"] + self._latest_update_epoch = callback_state["latest_update_epoch"] self._scheduler_state = callback_state["scheduler_state"] self._load_average_model_parameters(callback_state["average_model_parameters"]) if self._scheduler_state is not None: From c7c2818862c9a9d8f10967067d3db677d6b3aec2 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 28 Feb 2022 10:40:19 +1300 Subject: [PATCH 43/59] Fix for master change that broke resuming without validation dataloaders Changes in https://github.com/PyTorchLightning/pytorch-lightning/pull/11576 caused a crash when restarting with limit_val_batches = 0 --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index cb0e79ae89448..26a741c91431d 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -78,7 +78,7 @@ def dataloaders(self) -> Sequence[DataLoader]: """Returns the validation or test dataloaders.""" dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders if dataloaders is None: - raise RuntimeError("Dataloaders should be available.") + return [] return dataloaders def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override] From d2ed468a3adea2e3e14e592fb604dae61ef35d0c Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 28 Feb 2022 11:08:03 +1300 Subject: [PATCH 44/59] Adjust SWA tests to account for current checkpoint resume behaviour --- tests/callbacks/test_stochastic_weight_avg.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 4e05a1063c6c9..fcc9819b654ff 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -103,7 +103,13 @@ def on_train_epoch_end(self, trainer, *args): assert self.n_averaged == swa_epoch + 1 assert self._swa_scheduler is not None # Scheduler is stepped once on initialization and then at the end of each epoch - assert self._swa_scheduler._step_count == swa_epoch + 2 + expected_step_count = swa_epoch + 2 + if trainer.fit_loop.restarting or (self.first_epoch is not None and self.first_epoch > self.swa_start): + # TODO: Remove this adjustment after fixing checkpoint resume behaviour. + # The number of scheduler steps is currently one less than expected as the scheduler isn't stepped on + # the first epoch after resuming. + expected_step_count -= 1 + assert self._swa_scheduler._step_count == expected_step_count elif trainer.current_epoch > self.swa_end: assert self.n_averaged == self._max_epochs - self.swa_start @@ -123,6 +129,12 @@ def on_train_end(self, trainer, pl_module): # check call counts first_swa_epoch = max(self.first_epoch, self.swa_start) + if first_swa_epoch > self.swa_start: + # TODO: Remove this adjustment after fixing checkpoint resume behaviour. + # When resuming from a checkpoint, first_epoch is currently incorrect as we resume with epoch set + # to the next epoch after the checkpoint. + # See https://github.com/PyTorchLightning/pytorch-lightning/pull/9938#discussion_r807441500 + first_swa_epoch -= 1 assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch assert self.transfer_weights_calls == 1 From b71b690804022beda2160c1b62d779e800ff45bc Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 29 Mar 2022 11:16:44 +1300 Subject: [PATCH 45/59] Revert workarounds for first epoch after resume having no batches --- tests/callbacks/test_stochastic_weight_avg.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 12ba69684de7d..fecf500d8d492 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -103,13 +103,7 @@ def on_train_epoch_end(self, trainer, *args): assert self.n_averaged == swa_epoch + 1 assert self._swa_scheduler is not None # Scheduler is stepped once on initialization and then at the end of each epoch - expected_step_count = swa_epoch + 2 - if trainer.fit_loop.restarting or (self.first_epoch is not None and self.first_epoch > self.swa_start): - # TODO: Remove this adjustment after fixing checkpoint resume behaviour. - # The number of scheduler steps is currently one less than expected as the scheduler isn't stepped on - # the first epoch after resuming. - expected_step_count -= 1 - assert self._swa_scheduler._step_count == expected_step_count + assert self._swa_scheduler._step_count == swa_epoch + 2 elif trainer.current_epoch > self.swa_end: assert self.n_averaged == self._max_epochs - self.swa_start @@ -129,12 +123,6 @@ def on_train_end(self, trainer, pl_module): # check call counts first_swa_epoch = max(self.first_epoch, self.swa_start) - if first_swa_epoch > self.swa_start: - # TODO: Remove this adjustment after fixing checkpoint resume behaviour. - # When resuming from a checkpoint, first_epoch is currently incorrect as we resume with epoch set - # to the next epoch after the checkpoint. - # See https://github.com/PyTorchLightning/pytorch-lightning/pull/9938#discussion_r807441500 - first_swa_epoch -= 1 assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch assert self.transfer_weights_calls == 1 From 15e6334493e89d4f7bcc478b01d0d88c8d5e1ec0 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 29 Mar 2022 12:34:38 +1300 Subject: [PATCH 46/59] Use state_dict/load_state_dict instead of on_save/load_checkpoint in SWA This required a bit of a hacky workaround to continue to correctly clear the scheduler state if it's going to be restored from the SWA checkpoint data --- .../callbacks/stochastic_weight_avg.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index c0b40124537b7..20654db76d720 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -15,6 +15,7 @@ Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ +import weakref from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union @@ -129,6 +130,7 @@ def __init__( self._swa_scheduler: Optional[SWALR] = None self._scheduler_state: Optional[Dict] = None self._scheduler_configs: Optional[List] = None + self._trainer: Optional[weakref.ref] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @@ -172,6 +174,11 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if self._scheduler_state is not None: self._clear_schedulers(trainer) + else: + # We're probably not restoring from a checkpoint, but possibly the checkpoint data just + # hasn't been loaded yet if strategy.restore_checkpoint_after_setup is True, + # so keep a hold of the trainer so that we can defer clearing schedulers if needed. + self._trainer = weakref.ref(trainer) def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): @@ -314,31 +321,29 @@ def avg_fn( """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) - def on_save_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] - ) -> dict: + def state_dict(self) -> Dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), - "average_model_parameters": self._get_average_model_parameters(trainer), + "average_model_parameters": None if self._average_model is None else list(self._average_model.parameters()), } - def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] - ) -> None: - self._init_n_averaged = callback_state["n_averaged"] - self._latest_update_epoch = callback_state["latest_update_epoch"] - self._scheduler_state = callback_state["scheduler_state"] - self._load_average_model_parameters(callback_state["average_model_parameters"]) - if self._scheduler_state is not None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self._init_n_averaged = state_dict["n_averaged"] + self._latest_update_epoch = state_dict["latest_update_epoch"] + self._scheduler_state = state_dict["scheduler_state"] + self._load_average_model_parameters(state_dict["average_model_parameters"]) + # If we're loading state after on_fit_start, check if we need to clear schedulers + trainer = None if self._trainer is None else self._trainer() + if self._scheduler_state is not None and trainer is not None: self._clear_schedulers(trainer) def _clear_schedulers(self, trainer: "pl.Trainer") -> None: # If we have scheduler state saved, clear the scheduler configs so that we don't try to # load state into the wrong type of schedulers when restoring scheduler checkpoint state. # We'll configure the scheduler and re-load its state in on_train_epoch_start. - # Note that this is called from both on_load_checkpoint and on_fit_start, to handle when the + # Note that this is called from both load_state_dict and on_fit_start, to handle when the # training strategy's restore_checkpoint_after_setup is both True and False, and relies # on the callback state being restored before the schedulers. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background. @@ -347,13 +352,6 @@ def _clear_schedulers(self, trainer: "pl.Trainer") -> None: self._scheduler_configs = list(trainer.strategy.lr_scheduler_configs) trainer.lr_scheduler_configs.clear() - def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[nn.Parameter]]: - if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end): - # If we're not within the SWA epochs then when loading checkpoint data we would want - # to use parameters from the underlying model rather than the SWA parameters. - return - return list(self._average_model.parameters()) - def _load_average_model_parameters(self, parameter_state: Any) -> None: if self._average_model is None or parameter_state is None: return From e3104bc81962d1dfda33ef746bab4b4ce1c9227e Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 20 Apr 2022 13:41:12 +1200 Subject: [PATCH 47/59] Remove unnecessary workaround for handling restore_checkpoint_after_setup --- .../callbacks/stochastic_weight_avg.py | 17 +++---------- tests/callbacks/test_stochastic_weight_avg.py | 25 ++++--------------- 2 files changed, 8 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 20654db76d720..3ab51e1f8d594 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -15,7 +15,6 @@ Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ -import weakref from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union @@ -130,7 +129,6 @@ def __init__( self._swa_scheduler: Optional[SWALR] = None self._scheduler_state: Optional[Dict] = None self._scheduler_configs: Optional[List] = None - self._trainer: Optional[weakref.ref] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @@ -174,11 +172,6 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if self._scheduler_state is not None: self._clear_schedulers(trainer) - else: - # We're probably not restoring from a checkpoint, but possibly the checkpoint data just - # hasn't been loaded yet if strategy.restore_checkpoint_after_setup is True, - # so keep a hold of the trainer so that we can defer clearing schedulers if needed. - self._trainer = weakref.ref(trainer) def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): @@ -334,18 +327,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] self._load_average_model_parameters(state_dict["average_model_parameters"]) - # If we're loading state after on_fit_start, check if we need to clear schedulers - trainer = None if self._trainer is None else self._trainer() - if self._scheduler_state is not None and trainer is not None: - self._clear_schedulers(trainer) def _clear_schedulers(self, trainer: "pl.Trainer") -> None: # If we have scheduler state saved, clear the scheduler configs so that we don't try to # load state into the wrong type of schedulers when restoring scheduler checkpoint state. # We'll configure the scheduler and re-load its state in on_train_epoch_start. - # Note that this is called from both load_state_dict and on_fit_start, to handle when the - # training strategy's restore_checkpoint_after_setup is both True and False, and relies - # on the callback state being restored before the schedulers. + # Note that this relies on the callback state being restored before the scheduler state is + # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of + # writing that is only True for deepspeed which is already not supported by SWA. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background. if trainer.lr_scheduler_configs: assert len(trainer.lr_scheduler_configs) == 1 diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index fecf500d8d492..586aa79b8be3d 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -311,7 +311,7 @@ def on_train_epoch_start(self): assert model.on_train_epoch_start_called -def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False, restore_after_setup=False): +def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False): swa_start = 3 trainer_kwargs = { "default_root_dir": tmpdir, @@ -326,9 +326,7 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False, } trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) - backward_patch = _backward_patch(trainer) - restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) - with backward_patch, restore_patch, pytest.raises(Exception if ddp else DummyError): + with _backward_patch(trainer), pytest.raises(Exception if ddp else DummyError): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" @@ -338,9 +336,7 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False, trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) - backward_patch = _backward_patch(trainer) - restore_patch = _restore_after_setup_patch(trainer, restore_after_setup) - with restore_patch, backward_patch: + with _backward_patch(trainer): trainer.fit(resume_model, ckpt_path=ckpt_path) @@ -369,12 +365,11 @@ def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch): @pytest.mark.parametrize("crash_on_epoch", [1, 3]) -@pytest.mark.parametrize("restore_after_setup", [False, True]) -def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch, restore_after_setup): +def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch): # Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch) resume_model = CustomSchedulerModel() - _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, restore_after_setup=restore_after_setup) + _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) @RunIf(skip_windows=True, min_torch="1.8") @@ -413,13 +408,3 @@ def test_misconfiguration_error_with_deep_speed(tmpdir): def _backward_patch(trainer: Trainer) -> ContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) - - -def _restore_after_setup_patch(trainer: Trainer, restore_after_setup: bool) -> ContextManager: - return mock.patch.object( - Strategy, - "restore_checkpoint_after_setup", - wraps=trainer.strategy.restore_checkpoint_after_setup, - new_callable=mock.PropertyMock, - return_value=restore_after_setup, - ) From f509178fd6a4179896a7d4f7083d8174d1200186 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 18 May 2022 10:33:09 +1200 Subject: [PATCH 48/59] Fix deprecation warning in tests --- tests/callbacks/test_stochastic_weight_avg.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 55c7cb829b9cb..39c5d548e7dff 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -350,7 +350,7 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir): _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) -def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None): +def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy): model = SwaTestModel() swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) trainer = Trainer( @@ -359,7 +359,8 @@ def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None) max_epochs=5, callbacks=[swa_callback], strategy=strategy, - gpus=gpus, + accelerator="gpu", + devices=1, ) with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"): trainer.fit(model) @@ -367,12 +368,12 @@ def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None) @RunIf(fairscale_fully_sharded=True, min_gpus=1) def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp", 1) + _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp") @RunIf(deepspeed=True, min_gpus=1) def test_misconfiguration_error_with_deep_speed(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed", 1) + _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed") def _backward_patch(trainer: Trainer) -> ContextManager: From 77f137cb8f02e4eb1c3d5afd1de595867336e53d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 25 Jul 2022 22:16:39 +0200 Subject: [PATCH 49/59] update runif --- tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index f52056097eafc..c0cbca667cb94 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -369,12 +369,12 @@ def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy): trainer.fit(model) -@RunIf(fairscale_fully_sharded=True, min_gpus=1) +@RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1) def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp") -@RunIf(deepspeed=True, min_gpus=1) +@RunIf(deepspeed=True, min_cuda_gpus=1) def test_misconfiguration_error_with_deep_speed(tmpdir): _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed") From 324499ee5737515e66444a1e68db1176e2c6eab1 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 18:30:20 +1200 Subject: [PATCH 50/59] Remove no-longer required minimum torch version from test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index c0cbca667cb94..4ec098af0c884 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -344,10 +344,8 @@ def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_e _swa_resume_training_from_checkpoint(tmpdir, model, resume_model) -@RunIf(skip_windows=True, min_torch="1.8") +@RunIf(skip_windows=True) def test_swa_resume_training_from_checkpoint_ddp(tmpdir): - # Requires PyTorch >= 1.8 to include this segfault fix: - # https://github.com/pytorch/pytorch/pull/50998 model = SwaTestModel(crash_on_epoch=3) resume_model = SwaTestModel() _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) From ab8aca0b1e8e0d41951e94139e40ec60e0756640 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 19:07:58 +1200 Subject: [PATCH 51/59] Remove redundant None check that could hide a bug Co-authored-by: Rohit Gupta --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index da7492b3a5ce1..cb12680deff2b 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -337,7 +337,7 @@ def _clear_schedulers(self, trainer: "pl.Trainer") -> None: trainer.lr_scheduler_configs.clear() def _load_average_model_parameters(self, parameter_state: Any) -> None: - if self._average_model is None or parameter_state is None: + if self._average_model is None: return for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state): device = p_swa.device From 7d6e7a82a3469daec2d79ebee1bbdf58d9281905 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 19:05:07 +1200 Subject: [PATCH 52/59] Don't save scheduler configs as they will only be overridden --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index cb12680deff2b..75d1fc01f43ca 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -125,7 +125,6 @@ def __init__( self._initialized = False self._swa_scheduler: Optional[SWALR] = None self._scheduler_state: Optional[Dict] = None - self._scheduler_configs: Optional[List] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None @@ -206,10 +205,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 - if self._scheduler_configs: - trainer.lr_scheduler_configs[:] = self._scheduler_configs - self._scheduler_configs = None - if trainer.lr_scheduler_configs: scheduler_cfg = trainer.lr_scheduler_configs[0] if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: @@ -323,7 +318,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._scheduler_state = state_dict["scheduler_state"] self._load_average_model_parameters(state_dict["average_model_parameters"]) - def _clear_schedulers(self, trainer: "pl.Trainer") -> None: + @staticmethod + def _clear_schedulers(trainer: "pl.Trainer") -> None: # If we have scheduler state saved, clear the scheduler configs so that we don't try to # load state into the wrong type of schedulers when restoring scheduler checkpoint state. # We'll configure the scheduler and re-load its state in on_train_epoch_start. @@ -333,7 +329,6 @@ def _clear_schedulers(self, trainer: "pl.Trainer") -> None: # See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background. if trainer.lr_scheduler_configs: assert len(trainer.lr_scheduler_configs) == 1 - self._scheduler_configs = list(trainer.strategy.lr_scheduler_configs) trainer.lr_scheduler_configs.clear() def _load_average_model_parameters(self, parameter_state: Any) -> None: From 9bf237ecad021823af063de3520c2593c06c8cdc Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 19:15:43 +1200 Subject: [PATCH 53/59] Use state_dict/load_state_dict to save and load average model state --- .../callbacks/stochastic_weight_avg.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 75d1fc01f43ca..45139f1ec417c 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -309,14 +309,14 @@ def state_dict(self) -> Dict[str, Any]: "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), - "average_model_parameters": None if self._average_model is None else list(self._average_model.parameters()), + "average_model_state": None if self._average_model is None else self._average_model.state_dict(), } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] - self._load_average_model_parameters(state_dict["average_model_parameters"]) + self._load_average_model_state(state_dict["average_model_parameters"]) @staticmethod def _clear_schedulers(trainer: "pl.Trainer") -> None: @@ -331,10 +331,7 @@ def _clear_schedulers(trainer: "pl.Trainer") -> None: assert len(trainer.lr_scheduler_configs) == 1 trainer.lr_scheduler_configs.clear() - def _load_average_model_parameters(self, parameter_state: Any) -> None: + def _load_average_model_state(self, model_state: Any) -> None: if self._average_model is None: return - for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state): - device = p_swa.device - p_swa_ = p_swa.detach() - p_swa_.copy_(p_checkpoint.to(device)) + self._average_model.load_state_dict(model_state) From a9b6334dec9b7954a3865044904f0b483809e110 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 19:20:09 +1200 Subject: [PATCH 54/59] Parametrize misconfiguration error tests --- .../callbacks/test_stochastic_weight_avg.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 4ec098af0c884..b5ac8d181f885 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -351,7 +351,11 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir): _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) -def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy): +@pytest.mark.parametrize("strategy", [ + pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)), + pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), +]) +def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str): model = SwaTestModel() swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) trainer = Trainer( @@ -367,15 +371,5 @@ def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy): trainer.fit(model) -@RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1) -def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp") - - -@RunIf(deepspeed=True, min_cuda_gpus=1) -def test_misconfiguration_error_with_deep_speed(tmpdir): - _test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed") - - def _backward_patch(trainer: Trainer) -> ContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) From c24522b504fff6c79e936defeb4d47c1f194578a Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 19:51:25 +1200 Subject: [PATCH 55/59] Remove DummyError and match exception message --- .../callbacks/test_stochastic_weight_avg.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index b5ac8d181f885..cf288fb49b70d 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -49,7 +49,7 @@ def __init__( def training_step(self, batch, batch_idx): if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: - raise DummyError() + raise Exception("SWA crash test") output = self.forward(batch) loss = self.loss(batch, output) return {"loss": loss} @@ -128,13 +128,6 @@ def on_train_end(self, trainer, pl_module): assert self.transfer_weights_calls == 1 -class DummyError(Exception): - """Dummy error used to simulate a crash during training.""" - - def __init__(self): - super().__init__("Crash test") - - def train_with_swa( tmpdir, batchnorm=True, @@ -298,7 +291,7 @@ def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False) } trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs) - with _backward_patch(trainer), pytest.raises(Exception if ddp else DummyError): + with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"): trainer.fit(model) checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints" From ba7cb5e1ebd64050e0e2bb90e187e830f1af4c31 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Tue, 2 Aug 2022 21:58:33 +1200 Subject: [PATCH 56/59] Fix state dict key --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index cc32cf0b2053a..2f215d0b43c77 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -326,7 +326,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] - self._load_average_model_state(state_dict["average_model_parameters"]) + self._load_average_model_state(state_dict["average_model_state"]) @staticmethod def _clear_schedulers(trainer: "pl.Trainer") -> None: From 8bde4f460f6b9a6a2d3acb6b025a8d770d070bfd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Aug 2022 10:02:31 +0000 Subject: [PATCH 57/59] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../callbacks/test_stochastic_weight_avg.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index cf288fb49b70d..65a0fea2fb4a5 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -344,10 +344,13 @@ def test_swa_resume_training_from_checkpoint_ddp(tmpdir): _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True) -@pytest.mark.parametrize("strategy", [ - pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)), - pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), -]) +@pytest.mark.parametrize( + "strategy", + [ + pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)), + pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)), + ], +) def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str): model = SwaTestModel() swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1) From 3ed8ea4d674a49493d301855ed6409fd1102c772 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 3 Aug 2022 09:20:48 +1200 Subject: [PATCH 58/59] Type checking fixes --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 2f215d0b43c77..4e31e3e5bc0e0 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -123,7 +123,7 @@ def __init__( self._model_contains_batch_norm: Optional[bool] = None self._average_model: "pl.LightningModule" self._initialized = False - self._swa_scheduler: Optional[SWALR] = None + self._swa_scheduler: Optional[_LRScheduler] = None self._scheduler_state: Optional[Dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 @@ -228,6 +228,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( trainer.current_epoch > self._latest_update_epoch ): + assert self.n_averaged is not None self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) self._latest_update_epoch = trainer.current_epoch @@ -293,6 +294,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" + assert self.momenta is not None for bn_module in self.momenta: bn_module.momentum = self.momenta[bn_module] From 15fe88e5ccd35209203f14d3cba11cab15872e2a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 4 Aug 2022 01:56:03 +0200 Subject: [PATCH 59/59] fix changelog conflicts --- src/pytorch_lightning/CHANGELOG.md | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 845e7fe5fcd8f..bcc6af4fd14cf 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -49,18 +49,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992)) -## [1.8.0] - YYYY-MM-DD - -### Added - -### Changed - -### Deprecated - -### Removed - -### Fixed - - Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))