From a5a17652293e53e3b14a8d68029afb91bd9705c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 09:51:30 +0200 Subject: [PATCH 01/10] repro script --- pl_examples/bug_report_model.py | 42 ++++++++------------------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index f83cc70d44526..7d6ae177791c4 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,8 +1,5 @@ -import os - import torch from torch.utils.data import DataLoader, Dataset - from pytorch_lightning import LightningModule, Trainer @@ -28,38 +25,19 @@ def forward(self, x): def training_step(self, batch, batch_idx): loss = self(batch).sum() - self.log("train_loss", loss) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("valid_loss", loss) - - def test_step(self, batch, batch_idx): - loss = self(batch).sum() - self.log("test_loss", loss) + print(f"training_step, {batch_idx=}: {loss=}") + return loss def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) + return torch.optim.SGD(self.parameters(), lr=0.1) - -def run(): - train_data = DataLoader(RandomDataset(32, 64), batch_size=2) - val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - weights_summary=None, - ) - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - trainer.test(model, dataloaders=test_data) + def training_epoch_end(self, outputs): + print("training_epoch_end:", [id(x["loss"]) for x in outputs]) if __name__ == "__main__": - run() + dl = DataLoader(RandomDataset(32, 100), batch_size=10) + + model = BoringModel() + trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0) + trainer.fit(model, dl) From 47861d7a054a98870dabda09e2b2ba85cf24dd86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 11:36:15 +0200 Subject: [PATCH 02/10] apply fix + tests --- pytorch_lightning/loops/batch/training_batch_loop.py | 3 ++- tests/trainer/loops/test_training_loop_flow_dict.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 3b49fe4a1f63a..360313546485d 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -14,6 +14,7 @@ from collections import OrderedDict from contextlib import contextmanager +from copy import copy from functools import partial, update_wrapper from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple @@ -318,7 +319,7 @@ def _training_step( closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it loss = closure_loss.detach().clone() - return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) + return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=copy(training_step_output)) def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Optional[ResultCollection]: """Adds the :param:`training_step_output` to the trainer's results diff --git a/tests/trainer/loops/test_training_loop_flow_dict.py b/tests/trainer/loops/test_training_loop_flow_dict.py index f064dacb78844..58ccc8b6adb36 100644 --- a/tests/trainer/loops/test_training_loop_flow_dict.py +++ b/tests/trainer/loops/test_training_loop_flow_dict.py @@ -108,7 +108,7 @@ def training_step(self, batch, batch_idx): acc = acc + batch_idx self.training_step_called = True - out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]} + out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx} return out def training_epoch_end(self, outputs): @@ -116,11 +116,13 @@ def training_epoch_end(self, outputs): # verify we saw the current num of batches assert len(outputs) == 2 + assert len(set(id(output) for output in outputs)) == 2 + assert [output["batch_idx"] for output in outputs] == [0, 1] for b in outputs: assert isinstance(b, dict) assert self.count_num_graphs(b) == 0 - assert {"random_things", "loss"} == set(b.keys()) + assert {"random_things", "loss", "batch_idx"} == set(b.keys()) def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) @@ -155,7 +157,7 @@ def training_step(self, batch, batch_idx): acc = acc + batch_idx self.training_step_called = True - self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]} + self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx} return self.out def training_step_end(self, tr_step_output): @@ -169,11 +171,13 @@ def training_epoch_end(self, outputs): # verify we saw the current num of batches assert len(outputs) == 2 + assert len(set(id(output) for output in outputs)) == 2 + assert [output["batch_idx"] for output in outputs] == [0, 1] for b in outputs: assert isinstance(b, dict) assert self.count_num_graphs(b) == 0 - assert {"random_things", "loss"} == set(b.keys()) + assert {"random_things", "loss", "batch_idx"} == set(b.keys()) def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) From 5c7daace2e0d2cc3df143069a8dbd30762ef9b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 11:37:15 +0200 Subject: [PATCH 03/10] reset bugreport model --- pl_examples/bug_report_model.py | 42 +++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 7d6ae177791c4..f83cc70d44526 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,5 +1,8 @@ +import os + import torch from torch.utils.data import DataLoader, Dataset + from pytorch_lightning import LightningModule, Trainer @@ -25,19 +28,38 @@ def forward(self, x): def training_step(self, batch, batch_idx): loss = self(batch).sum() - print(f"training_step, {batch_idx=}: {loss=}") - return loss + self.log("train_loss", loss) + return {"loss": loss} - def configure_optimizers(self): - return torch.optim.SGD(self.parameters(), lr=0.1) + def validation_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("valid_loss", loss) - def training_epoch_end(self, outputs): - print("training_epoch_end:", [id(x["loss"]) for x in outputs]) + def test_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("test_loss", loss) + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) -if __name__ == "__main__": - dl = DataLoader(RandomDataset(32, 100), batch_size=10) + +def run(): + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + val_data = DataLoader(RandomDataset(32, 64), batch_size=2) + test_data = DataLoader(RandomDataset(32, 64), batch_size=2) model = BoringModel() - trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0) - trainer.fit(model, dl) + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + trainer.test(model, dataloaders=test_data) + + +if __name__ == "__main__": + run() From 1d3cae7bd3b217efe97323fe9678c700a0e8cff5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jul 2021 09:42:02 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/loops/test_training_loop_flow_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/loops/test_training_loop_flow_dict.py b/tests/trainer/loops/test_training_loop_flow_dict.py index 58ccc8b6adb36..ab4d7979bbf39 100644 --- a/tests/trainer/loops/test_training_loop_flow_dict.py +++ b/tests/trainer/loops/test_training_loop_flow_dict.py @@ -116,7 +116,7 @@ def training_epoch_end(self, outputs): # verify we saw the current num of batches assert len(outputs) == 2 - assert len(set(id(output) for output in outputs)) == 2 + assert len({id(output) for output in outputs}) == 2 assert [output["batch_idx"] for output in outputs] == [0, 1] for b in outputs: @@ -171,7 +171,7 @@ def training_epoch_end(self, outputs): # verify we saw the current num of batches assert len(outputs) == 2 - assert len(set(id(output) for output in outputs)) == 2 + assert len({id(output) for output in outputs}) == 2 assert [output["batch_idx"] for output in outputs] == [0, 1] for b in outputs: From d3278fd45490c8464ef46951768132fe0944bfa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 11:43:26 +0200 Subject: [PATCH 05/10] update chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a33cdd17031bd..5c320befec7b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613)) - From ecf42db5bcc17b187fa8a56ea27136c9404105c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 15:08:57 +0200 Subject: [PATCH 06/10] move copy() as suggested by @carmocca --- pytorch_lightning/loops/batch/training_batch_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 360313546485d..4859b1060c7c8 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -144,12 +144,12 @@ def advance(self, batch, batch_idx, dataloader_idx): result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: - self.batch_outputs[opt_idx].append(result.training_step_output) + self.batch_outputs[opt_idx].append(copy(result.training_step_output)) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) if result: - self.batch_outputs[0].append(result.training_step_output) + self.batch_outputs[0].append(copy(result.training_step_output)) def teardown(self) -> None: # release memory @@ -319,7 +319,7 @@ def _training_step( closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it loss = closure_loss.detach().clone() - return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=copy(training_step_output)) + return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Optional[ResultCollection]: """Adds the :param:`training_step_output` to the trainer's results From 6f3afb20786dc21a4665d82718b4d0c6b10a4684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 29 Jul 2021 22:52:32 +0200 Subject: [PATCH 07/10] gc collect on trainer teardown --- pytorch_lightning/trainer/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 212086dc0f678..d22898163a1b7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer to automate the training.""" +import gc import logging import os import traceback @@ -982,6 +983,7 @@ def _post_dispatch(self): self.accelerator.teardown() self._active_loop.teardown() self.logger_connector.teardown() + gc.collect() def _dispatch(self): if self.evaluating: From 1a03fa5b2700828662a092bc5a4362c097e17891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 30 Jul 2021 02:40:57 +0200 Subject: [PATCH 08/10] add a comment for gc.collect() --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d22898163a1b7..513ea46f4695c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -983,6 +983,7 @@ def _post_dispatch(self): self.accelerator.teardown() self._active_loop.teardown() self.logger_connector.teardown() + # force release any leftover memory, including CUDA tensors gc.collect() def _dispatch(self): From d1029d63edb12e4f5b60604fcfd28aca357d2b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 30 Jul 2021 11:56:39 +0200 Subject: [PATCH 09/10] move gc collect to the test --- pytorch_lightning/trainer/trainer.py | 3 --- tests/trainer/test_trainer.py | 6 ++++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 513ea46f4695c..212086dc0f678 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Trainer to automate the training.""" -import gc import logging import os import traceback @@ -983,8 +982,6 @@ def _post_dispatch(self): self.accelerator.teardown() self._active_loop.teardown() self.logger_connector.teardown() - # force release any leftover memory, including CUDA tensors - gc.collect() def _dispatch(self): if self.evaluating: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 86ca0d1fc5618..2378cf196dbbc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import logging import math import os @@ -1885,6 +1886,8 @@ def on_epoch_start(self, trainer, *_): assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") assert trainer.callback_metrics["train_loss"].device == torch.device("cpu") + # before measuring the memory force release any leftover allocations, including CUDA tensors + gc.collect() memory_1 = torch.cuda.memory_allocated(0) deepcopy(trainer) memory_2 = torch.cuda.memory_allocated(0) @@ -1892,6 +1895,9 @@ def on_epoch_start(self, trainer, *_): trainer_2 = Trainer(**trainer_kwargs) trainer_2.fit(model) + + # before measuring the memory force release any leftover allocations, including CUDA tensors + gc.collect() memory_3 = torch.cuda.memory_allocated(0) assert initial == memory_1 == memory_3 From 437898c92cfdca5127f2e75e9197c07b66820a5c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Jul 2021 12:22:23 +0200 Subject: [PATCH 10/10] Split memory checks for better errors on failure --- tests/trainer/test_trainer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2378cf196dbbc..a5ac053395515 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1889,9 +1889,14 @@ def on_epoch_start(self, trainer, *_): # before measuring the memory force release any leftover allocations, including CUDA tensors gc.collect() memory_1 = torch.cuda.memory_allocated(0) + assert memory_1 == initial + deepcopy(trainer) + + # before measuring the memory force release any leftover allocations, including CUDA tensors + gc.collect() memory_2 = torch.cuda.memory_allocated(0) - assert memory_1 == memory_2 == initial + assert memory_2 == initial trainer_2 = Trainer(**trainer_kwargs) trainer_2.fit(model) @@ -1899,5 +1904,4 @@ def on_epoch_start(self, trainer, *_): # before measuring the memory force release any leftover allocations, including CUDA tensors gc.collect() memory_3 = torch.cuda.memory_allocated(0) - - assert initial == memory_1 == memory_3 + assert memory_3 == initial