Skip to content

Commit

Permalink
Enable logging hparams only if there are any (#11105)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored Dec 17, 2021
1 parent dbb7f56 commit 860959f
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug to disable logging hyperparameters in logger if there are no hparams ([#11105](https://github.com/PyTorchLightning/pytorch-lightning/issues/11105))


-


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/mixins/hparams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class HyperparametersMixin:

def __init__(self) -> None:
super().__init__()
self._log_hyperparams = True
self._log_hyperparams = False

def save_hyperparameters(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_gpu_stats_monitor_no_queries(tmpdir):
with mock.patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics") as log_metrics_mock:
trainer.fit(model)

assert log_metrics_mock.mock_calls[2:] == [
assert log_metrics_mock.mock_calls[1:] == [
mock.call({"batch_time/intra_step (ms)": mock.ANY}, step=0),
mock.call({"batch_time/inter_step (ms)": mock.ANY}, step=1),
mock.call({"batch_time/intra_step (ms)": mock.ANY}, step=1),
Expand Down
2 changes: 0 additions & 2 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,8 @@ def log_metrics(self, metrics, step):
log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
if logger_class == TensorBoardLogger:
expected = [
(0, ["hp_metric"]),
(0, ["epoch", "train_some_val"]),
(0, ["early_stop_on", "epoch", "val_loss"]),
(0, ["hp_metric"]),
(1, ["epoch", "test_loss"]),
]
assert log_metric_names == expected
Expand Down
5 changes: 2 additions & 3 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def training_step(self, batch, batch_idx):
trainer = Trainer(max_steps=2, log_every_n_steps=1, logger=logger, default_root_dir=tmpdir)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert logger.hparams_logged == model.hparams
assert logger.metrics_logged != {}
assert logger.after_save_checkpoint_called
assert logger.finalized_status == "success"
Expand All @@ -133,11 +132,11 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

assert logger1.hparams_logged == model.hparams
assert logger1.hparams_logged is None
assert logger1.metrics_logged != {}
assert logger1.finalized_status == "success"

assert logger2.hparams_logged == model.hparams
assert logger2.hparams_logged is None
assert logger2.metrics_logged != {}
assert logger2.finalized_status == "success"

Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,10 @@ def test_adding_datamodule_hparams(tmpdir, model, data):
# Merged hparams were logged
merged_hparams = copy.deepcopy(org_model_hparams)
merged_hparams.update(org_data_hparams)
mock_logger.log_hyperparams.assert_called_with(merged_hparams)
if merged_hparams:
mock_logger.log_hyperparams.assert_called_with(merged_hparams)
else:
mock_logger.log_hyperparams.assert_not_called()


def test_no_datamodule_for_hparams(tmpdir):
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/logging_/test_distributed_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def on_fit_start(self, trainer, pl_module):

def on_train_start(self, trainer, pl_module):
assert trainer.logger.method_call
trainer.logger.log_hyperparams.assert_called_once()
trainer.logger.log_graph.assert_called_once()

logger = Mock()
Expand Down
4 changes: 4 additions & 0 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ class ExtendedModel(BoringModel):

val_losses = []

def __init__(self, some_val=7):
super().__init__()
self.save_hyperparameters()

def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
Expand Down

0 comments on commit 860959f

Please sign in to comment.