Skip to content

Commit

Permalink
Convert progress bar metrics to float (#5692)
Browse files Browse the repository at this point in the history
* MetricsHolder(to_float=True)

* Update CHANGELOG

* Update tests/callbacks/test_progress_bar.py

* flake8

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
carmocca and Borda authored Feb 11, 2021
1 parent 7b00894 commit e8190e8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


- Changed `ModelCheckpoint` version suffixes to start at 1 ([5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))
- Changed `ModelCheckpoint` version suffixes to start at 1 ([#5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))


- Progress bar metrics tensors are now converted to float ([#5692](https://github.com/PyTorchLightning/pytorch-lightning/pull/5692))


- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from copy import deepcopy
from pprint import pprint
from typing import Any, Dict, Iterable, Union
from typing import Dict, Iterable, Union

import torch

Expand All @@ -37,7 +37,7 @@ def __init__(self, trainer):
self._callback_metrics = MetricsHolder()
self._evaluation_callback_metrics = MetricsHolder(to_float=True)
self._logged_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder(to_float=True)
self.eval_loop_results = []
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage}
self._cached_results[None] = EpochResultStore(trainer, None)
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_metrics(self, key: str) -> Dict:
)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Any) -> None:
def set_metrics(self, key: str, val: Dict) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder.reset(val)

Expand Down
25 changes: 25 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import call, Mock

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
Expand Down Expand Up @@ -349,3 +350,27 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, tes
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])


def test_tensor_to_float_conversion(tmpdir):
"""Check tensor gets converted to float"""

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
self.log('foo', torch.tensor(0.123), prog_bar=True)
self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True)
return super().training_step(batch, batch_idx)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
logger=False,
checkpoint_callback=False,
)
trainer.fit(TestModel())

pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")

0 comments on commit e8190e8

Please sign in to comment.