Skip to content

Commit

Permalink
Add typing for trainer.logger (#11114)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and lexierule committed Dec 21, 2021
1 parent 327cba5 commit 2264082
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116))
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))
- Fixed an `AttributeError` occuring when using a `CombinedLoader` (multiple dataloaders) for prediction ([#11111](https://github.com/PyTorchLightning/pytorch-lightning/pull/11111))
- Fixed bug where `Trainer(track_grad_norm=..., logger=False)' would fail ([#11114](https://github.com/PyTorchLightning/pytorch-lightning/pull/11114))

### Changed

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def on_train_batch_start(

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
assert trainer.logger is not None
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
Expand All @@ -75,6 +76,7 @@ def on_train_batch_end(

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
assert trainer.logger is not None
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def on_train_batch_start(
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)

@rank_zero_only
Expand All @@ -185,6 +186,7 @@ def on_train_batch_end(
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def optimizer_step(
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def __init__(
self.__init_profiler(profiler)

# init logger flags
self.logger: Optional[LightningLoggerBase]
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps, move_metrics_to_cpu)

# init debugging flags
Expand Down

0 comments on commit 2264082

Please sign in to comment.