Skip to content

Commit

Permalink
Enforce an epoch scheduler interval when using SWA (#6588)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and lexierule committed Apr 7, 2021
1 parent bb4fd7e commit 9e5d84d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,15 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)
_scheduler_config = _get_default_scheduler_config()
assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
_scheduler_config["scheduler"] = self._swa_scheduler

if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
trainer.lr_schedulers[0] = _scheduler_config
else:
_scheduler_config = _get_default_scheduler_config()
_scheduler_config["scheduler"] = self._swa_scheduler
trainer.lr_schedulers.append(_scheduler_config)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
Expand Down
29 changes: 25 additions & 4 deletions tests/callbacks/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

if _TORCH_GREATER_EQUAL_1_6:
from pytorch_lightning.callbacks import StochasticWeightAveraging
from torch.optim.swa_utils import SWALR

class SwaTestModel(BoringModel):

def __init__(self, batchnorm: bool = True):
def __init__(self, batchnorm: bool = True, interval: str = "epoch"):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
layers.append(nn.BatchNorm1d(32))
layers += [nn.ReLU(), nn.Linear(32, 2)]
self.layer = nn.Sequential(*layers)
self.interval = interval

def training_step(self, batch, batch_idx):
output = self.forward(batch)
Expand All @@ -46,6 +49,14 @@ def training_step(self, batch, batch_idx):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return {
"optimizer": optimizer,
"scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
"interval": self.interval,
}

class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
Expand All @@ -61,6 +72,10 @@ def transfer_weights(self, *args, **kwargs):
def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
assert trainer.lr_schedulers[0]["interval"] == "epoch"
assert trainer.lr_schedulers[0]["frequency"] == 1

def on_train_epoch_end(self, trainer, *args):
super().on_train_epoch_end(trainer, *args)
Expand Down Expand Up @@ -89,8 +104,8 @@ def on_train_end(self, trainer, pl_module):


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1):
model = SwaTestModel(batchnorm=batchnorm)
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"):
model = SwaTestModel(batchnorm=batchnorm, interval=interval)
swa_start = 2
max_epochs = 5
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
Expand Down Expand Up @@ -147,7 +162,13 @@ def test_swa_callback(tmpdir, batchnorm):
train_with_swa(tmpdir, batchnorm=batchnorm)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
@RunIf(min_torch="1.6.0")
@pytest.mark.parametrize("interval", ("epoch", "step"))
def test_swa_callback_scheduler_step(tmpdir, interval: bool):
train_with_swa(tmpdir, interval=interval)


@RunIf(min_torch="1.6.0")
def test_swa_raises():
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)
Expand Down

0 comments on commit 9e5d84d

Please sign in to comment.