From 4415677994859ba4b3302468d7118b0f8551732f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 17 Dec 2021 13:34:18 +0100 Subject: [PATCH] Add typing for `trainer.logger` (#11114) --- CHANGELOG.md | 2 +- pyproject.toml | 2 -- pytorch_lightning/callbacks/device_stats_monitor.py | 2 ++ pytorch_lightning/callbacks/gpu_stats_monitor.py | 2 ++ pytorch_lightning/plugins/precision/precision_plugin.py | 3 ++- pytorch_lightning/trainer/trainer.py | 1 + 6 files changed, 8 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1c0551703af2..fe6deeaca000e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -285,7 +285,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078)) -- +- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114)) ## [1.5.6] - 2021-12-15 diff --git a/pyproject.toml b/pyproject.toml index 5adc3b444e5c1..346370dd506bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ warn_no_return = "False" # the list can be generated with: # mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",' module = [ - "pytorch_lightning.accelerators.accelerator", "pytorch_lightning.accelerators.gpu", "pytorch_lightning.callbacks.finetuning", "pytorch_lightning.callbacks.lr_monitor", @@ -106,7 +105,6 @@ module = [ "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.enums", "pytorch_lightning.utilities.fetching", - "pytorch_lightning.utilities.imports", "pytorch_lightning.utilities.memory", "pytorch_lightning.utilities.meta", "pytorch_lightning.utilities.metrics", diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index b743ed3e1bbeb..016d2015a81e1 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -59,6 +59,7 @@ def on_train_batch_start( device_stats = trainer.accelerator.get_device_stats(pl_module.device) prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start") + assert trainer.logger is not None trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) def on_train_batch_end( @@ -75,6 +76,7 @@ def on_train_batch_end( device_stats = trainer.accelerator.get_device_stats(pl_module.device) prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end") + assert trainer.logger is not None trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 088c8e650074c..98a83000170e1 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -161,6 +161,7 @@ def on_train_batch_start( # First log at beginning of second step logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 + assert trainer.logger is not None trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only @@ -185,6 +186,7 @@ def on_train_batch_end( if self._log_stats.intra_step_time and self._snap_intra_step_time: logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 + assert trainer.logger is not None trainer.logger.log_metrics(logs, step=trainer.global_step) @staticmethod diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 0472ab42c6918..1f5b076e491a1 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -155,7 +155,8 @@ def optimizer_step( def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: return - grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator) + kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {} + grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs) if grad_norm_dict: prev_fx = trainer.lightning_module._current_fx_name trainer.lightning_module._current_fx_name = "on_before_optimizer_step" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index eb7a144ad61b5..fe19accb7dbc0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -570,6 +570,7 @@ def __init__( self.__init_profiler(profiler) # init logger flags + self.logger: Optional[LightningLoggerBase] self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu) # init debugging flags