From cacaa6817e694f8dfc2cc241cbb121a8557951cf Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 13 Jun 2021 07:32:16 +0530 Subject: [PATCH] Fix(Early Stopping): move best score to device --- pytorch_lightning/callbacks/early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index e40bb7180c8c8..b6bff43fd6317 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,7 +196,7 @@ def _run_early_stopping_check(self, trainer) -> None: # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - should_stop, reason = self._evalute_stopping_criteria(current) + should_stop, reason = self._evalute_stopping_criteria(current, trainer) # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) @@ -206,7 +206,7 @@ def _run_early_stopping_check(self, trainer) -> None: if reason and self.verbose: self._log_info(trainer, reason) - def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: + def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): @@ -229,7 +229,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." " Signaling Trainer to stop." ) - elif self.monitor_op(current - self.min_delta, self.best_score): + elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)): should_stop = False reason = self._improvement_message(current) self.best_score = current