Skip to content

Commit

Permalink
Merge cc73c5c into facfda8
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 6, 2021
2 parents facfda8 + cc73c5c commit a16a75b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


## [1.2.2] - 2021-03-02
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

class TrainerOptimizersMixin(ABC):

_lightning_optimizers: Optional[List[LightningOptimizer]]

def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
Expand Down
25 changes: 25 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,3 +1803,28 @@ def backward(self, *args, **kwargs):
"training_step",
"backward",
]


def test_init_optimizers_resets_lightning_optimizers(tmpdir):
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """

def compare_optimizers():
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]

model = BoringModel()
model.lr = 0.2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
auto_lr_find=True,
)

trainer.tune(model)
compare_optimizers()

trainer.fit(model)
compare_optimizers()

trainer.max_epochs = 2 # simulate multiple fit calls
trainer.fit(model)
compare_optimizers()

0 comments on commit a16a75b

Please sign in to comment.