diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ba46ebdc8520..cc0cf47e1f9ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675)) +- Updated `fast_dev_run` to accept integer representing num_batches ([#4629](https://github.com/PyTorchLightning/pytorch-lightning/pull/4629)) + + +- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647)) + + - Added `prefix` argument in loggers ([#4557](https://github.com/PyTorchLightning/pytorch-lightning/pull/4557)) @@ -156,6 +162,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added lambda closure to `manual_optimizer_step` ([#4618](https://github.com/PyTorchLightning/pytorch-lightning/pull/4618)) + ### Changed - Change Metrics `persistent` default mode to `False` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/4685)) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index fea230d67d016..5eaf4303d3e4c 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -21,17 +21,20 @@ The following are flags that make debugging much easier. fast_dev_run ------------ -This flag runs a "unit test" by running 1 training batch and 1 validation batch. -The point is to detect any bugs in the training/validation loop without having to wait for -a full epoch to crash. +This flag runs a "unit test" by running n if set to ``n`` (int) else 1 if set to ``True`` training and validation batch(es). +The point is to detect any bugs in the training/validation loop without having to wait for a full epoch to crash. (See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run` argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. testcode:: - + + # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) + # runs 7 train, val, test batches and program ends + trainer = Trainer(fast_dev_run=7) + ---------------- Inspect gradient norms diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 92b6b290176f5..99f93bd02f0b4 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -245,7 +245,7 @@ Example:: # ddp2 = DistributedDataParallel + dp trainer = Trainer(gpus=2, num_nodes=2, accelerator='ddp2') -.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core) +.. note:: This option does not apply to TPU. TPUs use ```ddp``` by default (over each core) You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs. @@ -632,9 +632,10 @@ fast_dev_run | -Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). +Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test +to find any bugs (ie: a sort of unit test). -Under the hood the pseudocode looks like this: +Under the hood the pseudocode looks like this when running *fast_dev_run* with a single batch: .. code-block:: python @@ -659,6 +660,16 @@ Under the hood the pseudocode looks like this: # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) + # runs 7 train, val, test batches and program ends + trainer = Trainer(fast_dev_run=7) + +.. note:: + + This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will + disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be + used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't + disable anything. + gpus ^^^^ @@ -1200,8 +1211,7 @@ Orders the progress bar. Useful when running multiple trainers on the same node. # default used by the Trainer trainer = Trainer(process_position=0) -Note: - This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. +.. note:: This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`. profiler ^^^^^^^^ diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 580d13a6df9fd..61d7cbd189fde 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -31,16 +31,34 @@ def on_init_start( overfit_batches, fast_dev_run ): + if not isinstance(fast_dev_run, (bool, int)): + raise MisconfigurationException( + f'fast_dev_run={fast_dev_run} is not a valid configuration.' + ' It should be either a bool or an int >= 0' + ) + + if isinstance(fast_dev_run, int) and (fast_dev_run < 0): + raise MisconfigurationException( + f'fast_dev_run={fast_dev_run} is not a' + ' valid configuration. It should be >= 0.' + ) self.trainer.fast_dev_run = fast_dev_run - if self.trainer.fast_dev_run: - limit_train_batches = 1 - limit_val_batches = 1 - limit_test_batches = 1 + fast_dev_run = int(fast_dev_run) + + # set fast_dev_run=True when it is 1, used while logging + if fast_dev_run == 1: + self.trainer.fast_dev_run = True + + if fast_dev_run: + limit_train_batches = fast_dev_run + limit_val_batches = fast_dev_run + limit_test_batches = fast_dev_run self.trainer.num_sanity_val_steps = 0 self.trainer.max_epochs = 1 rank_zero_info( - 'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch' + 'Running in fast_dev_run mode: will run a full train,' + f' val and test loop using {fast_dev_run} batch(es)' ) self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches') diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 99a0c846fe86f..6fdd2f0d57b63 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -589,7 +589,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): def log_train_step_metrics(self, batch_output): _, batch_log_metrics = self.cached_results.update_logger_connector() # when metrics should be logged - if self.should_update_logs or self.trainer.fast_dev_run: + if self.should_update_logs or self.trainer.fast_dev_run is True: # logs user requested information to logger grad_norm_dic = batch_output.grad_norm_dic if grad_norm_dic is None: diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ba94ec2d95abb..f315bf9df819c 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -38,7 +38,7 @@ class TrainerProperties(ABC): logger_connector: LoggerConnector _state: TrainerState global_rank: int - fast_dev_run: bool + fast_dev_run: Union[int, bool] use_dp: bool use_ddp: bool use_ddp2: bool @@ -57,19 +57,19 @@ class TrainerProperties(ABC): @property def log_dir(self): if self.checkpoint_callback is not None: - dir = self.checkpoint_callback.dirpath - dir = os.path.split(dir)[0] + dirpath = self.checkpoint_callback.dirpath + dirpath = os.path.split(dirpath)[0] elif self.logger is not None: if isinstance(self.logger, TensorBoardLogger): - dir = self.logger.log_dir + dirpath = self.logger.log_dir else: - dir = self.logger.save_dir + dirpath = self.logger.save_dir else: - dir = self._default_root_dir + dirpath = self._default_root_dir if self.accelerator_backend is not None: - dir = self.accelerator_backend.broadcast(dir) - return dir + dirpath = self.accelerator_backend.broadcast(dirpath) + return dirpath @property def use_amp(self) -> bool: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6740406c5fe9a..ccb9f9418c838 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -98,7 +98,7 @@ def __init__( overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, - fast_dev_run: bool = False, + fast_dev_run: Union[int, bool] = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, @@ -186,7 +186,8 @@ def __init__( distributed_backend: deprecated. Please use 'accelerator' - fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). + fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) + of train, val and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4e3b3ef9a8620..20dfb0f4b380f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -215,10 +215,14 @@ def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback if should_save and self.trainer.checkpoint_connector.has_trained: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info("Saving latest checkpoint...") + model = self.trainer.get_model() - [cb.on_validation_end(self.trainer, model) for cb in checkpoint_callbacks] + + for callback in checkpoint_callbacks: + callback.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): @@ -908,7 +912,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs or self.trainer.fast_dev_run: + if should_flush_logs or self.trainer.fast_dev_run is True: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 1ed90ff7d49e4..52662f6172d8d 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -70,7 +70,7 @@ def scale_batch_size(trainer, or datamodule. """ if trainer.fast_dev_run: - rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning) + rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning) return if not lightning_hasattr(model, batch_arg_name): diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index b6d8c8178093b..2982454d02f70 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -137,7 +137,7 @@ def lr_find( """ if trainer.fast_dev_run: - rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning) + rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index cbe4d4012227a..00c62cdf48fce 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder']) -def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): +def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): """ Test that tuner algorithms are skipped if fast dev run is enabled """ hparams = EvalModelTemplate.get_default_hparams() @@ -16,6 +16,6 @@ def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg): auto_lr_find=True if tuner_alg == 'learning rate finder' else False, fast_dev_run=True ) - expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`' + expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.' with pytest.warns(UserWarning, match=expected_message): trainer.tune(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f16ef22faa507..71e1c088ece14 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -436,9 +436,11 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_dataloaders_with_fast_dev_run(tmpdir): - """Verify num_batches for train, val & test dataloaders passed with fast_dev_run = True""" - +@pytest.mark.parametrize('fast_dev_run', [True, 1, 3, -1, 'temp']) +def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): + """ + Verify num_batches for train, val & test dataloaders passed with fast_dev_run + """ model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -447,26 +449,47 @@ def test_dataloaders_with_fast_dev_run(tmpdir): model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders - # train, multiple val and multiple test dataloaders passed with fast_dev_run = True - trainer = Trainer( + trainer_options = dict( default_root_dir=tmpdir, max_epochs=2, - fast_dev_run=True, + fast_dev_run=fast_dev_run, ) - assert trainer.max_epochs == 1 - assert trainer.num_sanity_val_steps == 0 - trainer.fit(model) - assert not trainer.disable_validation - assert trainer.num_training_batches == 1 - assert trainer.num_val_batches == [1] * len(trainer.val_dataloaders) + if fast_dev_run == 'temp': + with pytest.raises(MisconfigurationException, match='either a bool or an int'): + trainer = Trainer(**trainer_options) + elif fast_dev_run == -1: + with pytest.raises(MisconfigurationException, match='should be >= 0'): + trainer = Trainer(**trainer_options) + else: + trainer = Trainer(**trainer_options) - trainer.test(ckpt_path=None) - assert trainer.num_test_batches == [1] * len(trainer.test_dataloaders) + # fast_dev_run is set to True when it is 1 + if fast_dev_run == 1: + fast_dev_run = True - # verify sanity check batches match as expected - num_val_dataloaders = len(model.val_dataloader()) - assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders + assert trainer.fast_dev_run is fast_dev_run + + if fast_dev_run is True: + fast_dev_run = 1 + + assert trainer.limit_train_batches == fast_dev_run + assert trainer.limit_val_batches == fast_dev_run + assert trainer.limit_test_batches == fast_dev_run + assert trainer.num_sanity_val_steps == 0 + assert trainer.max_epochs == 1 + + trainer.fit(model) + assert not trainer.disable_validation + assert trainer.num_training_batches == fast_dev_run + assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders) + + trainer.test(ckpt_path=None) + assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders) + + # verify sanity check batches match as expected + num_val_dataloaders = len(model.val_dataloader()) + assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])