From a6da5d7c7c7ef9a5a6571aa84671982c7036d4dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Apr 2020 06:04:09 +0200 Subject: [PATCH 01/17] squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation --- CHANGELOG.md | 2 + docs/source/callbacks.rst | 6 + docs/source/trainer.rst | 1 + pytorch_lightning/callbacks/__init__.py | 3 + pytorch_lightning/callbacks/base.py | 24 ++ pytorch_lightning/callbacks/progress.py | 380 ++++++++++++++++++ pytorch_lightning/core/lightning.py | 17 +- pytorch_lightning/trainer/__init__.py | 9 +- pytorch_lightning/trainer/callback_config.py | 27 +- pytorch_lightning/trainer/callback_hook.py | 34 +- .../trainer/distrib_data_parallel.py | 6 +- pytorch_lightning/trainer/distrib_parts.py | 4 +- pytorch_lightning/trainer/evaluation_loop.py | 61 +-- pytorch_lightning/trainer/logging.py | 12 +- pytorch_lightning/trainer/lr_finder.py | 7 +- pytorch_lightning/trainer/trainer.py | 65 ++- pytorch_lightning/trainer/training_loop.py | 44 +- tests/base/utils.py | 2 +- tests/trainer/test_callbacks.py | 210 +++++++++- 19 files changed, 777 insertions(+), 137 deletions(-) create mode 100644 pytorch_lightning/callbacks/progress.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ce2fe73c1fc9..4388a48ced123 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) +- Decoupled the progress bar from trainer. It is a callback now and can be customized or even be replaced entirely ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)). - Changed lr schedule step interval behavior to update every backwards pass instead of every forwards pass ([#1476](https://github.com/PyTorchLightning/pytorch-lightning/issues/1476)) @@ -39,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecatd `training_tqdm_dict` in favor of `progress_bar_dict` ([#1450](https://github.com/PyTorchLightning/pytorch-lightning/pull/1450)). ### Removed diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index ffb7671b7211d..10323472facd8 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -78,3 +78,9 @@ We successfully extended functionality without polluting our super clean _save_model, _abc_impl, check_monitor_top_k, + +--------- + +.. automodule:: pytorch_lightning.callbacks.progress + :noindex: + :exclude-members: diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index b160cfae3cf90..19c394db4854b 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -19,5 +19,6 @@ Trainer slurm_job_id, tng_tqdm_dic, training_tqdm_dict, + progress_bar_dict, init_optimizers, configure_schedulers diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index b1c47767339cc..c232060ca4ecb 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -2,10 +2,13 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar __all__ = [ 'Callback', 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', + 'ProgressBarBase', + 'ProgressBar', ] diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 9bf576b0c1926..50ea061df615e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -22,6 +22,14 @@ def on_init_end(self, trainer): """Called when the trainer initialization ends, model has not yet been set.""" pass + def on_sanity_check_start(self, trainer, pl_module): + """Called when the validation sanity check starts.""" + pass + + def on_sanity_check_end(self, trainer, pl_module): + """Called when the validation sanity check ends.""" + pass + def on_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" pass @@ -34,6 +42,22 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass + def on_validation_batch_start(self, trainer, pl_module): + """Called when the validation batch begins.""" + pass + + def on_validation_batch_end(self, trainer, pl_module): + """Called when the validation batch ends.""" + pass + + def on_test_batch_start(self, trainer, pl_module): + """Called when the test batch begins.""" + pass + + def on_test_batch_end(self, trainer, pl_module): + """Called when the test batch ends.""" + pass + def on_batch_end(self, trainer, pl_module): """Called when the training batch ends.""" pass diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py new file mode 100644 index 0000000000000..afd3b84e3f1c1 --- /dev/null +++ b/pytorch_lightning/callbacks/progress.py @@ -0,0 +1,380 @@ +""" +Progress Bars +============= + +Use or override one of the progress bar callbacks. + +""" +import sys +from typing import Optional + +from tqdm.auto import tqdm + +from pytorch_lightning.callbacks import Callback + + +class ProgressBarBase(Callback): + r""" + The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` + that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. + You should implement your highly custom progress bars with this as the base class. + + Example:: + + class LitProgressBar(ProgressBarBase): + + def __init__(self): + super().__init__() # don't forget this :) + self.enabled = True + + def disable(self): + self.enabled = False + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) # don't forget this :) + percent = (self.train_batch_idx / self.total_train_batches) * 100 + sys.stdout.flush() + sys.stdout.write(f'{percent:.01f} percent complete \r') + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + """ + def __init__(self): + + self._trainer = None + self._train_batch_idx = 0 + self._val_batch_idx = 0 + self._test_batch_idx = 0 + + @property + def trainer(self): + return self._trainer + + @property + def train_batch_idx(self) -> int: + """ + The current batch index being processed during training. + Use this to update your progress bar. + """ + return self._train_batch_idx + + @property + def val_batch_idx(self) -> int: + """ + The current batch index being processed during validation. + Use this to update your progress bar. + """ + return self._val_batch_idx + + @property + def test_batch_idx(self) -> int: + """ + The current batch index being processed during testing. + Use this to update your progress bar. + """ + return self._test_batch_idx + + @property + def total_train_batches(self) -> Optional[int]: + """ + The total number of training batches during training, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + training dataloader is of infinite size. + """ + if self.trainer.fast_dev_run: + total_train_batches = 1 + else: + total_train_batches = self.trainer.num_training_batches + return total_train_batches + + @property + def total_val_batches(self) -> Optional[int]: + """ + The total number of training batches during validation, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + validation dataloader is of infinite size. + """ + trainer = self.trainer + total_val_batches = 0 + if trainer.fast_dev_run: + total_val_batches = len(trainer.val_dataloaders) + elif not self.trainer.disable_validation: + is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0 + total_val_batches = trainer.num_val_batches if is_val_epoch else 0 + return total_val_batches + + @property + def total_test_batches(self) -> Optional[int]: + """ + The total number of training batches during testing, which may change from epoch to epoch. + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the + test dataloader is of infinite size. + """ + if self.trainer.fast_dev_run: + total_test_batches = len(self.trainer.test_dataloaders) + else: + total_test_batches = self.trainer.num_test_batches + return total_test_batches + + def disable(self): + """ + You should provide a way to disable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the + output on processes that have a rank different from 0, e.g., in multi-node training. + """ + raise NotImplementedError + + def enable(self): + """ + You should provide a way to enable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training + routines like the `learning rate finder `_ to temporarily enable and + disable the main progress bar. + """ + raise NotImplementedError + + def on_init_end(self, trainer): + self._trainer = trainer + + def on_train_start(self, trainer, pl_module): + self._train_batch_idx = trainer.batch_idx + + def on_epoch_start(self, trainer, pl_module): + self._train_batch_idx = 0 + + def on_batch_end(self, trainer, pl_module): + self._train_batch_idx += 1 + + def on_epoch_end(self, trainer, pl_module): + self._train_batch_idx = 0 + + def on_validation_start(self, trainer, pl_module): + self._val_batch_idx = 0 + + def on_validation_batch_end(self, trainer, pl_module): + self._val_batch_idx += 1 + + def on_validation_end(self, trainer, pl_module): + self._val_batch_idx = 0 + + def on_test_start(self, trainer, pl_module): + self._test_batch_idx = 0 + + def on_test_batch_end(self, trainer, pl_module): + self._test_batch_idx += 1 + + def on_test_end(self, trainer, pl_module): + self._test_batch_idx = 0 + + +class ProgressBar(ProgressBarBase): + r""" + This is the default progress bar used by Lightning. It prints to `stdout` using the + :mod:`tqdm` package and shows up to four different bars: + + - **sanity check progress:** the progress during the sanity check run + - **main progress:** shows training + validation progress combined. It also accounts for + multiple validation runs during training when + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. + - **validation progress:** only visible during validation; + shows total progress over all validation datasets. + - **test progress:** only active when testing; shows total progress over all test datasets. + + For infinite datasets, the progress bar never ends. + + If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override + specific methods of the callback class and pass your custom implementation to the + :class:`~pytorch_lightning.trainer.trainer.Trainer`: + + Example:: + + class LitProgressBar(ProgressBar): + + def init_validation_tqdm(self): + bar = super().init_validation_tqdm() + bar.set_description('running validation ...') + return bar + + bar = LitProgressBar() + trainer = Trainer(callbacks=[bar]) + + Args: + refresh_rate: + Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. By default, the + :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress + bar and sets the refresh rate to the value provided to the + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + process_position: + Set this to a value greater than ``0`` to offset the progress bars by this many lines. + This is useful when you have progress bars defined elsewhere and want to show all of them + together. This corresponds to + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. + + """ + def __init__(self, refresh_rate: int = 1, process_position: int = 0): + super().__init__() + self._refresh_rate = refresh_rate + self._process_position = process_position + self._enabled = True + self.main_progress_bar = None + self.val_progress_bar = None + self.test_progress_bar = None + + def __getstate__(self): + # can't pickle the tqdm objects + state = self.__dict__.copy() + state['main_progress_bar'] = None + state['val_progress_bar'] = None + state['test_progress_bar'] = None + return state + + @property + def refresh_rate(self) -> int: + return self._refresh_rate + + @property + def process_position(self) -> int: + return self._process_position + + @property + def enabled(self) -> bool: + return self._enabled and self.refresh_rate > 0 + + @property + def disabled(self) -> bool: + return not self.enabled + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + + def init_sanity_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for the validation sanity run. """ + bar = tqdm( + desc='Validation sanity check', + position=(2 * self.process_position), + disable=self.disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout, + ) + return bar + + def init_train_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for training. """ + bar = tqdm( + desc='Training', + initial=self.train_batch_idx, + position=(2 * self.process_position), + disable=self.disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout, + smoothing=0, + ) + return bar + + def init_validation_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for validation. """ + bar = tqdm( + desc='Validating', + position=(2 * self.process_position + 1), + disable=self.disabled, + leave=False, + dynamic_ncols=True, + file=sys.stdout + ) + return bar + + def init_test_tqdm(self) -> tqdm: + """ Override this to customize the tqdm bar for testing. """ + bar = tqdm( + desc='Testing', + position=(2 * self.process_position), + disable=self.disabled, + leave=True, + dynamic_ncols=True, + file=sys.stdout + ) + return bar + + def on_sanity_check_start(self, trainer, pl_module): + super().on_sanity_check_start(trainer, pl_module) + self.val_progress_bar = self.init_sanity_tqdm() + self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders) + self.main_progress_bar = tqdm(disable=True) # dummy progress bar + + def on_sanity_check_end(self, trainer, pl_module): + super().on_sanity_check_end(trainer, pl_module) + self.main_progress_bar.close() + self.val_progress_bar.close() + + def on_train_start(self, trainer, pl_module): + super().on_train_start(trainer, pl_module) + self.main_progress_bar = self.init_train_tqdm() + + def on_epoch_start(self, trainer, pl_module): + super().on_epoch_start(trainer, pl_module) + total_train_batches = self.total_train_batches + total_val_batches = self.total_val_batches + if total_train_batches != float('inf'): + # val can be checked multiple times per epoch + val_checks_per_epoch = total_train_batches // trainer.val_check_batch + total_val_batches = total_val_batches * val_checks_per_epoch + total_batches = total_train_batches + total_val_batches + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(convert_inf(total_batches)) + self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + if self.enabled and self.train_batch_idx % self.refresh_rate == 0: + self.main_progress_bar.update(self.refresh_rate) + self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + + def on_validation_start(self, trainer, pl_module): + super().on_validation_start(trainer, pl_module) + self.val_progress_bar = self.init_validation_tqdm() + self.val_progress_bar.total = convert_inf(self.total_val_batches) + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if self.enabled and self.val_batch_idx % self.refresh_rate == 0: + self.val_progress_bar.update(self.refresh_rate) + self.main_progress_bar.update(self.refresh_rate) + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) + self.val_progress_bar.close() + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + self.main_progress_bar.close() + + def on_test_start(self, trainer, pl_module): + super().on_test_start(trainer, pl_module) + self.test_progress_bar = self.init_test_tqdm() + self.test_progress_bar.total = convert_inf(self.total_test_batches) + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if self.enabled and self.test_batch_idx % self.refresh_rate == 0: + self.test_progress_bar.update(self.refresh_rate) + + def on_test_end(self, trainer, pl_module): + super().on_test_end(trainer, pl_module) + self.test_progress_bar.close() + + +def convert_inf(x): + """ The tqdm doesn't support inf values. We have to convert it to None. """ + if x == float('inf'): + return None + return x diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bde43a6a0f8f6..83849cc8c9fc5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1623,7 +1623,7 @@ def on_save_checkpoint(self, checkpoint): """ - def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: + def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: r""" Additional items to be displayed in the progress bar. @@ -1644,3 +1644,18 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: tqdm_dict['v_num'] = self.trainer.logger.version return tqdm_dict + + def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: + """ + Additional items to be displayed in the progress bar. + + Return: + Dictionary with the items to be displayed in the progress bar. + + Warning: + Deprecated since v0.7.3. + Use :meth:`get_progress_bar_dict` instead. + """ + rank_zero_warn("`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" + " and this method will be removed in v1.0.0", DeprecationWarning) + return self.get_progress_bar_dict() diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 0bd161bf0a56d..41b7b1b999a75 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -652,14 +652,16 @@ def on_train_end(self): process_position ^^^^^^^^^^^^^^^^ -Orders the tqdm bar. Useful when running multiple trainers -on the same node. +Orders the progress bar. Useful when running multiple trainers on the same node. Example:: # default used by the Trainer trainer = Trainer(process_position=0) +Note: + This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + profiler ^^^^^^^^ To profile individual steps during training and assist in identifying bottlenecks. @@ -698,6 +700,9 @@ def on_train_end(self): # disable progress bar trainer = Trainer(progress_bar_refresh_rate=0) +Note: + This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. + reload_dataloaders_every_epoch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Set to True to reload dataloaders every epoch. diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 7ace0fb20a255..39c67963169e4 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -1,20 +1,25 @@ import os from abc import ABC, abstractmethod -from typing import Union +from typing import Union, List -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping + +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities.exceptions import MisconfigurationException class TrainerCallbackConfigMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class + callbacks: List[Callback] default_root_dir: str logger: Union[LightningLoggerBase, bool] weights_save_path: str ckpt_path: str checkpoint_callback: ModelCheckpoint + progress_bar_refresh_rate: int + process_position: int @property @abstractmethod @@ -101,3 +106,21 @@ def configure_early_stopping(self, early_stop_callback): else: self.early_stop_callback = early_stop_callback self.enable_early_stop = True + + def configure_progress_bar(self): + progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] + if len(progress_bars) > 1: + raise MisconfigurationException( + 'You added multiple progress bar callbacks to the Trainer, but currently only one' + ' progress bar is supported.' + ) + elif len(progress_bars) == 1: + self.progress_bar_callback = progress_bars[0] + elif self.progress_bar_refresh_rate > 0: + self.progress_bar_callback = ProgressBar( + refresh_rate=self.progress_bar_refresh_rate, + process_position=self.process_position, + ) + self.callbacks.append(self.progress_bar_callback) + else: + self.progress_bar_callback = None diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 48d703b84ebe0..37f56e6941039 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Callable +from typing import Callable, List from pytorch_lightning.callbacks import Callback @@ -9,7 +9,7 @@ class TrainerCallbackHookMixin(ABC): def __init__(self): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - self.callbacks: list[Callback] = [] + self.callbacks: List[Callback] = [] self.get_model: Callable = ... def on_init_start(self): @@ -22,6 +22,16 @@ def on_init_end(self): for callback in self.callbacks: callback.on_init_end(self) + def on_sanity_check_start(self): + """Called when the validation sanity check starts.""" + for callback in self.callbacks: + callback.on_sanity_check_start(self, self.get_model()) + + def on_sanity_check_end(self): + """Called when the validation sanity check ends.""" + for callback in self.callbacks: + callback.on_sanity_check_end(self, self.get_model()) + def on_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: @@ -52,6 +62,26 @@ def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end(self, self.get_model()) + def on_validation_batch_start(self): + """Called when the validation batch begins.""" + for callback in self.callbacks: + callback.on_validation_batch_start(self, self.get_model()) + + def on_validation_batch_end(self): + """Called when the validation batch ends.""" + for callback in self.callbacks: + callback.on_validation_batch_end(self, self.get_model()) + + def on_test_batch_start(self): + """Called when the test batch begins.""" + for callback in self.callbacks: + callback.on_test_batch_start(self, self.get_model()) + + def on_test_batch_end(self): + """Called when the test batch ends.""" + for callback in self.callbacks: + callback.on_test_batch_end(self, self.get_model()) + def on_validation_start(self): """Called when the validation loop begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index bfc85ee883f6e..3a46717161a5c 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -151,6 +151,7 @@ class TrainerDDPMixin(ABC): amp_level: str use_tpu: bool default_root_dir: str + progress_bar_callback: ... @property @abstractmethod @@ -309,9 +310,8 @@ def ddp_train(self, process_idx, model): self.node_rank = 0 # show progressbar only on progress_rank 0 - self.progress_bar_refresh_rate = ( - self.progress_bar_refresh_rate if self.node_rank == 0 and process_idx == 0 else 0 - ) + if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() # determine which process we are and world size if self.use_ddp: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 7ce61bbfb77e6..279be0b159f7e 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -396,6 +396,7 @@ class TrainerDPMixin(ABC): use_tpu: bool data_parallel_device_ids: ... logger: Union[LightningLoggerBase, bool] + progress_bar_callback: ... @property @abstractmethod @@ -497,7 +498,8 @@ def tpu_train(self, tpu_core_idx, model): self.tpu_global_core_rank = xm.get_ordinal() # avoid duplicating progress bar - self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if self.tpu_global_core_rank == 0 else 0 + if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None: + self.progress_bar_callback.disable() # track current tpu self.current_tpu_idx = tpu_core_idx diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a996bd7a60d70..d9c7a59597549 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -123,14 +123,12 @@ """ -import sys from abc import ABC, abstractmethod from pprint import pprint from typing import Callable import torch from torch.utils.data import DataLoader -from tqdm.auto import tqdm from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel @@ -157,9 +155,6 @@ class TrainerEvaluationLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - test_progress_bar: ... - val_progress_bar: ... - main_progress_bar: ... on_gpu: bool use_ddp: bool use_dp: bool @@ -171,9 +166,8 @@ class TrainerEvaluationLoopMixin(ABC): num_test_batches: int num_val_batches: int fast_dev_run: ... - process_position: ... process_output: ... - training_tqdm_dict: ... + progress_bar_dict: ... proc_rank: int current_epoch: int callback_metrics: ... @@ -181,9 +175,12 @@ class TrainerEvaluationLoopMixin(ABC): val_dataloaders: DataLoader use_tpu: bool reload_dataloaders_every_epoch: ... - progress_bar_refresh_rate: ... # Callback system + on_validation_batch_start: Callable + on_validation_batch_end: Callable + on_test_batch_start: Callable + on_test_batch_end: Callable on_validation_start: Callable on_validation_end: Callable on_test_start: Callable @@ -210,7 +207,7 @@ def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, *args): + def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -265,6 +262,12 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ if batch_idx >= max_batches: break + # callbacks + if test_mode: + self.on_test_batch_start() + else: + self.on_validation_batch_start() + # ----------------- # RUN EVALUATION STEP # ----------------- @@ -276,22 +279,17 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) + self.on_test_batch_end() else: if self.is_overriden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) + self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) - # batch done - if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: - if test_mode: - self.test_progress_bar.update(self.progress_bar_refresh_rate) - else: - self.val_progress_bar.update(self.progress_bar_refresh_rate) - self.main_progress_bar.update(self.progress_bar_refresh_rate) outputs.append(dl_outputs) eval_results = {} @@ -339,12 +337,6 @@ def run_evaluation(self, test_mode: bool = False): "You called `.test()` without defining model's `.test_step()`." " Please define and try again") - # Validation/Test begin callbacks - if test_mode: - self.on_test_start() - else: - self.on_validation_start() - # hook model = self.get_model() model.on_pre_performance_check() @@ -368,21 +360,18 @@ def run_evaluation(self, test_mode: bool = False): if self.fast_dev_run: max_batches = 1 - # init validation or test progress bar - # main progress bar will already be closed when testing so initial position is free - position = 2 * self.process_position + (not test_mode) - desc = 'Testing' if test_mode else 'Validating' - total = max_batches if max_batches != float('inf') else None - pbar = tqdm(desc=desc, total=total, leave=test_mode, position=position, - disable=not self.progress_bar_refresh_rate, dynamic_ncols=True, file=sys.stdout) - setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar) + # Validation/Test begin callbacks + if test_mode: + self.on_test_start() + else: + self.on_validation_start() # run evaluation eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) # add metrics to prog bar - self.add_tqdm_metrics(prog_bar_metrics) + self.add_progress_bar_metrics(prog_bar_metrics) # log results of test if test_mode and self.proc_rank == 0: @@ -400,16 +389,6 @@ def run_evaluation(self, test_mode: bool = False): # hook model.on_post_performance_check() - # add model specific metrics - if not test_mode: - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) - - # close progress bar - if test_mode: - self.test_progress_bar.close() - else: - self.val_progress_bar.close() - # eventual dataset reloading if test_mode: if self.reload_dataloaders_every_epoch: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index dbe05e4aa2818..5c05833c8d57d 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -16,7 +16,7 @@ class TrainerLoggingMixin(ABC): on_gpu: bool log_gpu_memory: ... logger: Union[LightningLoggerBase, bool] - tqdm_metrics: ... + progress_bar_metrics: ... global_step: int proc_rank: int use_dp: bool @@ -75,12 +75,12 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() - def add_tqdm_metrics(self, metrics): + def add_progress_bar_metrics(self, metrics): for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() - self.tqdm_metrics[k] = v + self.progress_bar_metrics[k] = v def metrics_to_scalars(self, metrics): new_metrics = {} @@ -98,7 +98,7 @@ def metrics_to_scalars(self, metrics): def process_output(self, output, train=False): """Reduces output according to the training mode. - Separates loss from logging and tqdm metrics + Separates loss from logging and progress bar metrics """ # --------------- # EXTRACT CALLBACK KEYS @@ -119,7 +119,7 @@ def process_output(self, output, train=False): try: progress_output = output['progress_bar'] - # reduce progress metrics for tqdm when using dp + # reduce progress metrics for progress bar when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus progress_output = self.reduce_distributed_output(progress_output, num_gpus) @@ -135,7 +135,7 @@ def process_output(self, output, train=False): try: log_output = output['log'] - # reduce progress metrics for tqdm when using dp + # reduce progress metrics for progress bar when using dp if train and (self.use_dp or self.use_ddp2): num_gpus = self.num_gpus log_output = self.reduce_distributed_output(log_output, num_gpus) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index ff14e93a49e1b..23eab617001b0 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -123,7 +123,8 @@ def lr_find(self, self.max_steps = num_training # Disable standard progress bar for fit - self.progress_bar_refresh_rate = False + if self.progress_bar_callback: + self.progress_bar_callback.disable() # Accumulation of gradients self.accumulate_grad_batches = num_accumulation_steps @@ -165,6 +166,8 @@ def lr_find(self, # Finish by resetting variables so trainer is ready to fit model self._restore_params(model) + if self.progress_bar_callback: + self.progress_bar_callback.enable() return lr_finder @@ -178,6 +181,7 @@ def _dump_params(self, model): 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, 'accumulate_grad_batches': self.accumulate_grad_batches, 'checkpoint_callback': self.checkpoint_callback, + 'progress_bar_callback': self.progress_bar_callback, 'configure_optimizers': model.configure_optimizers, } @@ -189,6 +193,7 @@ def _restore_params(self, model): self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate'] self.accumulate_grad_batches = self._params['accumulate_grad_batches'] self.checkpoint_callback = self._params['checkpoint_callback'] + self.progress_bar_callback = self._params['progress_bar_callback'] model.configure_optimizers = self._params['configure_optimizers'] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 20ef14ca3cb50..531efdba6ae95 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,7 +1,6 @@ import distutils import inspect import os -import sys from argparse import ArgumentParser from typing import Union, Optional, List, Dict, Tuple, Iterable, Any @@ -9,10 +8,9 @@ import torch.distributed as torch_distrib import torch.multiprocessing as mp from torch.utils.data import DataLoader -from tqdm.auto import tqdm from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler @@ -124,6 +122,7 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, + progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0 @@ -162,7 +161,7 @@ def __init__( Use `gradient_clip_val` instead. Will remove 0.9.0. - process_position: orders the tqdm bar when running multiple models on same machine. + process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. @@ -190,6 +189,7 @@ def __init__( Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. + Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. overfit_pct: How much of training-, validation-, and test dataset to check. @@ -312,7 +312,6 @@ def __init__( " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip - self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False @@ -390,7 +389,7 @@ def __init__( self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 - self.tqdm_metrics = {} + self.progress_bar_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 @@ -408,7 +407,6 @@ def __init__( self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 - self.total_batches = 0 self.interrupted = False # configure logger @@ -464,12 +462,14 @@ def __init__( # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) - # can't init progress bar here because starting a new process - # means the progress_bar won't survive pickling # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar + self.progress_bar_refresh_rate = progress_bar_refresh_rate + self.progress_bar_callback = None + self.configure_progress_bar() + # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval @@ -635,17 +635,27 @@ def data_parallel(self) -> bool: return self.use_dp or self.use_ddp or self.use_ddp2 @property - def training_tqdm_dict(self) -> dict: - """Read-only for tqdm metrics. - :return: - """ + def progress_bar_dict(self) -> dict: + """ Read-only for progress bar metrics. """ ref_model = self.model if not self.data_parallel else self.model.module + return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) - return dict(**ref_model.get_tqdm_dict(), **self.tqdm_metrics) + @property + def training_tqdm_dict(self): + """Read-only for progress bar metrics. + + Warning: + Deprecated since v0.7.3. + Use :meth:`progress_bar_dict` instead. + + """ + rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3" + " and this method will be removed in v1.0.0", DeprecationWarning) + return self.progress_bar_dict @property def tng_tqdm_dic(self): - """Read-only for tqdm metrics. + """Read-only for progress bar metrics. .. warning:: .. deprecated:: 0.5.0 @@ -654,7 +664,7 @@ def tng_tqdm_dic(self): """ rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) - return self.training_tqdm_dict + return self.progress_bar_dict # ----------------------------- # MODEL TRAINING @@ -873,17 +883,12 @@ def run_pretrain_routine(self, model: LightningModule): # run tiny validation (if validation defined) # to make sure program won't crash during val - ref_model.on_sanity_check_start() if not self.disable_validation and self.num_sanity_val_steps > 0: self.reset_val_dataloader(ref_model) - # init progress bars for validation sanity check - pbar = tqdm(desc='Validation sanity check', - total=self.num_sanity_val_steps * len(self.val_dataloaders), - leave=False, position=2 * self.process_position, - disable=not self.progress_bar_refresh_rate, dynamic_ncols=True) - self.main_progress_bar = pbar - # dummy validation progress bar - self.val_progress_bar = tqdm(disable=True) + + # hook and callback + ref_model.on_sanity_check_start() + self.on_sanity_check_start() eval_results = self._evaluate(model, self.val_dataloaders, @@ -891,20 +896,12 @@ def run_pretrain_routine(self, model: LightningModule): False) _, _, _, callback_metrics, _ = self.process_output(eval_results) - # close progress bars - self.main_progress_bar.close() - self.val_progress_bar.close() + self.on_sanity_check_end() # verify that early stop has conditioned on a metric that exists if self.enable_early_stop: self.early_stop_callback._validate_condition_metric(callback_metrics) - # init progress bar - pbar = tqdm(leave=True, position=2 * self.process_position, - disable=not self.show_progress_bar, dynamic_ncols=True, - file=sys.stdout, smoothing=0) - self.main_progress_bar = pbar - # clear cache before training if self.on_gpu: torch.cuda.empty_cache() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5b3d13c72b5f1..b1b61785ee413 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -202,7 +202,6 @@ class TrainerTrainLoopMixin(ABC): num_val_batches: int disable_validation: bool fast_dev_run: ... - main_progress_bar: ... accumulation_scheduler: ... lr_schedulers: ... enable_early_stop: ... @@ -214,7 +213,6 @@ class TrainerTrainLoopMixin(ABC): log_save_interval: float proc_rank: int row_log_interval: float - total_batches: int truncated_bptt_steps: ... optimizers: ... optimizer_frequencies: ... @@ -223,14 +221,13 @@ class TrainerTrainLoopMixin(ABC): model: LightningModule interrupted: bool running_loss: ... - training_tqdm_dict: ... + progress_bar_dict: ... reduce_lr_on_plateau_scheduler: ... profiler: ... batch_idx: int precision: ... train_dataloader: DataLoader reload_dataloaders_every_epoch: bool - progress_bar_refresh_rate: ... max_steps: int min_steps: int total_batch_idx: int @@ -280,7 +277,7 @@ def is_overriden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def add_tqdm_metrics(self, *args): + def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -341,18 +338,6 @@ def train(self): model.current_epoch = epoch self.current_epoch = epoch - total_val_batches = 0 - is_val_epoch = False - if not self.disable_validation and self.num_training_batches != float('inf'): - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 - total_val_batches = self.num_val_batches * val_checks_per_epoch - - # total batches includes multiple val checks - self.total_batches = self.num_training_batches + total_val_batches - # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start(self, self.get_model()) @@ -361,22 +346,6 @@ def train(self): window_length=self.accumulate_grad_batches ) - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run - num_iterations = 2 - elif self.total_batches == float('inf'): - # for infinite train or val loader, the progress bar never ends - num_iterations = None - else: - num_iterations = self.total_batches - - # reset progress bar - # .reset() doesn't work on disabled progress bar so we should check - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' - self.main_progress_bar.set_description(desc) - # ----------------- # RUN TNG EPOCH # ----------------- @@ -609,7 +578,7 @@ def optimizer_closure(): all_callback_metrics.append(callback_metrics) # track progress bar metrics - self.add_tqdm_metrics(progress_bar_metrics) + self.add_progress_bar_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) if self.use_horovod: @@ -669,11 +638,6 @@ def optimizer_closure(): if self.is_function_implemented('on_batch_end'): self.get_model().on_batch_end() - # update progress bar - if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0: - self.main_progress_bar.update(self.progress_bar_refresh_rate) - self.main_progress_bar.set_postfix(**self.training_tqdm_dict) - # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} @@ -696,8 +660,6 @@ def _get_optimizers_iterable(self): return [(opt_idx, self.optimizers[opt_idx])] def run_training_teardown(self): - self.main_progress_bar.close() - # Train end events with self.profiler.profile('on_train_end'): # callbacks diff --git a/tests/base/utils.py b/tests/base/utils.py index 1bb485f270ada..1f0d582ed6e01 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -220,7 +220,7 @@ def run_prediction(dataloader, trained_model, dp=False, min_acc=0.5): def assert_ok_model_acc(trainer, key='test_acc', thr=0.5): # this model should get 0.80+ acc - acc = trainer.training_tqdm_dict[key] + acc = trainer.progress_bar_dict[key] assert acc > thr, f"Model failed to get expected {thr} accuracy. {key} = {acc}" diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 4f77dabbd12a5..8c3b0625f9702 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,7 +1,10 @@ +import pytest + import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( LightTrainDataloader, LightTestMixin, @@ -33,10 +36,16 @@ def __init__(self): super().__init__() self.on_init_start_called = False self.on_init_end_called = False + self.on_sanity_check_start_called = False + self.on_sanity_check_end_called = False self.on_epoch_start_called = False self.on_epoch_end_called = False self.on_batch_start_called = False self.on_batch_end_called = False + self.on_validation_batch_start_called = False + self.on_validation_batch_end_called = False + self.on_test_batch_start_called = False + self.on_test_batch_end_called = False self.on_train_start_called = False self.on_train_end_called = False self.on_validation_start_called = False @@ -52,6 +61,14 @@ def on_init_end(self, trainer): assert isinstance(trainer, Trainer) self.on_init_end_called = True + def on_sanity_check_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_sanity_check_start_called = True + + def on_sanity_check_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_sanity_check_end_called = True + def on_epoch_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_epoch_start_called = True @@ -68,6 +85,22 @@ def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True + def on_validation_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_batch_start_called = True + + def on_validation_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_validation_batch_end_called = True + + def on_test_batch_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_batch_start_called = True + + def on_test_batch_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_test_batch_end_called = True + def on_train_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_start_called = True @@ -104,10 +137,16 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_init_start_called assert not test_callback.on_init_end_called + assert not test_callback.on_sanity_check_start_called + assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called @@ -121,10 +160,16 @@ def on_test_end(self, trainer, pl_module): assert trainer.callbacks[0] == test_callback assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert not test_callback.on_sanity_check_start_called + assert not test_callback.on_sanity_check_end_called assert not test_callback.on_epoch_start_called assert not test_callback.on_epoch_start_called assert not test_callback.on_batch_start_called assert not test_callback.on_batch_end_called + assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called assert not test_callback.on_validation_start_called @@ -136,21 +181,36 @@ def on_test_end(self, trainer, pl_module): assert test_callback.on_init_start_called assert test_callback.on_init_end_called + assert test_callback.on_sanity_check_start_called + assert test_callback.on_sanity_check_end_called assert test_callback.on_epoch_start_called assert test_callback.on_epoch_start_called assert test_callback.on_batch_start_called assert test_callback.on_batch_end_called + assert test_callback.on_validation_batch_start_called + assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called assert test_callback.on_train_end_called assert test_callback.on_validation_start_called assert test_callback.on_validation_end_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called - trainer.test() + test_callback = TestCallback() + trainer_options['callbacks'] = [test_callback] + trainer = Trainer(**trainer_options) + trainer.test(model) + assert test_callback.on_test_batch_start_called + assert test_callback.on_test_batch_end_called assert test_callback.on_test_start_called assert test_callback.on_test_end_called + assert not test_callback.on_validation_start_called + assert not test_callback.on_validation_end_called + assert not test_callback.on_validation_batch_end_called + assert not test_callback.on_validation_batch_start_called def test_early_stopping_no_val_step(tmpdir): @@ -183,6 +243,152 @@ def training_step(self, *args, **kwargs): assert trainer.current_epoch < trainer.max_epochs +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 1), + ([], 2), + ([ProgressBar(refresh_rate=1)], 0), + ([ProgressBar(refresh_rate=2)], 0), + ([ProgressBar(refresh_rate=2)], 1), +]) +def test_progress_bar_on(callbacks, refresh_rate): + """Test different ways the progress bar can be turned on.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + max_epochs=1, + overfit_pct=0.2, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + # Trainer supports only a single progress bar callback at the moment + assert len(progress_bars) == 1 + assert progress_bars[0] is trainer.progress_bar_callback + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 0), + ([], False), + ([ModelCheckpoint('.')], 0), +]) +def test_progress_bar_off(callbacks, refresh_rate): + """Test different ways the progress bar can be turned off.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] + assert 0 == len(progress_bars) + assert not trainer.progress_bar_callback + + +def test_progress_bar_misconfiguration(): + """Test that Trainer doesn't accept multiple progress bars.""" + callbacks = [ProgressBar(), ProgressBar()] + with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): + Trainer(callbacks=callbacks) + + +def test_progress_bar_totals(): + """Test that the progress finishes with the correct total steps processed.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + progress_bar_callback=True, + progress_bar_refresh_rate=1, + val_percent_check=1.0, + max_epochs=1, + ) + progress_bar = trainer.progress_bar_callback + assert 0 == progress_bar.total_train_batches + assert 0 == progress_bar.total_val_batches + assert 0 == progress_bar.total_test_batches + + trainer.fit(model) + assert progress_bar.total_train_batches == len(trainer.train_dataloader) + assert progress_bar.total_val_batches == progress_bar.val_progress_bar.total + assert progress_bar.total_val_batches == sum(len(loader) for loader in trainer.val_dataloaders) + assert 0 == progress_bar.total_test_batches + + trainer.test(model) + assert progress_bar.total_test_batches == progress_bar.test_progress_bar.total + assert progress_bar.total_test_batches == sum(len(loader) for loader in trainer.test_dataloaders) + + +@pytest.mark.parametrize('refresh_rate', [0, 1, 50]) +def test_progress_bar_progress_refresh(refresh_rate): + """Test that the three progress bars get correctly updated when using different refresh rates.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + class CurrentProgressBar(ProgressBar): + + train_batches_seen = 0 + val_batches_seen = 0 + test_batches_seen = 0 + + def on_batch_start(self, trainer, pl_module): + super().on_batch_start(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + 1 + if not self.disabled and self.train_batch_idx % self.refresh_rate == 0: + assert self.main_progress_bar.n == self.train_batch_idx + self.train_batches_seen += 1 + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if not self.disabled and self.val_batch_idx % self.refresh_rate == 0: + assert self.val_progress_bar.n == self.val_batch_idx + self.val_batches_seen += 1 + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if not self.disabled and self.test_batch_idx % self.refresh_rate == 0: + assert self.test_progress_bar.n == self.test_batch_idx + self.test_batches_seen += 1 + + progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) + trainer = Trainer( + callbacks=[progress_bar], + progress_bar_refresh_rate=101, # should not matter if custom callback provided + train_percent_check=1.0, + num_sanity_val_steps=2, + max_epochs=3, + ) + assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate + + trainer.fit(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + + trainer.test(model) + assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + def test_model_checkpoint_with_non_string_input(tmpdir): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ From 0cd98b52f1b9708a49e2a4c9c6c223678290a8e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Apr 2020 05:08:49 +0200 Subject: [PATCH 02/17] fix fast dev run total --- pytorch_lightning/callbacks/progress.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index afd3b84e3f1c1..24cf26dd9761b 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -146,27 +146,18 @@ def on_epoch_start(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module): self._train_batch_idx += 1 - def on_epoch_end(self, trainer, pl_module): - self._train_batch_idx = 0 - def on_validation_start(self, trainer, pl_module): self._val_batch_idx = 0 def on_validation_batch_end(self, trainer, pl_module): self._val_batch_idx += 1 - def on_validation_end(self, trainer, pl_module): - self._val_batch_idx = 0 - def on_test_start(self, trainer, pl_module): self._test_batch_idx = 0 def on_test_batch_end(self, trainer, pl_module): self._test_batch_idx += 1 - def on_test_end(self, trainer, pl_module): - self._test_batch_idx = 0 - class ProgressBar(ProgressBarBase): r""" @@ -323,7 +314,7 @@ def on_epoch_start(self, trainer, pl_module): super().on_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches - if total_train_batches != float('inf'): + if total_train_batches != float('inf') and not trainer.fast_dev_run: # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch From 3968eb03cd8d5dc30ad1322b52da190fe1b19e0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Apr 2020 05:09:46 +0200 Subject: [PATCH 03/17] more thorough testing --- tests/trainer/test_callbacks.py | 80 ++++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 8c3b0625f9702..14d2afb9a7c50 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -286,7 +286,7 @@ def test_progress_bar_off(callbacks, refresh_rate): def test_progress_bar_misconfiguration(): """Test that Trainer doesn't accept multiple progress bars.""" - callbacks = [ProgressBar(), ProgressBar()] + callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint('.')] with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): Trainer(callbacks=callbacks) @@ -311,20 +311,78 @@ class CurrentTestModel( val_percent_check=1.0, max_epochs=1, ) - progress_bar = trainer.progress_bar_callback - assert 0 == progress_bar.total_train_batches - assert 0 == progress_bar.total_val_batches - assert 0 == progress_bar.total_test_batches + bar = trainer.progress_bar_callback + assert 0 == bar.total_train_batches + assert 0 == bar.total_val_batches + assert 0 == bar.total_test_batches trainer.fit(model) - assert progress_bar.total_train_batches == len(trainer.train_dataloader) - assert progress_bar.total_val_batches == progress_bar.val_progress_bar.total - assert progress_bar.total_val_batches == sum(len(loader) for loader in trainer.val_dataloaders) - assert 0 == progress_bar.total_test_batches + + # check main progress bar total + n = bar.total_train_batches + m = bar.total_val_batches + assert len(trainer.train_dataloader) == n + assert bar.main_progress_bar.total == n + m + + # check val progress bar total + assert sum(len(loader) for loader in trainer.val_dataloaders) == m + assert bar.val_progress_bar.total == m + + # main progress bar should have reached the end (train batches + val batches) + assert bar.main_progress_bar.n == n + m + assert bar.train_batch_idx == n + + # val progress bar should have reached the end + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + + # check that the test progress bar is off + assert 0 == bar.total_test_batches + assert bar.test_progress_bar is None trainer.test(model) - assert progress_bar.total_test_batches == progress_bar.test_progress_bar.total - assert progress_bar.total_test_batches == sum(len(loader) for loader in trainer.test_dataloaders) + + # check test progress bar total + k = bar.total_test_batches + assert sum(len(loader) for loader in trainer.test_dataloaders) == k + assert bar.test_progress_bar.total == k + + # test progress bar should have reached the end + assert bar.test_progress_bar.n == k + assert bar.test_batch_idx == k + + +def test_progress_bar_fast_dev_run(): + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + progress_bar_callback=True, + fast_dev_run=True, + ) + + progress_bar = trainer.progress_bar_callback + assert 1 == progress_bar.total_train_batches + # total val batches are known only after val dataloaders have reloaded + + trainer.fit(model) + + assert 1 == progress_bar.total_val_batches + assert 1 == progress_bar.train_batch_idx + assert 1 == progress_bar.val_batch_idx + assert 0 == progress_bar.test_batch_idx + + # the main progress bar should display 2 batches (1 train, 1 val) + assert 2 == progress_bar.main_progress_bar.total + assert 2 == progress_bar.main_progress_bar.n @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) From 6798d1268865492b3c45158629e2d71e097a9aab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Apr 2020 06:45:32 +0200 Subject: [PATCH 04/17] remove old args --- tests/trainer/test_callbacks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 14d2afb9a7c50..d9cb058f67d6a 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -306,7 +306,6 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer( - progress_bar_callback=True, progress_bar_refresh_rate=1, val_percent_check=1.0, max_epochs=1, @@ -365,7 +364,6 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer( - progress_bar_callback=True, fast_dev_run=True, ) From 172aff2077b54d418ae23b16eb23c233126c0cc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 19:16:12 +0200 Subject: [PATCH 05/17] fix merge --- tests/trainer/test_callbacks.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index d9cb058f67d6a..8ce971402b807 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -250,6 +250,32 @@ def training_step(self, *args, **kwargs): ([ProgressBar(refresh_rate=2)], 0), ([ProgressBar(refresh_rate=2)], 1), ]) + + +def test_model_checkpoint_with_non_string_input(tmpdir): + """ Test that None in checkpoint callback is valid and that chkp_path is + set correctly """ + tutils.reset_seed() + + class CurrentTestModel(LightTrainDataloader, TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + checkpoint = ModelCheckpoint(filepath=None, save_top_k=-1) + + trainer = Trainer(default_root_dir=tmpdir, + checkpoint_callback=checkpoint, + overfit_pct=0.20, + max_epochs=5 + ) + result = trainer.fit(model) + + # These should be different if the dirpath has be overridden + assert trainer.ckpt_path != trainer.default_root_dir + + def test_progress_bar_on(callbacks, refresh_rate): """Test different ways the progress bar can be turned on.""" From 5cb79c3d09819dba69f25d0488290b984bb28ada Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 19:20:08 +0200 Subject: [PATCH 06/17] fix merge --- tests/trainer/test_callbacks.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 8ce971402b807..e484b90d21d96 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -243,15 +243,6 @@ def training_step(self, *args, **kwargs): assert trainer.current_epoch < trainer.max_epochs -@pytest.mark.parametrize('callbacks,refresh_rate', [ - ([], 1), - ([], 2), - ([ProgressBar(refresh_rate=1)], 0), - ([ProgressBar(refresh_rate=2)], 0), - ([ProgressBar(refresh_rate=2)], 1), -]) - - def test_model_checkpoint_with_non_string_input(tmpdir): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ @@ -276,6 +267,13 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): assert trainer.ckpt_path != trainer.default_root_dir +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 1), + ([], 2), + ([ProgressBar(refresh_rate=1)], 0), + ([ProgressBar(refresh_rate=2)], 0), + ([ProgressBar(refresh_rate=2)], 1), +]) def test_progress_bar_on(callbacks, refresh_rate): """Test different ways the progress bar can be turned on.""" From 0f3d21c4d380e2eb44fe022a74d60eca7d493480 Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Thu, 23 Apr 2020 22:17:52 +0200 Subject: [PATCH 07/17] separate tests --- tests/callbacks/__init__.py | 0 .../{trainer => callbacks}/test_callbacks.py | 207 +---------------- tests/callbacks/test_progress_bar.py | 214 ++++++++++++++++++ 3 files changed, 215 insertions(+), 206 deletions(-) create mode 100644 tests/callbacks/__init__.py rename tests/{trainer => callbacks}/test_callbacks.py (60%) create mode 100644 tests/callbacks/test_progress_bar.py diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/trainer/test_callbacks.py b/tests/callbacks/test_callbacks.py similarity index 60% rename from tests/trainer/test_callbacks.py rename to tests/callbacks/test_callbacks.py index e484b90d21d96..312c95cbf1211 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,10 +1,7 @@ -import pytest - import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, ProgressBarBase, ProgressBar, ModelCheckpoint -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from tests.base import ( LightTrainDataloader, LightTestMixin, @@ -267,208 +264,6 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): assert trainer.ckpt_path != trainer.default_root_dir -@pytest.mark.parametrize('callbacks,refresh_rate', [ - ([], 1), - ([], 2), - ([ProgressBar(refresh_rate=1)], 0), - ([ProgressBar(refresh_rate=2)], 0), - ([ProgressBar(refresh_rate=2)], 1), -]) -def test_progress_bar_on(callbacks, refresh_rate): - """Test different ways the progress bar can be turned on.""" - - trainer = Trainer( - callbacks=callbacks, - progress_bar_refresh_rate=refresh_rate, - max_epochs=1, - overfit_pct=0.2, - ) - - progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] - # Trainer supports only a single progress bar callback at the moment - assert len(progress_bars) == 1 - assert progress_bars[0] is trainer.progress_bar_callback - - -@pytest.mark.parametrize('callbacks,refresh_rate', [ - ([], 0), - ([], False), - ([ModelCheckpoint('.')], 0), -]) -def test_progress_bar_off(callbacks, refresh_rate): - """Test different ways the progress bar can be turned off.""" - - trainer = Trainer( - callbacks=callbacks, - progress_bar_refresh_rate=refresh_rate, - ) - - progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] - assert 0 == len(progress_bars) - assert not trainer.progress_bar_callback - - -def test_progress_bar_misconfiguration(): - """Test that Trainer doesn't accept multiple progress bars.""" - callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint('.')] - with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): - Trainer(callbacks=callbacks) - - -def test_progress_bar_totals(): - """Test that the progress finishes with the correct total steps processed.""" - - class CurrentTestModel( - LightTrainDataloader, - LightTestMixin, - LightValidationMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - trainer = Trainer( - progress_bar_refresh_rate=1, - val_percent_check=1.0, - max_epochs=1, - ) - bar = trainer.progress_bar_callback - assert 0 == bar.total_train_batches - assert 0 == bar.total_val_batches - assert 0 == bar.total_test_batches - - trainer.fit(model) - - # check main progress bar total - n = bar.total_train_batches - m = bar.total_val_batches - assert len(trainer.train_dataloader) == n - assert bar.main_progress_bar.total == n + m - - # check val progress bar total - assert sum(len(loader) for loader in trainer.val_dataloaders) == m - assert bar.val_progress_bar.total == m - - # main progress bar should have reached the end (train batches + val batches) - assert bar.main_progress_bar.n == n + m - assert bar.train_batch_idx == n - - # val progress bar should have reached the end - assert bar.val_progress_bar.n == m - assert bar.val_batch_idx == m - - # check that the test progress bar is off - assert 0 == bar.total_test_batches - assert bar.test_progress_bar is None - - trainer.test(model) - - # check test progress bar total - k = bar.total_test_batches - assert sum(len(loader) for loader in trainer.test_dataloaders) == k - assert bar.test_progress_bar.total == k - - # test progress bar should have reached the end - assert bar.test_progress_bar.n == k - assert bar.test_batch_idx == k - - -def test_progress_bar_fast_dev_run(): - class CurrentTestModel( - LightTrainDataloader, - LightTestMixin, - LightValidationMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - trainer = Trainer( - fast_dev_run=True, - ) - - progress_bar = trainer.progress_bar_callback - assert 1 == progress_bar.total_train_batches - # total val batches are known only after val dataloaders have reloaded - - trainer.fit(model) - - assert 1 == progress_bar.total_val_batches - assert 1 == progress_bar.train_batch_idx - assert 1 == progress_bar.val_batch_idx - assert 0 == progress_bar.test_batch_idx - - # the main progress bar should display 2 batches (1 train, 1 val) - assert 2 == progress_bar.main_progress_bar.total - assert 2 == progress_bar.main_progress_bar.n - - -@pytest.mark.parametrize('refresh_rate', [0, 1, 50]) -def test_progress_bar_progress_refresh(refresh_rate): - """Test that the three progress bars get correctly updated when using different refresh rates.""" - - class CurrentTestModel( - LightTrainDataloader, - LightTestMixin, - LightValidationMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - class CurrentProgressBar(ProgressBar): - - train_batches_seen = 0 - val_batches_seen = 0 - test_batches_seen = 0 - - def on_batch_start(self, trainer, pl_module): - super().on_batch_start(trainer, pl_module) - assert self.train_batch_idx == trainer.batch_idx - - def on_batch_end(self, trainer, pl_module): - super().on_batch_end(trainer, pl_module) - assert self.train_batch_idx == trainer.batch_idx + 1 - if not self.disabled and self.train_batch_idx % self.refresh_rate == 0: - assert self.main_progress_bar.n == self.train_batch_idx - self.train_batches_seen += 1 - - def on_validation_batch_end(self, trainer, pl_module): - super().on_validation_batch_end(trainer, pl_module) - if not self.disabled and self.val_batch_idx % self.refresh_rate == 0: - assert self.val_progress_bar.n == self.val_batch_idx - self.val_batches_seen += 1 - - def on_test_batch_end(self, trainer, pl_module): - super().on_test_batch_end(trainer, pl_module) - if not self.disabled and self.test_batch_idx % self.refresh_rate == 0: - assert self.test_progress_bar.n == self.test_batch_idx - self.test_batches_seen += 1 - - progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) - trainer = Trainer( - callbacks=[progress_bar], - progress_bar_refresh_rate=101, # should not matter if custom callback provided - train_percent_check=1.0, - num_sanity_val_steps=2, - max_epochs=3, - ) - assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate - - trainer.fit(model) - assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches - assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps - - trainer.test(model) - assert progress_bar.test_batches_seen == progress_bar.total_test_batches - - def test_model_checkpoint_with_non_string_input(tmpdir): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py new file mode 100644 index 0000000000000..8a73c46e754fd --- /dev/null +++ b/tests/callbacks/test_progress_bar.py @@ -0,0 +1,214 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import ( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase +) + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 1), + ([], 2), + ([ProgressBar(refresh_rate=1)], 0), + ([ProgressBar(refresh_rate=2)], 0), + ([ProgressBar(refresh_rate=2)], 1), +]) +def test_progress_bar_on(callbacks, refresh_rate): + """Test different ways the progress bar can be turned on.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + max_epochs=1, + overfit_pct=0.2, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + # Trainer supports only a single progress bar callback at the moment + assert len(progress_bars) == 1 + assert progress_bars[0] is trainer.progress_bar_callback + + +@pytest.mark.parametrize('callbacks,refresh_rate', [ + ([], 0), + ([], False), + ([ModelCheckpoint('../trainer')], 0), +]) +def test_progress_bar_off(callbacks, refresh_rate): + """Test different ways the progress bar can be turned off.""" + + trainer = Trainer( + callbacks=callbacks, + progress_bar_refresh_rate=refresh_rate, + ) + + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] + assert 0 == len(progress_bars) + assert not trainer.progress_bar_callback + + +def test_progress_bar_misconfiguration(): + """Test that Trainer doesn't accept multiple progress bars.""" + callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint('../trainer')] + with pytest.raises(MisconfigurationException, match=r'^You added multiple progress bar callbacks'): + Trainer(callbacks=callbacks) + + +def test_progress_bar_totals(): + """Test that the progress finishes with the correct total steps processed.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + progress_bar_refresh_rate=1, + val_percent_check=1.0, + max_epochs=1, + ) + bar = trainer.progress_bar_callback + assert 0 == bar.total_train_batches + assert 0 == bar.total_val_batches + assert 0 == bar.total_test_batches + + trainer.fit(model) + + # check main progress bar total + n = bar.total_train_batches + m = bar.total_val_batches + assert len(trainer.train_dataloader) == n + assert bar.main_progress_bar.total == n + m + + # check val progress bar total + assert sum(len(loader) for loader in trainer.val_dataloaders) == m + assert bar.val_progress_bar.total == m + + # main progress bar should have reached the end (train batches + val batches) + assert bar.main_progress_bar.n == n + m + assert bar.train_batch_idx == n + + # val progress bar should have reached the end + assert bar.val_progress_bar.n == m + assert bar.val_batch_idx == m + + # check that the test progress bar is off + assert 0 == bar.total_test_batches + assert bar.test_progress_bar is None + + trainer.test(model) + + # check test progress bar total + k = bar.total_test_batches + assert sum(len(loader) for loader in trainer.test_dataloaders) == k + assert bar.test_progress_bar.total == k + + # test progress bar should have reached the end + assert bar.test_progress_bar.n == k + assert bar.test_batch_idx == k + + +def test_progress_bar_fast_dev_run(): + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + fast_dev_run=True, + ) + + progress_bar = trainer.progress_bar_callback + assert 1 == progress_bar.total_train_batches + # total val batches are known only after val dataloaders have reloaded + + trainer.fit(model) + + assert 1 == progress_bar.total_val_batches + assert 1 == progress_bar.train_batch_idx + assert 1 == progress_bar.val_batch_idx + assert 0 == progress_bar.test_batch_idx + + # the main progress bar should display 2 batches (1 train, 1 val) + assert 2 == progress_bar.main_progress_bar.total + assert 2 == progress_bar.main_progress_bar.n + + +@pytest.mark.parametrize('refresh_rate', [0, 1, 50]) +def test_progress_bar_progress_refresh(refresh_rate): + """Test that the three progress bars get correctly updated when using different refresh rates.""" + + class CurrentTestModel( + LightTrainDataloader, + LightTestMixin, + LightValidationMixin, + TestModelBase, + ): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + class CurrentProgressBar(ProgressBar): + + train_batches_seen = 0 + val_batches_seen = 0 + test_batches_seen = 0 + + def on_batch_start(self, trainer, pl_module): + super().on_batch_start(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + + def on_batch_end(self, trainer, pl_module): + super().on_batch_end(trainer, pl_module) + assert self.train_batch_idx == trainer.batch_idx + 1 + if not self.disabled and self.train_batch_idx % self.refresh_rate == 0: + assert self.main_progress_bar.n == self.train_batch_idx + self.train_batches_seen += 1 + + def on_validation_batch_end(self, trainer, pl_module): + super().on_validation_batch_end(trainer, pl_module) + if not self.disabled and self.val_batch_idx % self.refresh_rate == 0: + assert self.val_progress_bar.n == self.val_batch_idx + self.val_batches_seen += 1 + + def on_test_batch_end(self, trainer, pl_module): + super().on_test_batch_end(trainer, pl_module) + if not self.disabled and self.test_batch_idx % self.refresh_rate == 0: + assert self.test_progress_bar.n == self.test_batch_idx + self.test_batches_seen += 1 + + progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) + trainer = Trainer( + callbacks=[progress_bar], + progress_bar_refresh_rate=101, # should not matter if custom callback provided + train_percent_check=1.0, + num_sanity_val_steps=2, + max_epochs=3, + ) + assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate + + trainer.fit(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + + trainer.test(model) + assert progress_bar.test_batches_seen == progress_bar.total_test_batches \ No newline at end of file From 01df5a64b6af1088a922e7d008d9fa680f1bd3fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 22:54:37 +0200 Subject: [PATCH 08/17] type hint total batches --- pytorch_lightning/callbacks/progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 24cf26dd9761b..c2aa2cebb54bf 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -76,7 +76,7 @@ def test_batch_idx(self) -> int: return self._test_batch_idx @property - def total_train_batches(self) -> Optional[int]: + def total_train_batches(self) -> int: """ The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the @@ -89,7 +89,7 @@ def total_train_batches(self) -> Optional[int]: return total_train_batches @property - def total_val_batches(self) -> Optional[int]: + def total_val_batches(self) -> int: """ The total number of training batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the @@ -105,7 +105,7 @@ def total_val_batches(self) -> Optional[int]: return total_val_batches @property - def total_test_batches(self) -> Optional[int]: + def total_test_batches(self) -> int: """ The total number of training batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the From 1aca00a31f61a5c698981299eb695d077eba835c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 22:56:13 +0200 Subject: [PATCH 09/17] reduce if Co-Authored-By: Jirka Borovec --- pytorch_lightning/callbacks/progress.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index c2aa2cebb54bf..9ea96f155cf6e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -82,10 +82,7 @@ def total_train_batches(self) -> int: Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """ - if self.trainer.fast_dev_run: - total_train_batches = 1 - else: - total_train_batches = self.trainer.num_training_batches + total_train_batches = 1 if self.trainer.fast_dev_run else self.trainer.num_training_batches return total_train_batches @property From 9693948acbb530de6c1eb3b6e1f52f36f9d535c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 22:56:26 +0200 Subject: [PATCH 10/17] is_disabled Co-Authored-By: Jirka Borovec --- pytorch_lightning/callbacks/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 9ea96f155cf6e..5c564712776ab 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -233,7 +233,7 @@ def enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property - def disabled(self) -> bool: + def is_disabled(self) -> bool: return not self.enabled def disable(self) -> None: From 293e4e19c45321301fef4e88de5def7a5b762d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 22:56:45 +0200 Subject: [PATCH 11/17] is_enabled Co-Authored-By: Jirka Borovec --- pytorch_lightning/callbacks/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 5c564712776ab..e041898bb4dcc 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -229,7 +229,7 @@ def process_position(self) -> int: return self._process_position @property - def enabled(self) -> bool: + def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property From 1948beebae57fd8fb20e580b15875ebdeac72368 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 23:00:26 +0200 Subject: [PATCH 12/17] rename enabled/disabled --- pytorch_lightning/callbacks/progress.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index e041898bb4dcc..a397c0d2c8d56 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -6,7 +6,6 @@ """ import sys -from typing import Optional from tqdm.auto import tqdm @@ -28,7 +27,7 @@ def __init__(self): self.enabled = True def disable(self): - self.enabled = False + self.enableenabled = False def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) # don't forget this :) @@ -234,7 +233,7 @@ def is_enabled(self) -> bool: @property def is_disabled(self) -> bool: - return not self.enabled + return not self.is_enabled def disable(self) -> None: self._enabled = False @@ -247,7 +246,7 @@ def init_sanity_tqdm(self) -> tqdm: bar = tqdm( desc='Validation sanity check', position=(2 * self.process_position), - disable=self.disabled, + disable=self.is_disabled, leave=False, dynamic_ncols=True, file=sys.stdout, @@ -260,7 +259,7 @@ def init_train_tqdm(self) -> tqdm: desc='Training', initial=self.train_batch_idx, position=(2 * self.process_position), - disable=self.disabled, + disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout, @@ -273,7 +272,7 @@ def init_validation_tqdm(self) -> tqdm: bar = tqdm( desc='Validating', position=(2 * self.process_position + 1), - disable=self.disabled, + disable=self.is_disabled, leave=False, dynamic_ncols=True, file=sys.stdout @@ -285,7 +284,7 @@ def init_test_tqdm(self) -> tqdm: bar = tqdm( desc='Testing', position=(2 * self.process_position), - disable=self.disabled, + disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout @@ -322,7 +321,7 @@ def on_epoch_start(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) - if self.enabled and self.train_batch_idx % self.refresh_rate == 0: + if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: self.main_progress_bar.update(self.refresh_rate) self.main_progress_bar.set_postfix(**trainer.progress_bar_dict) @@ -333,7 +332,7 @@ def on_validation_start(self, trainer, pl_module): def on_validation_batch_end(self, trainer, pl_module): super().on_validation_batch_end(trainer, pl_module) - if self.enabled and self.val_batch_idx % self.refresh_rate == 0: + if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0: self.val_progress_bar.update(self.refresh_rate) self.main_progress_bar.update(self.refresh_rate) @@ -353,7 +352,7 @@ def on_test_start(self, trainer, pl_module): def on_test_batch_end(self, trainer, pl_module): super().on_test_batch_end(trainer, pl_module) - if self.enabled and self.test_batch_idx % self.refresh_rate == 0: + if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0: self.test_progress_bar.update(self.refresh_rate) def on_test_end(self, trainer, pl_module): From 40602ee99df8853568335fd3f975de2841e9aeb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 23:20:03 +0200 Subject: [PATCH 13/17] move deprecated api --- pytorch_lightning/trainer/deprecated_api.py | 14 ++++++++++ pytorch_lightning/trainer/trainer.py | 30 ++------------------- 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index d9f461d5d039a..2705c4f160464 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -103,6 +103,13 @@ def default_save_path(self, path): " and this method will be removed in v0.8.0", DeprecationWarning) self.default_root_dir = path + @property + def tng_tqdm_dic(self): + """Back compatibility, will be removed in v0.8.0""" + rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" + " and this method will be removed in v0.8.0", DeprecationWarning) + return self.progress_bar_dict + class TrainerDeprecatedAPITillVer0_9(ABC): @@ -121,3 +128,10 @@ def show_progress_bar(self, tf): """Back compatibility, will be removed in v0.9.0""" rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2" " and this method will be removed in v0.9.0", DeprecationWarning) + + @property + def training_tqdm_dict(self): + """Back compatibility, will be removed in v0.9.0""" + rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3" + " and this method will be removed in v0.9.0", DeprecationWarning) + return self.progress_bar_dict diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8700ad1944f19..ce8b9ca090d26 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -72,9 +72,9 @@ class Trainer( ): DEPRECATED_IN_0_8 = ( 'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', - 'add_row_log_interval', 'nb_sanity_val_steps' + 'add_row_log_interval', 'nb_sanity_val_steps', 'tng_tqdm_dic', ) - DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar') + DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict') def __init__( self, @@ -637,32 +637,6 @@ def progress_bar_dict(self) -> dict: ref_model = self.model if not self.data_parallel else self.model.module return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics) - @property - def training_tqdm_dict(self): - """Read-only for progress bar metrics. - - Warning: - Deprecated since v0.7.3. - Use :meth:`progress_bar_dict` instead. - - """ - rank_zero_warn("`training_tqdm_dict` was renamed to `progress_bar_dict` in v0.7.3" - " and this method will be removed in v1.0.0", DeprecationWarning) - return self.progress_bar_dict - - @property - def tng_tqdm_dic(self): - """Read-only for progress bar metrics. - - .. warning:: .. deprecated:: 0.5.0 - - Use `training_tqdm_dict` instead. Will remove 0.8.0. - - """ - rank_zero_warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0" - " and this method will be removed in v0.8.0", DeprecationWarning) - return self.progress_bar_dict - # ----------------------------- # MODEL TRAINING # ----------------------------- From 1d825ed843a1584b6ae576a60954d47368488ce3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 23:21:18 +0200 Subject: [PATCH 14/17] remove duplicated test from merge --- tests/callbacks/test_callbacks.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 312c95cbf1211..4731d4351679e 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -262,27 +262,3 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): # These should be different if the dirpath has be overridden assert trainer.ckpt_path != trainer.default_root_dir - - -def test_model_checkpoint_with_non_string_input(tmpdir): - """ Test that None in checkpoint callback is valid and that chkp_path is - set correctly """ - tutils.reset_seed() - - class CurrentTestModel(LightTrainDataloader, TestModelBase): - pass - - hparams = tutils.get_default_hparams() - model = CurrentTestModel(hparams) - - checkpoint = ModelCheckpoint(filepath=None, save_top_k=-1) - - trainer = Trainer(default_root_dir=tmpdir, - checkpoint_callback=checkpoint, - overfit_pct=0.20, - max_epochs=5 - ) - result = trainer.fit(model) - - # These should be different if the dirpath has be overridden - assert trainer.ckpt_path != trainer.default_root_dir From e75782383f0735e6bb855b90ab42e50bb9e3218a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 23:23:57 +0200 Subject: [PATCH 15/17] fix rename is_disabled --- tests/callbacks/test_progress_bar.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8a73c46e754fd..07f4ad8c6a688 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -180,19 +180,19 @@ def on_batch_start(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) assert self.train_batch_idx == trainer.batch_idx + 1 - if not self.disabled and self.train_batch_idx % self.refresh_rate == 0: + if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx self.train_batches_seen += 1 def on_validation_batch_end(self, trainer, pl_module): super().on_validation_batch_end(trainer, pl_module) - if not self.disabled and self.val_batch_idx % self.refresh_rate == 0: + if not self.is_disabled and self.val_batch_idx % self.refresh_rate == 0: assert self.val_progress_bar.n == self.val_batch_idx self.val_batches_seen += 1 def on_test_batch_end(self, trainer, pl_module): super().on_test_batch_end(trainer, pl_module) - if not self.disabled and self.test_batch_idx % self.refresh_rate == 0: + if not self.is_disabled and self.test_batch_idx % self.refresh_rate == 0: assert self.test_progress_bar.n == self.test_batch_idx self.test_batches_seen += 1 From 334bc584c74f18bd2d35a7f2e01d76996e97cb80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 23:25:37 +0200 Subject: [PATCH 16/17] newline --- tests/callbacks/test_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 07f4ad8c6a688..226b3088da71d 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -211,4 +211,4 @@ def on_test_batch_end(self, trainer, pl_module): assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps trainer.test(model) - assert progress_bar.test_batches_seen == progress_bar.total_test_batches \ No newline at end of file + assert progress_bar.test_batches_seen == progress_bar.total_test_batches From 30eaad570ffb90316ba2fd9429a663646f271795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Apr 2020 23:33:04 +0200 Subject: [PATCH 17/17] test also testprogress for fast dev run --- tests/callbacks/test_progress_bar.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 226b3088da71d..7cd5d5435adef 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -151,6 +151,13 @@ class CurrentTestModel( assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.test(model) + + # the test progress bar should display 1 batch + assert 1 == progress_bar.test_batch_idx + assert 1 == progress_bar.test_progress_bar.total + assert 1 == progress_bar.test_progress_bar.n + @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) def test_progress_bar_progress_refresh(refresh_rate):