Skip to content

Commit

Permalink
Convert progress bar metrics to float
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 28, 2021
1 parent 817a41d commit 327531a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
1 change: 0 additions & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def init_validation_tqdm(self) -> tqdm:

def init_test_tqdm(self, trainer=None) -> tqdm:
""" Override this to customize the tqdm bar for testing. """
desc = "Testing"
desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing"
bar = tqdm(
desc=desc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, trainer):
self._callback_metrics = MetricsHolder()
self._evaluation_callback_metrics = MetricsHolder(to_float=True)
self._logged_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder(to_float=True)
self.eval_loop_results = []
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage}
self._cached_results[None] = EpochResultStore(trainer, None)
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_metrics(self, key: str) -> Dict:
)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Any) -> None:
def set_metrics(self, key: str, val: Dict) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder.reset(val)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Any
from typing import Any, Dict

import torch

Expand All @@ -33,13 +33,13 @@ def __init__(self, to_float: bool = False):
self.metrics = {}
self._to_float = to_float

def update(self, metrics):
def update(self, metrics: Dict):
self.metrics.update(metrics)

def pop(self, key, default):
def pop(self, key: Any, default: Any):
return self.metrics.pop(key, default)

def reset(self, metrics):
def reset(self, metrics: Dict):
self.metrics = metrics

def convert(self, use_tpu: bool, device: torch.device):
Expand All @@ -48,10 +48,10 @@ def convert(self, use_tpu: bool, device: torch.device):

def _convert(self, current: Any, use_tpu: bool, device: torch.device):
if self._to_float:
return self._convert_to_float(current, use_tpu, device)
return self._convert_to_float(current)
return self._convert_to_tensor(current, use_tpu, device)

def _convert_to_float(self, current, use_tpu: bool, device: torch.device):
def _convert_to_float(self, current):
if isinstance(current, Metric):
current = current.compute().detach()

Expand Down
23 changes: 23 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import call, Mock

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
Expand Down Expand Up @@ -340,3 +341,25 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, tes
)
trainer.test(model)
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])


def test_test_progress_bar_thing(tmpdir):
"""Check tensor gets converted to float"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log('foo', torch.tensor(0.123), prog_bar=True)
self.log('bar', {"baz": torch.tensor([1])}, prog_bar=True)
return super().training_step(batch, batch_idx)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
logger=False,
checkpoint_callback=False,
)
trainer.fit(TestModel())

pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")

0 comments on commit 327531a

Please sign in to comment.