Skip to content

Commit

Permalink
fix best score on wrong device in EarlyStopping callback (#8295)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jul 6, 2021
1 parent 8fead58 commit 1e1d182
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer: 'pl.Trainer') -> None:
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

should_stop, reason = self._evalute_stopping_criteria(current, trainer)
should_stop, reason = self._evalute_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
Expand All @@ -210,7 +210,7 @@ def _run_early_stopping_check(self, trainer: 'pl.Trainer') -> None:
if reason and self.verbose:
self._log_info(trainer, reason)

def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):
Expand All @@ -233,7 +233,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
" Signaling Trainer to stop."
)
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
should_stop = False
reason = self._improvement_message(current)
self.best_score = current
Expand Down

0 comments on commit 1e1d182

Please sign in to comment.