From 1f83436b71ec02c2a5a2b829978c3391d13a0f26 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 25 Apr 2021 20:30:00 -0700 Subject: [PATCH 1/9] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8a007086fb380..57158d6354251 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -189,10 +189,10 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): - if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch: + if self.trainer.train_dataloader is None: self.trainer.reset_train_dataloader(model) - if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: + if self.trainer.val_dataloaders is None: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): From e37b9d59db14f82dd40c11437fec5575551c5fe9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 25 Apr 2021 21:14:24 -0700 Subject: [PATCH 2/9] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 33 +++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 2a744c9c05c73..f8070033b9145 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -25,6 +25,7 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6 from pytorch_lightning.utilities.data import has_iterable_dataset, has_len @@ -1199,7 +1200,16 @@ def test_dataloaders_load_every_epoch(tmpdir): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): - model = EvalModelTemplate() + class TestModel(BoringModel): + + def validation_step(self, batch, batch_idx): + self.log("dummy_val", 5.0) + return super().validation_step(batch, batch_idx) + + model = TestModel() + + # This callback tests that the evaluation metrics are available by the time we run checkpointing + checkpoint_callback = ModelCheckpoint(monitor="dummy_val", save_top_k=1) # logger file to get meta trainer = Trainer( @@ -1209,26 +1219,29 @@ def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir): num_sanity_val_steps=0, reload_dataloaders_every_epoch=True, max_epochs=3, + callbacks=[checkpoint_callback], ) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" trainer.test() - assert len(trainer.dev_debugger.val_dataloader_calls) == 3 + assert len(trainer.dev_debugger.val_dataloader_calls) == 4 assert len(trainer.dev_debugger.train_dataloader_calls) == 3 assert len(trainer.dev_debugger.test_dataloader_calls) == 1 - # verify the sequence + # # verify the sequence calls = trainer.dev_debugger.dataloader_sequence_calls + expected_sequence = [ - 'train_dataloader', - 'val_dataloader', - 'train_dataloader', - 'val_dataloader', - 'train_dataloader', - 'val_dataloader', - 'test_dataloader', + "train_dataloader", + "val_dataloader", + "val_dataloader", + "train_dataloader", + "val_dataloader", + "train_dataloader", + "val_dataloader", + "test_dataloader", ] for call, expected in zip(calls, expected_sequence): assert call['name'] == expected From 8176238209ede5782500a45e5f351bb8e1d49faa Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 25 Apr 2021 21:47:38 -0700 Subject: [PATCH 3/9] changelog --- CHANGELOG.md | 4 ++++ tests/trainer/test_dataloaders.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04c25eefcad66..e23630076eec6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -281,6 +281,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + +- Fixed attaching train and validation dataloaders when `reload_dataloaders_every_epoch=True` and `num_sanity_val_steps=0` ([#7207](https://github.com/PyTorchLightning/pytorch-lightning/pull/7207)) + + - Added a barrier in the accelerator `teardown` to synchronize processes before execution finishes ([#6814](https://github.com/PyTorchLightning/pytorch-lightning/pull/6814)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f8070033b9145..69a38f3d3c5da 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1230,7 +1230,7 @@ def validation_step(self, batch, batch_idx): assert len(trainer.dev_debugger.train_dataloader_calls) == 3 assert len(trainer.dev_debugger.test_dataloader_calls) == 1 - # # verify the sequence + # verify the sequence calls = trainer.dev_debugger.dataloader_sequence_calls expected_sequence = [ From e218dabd14c9a8a956ed087a54f689a4258dfc1b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 25 Apr 2021 22:31:25 -0700 Subject: [PATCH 4/9] delay reload --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- tests/trainer/test_dataloaders.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 9f10ca8306ff3..1df12cef2fd71 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -60,7 +60,7 @@ def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[U max_batches = self.trainer.num_test_batches else: # val - if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: + if self.trainer.val_dataloaders is None or (self.trainer.reload_dataloaders_every_epoch and self.trainer.current_epoch > 0): self.trainer.reset_val_dataloader(model) if self.trainer.sanity_checking: self.trainer.num_sanity_val_batches = [ diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 69a38f3d3c5da..97f14e0e68622 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1226,7 +1226,7 @@ def validation_step(self, batch, batch_idx): trainer.test() - assert len(trainer.dev_debugger.val_dataloader_calls) == 4 + assert len(trainer.dev_debugger.val_dataloader_calls) == 3 assert len(trainer.dev_debugger.train_dataloader_calls) == 3 assert len(trainer.dev_debugger.test_dataloader_calls) == 1 @@ -1236,7 +1236,6 @@ def validation_step(self, batch, batch_idx): expected_sequence = [ "train_dataloader", "val_dataloader", - "val_dataloader", "train_dataloader", "val_dataloader", "train_dataloader", From c585f23f6d65ecd17161bd9161daf9a47edf68d3 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 25 Apr 2021 22:41:48 -0700 Subject: [PATCH 5/9] go back --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- tests/trainer/test_dataloaders.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1df12cef2fd71..9f10ca8306ff3 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -60,7 +60,7 @@ def get_evaluation_dataloaders(self) -> Tuple[Optional[List[DataLoader]], List[U max_batches = self.trainer.num_test_batches else: # val - if self.trainer.val_dataloaders is None or (self.trainer.reload_dataloaders_every_epoch and self.trainer.current_epoch > 0): + if self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) if self.trainer.sanity_checking: self.trainer.num_sanity_val_batches = [ diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 97f14e0e68622..97006a047bbf7 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1226,7 +1226,7 @@ def validation_step(self, batch, batch_idx): trainer.test() - assert len(trainer.dev_debugger.val_dataloader_calls) == 3 + assert len(trainer.dev_debugger.val_dataloader_calls) == 4 assert len(trainer.dev_debugger.train_dataloader_calls) == 3 assert len(trainer.dev_debugger.test_dataloader_calls) == 1 @@ -1234,13 +1234,14 @@ def validation_step(self, batch, batch_idx): calls = trainer.dev_debugger.dataloader_sequence_calls expected_sequence = [ - "train_dataloader", - "val_dataloader", - "train_dataloader", - "val_dataloader", - "train_dataloader", - "val_dataloader", - "test_dataloader", + 'train_dataloader', + 'val_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'train_dataloader', + 'val_dataloader', + 'test_dataloader', ] for call, expected in zip(calls, expected_sequence): assert call['name'] == expected From 06587f86a430e58505252483c4e1159c6edad1e9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 28 Apr 2021 21:53:58 -0700 Subject: [PATCH 6/9] comments --- pytorch_lightning/trainer/training_loop.py | 6 ++++++ tests/trainer/test_dataloaders.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 57158d6354251..94b1b8fffee33 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -189,6 +189,12 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): + """ + Resets train and val dataloaders if none are attached to the trainer. + + The val dataloader must be initialized before training loop starts, as the training loop + inspects the val dataloader to determine whether to run the evaluation loop. + """ if self.trainer.train_dataloader is None: self.trainer.reset_train_dataloader(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 97006a047bbf7..d3205242656f5 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1112,6 +1112,14 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): expected_sequence = [ 'val_dataloader', 'train_dataloader', + # This has subsequent calls to val_dataloader + # because the training loop runs the evaluation loop, + # which reloads the val dataloader again. + # We cannot yet rely on trainer.current_epoch=0 to skip reloading + # the val dataloader on the first epoch because this only tracks the training epoch + # meaning multiple passes through the validation data within a single training epoch + # would not have the datalodaer reloaded. + # This breaks the assumption behind reload_dataloaders_every_epoch=True 'val_dataloader', 'val_dataloader', 'val_dataloader', From 727628ae032ee56b1ba0c64f6a4db19f3f93cd7a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 28 Apr 2021 21:55:18 -0700 Subject: [PATCH 7/9] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 94b1b8fffee33..cef70e2bf7811 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -188,7 +188,7 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() - def reset_train_val_dataloaders(self, model): + def reset_train_val_dataloaders(self, model) -> None: """ Resets train and val dataloaders if none are attached to the trainer. From 783168804e78248c8900240cc9effdfeb423e1f9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 28 Apr 2021 21:58:28 -0700 Subject: [PATCH 8/9] Update test_dataloaders.py --- tests/trainer/test_dataloaders.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d3205242656f5..83b9a7dbdecb9 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1112,14 +1112,6 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): expected_sequence = [ 'val_dataloader', 'train_dataloader', - # This has subsequent calls to val_dataloader - # because the training loop runs the evaluation loop, - # which reloads the val dataloader again. - # We cannot yet rely on trainer.current_epoch=0 to skip reloading - # the val dataloader on the first epoch because this only tracks the training epoch - # meaning multiple passes through the validation data within a single training epoch - # would not have the datalodaer reloaded. - # This breaks the assumption behind reload_dataloaders_every_epoch=True 'val_dataloader', 'val_dataloader', 'val_dataloader', @@ -1244,6 +1236,14 @@ def validation_step(self, batch, batch_idx): expected_sequence = [ 'train_dataloader', 'val_dataloader', + # This has subsequent calls to val_dataloader + # because the training loop runs the evaluation loop, + # which reloads the val dataloader again. + # We cannot yet rely on trainer.current_epoch=0 to skip reloading + # the val dataloader on the first epoch because this only tracks the training epoch + # meaning multiple passes through the validation data within a single training epoch + # would not have the datalodaer reloaded. + # This breaks the assumption behind reload_dataloaders_every_epoch=True 'val_dataloader', 'train_dataloader', 'val_dataloader', From c97d172fbaabdb5a6397fd0935977bf8ed59e956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 30 Apr 2021 00:07:02 +0200 Subject: [PATCH 9/9] Update tests/trainer/test_dataloaders.py --- tests/trainer/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 83b9a7dbdecb9..a935fbd401e7e 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1242,7 +1242,7 @@ def validation_step(self, batch, batch_idx): # We cannot yet rely on trainer.current_epoch=0 to skip reloading # the val dataloader on the first epoch because this only tracks the training epoch # meaning multiple passes through the validation data within a single training epoch - # would not have the datalodaer reloaded. + # would not have the dataloader reloaded. # This breaks the assumption behind reload_dataloaders_every_epoch=True 'val_dataloader', 'train_dataloader',