Skip to content

Commit

Permalink
Prevent flickering progress bar (#6009)
Browse files Browse the repository at this point in the history
* add padding

* fix

* fix

* Update pytorch_lightning/callbacks/progress.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* updated based on suggestion

* changelog

* add test

* fix pep8

* resolve test

* fix code format

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: tchaton <thomas@grid.ai>
  • Loading branch information
3 people authored Feb 17, 2021
1 parent ad36c7b commit 68fd308
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))


- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009))


- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))


Expand Down
29 changes: 27 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,37 @@
from typing import Optional, Union

if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
from tqdm.auto import tqdm as _tqdm
else:
from tqdm import tqdm
from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks import Callback

_PAD_SIZE = 5


class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
"""

@staticmethod
def format_num(n) -> str:
""" Add additional padding to the formatted numbers """
should_be_padded = isinstance(n, (float, str))
if not isinstance(n, str):
n = _tqdm.format_num(n)
if should_be_padded and 'e' not in n:
if '.' not in n and len(n) < _PAD_SIZE:
try:
_ = float(n)
except ValueError:
return n
n += '.'
n += "0" * (_PAD_SIZE - len(n))
return n


class ProgressBarBase(Callback):
r"""
Expand Down
10 changes: 10 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import tqdm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

Expand Down Expand Up @@ -371,3 +372,12 @@ def training_step(self, batch, batch_idx):
pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")


@pytest.mark.parametrize(
"input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'],
['10000', '10000'], ['abc', 'abc']]
)
def test_tqdm_format_num(input_num, expected):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected

0 comments on commit 68fd308

Please sign in to comment.