Skip to content

Commit

Permalink
move show_progress_bar to deprecated 0.9 api
Browse files Browse the repository at this point in the history
  • Loading branch information
gbkh2015 committed Mar 26, 2020
1 parent 03c9d0f commit db9dcdf
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
20 changes: 20 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,23 @@ def nb_sanity_val_steps(self, nb):
"`num_sanity_val_steps` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
self.num_sanity_val_steps = nb


class TrainerDeprecatedAPITillVer0_9(ABC):

def __init__(self):
super().__init__() # mixin calls super too

@property
def show_progress_bar(self):
"""Back compatibility, will be removed in v0.9.0"""
warnings.warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
" and this method will be removed in v0.9.0", DeprecationWarning)
return self.progress_bar_refresh_rate >= 1

@show_progress_bar.setter
def show_progress_bar(self, tf):
"""Back compatibility, will be removed in v0.9.0"""
warnings.warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2"
" and this method will be removed in v0.9.0", DeprecationWarning)
self.show_progress_bar = tf
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_m
dl_outputs.append(output)

# batch done
if self.show_progress_bar and batch_idx % self.progress_bar_refresh_rate == 0:
if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
if test_mode:
self.test_progress_bar.update(self.progress_bar_refresh_rate)
else:
Expand Down
19 changes: 8 additions & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
from pytorch_lightning.trainer.deprecated_api import (TrainerDeprecatedAPITillVer0_8,
TrainerDeprecatedAPITillVer0_9)
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
Expand Down Expand Up @@ -66,12 +67,13 @@ class Trainer(
TrainerCallbackConfigMixin,
TrainerCallbackHookMixin,
TrainerDeprecatedAPITillVer0_8,
TrainerDeprecatedAPITillVer0_9,
):
DEPRECATED_IN_0_8 = (
'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs',
'add_row_log_interval', 'nb_sanity_val_steps'
)
DEPRECATED_IN_0_9 = ('use_amp',)
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar')

def __init__(
self,
Expand All @@ -88,7 +90,7 @@ def __init__(
gpus: Optional[Union[List[int], str, int]] = None,
num_tpu_cores: Optional[int] = None,
log_gpu_memory: Optional[str] = None,
show_progress_bar=None, # backward compatible, todo: remove in v0.8.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
progress_bar_refresh_rate: int = 1,
overfit_pct: float = 0.0,
track_grad_norm: int = -1,
Expand Down Expand Up @@ -416,12 +418,11 @@ def __init__(
# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)

# Backward compatibility, TODO: remove in v0.8.0
if show_progress_bar is not None:
warnings.warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.1"
" and this method will be removed in v0.8.0", DeprecationWarning)
# can't init progress bar here because starting a new process
# means the progress_bar won't survive pickling
# backward compatibility
if show_progress_bar is not None:
self.show_progress_bar = show_progress_bar

# logging
self.log_save_interval = log_save_interval
Expand Down Expand Up @@ -567,10 +568,6 @@ def from_argparse_args(cls, args):
params = vars(args)
return cls(**params)

@property
def show_progress_bar(self) -> bool:
return self.progress_bar_refresh_rate >= 1

@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def optimizer_closure():
self.get_model().on_batch_end()

# update progress bar
if self.show_progress_bar and batch_idx % self.progress_bar_refresh_rate == 0:
if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
self.main_progress_bar.update(self.progress_bar_refresh_rate)
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)

Expand Down

0 comments on commit db9dcdf

Please sign in to comment.