Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add StochasticWeightAveraging (SWA) callback #5640

Merged
merged 69 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
4665196
add swa callback
tchaton Jan 24, 2021
be261a8
switch back to 1.6.0
tchaton Jan 24, 2021
d999312
remove optimizer_step
tchaton Jan 24, 2021
e9f89b4
move super
tchaton Jan 24, 2021
c70c35e
update
tchaton Jan 24, 2021
b7ceb5f
forgot update_parameters
tchaton Jan 25, 2021
6a897bc
update on comments
tchaton Jan 25, 2021
051d201
works for ddp
tchaton Jan 25, 2021
0bab9ac
resolve flake8
tchaton Jan 25, 2021
0028fe7
remove set_model
tchaton Jan 25, 2021
3771884
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 25, 2021
c886875
resolve flake8
tchaton Jan 25, 2021
415e5c6
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 25, 2021
663608e
resolve cpu
tchaton Jan 25, 2021
0379f3d
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 25, 2021
9cf56a4
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
774835d
resolve flake8
tchaton Jan 26, 2021
377235f
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 26, 2021
c457608
resolve flake8
tchaton Jan 26, 2021
83eb0b1
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
d25fb36
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
efe2c64
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
b0bb85e
update
tchaton Jan 26, 2021
52d5400
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 26, 2021
b9b9264
Merge branch 'release/0.2-dev' of https://github.com/PyTorchLightning…
tchaton Jan 26, 2021
f0452f2
update on comments
tchaton Jan 26, 2021
4ce6ce0
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 27, 2021
fce8327
Apply suggestions from code review
Borda Jan 27, 2021
8704023
fix
Borda Jan 27, 2021
33bf190
resolve on comments
tchaton Jan 28, 2021
8782eaa
update
tchaton Jan 28, 2021
def691f
typo
tchaton Jan 28, 2021
172cb68
credit to Pytorch Team
tchaton Jan 28, 2021
d683849
update
tchaton Feb 3, 2021
ddfe9d8
add space to docstring
tchaton Feb 3, 2021
76a057f
resolve some bugs
tchaton Feb 3, 2021
c0e36f0
resolve bug
Feb 3, 2021
2e2c412
Revert finetuning changes
carmocca Feb 4, 2021
702b853
Minor changes
carmocca Feb 4, 2021
5ba2a48
Remove pruning check because it was added in 1.4.0 and that is our mi…
carmocca Feb 4, 2021
bbc377e
Fix SWA indices. Improve tests
carmocca Feb 4, 2021
1b47b99
Use properties. Move skip_backward on fn higher
carmocca Feb 4, 2021
dfcb9ce
Check backward call count
carmocca Feb 4, 2021
ca10126
Address comments
carmocca Feb 4, 2021
a4c2b0f
Add misconfig tests
carmocca Feb 4, 2021
8e3d328
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 4, 2021
50b8326
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 4, 2021
1100db2
pre-commit
carmocca Feb 4, 2021
e2003a5
Apply suggestions from code review
carmocca Feb 5, 2021
bf051fb
Update tests/callbacks/test_swa.py
carmocca Feb 5, 2021
53106cd
Revert "Remove pruning check because it was added in 1.4.0 and that i…
carmocca Feb 5, 2021
d019ba8
Use on_train_end. Fix tests. Add ddp_cpu test
carmocca Feb 5, 2021
8b6d55c
Apply suggestions from code review
Borda Feb 9, 2021
4ffcb42
WIP
carmocca Feb 9, 2021
53136cc
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 9, 2021
9669cbd
yapf
carmocca Feb 9, 2021
3316b12
Revert change
carmocca Feb 9, 2021
4eb4bc4
Minor changes
carmocca Feb 9, 2021
5d67b23
Typo
carmocca Feb 9, 2021
cda86e1
Add beta warning
carmocca Feb 10, 2021
fc4ce1a
Docs fixes
carmocca Feb 10, 2021
781b367
Merge branch 'release/1.2-dev' into feat/swa
Feb 10, 2021
c4d1669
Call base file
Feb 10, 2021
70343c3
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 10, 2021
4568bfa
Skip test on Windows
carmocca Feb 10, 2021
02da3a9
.
Borda Feb 10, 2021
adf7931
..
Borda Feb 10, 2021
e46383e
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 10, 2021
7fae40d
Correct variable
carmocca Feb 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Lightning has a few built-in callbacks.
ModelPruning
ProgressBar
ProgressBarBase
StochasticWeightAveraging

----------

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_6_0

__all__ = [
'BackboneFinetuning',
Expand All @@ -36,3 +38,6 @@
'ProgressBarBase',
'ModelPruning',
]

if _PYTORCH_GREATER_EQUAL_1_6_0:
__all__ += ['StochasticWeightAveraging']
256 changes: 256 additions & 0 deletions pytorch_lightning/callbacks/swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Stochastic Weight Averaging Callback
====================================

carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
from copy import deepcopy
from typing import Callable, Optional, Union

import torch
from torch import nn

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_6_0, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _PYTORCH_GREATER_EQUAL_1_6_0:
from torch.optim.swa_utils import SWALR


class StochasticWeightAveraging(Callback):

def __init__(
self,
swa_epoch_start: Union[int, float] = 0.8,
swa_lrs: Optional[Union[float, list]] = None,
annealing_epochs: int = 10,
teddykoker marked this conversation as resolved.
Show resolved Hide resolved
annealing_strategy: str = "cos",
avg_fn: Optional[Callable] = None,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
):

Borda marked this conversation as resolved.
Show resolved Hide resolved
r"""

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to
Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).

This documentation is highly inspired by PyTorch's work on swa
and this callback exposes the same arguments as PyTorch's ``swa_utils`` function.

Find ``swa_utils` source code there: https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py
Find ``SWA explanation`` there: https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

.. note:: `StochasticWeightAveraging` is currently not supported for multiple optimizers / schedulers.
Borda marked this conversation as resolved.
Show resolved Hide resolved

Arguments:

swa_epoch_start (int, float): If provided as int, the procedure will start from
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch

swa_lrs (float or list): the learning rate value for all param groups
together or separately for each group.

annealing_epochs (int): number of epochs in the annealing phase
(default: 10)
teddykoker marked this conversation as resolved.
Show resolved Hide resolved

annealing_strategy (str): "cos" or "linear"; specifies the annealing
strategy: "cos" for cosine annealing, "linear" for linear annealing
(default: "cos")

avg_fn (function, optional): the averaging function used to update
parameters; the function must take in the current value of the
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)

device (torch.device, optional): if provided, the averaged model will be
stored on the `device`. Default: `cpu`
When None is provided, it will infer the `device` from ``pl_module``.

"""

err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:
raise MisconfigurationException(err_msg)
if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
raise MisconfigurationException(err_msg)

if not isinstance(swa_lrs, (float, list)) \
or isinstance(swa_lrs, float) and swa_lrs <= 0 \
carmocca marked this conversation as resolved.
Show resolved Hide resolved
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs):
Borda marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("swa_lrs should be a positive float or a list of positive float.")
Borda marked this conversation as resolved.
Show resolved Hide resolved

if avg_fn is not None and not isinstance(avg_fn, Callable):
raise MisconfigurationException("avg_fn should be callable.")
Borda marked this conversation as resolved.
Show resolved Hide resolved

if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._model_contains_batch_norm = None

@property
def swa_start(self) -> int:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return max(self._swa_epoch_start - 1, 0) # 0-based

@property
def swa_end(self) -> int:
return self._max_epochs - 1 # 0-based

@staticmethod
def pl_module_contains_batch_norm(pl_module):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())

def reset_batch_norm_and_save_state(self, average_model):
"""
Credit to PyTorch Team.
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L115
"""
self.momenta = {}
for module in average_model.modules():
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(
module.running_mean, device=average_model.device, dtype=module.running_mean.dtype)
module.running_var = torch.ones_like(
module.running_var, device=average_model.device, dtype=module.running_var.dtype)
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0
Borda marked this conversation as resolved.
Show resolved Hide resolved

def reset_momenta(self):
"""
Credit to PyTorch Team.
Taken from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164
"""
for bn_module in self.momenta.keys():
bn_module.momentum = self.momenta[bn_module]

def on_before_accelerator_backend_setup(self, trainer, pl_module):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module)
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers

if len(optimizers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.")

if len(lr_schedulers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)

self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.max_epochs += 1

def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch == self.swa_start:
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)

optimizers = trainer.optimizers
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]

self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=self._swa_lrs,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)

rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

if self.swa_start <= trainer.current_epoch <= self.swa_end:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:

# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

# Reset BatchNorm for update
self.reset_batch_norm_and_save_state(pl_module)

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
trainer.num_training_batches += 1
trainer.train_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches
trainer.accumulate_grad_batches = len(trainer.train_dataloader)

def on_train_epoch_end(self, trainer, *args):
trainer.train_loop._skip_backward = False

def on_train_end(self, trainer, pl_module):
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

@staticmethod
def update_parameters(average_model, model, n_averaged, avg_fn):
"""
Credit to PyTorch Team.
Taken from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L103
"""
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
p_swa_ = p_swa.detach()
p_model_ = p_model.detach().to(device)
src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
p_swa_.copy_(src)
n_averaged += 1

@staticmethod
def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'):
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))

@staticmethod
def avg_fn(averaged_model_parameter, model_parameter, num_averaged) -> torch.FloatTensor:
"""
Credit to PyTorch Team.
Taken from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95
"""
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
23 changes: 8 additions & 15 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from contextlib import suppress
from copy import copy
from copy import deepcopy
from contextlib import contextmanager, suppress
from copy import copy, deepcopy

import numpy as np
import torch
Expand All @@ -24,16 +22,10 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import Accumulator
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities import parsing
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -55,6 +47,7 @@ def __init__(self, trainer, multiple_trainloader_mode):
self._curr_step_result = None
self._cur_grad_norm_dict = None
self._multiple_trainloader_mode = multiple_trainloader_mode
self._skip_backward = False
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode

def on_trainer_init(
Expand Down Expand Up @@ -793,7 +786,7 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
self.warning_cache.warn("training_step returned None if it was on purpose, ignore this warning...")
return None

if self.trainer.train_loop.automatic_optimization:
if not self._skip_backward and self.trainer.train_loop.automatic_optimization:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# backward pass
with self.trainer.profiler.profile("model_backward"):
self.backward(result, optimizer, opt_idx)
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_PYTORCH_GREATER_EQUAL_1_6_0,
_PYTORCH_PRUNE_AVAILABLE,
_RPC_AVAILABLE,
_TORCHTEXT_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ def _module_available(module_path: str) -> bool:
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_PYTORCH_GREATER_EQUAL_1_6_0 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
_PYTORCH_PRUNE_AVAILABLE = _module_available('torch.nn.utils.prune')
Loading