Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add StochasticWeightAveraging (SWA) callback #5640

Merged
merged 69 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
4665196
add swa callback
tchaton Jan 24, 2021
be261a8
switch back to 1.6.0
tchaton Jan 24, 2021
d999312
remove optimizer_step
tchaton Jan 24, 2021
e9f89b4
move super
tchaton Jan 24, 2021
c70c35e
update
tchaton Jan 24, 2021
b7ceb5f
forgot update_parameters
tchaton Jan 25, 2021
6a897bc
update on comments
tchaton Jan 25, 2021
051d201
works for ddp
tchaton Jan 25, 2021
0bab9ac
resolve flake8
tchaton Jan 25, 2021
0028fe7
remove set_model
tchaton Jan 25, 2021
3771884
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 25, 2021
c886875
resolve flake8
tchaton Jan 25, 2021
415e5c6
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 25, 2021
663608e
resolve cpu
tchaton Jan 25, 2021
0379f3d
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 25, 2021
9cf56a4
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
774835d
resolve flake8
tchaton Jan 26, 2021
377235f
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 26, 2021
c457608
resolve flake8
tchaton Jan 26, 2021
83eb0b1
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
d25fb36
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
efe2c64
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 26, 2021
b0bb85e
update
tchaton Jan 26, 2021
52d5400
Merge branch 'feat/swa' of https://github.com/PyTorchLightning/pytorc…
tchaton Jan 26, 2021
b9b9264
Merge branch 'release/0.2-dev' of https://github.com/PyTorchLightning…
tchaton Jan 26, 2021
f0452f2
update on comments
tchaton Jan 26, 2021
4ce6ce0
Merge branch 'release/1.2-dev' into feat/swa
tchaton Jan 27, 2021
fce8327
Apply suggestions from code review
Borda Jan 27, 2021
8704023
fix
Borda Jan 27, 2021
33bf190
resolve on comments
tchaton Jan 28, 2021
8782eaa
update
tchaton Jan 28, 2021
def691f
typo
tchaton Jan 28, 2021
172cb68
credit to Pytorch Team
tchaton Jan 28, 2021
d683849
update
tchaton Feb 3, 2021
ddfe9d8
add space to docstring
tchaton Feb 3, 2021
76a057f
resolve some bugs
tchaton Feb 3, 2021
c0e36f0
resolve bug
Feb 3, 2021
2e2c412
Revert finetuning changes
carmocca Feb 4, 2021
702b853
Minor changes
carmocca Feb 4, 2021
5ba2a48
Remove pruning check because it was added in 1.4.0 and that is our mi…
carmocca Feb 4, 2021
bbc377e
Fix SWA indices. Improve tests
carmocca Feb 4, 2021
1b47b99
Use properties. Move skip_backward on fn higher
carmocca Feb 4, 2021
dfcb9ce
Check backward call count
carmocca Feb 4, 2021
ca10126
Address comments
carmocca Feb 4, 2021
a4c2b0f
Add misconfig tests
carmocca Feb 4, 2021
8e3d328
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 4, 2021
50b8326
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 4, 2021
1100db2
pre-commit
carmocca Feb 4, 2021
e2003a5
Apply suggestions from code review
carmocca Feb 5, 2021
bf051fb
Update tests/callbacks/test_swa.py
carmocca Feb 5, 2021
53106cd
Revert "Remove pruning check because it was added in 1.4.0 and that i…
carmocca Feb 5, 2021
d019ba8
Use on_train_end. Fix tests. Add ddp_cpu test
carmocca Feb 5, 2021
8b6d55c
Apply suggestions from code review
Borda Feb 9, 2021
4ffcb42
WIP
carmocca Feb 9, 2021
53136cc
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 9, 2021
9669cbd
yapf
carmocca Feb 9, 2021
3316b12
Revert change
carmocca Feb 9, 2021
4eb4bc4
Minor changes
carmocca Feb 9, 2021
5d67b23
Typo
carmocca Feb 9, 2021
cda86e1
Add beta warning
carmocca Feb 10, 2021
fc4ce1a
Docs fixes
carmocca Feb 10, 2021
781b367
Merge branch 'release/1.2-dev' into feat/swa
Feb 10, 2021
c4d1669
Call base file
Feb 10, 2021
70343c3
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 10, 2021
4568bfa
Skip test on Windows
carmocca Feb 10, 2021
02da3a9
.
Borda Feb 10, 2021
adf7931
..
Borda Feb 10, 2021
e46383e
Merge branch 'release/1.2-dev' into feat/swa
carmocca Feb 10, 2021
7fae40d
Correct variable
carmocca Feb 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Lightning has a few built-in callbacks.
ModelCheckpoint
ProgressBar
ProgressBarBase
StochasticWeightAveraging

----------

Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_6_0

__all__ = [
'BackboneLambdaFinetuningCallback',
Expand All @@ -34,3 +35,8 @@
'ProgressBar',
'ProgressBarBase',
]

if _PYTORCH_GREATER_EQUAL_1_6_0:
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging

__all__ += ['StochasticWeightAveraging']
223 changes: 223 additions & 0 deletions pytorch_lightning/callbacks/swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Stochastic Weight Averaging Callback
====================================

carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
from copy import deepcopy
from typing import Callable, Optional, Union

import torch
from torch import nn

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_6_0, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _PYTORCH_GREATER_EQUAL_1_6_0:
from torch.optim.swa_utils import SWALR


class StochasticWeightAveraging(Callback):

def __init__(
self,
swa_epoch_start: Union[int, float] = 0.8,
swa_lrs: Optional[Union[float, list]] = None,
annealing_epochs: int = 10,
teddykoker marked this conversation as resolved.
Show resolved Hide resolved
annealing_strategy: str = "cos",
avg_fn: Optional[Callable] = None,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
device: Optional[torch.device] = torch.device("cpu"),
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):

Borda marked this conversation as resolved.
Show resolved Hide resolved
r"""Implements averaged model for Stochastic Weight Averaging (SWA) Callbacks.

Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).

AveragedModel class creates a copy of the provided module :attr:`model`
on the device :attr:`device` and allows to compute running averages of the
parameters of the :attr:`model`.

.. note:: `StochasticWeightAveraging` is currently not supported for multiple optimizers / schedulers.
Borda marked this conversation as resolved.
Show resolved Hide resolved

Arguments:

swa_epoch_start (int, float): If provided as int, the average model will start from
tchaton marked this conversation as resolved.
Show resolved Hide resolved
``swa_epoch_start`` epoch. If provided as float between 0 and 1,
the average model will start from ``int(swa_epoch_start * max_epochs)`` epoch
tchaton marked this conversation as resolved.
Show resolved Hide resolved

swa_lrs (float or list): the learning rate value for all param groups
together or separately for each group.

annealing_epochs (int): number of epochs in the annealing phase
(default: 10)
teddykoker marked this conversation as resolved.
Show resolved Hide resolved

annealing_strategy (str): "cos" or "linear"; specifies the annealing
strategy: "cos" for cosine annealing, "linear" for linear annealing
(default: "cos")

avg_fn (function, optional): the averaging function used to update
parameters; the function must take in the current value of the
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)

device (torch.device, optional): if provided, the averaged model will be
stored on the `device`. Default: `cpu`
When None is provided, it will infer the `device` from ``pl_module``

"""

if not isinstance(swa_epoch_start, (float, int)) \
or isinstance(swa_epoch_start, (float, int)) and swa_epoch_start < 0:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException("swa_epoch_start should be positive integer or a float between 0 and 1.")

if isinstance(swa_epoch_start, float):
if swa_epoch_start > 1:
raise MisconfigurationException("swa_epoch_start should be a float between 0 and 1.")

if not isinstance(swa_lrs, (float, list)) \
or isinstance(swa_lrs, float) and swa_lrs <= 0 \
carmocca marked this conversation as resolved.
Show resolved Hide resolved
or isinstance(swa_lrs, list) and not all(lr > 0 for lr in swa_lrs):
raise MisconfigurationException("swa_lrs should be a positive float or a list of positive float.")
Borda marked this conversation as resolved.
Show resolved Hide resolved

if avg_fn is not None and not isinstance(avg_fn, Callable):
raise MisconfigurationException("avg_fn should be function.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if device is not None and not isinstance(device, torch.device):
raise MisconfigurationException(f"device is expected to be None or a torch.device. Found {device}")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._model_contains_batch_norm = None

@property
def swa_model(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return getattr(self, "_average_model")

@staticmethod
def pl_module_contains_batch_norm(pl_module):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for module in pl_module.modules():
if isinstance(module, nn.modules.batchnorm._BatchNorm):
return True
return False
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def reset_batch_norm_and_save_state(self, average_model, device):
self.momenta = {}
for module in average_model.modules():
if isinstance(module, nn.modules.batchnorm._BatchNorm):
running_mean_dtype = module.running_mean.dtype
running_var_dype = module.running_var.dtype
module.running_mean = torch.zeros_like(module.running_mean, device=device, dtype=running_mean_dtype)
module.running_var = torch.ones_like(module.running_var, device=device, dtype=running_var_dype)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0

def apply_momemta(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for bn_module in self.momenta.keys():
bn_module.momentum = self.momenta[bn_module]

def on_fit_start(self, trainer, pl_module):
self._average_model = deepcopy(pl_module).to("cpu")
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers

if len(optimizers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 optimizer.")
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if len(lr_schedulers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 lr_scheduler.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved

self._max_epochs = trainer.max_epochs

# convert float to integer.
if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(self._max_epochs * self._swa_epoch_start)

self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)

if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.max_epochs += 1

def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch == self._swa_epoch_start:
optimizers = trainer.optimizers
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]

self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=self._swa_lrs,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)

rank_zero_warn(f"swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")

trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

elif self._model_contains_batch_norm and trainer.current_epoch == self._max_epochs:
trainer.train_loop.do_backward = False
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.transfer_weights(self._average_model, pl_module)

# perform accumulation over the entire train_dataloader
# By doing so, it won't call optimizer.step()
self._accumulate_grad_batches = trainer.accumulate_grad_batches
trainer.accumulate_grad_batches = len(trainer.train_dataloader)
trainer.train_loop.do_backward = False

if trainer.current_epoch >= self._swa_epoch_start:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

@staticmethod
def update_parameters(average_model, model, n_averaged, avg_fn):
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
p_model_ = p_model.detach().to(device)

if n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(
avg_fn(
p_swa.detach(),
p_model_,
n_averaged.to(device)
)
)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
n_averaged += 1

@staticmethod
def transfer_weights(average_module, pl_module):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for p_swa, p_model in zip(average_module.parameters(), pl_module.parameters()):
device = p_model.device
p_model.detach().copy_(p_swa.to(device))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, trainer, multiple_trainloader_mode):
self.automatic_optimization = True
self._curr_step_result = None
self._cur_grad_norm_dict = None
self.do_backward = True
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._multiple_trainloader_mode = multiple_trainloader_mode
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode

Expand Down Expand Up @@ -838,6 +839,9 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
return result

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
if not self.do_backward:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return

self.trainer.dev_debugger.track_event("backward_call")

# backward can be called manually in the training loop
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_PYTORCH_GREATER_EQUAL_1_6_0,
_RPC_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_XLA_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ def _module_available(module_path: str) -> bool:
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_PYTORCH_GREATER_EQUAL_1_6_0 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
96 changes: 96 additions & 0 deletions tests/callbacks/test_swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from torch import nn
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.utilities import _PYTORCH_GREATER_EQUAL_1_6_0
from tests.base import BoringModel, RandomDataset

if _PYTORCH_GREATER_EQUAL_1_6_0:
from pytorch_lightning.callbacks import StochasticWeightAveraging

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(32, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Linear(32, 2),
)

def training_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

class SwaCheck(StochasticWeightAveraging):

def on_train_epoch_end(self, trainer, pl_module, *_):
super().on_train_epoch_end(trainer, pl_module, *_)
if self._model_contains_batch_norm and trainer.current_epoch == self._max_epochs:
assert self.n_averaged > 0


def train_with_swa(tmpdir, accelerator=None, gpus=None, num_processes=None):
model = TestModel()
swa_callback = SwaCheck(swa_epoch_start=2, swa_lrs=0.005)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=4,
limit_train_batches=4,
callbacks=[swa_callback],
accelerator=accelerator,
gpus=gpus,
num_processes=num_processes
)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(model)

assert swa_callback.swa_model is not None
assert trainer.get_model() == model


@pytest.mark.skipif(not _PYTORCH_GREATER_EQUAL_1_6_0, reason="SWA available from in PyTorch 1.7.0")
def test_stochastic_weight_averaging_callback(tmpdir):
train_with_swa(tmpdir, num_processes=1)


@pytest.mark.skipif(not _PYTORCH_GREATER_EQUAL_1_6_0, reason="SWA available from in PyTorch 1.7.0")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_stochastic_weight_averaging_callback_ddp(tmpdir):
train_with_swa(tmpdir, accelerator="ddp" , gpus=2)


@pytest.mark.skipif(not _PYTORCH_GREATER_EQUAL_1_6_0, reason="SWA available from in PyTorch 1.7.0")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_stochastic_weight_averaging_callback_ddp_spawn(tmpdir):
train_with_swa(tmpdir, accelerator="ddp_spawn" , gpus=2)


@pytest.mark.skipif(not _PYTORCH_GREATER_EQUAL_1_6_0, reason="SWA available from in PyTorch 1.7.0")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires a GPU machine")
def test_stochastic_weight_averaging_callback_1_gpu(tmpdir):
train_with_swa(tmpdir, accelerator=None , gpus=1)