From 440db6a42fa67060b138338bf83c4f75a1652013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 03:41:28 +0100 Subject: [PATCH 1/5] bugfix --- pytorch_lightning/trainer/optimizers.py | 3 +++ tests/trainer/test_trainer.py | 32 ++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 5cafa438cffcc2..a247fb92cd22f8 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -27,7 +27,10 @@ class TrainerOptimizersMixin(ABC): + _lightning_optimizers: Optional[List[LightningOptimizer]] + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + self._lightning_optimizers = None optim_conf = model.configure_optimizers() if optim_conf is None: rank_zero_warn( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1cd979c863d373..632428e6f1874f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from copy import deepcopy from distutils.version import LooseVersion from pathlib import Path -from unittest.mock import ANY, call, patch +from unittest.mock import ANY, call, patch, Mock import cloudpickle import pytest @@ -1785,3 +1785,33 @@ def backward(self, *args, **kwargs): "training_step", "backward", ] + + +def test_init_optimizers_resets_lightning_optimizers(tmpdir): + """ Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """ + + def compare_optimizers(): + assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] + + class OptimizerSpy(Callback): + def on_fit_start(self, *args, **kwargs): + compare_optimizers() + + model = BoringModel() + model.lr = 0.2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_lr_find=True, + callbacks=[OptimizerSpy()] + ) + + trainer.tune(model) + compare_optimizers() + + trainer.fit(model) + compare_optimizers() + + trainer.max_epochs = 2 # simulate multiple fit calls + trainer.fit(model) + compare_optimizers() From 0a79a6985be57b195a78bc0fb8b609c442a01aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 03:43:28 +0100 Subject: [PATCH 2/5] add changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 703a2c7eec19f8..21414a1b8b329b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,7 +86,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) -- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) +- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) + + +- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) ## [1.2.2] - 2021-03-02 From a2ec567ab78444de7f2747dc1e07221288625d8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 03:45:17 +0100 Subject: [PATCH 3/5] fix formatting --- tests/trainer/test_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 632428e6f1874f..948f2f201019ac 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from copy import deepcopy from distutils.version import LooseVersion from pathlib import Path -from unittest.mock import ANY, call, patch, Mock +from unittest.mock import ANY, call, patch import cloudpickle import pytest @@ -1794,6 +1794,7 @@ def compare_optimizers(): assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] class OptimizerSpy(Callback): + def on_fit_start(self, *args, **kwargs): compare_optimizers() @@ -1803,7 +1804,7 @@ def on_fit_start(self, *args, **kwargs): default_root_dir=tmpdir, max_epochs=1, auto_lr_find=True, - callbacks=[OptimizerSpy()] + callbacks=[OptimizerSpy()], ) trainer.tune(model) From bb5a1696991faf46a61d48a11493555d8478e050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 03:47:39 +0100 Subject: [PATCH 4/5] simplify test --- tests/trainer/test_trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 948f2f201019ac..d123c22383f82a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1793,18 +1793,12 @@ def test_init_optimizers_resets_lightning_optimizers(tmpdir): def compare_optimizers(): assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] - class OptimizerSpy(Callback): - - def on_fit_start(self, *args, **kwargs): - compare_optimizers() - model = BoringModel() model.lr = 0.2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, auto_lr_find=True, - callbacks=[OptimizerSpy()], ) trainer.tune(model) From cc73c5c05771e73b1a62985c0c0995a4382902dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 6 Mar 2021 18:26:57 +0100 Subject: [PATCH 5/5] Update CHANGELOG.md --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21414a1b8b329b..21edc5fec95716 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,9 +86,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) -- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) - - - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))