Skip to content

Commit

Permalink
Add typing for trainer.logger (#11114)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Dec 17, 2021
1 parent 5932f52 commit 4415677
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))


-
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))


## [1.5.6] - 2021-12-15
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ warn_no_return = "False"
# the list can be generated with:
# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.accelerators.accelerator",
"pytorch_lightning.accelerators.gpu",
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.lr_monitor",
Expand Down Expand Up @@ -106,7 +105,6 @@ module = [
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.enums",
"pytorch_lightning.utilities.fetching",
"pytorch_lightning.utilities.imports",
"pytorch_lightning.utilities.memory",
"pytorch_lightning.utilities.meta",
"pytorch_lightning.utilities.metrics",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def on_train_batch_start(

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
assert trainer.logger is not None
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
Expand All @@ -75,6 +76,7 @@ def on_train_batch_end(

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
assert trainer.logger is not None
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def on_train_batch_start(
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)

@rank_zero_only
Expand All @@ -185,6 +186,7 @@ def on_train_batch_end(
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def optimizer_step(
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def __init__(
self.__init_profiler(profiler)

# init logger flags
self.logger: Optional[LightningLoggerBase]
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)

# init debugging flags
Expand Down

0 comments on commit 4415677

Please sign in to comment.