Skip to content

Commit

Permalink
Fix trainer not resetting lightning_optimizers (#6372)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and lexierule committed Mar 9, 2021
1 parent 60d0c95 commit be0351f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330)


- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


## [1.2.2] - 2021-03-02

### Added
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC
from typing import List, Optional, Tuple, Dict, Any
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import optim
Expand All @@ -27,7 +27,10 @@

class TrainerOptimizersMixin(ABC):

_lightning_optimizers: Optional[List[LightningOptimizer]]

def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
Expand Down
25 changes: 25 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,3 +1825,28 @@ def backward(self, *args, **kwargs):
"training_step",
"backward",
]


def test_init_optimizers_resets_lightning_optimizers(tmpdir):
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """

def compare_optimizers():
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]

model = BoringModel()
model.lr = 0.2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
auto_lr_find=True,
)

trainer.tune(model)
compare_optimizers()

trainer.fit(model)
compare_optimizers()

trainer.max_epochs = 2 # simulate multiple fit calls
trainer.fit(model)
compare_optimizers()

0 comments on commit be0351f

Please sign in to comment.