From 8fe5dd41e951ed7f41f9c491615d9e23d95a9f15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 9 Mar 2021 09:49:59 +0000 Subject: [PATCH] fix logger creating directory structure too early in DDP (#6380) * fix * add simple test * fix imports * add changelog * tighter test with on_fit_start hook closer to the dispatch call * move class inside test f unction * add a comment (cherry picked from commit fc6d4027334b8869f02a3bdca0a0846f1cf79928) --- CHANGELOG.md | 135 +++++++++++++++++- pytorch_lightning/trainer/trainer.py | 25 ++-- .../logging_/test_distributed_logging.py | 36 +++++ 3 files changed, 176 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39e927106a5a2..2f9d90faaef4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,138 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [UnReleased] - 2021-MM-DD + +### Added + +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + + +- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) + + +- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) + + +- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) + + +- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + +### Changed + +- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) + + +- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +### Deprecated + + +- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +### Removed + +- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) + + +- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + +- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166)) + + +- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163)) + + +- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161)) + * from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve` + * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` + + +- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) + + +- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) + + +- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) + + +- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + +### Fixed + +- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) + + +- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)) + + +- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109)) + + +- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) + + +- Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221)) + + +- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073)) + + +- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) + + +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) + + +- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296)) + + +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) + + +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) + + +- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) + + +- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) + + +- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) + + ## [1.2.3] - 2021-03-09 @@ -23,9 +155,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) -- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) - - ## [1.2.2] - 2021-03-02 ### Added diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e123c1af5a5d0..ebf8ddb1f07ea 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -407,21 +407,6 @@ def __init__( # Callback system self.on_init_end() - def setup_trainer(self, model: LightningModule): - """ - Sanity check a few things before starting actual training or testing. - - Args: - model: The model to run sanity test on. - """ - - # log hyper-parameters - if self.logger is not None: - # save exp to get started (this is where the first experiment logs are written) - self.logger.log_hyperparams(model.hparams_initial) - self.logger.log_graph(model) - self.logger.save() - def fit( self, model: LightningModule, @@ -471,8 +456,7 @@ def fit( # ---------------------------- self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.setup(self, model) - self.setup_trainer(model) + self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- # INSPECT THE CORE LOOPS @@ -539,6 +523,13 @@ def fit( def pre_dispatch(self): self.accelerator.pre_dispatch() + # log hyper-parameters + if self.logger is not None: + # save exp to get started (this is where the first experiment logs are written) + self.logger.log_hyperparams(self.lightning_module.hparams_initial) + self.logger.log_graph(self.lightning_module) + self.logger.save() + def post_dispatch(self): self.accelerator.post_dispatch() self.accelerator.teardown() diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index dffb511614cf6..4456500147d18 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -69,3 +69,39 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir): weights_summary=None, ) trainer.fit(model) + + +def test_first_logger_call_in_subprocess(tmpdir): + """ + Test that the Trainer does not call the logger too early. Only when the worker processes are initialized + do we have access to the rank and know which one is the main process. + """ + + class LoggerCallsObserver(Callback): + + def on_fit_start(self, trainer, pl_module): + # this hook is executed directly before Trainer.pre_dispatch + # logger should not write any logs until this point + assert not trainer.logger.method_calls + assert not os.listdir(trainer.logger.save_dir) + + def on_train_start(self, trainer, pl_module): + assert trainer.logger.method_call + trainer.logger.log_hyperparams.assert_called_once() + trainer.logger.log_graph.assert_called_once() + + logger = Mock() + logger.version = "0" + logger.name = "name" + logger.save_dir = tmpdir + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + logger=logger, + callbacks=[LoggerCallsObserver()] + ) + trainer.fit(model)