Skip to content

Commit

Permalink
Fix for multiple callbacks (Lightning-AI#6197)
Browse files Browse the repository at this point in the history
* Fix for multiple callbacks

* Add CHANGELOG.md

* Remove old params

* Skip tests on windows using ddp

* Change name of the variable to not clash with should stop, which is separate

* Apply suggestions from code review

* Fix params

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
2 people authored and kaushikb11 committed Mar 2, 2021
1 parent 0151ab6 commit 246c65b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))


- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,12 @@ def _run_early_stopping_check(self, trainer, pl_module):
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
should_stop = False
else:
self.wait_count += 1
should_stop = self.wait_count >= self.patience

if bool(should_stop):
if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
trainer.should_stop = should_stop
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
55 changes: 55 additions & 0 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import pickle
import sys
from unittest import mock

import cloudpickle
Expand Down Expand Up @@ -344,3 +345,57 @@ def validation_epoch_end(self, outputs):
def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
EarlyStopping(mode="unknown_option")


class EarlyStoppingModel(BoringModel):

def __init__(self, expected_end_epoch):
super().__init__()
self.expected_end_epoch = expected_end_epoch

def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', torch.tensor(val_loss))
self.log('cba', torch.tensor(0))

def on_train_end(self) -> None:
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'


@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, accelerator, num_processes",
[
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')], 3, None, 1),
pytest.param([EarlyStopping(monitor='abc'),
EarlyStopping(monitor='cba', patience=3)],
3,
'ddp_cpu',
2,
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
pytest.param([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')],
3,
'ddp_cpu',
2,
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
],
)
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
"""
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
"""

model = EarlyStoppingModel(expected_stop_epoch)

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
overfit_batches=0.20,
max_epochs=20,
accelerator=accelerator,
num_processes=num_processes
)
trainer.fit(model)

0 comments on commit 246c65b

Please sign in to comment.