Skip to content

Commit

Permalink
ref: fixes logging for eval steps (#3763)
Browse files Browse the repository at this point in the history
* fixes logging for eval steps
  • Loading branch information
williamFalcon authored Oct 1, 2020
1 parent 5ec00cc commit 7c61fc7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
3 changes: 2 additions & 1 deletion pl_examples/basic_examples/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)
self.log('valid_loss', loss, on_step=True)

def test_step(self, batch, batch_idx):
x, y = batch
Expand Down
15 changes: 5 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def _save_model(self, filepath: str, trainer, pl_module):
raise ValueError(".save_function() not set")

def check_monitor_top_k(self, current) -> bool:
if current is None:
return False

if self.save_top_k == -1:
return True

Expand Down Expand Up @@ -421,7 +424,7 @@ def _add_backward_monitor_support(self, trainer):
if self.monitor is None and 'checkpoint_on' in metrics:
self.monitor = 'checkpoint_on'

if self.save_top_k is None:
if self.save_top_k is None and self.monitor is not None:
self.save_top_k = 1

def _validate_monitor_key(self, trainer):
Expand Down Expand Up @@ -486,15 +489,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)

if current is None:
m = f"Can save best model only with {self.monitor} available, skipping."
if self.monitor == 'checkpoint_on':
m = (
'No checkpoint_on found. HINT: Did you set it in '
'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?'
)
rank_zero_warn(m, RuntimeWarning)
elif self.check_monitor_top_k(current):
if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:
rank_zero_info(
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def _log_on_evaluation_epoch_end_metrics(self):
# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))

# actually log
self.log_metrics(logger_metrics, {}, step=self.trainer.global_step)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics
Expand Down

0 comments on commit 7c61fc7

Please sign in to comment.