From 401c07f861b4f4145489ee5fa51c56dd7e546172 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 12 Nov 2020 01:23:47 +0530 Subject: [PATCH 01/11] fast_dev_run can be int --- .../trainer/connectors/debugging_connector.py | 21 ++++++-- pytorch_lightning/trainer/properties.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 +- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- tests/trainer/test_dataloaders.py | 49 ++++++++++++------- 6 files changed, 53 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 580d13a6df9fd..0f5e00ad58899 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -32,15 +32,26 @@ def on_init_start( fast_dev_run ): + if isinstance(fast_dev_run, int) and fast_dev_run < 0: + raise MisconfigurationException( + f'fast_dev_run={fast_dev_run} is not a' + ' valid configuration. It should be >= 0.' + ) + self.trainer.fast_dev_run = fast_dev_run - if self.trainer.fast_dev_run: - limit_train_batches = 1 - limit_val_batches = 1 - limit_test_batches = 1 + + if fast_dev_run is True: + fast_dev_run = 1 + + if fast_dev_run: + limit_train_batches = fast_dev_run + limit_val_batches = fast_dev_run + limit_test_batches = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.max_epochs = 1 rank_zero_info( - 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' + 'Running in fast_dev_run mode: will run a full train,' + f' val and test loop using {fast_dev_run} batch(es)' ) self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index af06b1bbc1352..c7aa616adfb64 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -37,7 +37,7 @@ class TrainerProperties(ABC): logger_connector: LoggerConnector _state: TrainerState global_rank: int - fast_dev_run: bool + fast_dev_run: Union[int, bool] use_dp: bool use_ddp: bool use_ddp2: bool diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 46e4abbe584ae..c3efca5dc3825 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -104,7 +104,7 @@ def __init__( overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, - fast_dev_run: bool = False, + fast_dev_run: Union[int, bool] = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, @@ -191,7 +191,7 @@ def __init__( distributed_backend: deprecated. Please use 'accelerator' - fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). + fast_dev_run: runs n if set to n(int) else 1 batch(es) of train, test and val to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 67a4704b628fc..936efded09bd5 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -70,7 +70,7 @@ def scale_batch_size(trainer, or datamodule. """ if trainer.fast_dev_run: - rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning) + rank_zero_warn(f'Skipping batch size scaler since `fast_dev_run={trainer.fast_dev_run}`', UserWarning) return if not lightning_hasattr(model, batch_arg_name): diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index b6d8c8178093b..79335c7bce774 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -137,7 +137,7 @@ def lr_find( """ if trainer.fast_dev_run: - rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) + rank_zero_warn(f'Skipping learning rate finder since `fast_dev_run={trainer.fast_dev_run}`', UserWarning) return save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1ca34101a9141..4f0ff84214a16 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -434,9 +434,11 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_dataloaders_with_fast_dev_run(tmpdir): - """Verify num_batches for train, val & test dataloaders passed with fast_dev_run = True""" - +@pytest.mark.parametrize('fast_dev_run', [True, 3, -1]) +def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): + """ + Verify num_batches for train, val & test dataloaders passed with fast_dev_run + """ model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -445,26 +447,39 @@ def test_dataloaders_with_fast_dev_run(tmpdir): model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders - # train, multiple val and multiple test dataloaders passed with fast_dev_run = True - trainer = Trainer( + trainer_options = dict( default_root_dir=tmpdir, max_epochs=2, - fast_dev_run=True, + fast_dev_run=fast_dev_run, ) - assert trainer.max_epochs == 1 - assert trainer.num_sanity_val_steps == 0 - trainer.fit(model) - assert not trainer.disable_validation - assert trainer.num_training_batches == 1 - assert trainer.num_val_batches == [1] * len(trainer.val_dataloaders) + if fast_dev_run == -1: + with pytest.raises(MisconfigurationException, match='should be >= 0'): + trainer = Trainer(**trainer_options) + return + else: + trainer = Trainer(**trainer_options) - trainer.test(ckpt_path=None) - assert trainer.num_test_batches == [1] * len(trainer.test_dataloaders) + if fast_dev_run is True: + fast_dev_run = 1 - # verify sanity check batches match as expected - num_val_dataloaders = len(model.val_dataloader()) - assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders + assert trainer.limit_train_batches == fast_dev_run + assert trainer.limit_val_batches == fast_dev_run + assert trainer.limit_test_batches == fast_dev_run + assert trainer.num_sanity_val_steps == 0 + assert trainer.max_epochs == 1 + + trainer.fit(model) + assert not trainer.disable_validation + assert trainer.num_training_batches == fast_dev_run + assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders) + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders) + + # verify sanity check batches match as expected + num_val_dataloaders = len(model.val_dataloader()) + assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) From b9ecf4410abe0992554ece428600d14c6df45897 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 12 Nov 2020 01:27:11 +0530 Subject: [PATCH 02/11] pep --- pytorch_lightning/trainer/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c3efca5dc3825..ee77b2c0fe85e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -191,7 +191,8 @@ def __init__( distributed_backend: deprecated. Please use 'accelerator' - fast_dev_run: runs n if set to n(int) else 1 batch(es) of train, test and val to find any bugs (ie: a sort of unit test). + fast_dev_run: runs n if set to n(int) else 1 batch(es) of train, val + and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). From 638f8b4e5967101040768d9a959e295cf1d0bd78 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 12 Nov 2020 01:30:34 +0530 Subject: [PATCH 03/11] chlog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 797569f71128b..f111401327aac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675)) +- Updated `fast_dev_run` to accept integer representing num_batches ([#4629](https://github.com/PyTorchLightning/pytorch-lightning/pull/4629)) + + ### Changed - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) @@ -63,6 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added lambda closure to `manual_optimizer_step` ([#4618](https://github.com/PyTorchLightning/pytorch-lightning/pull/4618)) + ### Changed - Change Metrics `persistent` default mode to `False` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/4685)) From 2a36cefa02dd53bef84176d04d8b8fe48e573ecc Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 13 Nov 2020 00:07:57 +0530 Subject: [PATCH 04/11] add check and update docs --- .../trainer/connectors/debugging_connector.py | 7 ++++++- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/trainer/test_dataloaders.py | 8 +++++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 0f5e00ad58899..1ea4f733a2a2f 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -31,8 +31,13 @@ def on_init_start( overfit_batches, fast_dev_run ): + if not isinstance(fast_dev_run, (bool, int)): + raise MisconfigurationException( + f'fast_dev_run={fast_dev_run} is not a valid configuration.' + ' It should be either a bool or an int >= 0' + ) - if isinstance(fast_dev_run, int) and fast_dev_run < 0: + if isinstance(fast_dev_run, int) and (fast_dev_run < 0): raise MisconfigurationException( f'fast_dev_run={fast_dev_run} is not a' ' valid configuration. It should be >= 0.' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ee77b2c0fe85e..028494e37edf2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -191,8 +191,8 @@ def __init__( distributed_backend: deprecated. Please use 'accelerator' - fast_dev_run: runs n if set to n(int) else 1 batch(es) of train, val - and test to find any bugs (ie: a sort of unit test). + fast_dev_run: runs n if set to n(int) else 1 if set to ``True`` batch(es) + of train, val and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 4f0ff84214a16..844dc31f745b6 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -434,7 +434,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.parametrize('fast_dev_run', [True, 3, -1]) +@pytest.mark.parametrize('fast_dev_run', [True, 3, -1, 'temp']) def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): """ Verify num_batches for train, val & test dataloaders passed with fast_dev_run @@ -453,10 +453,12 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): fast_dev_run=fast_dev_run, ) - if fast_dev_run == -1: + if fast_dev_run == 'temp': + with pytest.raises(MisconfigurationException, match='either a bool or an int'): + trainer = Trainer(**trainer_options) + elif fast_dev_run == -1: with pytest.raises(MisconfigurationException, match='should be >= 0'): trainer = Trainer(**trainer_options) - return else: trainer = Trainer(**trainer_options) From df664377672acea3ec6b1f23ce46cee009c06608 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 16 Nov 2020 01:10:38 +0530 Subject: [PATCH 05/11] logging with fdr --- pytorch_lightning/trainer/connectors/debugging_connector.py | 4 +--- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 1ea4f733a2a2f..b8703d28e9ba7 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -44,9 +44,7 @@ def on_init_start( ) self.trainer.fast_dev_run = fast_dev_run - - if fast_dev_run is True: - fast_dev_run = 1 + fast_dev_run = int(fast_dev_run) if fast_dev_run: limit_train_batches = fast_dev_run diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 9386d428b1f07..bcc1122451ceb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -614,7 +614,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): def log_train_step_metrics(self, batch_output): # when metrics should be logged - if self.should_update_logs or self.trainer.fast_dev_run: + if self.should_update_logs or (int(self.trainer.fast_dev_run) == 1): # logs user requested information to logger metrics = self.cached_results.get_latest_batch_log_metrics() grad_norm_dic = batch_output.grad_norm_dic From 4271d2f6be952c3d12ab711b0c0250615e6eb50f Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 16 Nov 2020 03:08:36 +0530 Subject: [PATCH 06/11] update docs --- docs/source/debugging.rst | 11 +++++++---- docs/source/trainer.rst | 28 +++++++++++++++------------- pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 845f86a52b231..01de6de613a68 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -21,17 +21,20 @@ The following are flags that make debugging much easier. fast_dev_run ------------ -This flag runs a "unit test" by running 1 training batch and 1 validation batch. -The point is to detect any bugs in the training/validation loop without having to wait for -a full epoch to crash. +This flag runs a "unit test" by running n if set to ``n`` (int) else 1 if set to ``True`` training and validation batch(es). +The point is to detect any bugs in the training/validation loop without having to wait for a full epoch to crash. (See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run` argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. testcode:: - + + # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) + # runs 7 train, val, test batches and program ends + trainer = Trainer(fast_dev_run=7) + ---------------- Inspect gradient norms diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 0f9fc1fd42572..f6859d2b3d2e5 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -245,7 +245,7 @@ Example:: # ddp2 = DistributedDataParallel + dp trainer = Trainer(gpus=2, num_nodes=2, accelerator='ddp2') -.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core) +.. note:: This option does not apply to TPU. TPUs use ```ddp``` by default (over each core) You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs. @@ -623,17 +623,10 @@ fast_dev_run | -.. raw:: html - - - -| +Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test +to find any bugs (ie: a sort of unit test). -Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). - -Under the hood the pseudocode looks like this: +Under the hood the pseudocode looks like this when running *fast_dev_run* with a single batch: .. code-block:: python @@ -658,6 +651,16 @@ Under the hood the pseudocode looks like this: # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) + # runs 7 train, val, test batches and program ends + trainer = Trainer(fast_dev_run=7) + +.. note:: + + This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will + disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be + used only for debugging purpose. ``limit_train/val/test_batches`` serves their own purpose and won't + disable anything. + gpus ^^^^ @@ -1199,8 +1202,7 @@ Orders the progress bar. Useful when running multiple trainers on the same node. # default used by the Trainer trainer = Trainer(process_position=0) -Note: - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. +.. note:: This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. profiler ^^^^^^^^ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 028494e37edf2..d124e39352798 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -191,7 +191,7 @@ def __init__( distributed_backend: deprecated. Please use 'accelerator' - fast_dev_run: runs n if set to n(int) else 1 if set to ``True`` batch(es) + fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). From 938fe0863a8460825cc0b32c6fa1f6f0eee3778d Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 17 Nov 2020 22:55:57 +0530 Subject: [PATCH 07/11] suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- docs/source/trainer.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index f6859d2b3d2e5..bfc0e41d2b097 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -658,7 +658,7 @@ Under the hood the pseudocode looks like this when running *fast_dev_run* with a This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be - used only for debugging purpose. ``limit_train/val/test_batches`` serves their own purpose and won't + used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't disable anything. gpus @@ -1704,4 +1704,3 @@ The metrics sent to the progress bar. progress_bar_metrics = trainer.progress_bar_metrics assert progress_bar_metrics['a_val'] == 2 - From 0ea9a6727ee43f703dda02957c4024fe5dfbf692 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 18 Nov 2020 00:45:19 +0530 Subject: [PATCH 08/11] fdr flush logs --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8af55f64715f2..4c5eafb904972 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -873,7 +873,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs or self.trainer.fast_dev_run: + if should_flush_logs or (int(self.trainer.fast_dev_run) == 1): if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() From 4951bfd69c48b8c88fe037e02abfbde564bc4f1d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 19 Nov 2020 01:19:44 +0530 Subject: [PATCH 09/11] update trainer.fast_dev_run --- .../trainer/connectors/debugging_connector.py | 4 ++++ .../connectors/logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- tests/trainer/flags/test_fast_dev_run.py | 4 ++-- tests/trainer/test_dataloaders.py | 8 +++++++- 7 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index b8703d28e9ba7..61d7cbd189fde 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -46,6 +46,10 @@ def on_init_start( self.trainer.fast_dev_run = fast_dev_run fast_dev_run = int(fast_dev_run) + # set fast_dev_run=True when it is 1, used while logging + if fast_dev_run == 1: + self.trainer.fast_dev_run = True + if fast_dev_run: limit_train_batches = fast_dev_run limit_val_batches = fast_dev_run diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index bcc1122451ceb..4a35c8f0c1268 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -614,7 +614,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): def log_train_step_metrics(self, batch_output): # when metrics should be logged - if self.should_update_logs or (int(self.trainer.fast_dev_run) == 1): + if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger metrics = self.cached_results.get_latest_batch_log_metrics() grad_norm_dic = batch_output.grad_norm_dic diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4c5eafb904972..f46279a1e04c4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -873,7 +873,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs or (int(self.trainer.fast_dev_run) == 1): + if should_flush_logs or self.trainer.fast_dev_run is True: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 936efded09bd5..4d182e7180dc0 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -70,7 +70,7 @@ def scale_batch_size(trainer, or datamodule. """ if trainer.fast_dev_run: - rank_zero_warn(f'Skipping batch size scaler since `fast_dev_run={trainer.fast_dev_run}`', UserWarning) + rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning) return if not lightning_hasattr(model, batch_arg_name): diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 79335c7bce774..2982454d02f70 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -137,7 +137,7 @@ def lr_find( """ if trainer.fast_dev_run: - rank_zero_warn(f'Skipping learning rate finder since `fast_dev_run={trainer.fast_dev_run}`', UserWarning) + rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index cbe4d4012227a..00c62cdf48fce 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder']) -def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): +def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): """ Test that tuner algorithms are skipped if fast dev run is enabled """ hparams = EvalModelTemplate.get_default_hparams() @@ -16,6 +16,6 @@ def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): auto_lr_find=True if tuner_alg == 'learning rate finder' else False, fast_dev_run=True ) - expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`' + expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' with pytest.warns(UserWarning, match=expected_message): trainer.tune(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 844dc31f745b6..75799914404d3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -434,7 +434,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.parametrize('fast_dev_run', [True, 3, -1, 'temp']) +@pytest.mark.parametrize('fast_dev_run', [True, 1, 3, -1, 'temp']) def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): """ Verify num_batches for train, val & test dataloaders passed with fast_dev_run @@ -462,6 +462,12 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): else: trainer = Trainer(**trainer_options) + # fast_dev_run is set to True when it is 1 + if fast_dev_run == 1: + fast_dev_run = True + + assert trainer.fast_dev_run is fast_dev_run + if fast_dev_run is True: fast_dev_run = 1 From 29807f6c963f6a7bb598c2e32dc7d1d8e0298894 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 8 Dec 2020 23:43:31 +0530 Subject: [PATCH 10/11] codefactor and pre-commit isort --- pytorch_lightning/trainer/properties.py | 14 +++++++------- pytorch_lightning/trainer/training_loop.py | 6 +++++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 1c1cf66ba2f47..f315bf9df819c 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -57,19 +57,19 @@ class TrainerProperties(ABC): @property def log_dir(self): if self.checkpoint_callback is not None: - dir = self.checkpoint_callback.dirpath - dir = os.path.split(dir)[0] + dirpath = self.checkpoint_callback.dirpath + dirpath = os.path.split(dirpath)[0] elif self.logger is not None: if isinstance(self.logger, TensorBoardLogger): - dir = self.logger.log_dir + dirpath = self.logger.log_dir else: - dir = self.logger.save_dir + dirpath = self.logger.save_dir else: - dir = self._default_root_dir + dirpath = self._default_root_dir if self.accelerator_backend is not None: - dir = self.accelerator_backend.broadcast(dir) - return dir + dirpath = self.accelerator_backend.broadcast(dirpath) + return dirpath @property def use_amp(self) -> bool: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5e0877ccfa8f6..20dfb0f4b380f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -215,10 +215,14 @@ def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback if should_save and self.trainer.checkpoint_connector.has_trained: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info("Saving latest checkpoint...") + model = self.trainer.get_model() - [cb.on_validation_end(self.trainer, model) for cb in checkpoint_callbacks] + + for callback in checkpoint_callbacks: + callback.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): From f8248e3f14de3af02d082cdbef31ac5e9abc89f9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 8 Dec 2020 23:47:28 +0530 Subject: [PATCH 11/11] tmp