Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert progress bar metrics to float #5692

Merged
merged 5 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


- Changed `ModelCheckpoint` version suffixes to start at 1 ([5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))
- Changed `ModelCheckpoint` version suffixes to start at 1 ([#5008](https://github.com/PyTorchLightning/pytorch-lightning/pull/5008))


- Progress bar metrics tensors are now converted to float ([#5692](https://github.com/PyTorchLightning/pytorch-lightning/pull/5692))


- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from copy import deepcopy
from pprint import pprint
from typing import Any, Dict, Iterable, Union
from typing import Dict, Iterable, Union

import torch

Expand All @@ -37,7 +37,7 @@ def __init__(self, trainer):
self._callback_metrics = MetricsHolder()
self._evaluation_callback_metrics = MetricsHolder(to_float=True)
self._logged_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder(to_float=True)
self.eval_loop_results = []
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage}
self._cached_results[None] = EpochResultStore(trainer, None)
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_metrics(self, key: str) -> Dict:
)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Any) -> None:
def set_metrics(self, key: str, val: Dict) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder.reset(val)

Expand Down
23 changes: 23 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import call, Mock

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
Expand Down Expand Up @@ -349,3 +350,25 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, tes
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])


def test_tensor_to_float_conversion(tmpdir):
"""Check tensor gets converted to float"""
class TestModel(BoringModel):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def training_step(self, batch, batch_idx):
self.log('foo', torch.tensor(0.123), prog_bar=True)
self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True)
return super().training_step(batch, batch_idx)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
logger=False,
checkpoint_callback=False,
)
trainer.fit(TestModel())

pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")