From fc1861d7863123a44092d12fd4c9cef5e431f363 Mon Sep 17 00:00:00 2001 From: chaton Date: Wed, 16 Dec 2020 22:07:35 +0100 Subject: [PATCH] [bugfix] remove nan loss in manual optimization (#5121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * remove nan loss whe missing * Update pytorch_lightning/core/lightning.py Co-authored-by: Carlos MocholĂ­ * Apply suggestions from code review Co-authored-by: Carlos MocholĂ­ Co-authored-by: Rohit Gupta --- pytorch_lightning/core/lightning.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4e0d366e8a7c5..49525d2022fc9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1392,12 +1392,15 @@ def get_progress_bar_dict(self): """ # call .item() only once but store elements without graphs running_train_loss = self.trainer.train_loop.running_loss.mean() - avg_training_loss = ( - running_train_loss.cpu().item() - if running_train_loss is not None - else float("NaN") - ) - tqdm_dict = {"loss": "{:.3g}".format(avg_training_loss)} + avg_training_loss = None + if running_train_loss is not None: + avg_training_loss = running_train_loss.cpu().item() + elif self.trainer.train_loop.automatic_optimization: + avg_training_loss = float('NaN') + + tqdm_dict = {} + if avg_training_loss is not None: + tqdm_dict["loss"] = f"{avg_training_loss:.3g}" if self.trainer.truncated_bptt_steps is not None: tqdm_dict["split_idx"] = self.trainer.split_idx