Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent flickering progress bar #6009

Merged
merged 14 commits into from
Feb 17, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923))


- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009))


## [1.1.8] - 2021-02-08

### Fixed
Expand Down
24 changes: 22 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,32 @@
from typing import Optional, Union

if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
from tqdm.auto import tqdm as _tqdm
else:
from tqdm import tqdm
from tqdm import tqdm as _tqdm

from pytorch_lightning.callbacks import Callback

_PAD_SIZE = 5


class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
"""
@staticmethod
def format_num(n) -> str:
""" Add additional padding to the formatted numbers """
should_be_padded = isinstance(n, (float, str))
if not isinstance(n, str):
n = _tqdm.format_num(n)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if should_be_padded and 'e' not in n:
if '.' not in n:
n += '.'
n += "0" * (_PAD_SIZE - len(n))
return n


class ProgressBarBase(Callback):
r"""
Expand Down
14 changes: 14 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import tqdm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel
Expand Down Expand Up @@ -372,3 +373,16 @@ def training_step(self, batch, batch_idx):
pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")


@pytest.mark.parametrize("input_num, expected", [
[1, '1'],
[1.0, '1.000'],
[0.1, '0.100'],
[1e-3, '0.001'],
[1e-5, '1e-5'],
['1.0', '1.000']
])
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def test_tqdm_format_num(input_num, expected):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected