Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for multiple callbacks #6197

Merged
merged 7 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`


- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
Expand All @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))


- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def _run_early_stopping_check(self, trainer, pl_module):
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
should_stop = False
else:
self.wait_count += 1
should_stop = self.wait_count >= self.patience
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -168,5 +167,5 @@ def _run_early_stopping_check(self, trainer, pl_module):
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
trainer.should_stop = should_stop
Borda marked this conversation as resolved.
Show resolved Hide resolved
53 changes: 53 additions & 0 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,56 @@ def validation_epoch_end(self, outputs):
def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
EarlyStopping(mode="unknown_option")


@pytest.mark.parametrize(
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
"callbacks, expected_stop_epoch",
[
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3),
([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')], 3),
],
)
class EarlyStoppingModel(BoringModel):

def __init__(self, expected_end_epoch):
super().__init__()
self.expected_end_epoch = expected_end_epoch

def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', torch.tensor(val_loss))
self.log('cba', torch.tensor(0))

def on_train_end(self) -> None:
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'
Comment on lines +362 to +363
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would drop this and rather check in the test trainer epoch is as expected, so there is not random inference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I originally did but because of how DDP Spawn works, the local trainer's current epoch doesn't seem to be kept in sync which is fair (since it's only kept in sync during trainer on the processes). This is why I had to move it to on_train_end because this happens within the spawn process!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so can we have both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but I'd need to separate out the tests, I don't think its really worth it because it would be a lot of duplication



@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, accelerator, num_processes",
[
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')], 3, None, 1),
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, 'ddp_cpu', 2),
([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')], 3, 'ddp_cpu', 2),
],
)
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
"""
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
"""

model = EarlyStoppingModel(expected_stop_epoch)

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
overfit_batches=0.20,
max_epochs=20,
accelerator=accelerator,
num_processes=num_processes
)
trainer.fit(model)