Skip to content

Commit

Permalink
Fix logging on_train_batch_end in a callback with multiple optimizers (
Browse files Browse the repository at this point in the history
…#5521)

* Start with the failing test

* Then fix the failing test

* Update CHANGELOG
  • Loading branch information
carmocca committed Jan 18, 2021
1 parent a56f745 commit 18d2ae8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579))


- Fixed logging on_train_batch_end in a callback with multiple optimizers ([#5521](https://github.com/PyTorchLightning/pytorch-lightning/pull/5521))


- Fixed `reinit_scheduler_properties` with correct optimizer ([#5519](https://github.com/PyTorchLightning/pytorch-lightning/pull/5519))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,7 @@ def auto_reduce_results_on_epoch_end(self) -> None:
epoch_metrics = self._internals[dl_idx]

if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP:

num_opt_idx = len(self._internals[dl_idx]) - 1

# Make sure we didn't create key
assert num_opt_idx >= 0

for opt_idx in range(num_opt_idx + 1):
for opt_idx in list(epoch_metrics):
# TODO: Figure out to reduce memory
# TODO: How to start training in middle of epoch
opt_outputs = epoch_metrics[opt_idx]
Expand Down
45 changes: 26 additions & 19 deletions tests/trainer/optimization/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@

def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
"""
This tests ensures reduction works in un-balanced logging settings
This tests ensures reduction works in unbalanced logging settings,
even when a Callback also logs.
"""
class TestModel(BoringModel):

loss_1 = []
loss_2 = []
actual = {0: [], 1: []}

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
if optimizer_idx == 0 and self.trainer.global_step > 10:
self.log("loss_1", loss, on_epoch=True, prog_bar=True)
self.loss_1.append(loss.detach().clone())
elif optimizer_idx == 1:
self.log("loss_2", loss, on_epoch=True, prog_bar=True)
self.loss_2.append(loss.detach().clone())
return {"loss": loss}
out = super().training_step(batch, batch_idx)
loss = out["loss"]
self.log(f"loss_{optimizer_idx}", loss, on_epoch=True)
self.actual[optimizer_idx].append(loss)
return out

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
Expand All @@ -48,16 +43,28 @@ def configure_optimizers(self):
model = TestModel()
model.training_epoch_end = None

class TestCallback(pl.Callback):
def on_train_batch_end(self, trainer, pl_module, output, batch, batch_idx, dl_idx):
# when this is called, the EpochResultStore state has not been reset yet because we are still
# "INSIDE_BATCH_TRAIN_LOOP" and the LoggerConnector runs its `on_train_batch_end` after the
# Callback (see `TrainLoop.on_train_batch_end`). For this reason, opt_idx here is the index
# of the last optimizer updated (the second, index 1). This produced a KeyError as reported in #5459
pl_module.log("test_train_batch_end", trainer.logger_connector.cached_results._opt_idx)

# Initialize a trainer
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=5,
limit_val_batches=5,
callbacks=[TestCallback()],
weights_summary=None,
)

trainer.fit(model)

assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
# test loss are properly reduced
assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6
for k, v in model.actual.items():
assert torch.equal(trainer.callback_metrics[f"loss_{k}_step"], v[-1])
# test loss is properly reduced
torch.testing.assert_allclose(trainer.callback_metrics[f"loss_{k}_epoch"], torch.tensor(v).mean())

assert trainer.callback_metrics["test_train_batch_end"] == len(model.optimizers()) - 1

0 comments on commit 18d2ae8

Please sign in to comment.