From feae0a05c4870d0102bad87c045d9fb8b8482920 Mon Sep 17 00:00:00 2001 From: Yuan-Hang Zhang Date: Fri, 2 Apr 2021 16:40:41 +0800 Subject: [PATCH] Fix validation progress counter with check_val_every_n_epoch > 1 (#5952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: rohitgr7 Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/progress.py | 5 +- .../flags/test_check_val_every_n_epoch.py | 53 +++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 tests/trainer/flags/test_check_val_every_n_epoch.py diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 46331e004c1c7..649243f7600ba 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -146,9 +146,10 @@ def total_val_batches(self) -> int: validation dataloader is of infinite size. """ total_val_batches = 0 - if not self.trainer.disable_validation: - is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0 + if self.trainer.enable_validation: + is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0 + return total_val_batches @property diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py new file mode 100644 index 0000000000000..f7f1403ecdbfd --- /dev/null +++ b/tests/trainer/flags/test_check_val_every_n_epoch.py @@ -0,0 +1,53 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.trainer.states import TrainerState +from tests.helpers import BoringModel + + +@pytest.mark.parametrize( + 'max_epochs,expected_val_loop_calls,expected_val_batches', [ + (1, 0, [0]), + (4, 2, [0, 2, 0, 2]), + (5, 2, [0, 2, 0, 2, 0]), + ] +) +def test_check_val_every_n_epoch(tmpdir, max_epochs, expected_val_loop_calls, expected_val_batches): + + class TestModel(BoringModel): + val_epoch_calls = 0 + val_batches = [] + + def on_train_epoch_end(self, *args, **kwargs): + self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches) + + def on_validation_epoch_start(self) -> None: + self.val_epoch_calls += 1 + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=max_epochs, + num_sanity_val_steps=0, + limit_val_batches=2, + check_val_every_n_epoch=2, + logger=False, + ) + trainer.fit(model) + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert model.val_epoch_calls == expected_val_loop_calls + assert model.val_batches == expected_val_batches