Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix collecting training_step outputs #8613

Merged
merged 13 commits into from
Jul 30, 2021
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an issue with `training_step` outputs not getting collected correctly for `training_epoch_end` ([#8613](https://github.com/PyTorchLightning/pytorch-lightning/pull/8613))


-
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from collections import OrderedDict
from contextlib import contextmanager
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple

Expand Down Expand Up @@ -143,12 +144,12 @@ def advance(self, batch, batch_idx, dataloader_idx):

result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(result.training_step_output)
self.batch_outputs[opt_idx].append(copy(result.training_step_output))
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
self.batch_outputs[0].append(result.training_step_output)
self.batch_outputs[0].append(copy(result.training_step_output))

def teardown(self) -> None:
# release memory
Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/loops/test_training_loop_flow_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,21 @@ def training_step(self, batch, batch_idx):
acc = acc + batch_idx

self.training_step_called = True
out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]}
out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx}
return out

def training_epoch_end(self, outputs):
self.training_epoch_end_called = True

# verify we saw the current num of batches
assert len(outputs) == 2
assert len({id(output) for output in outputs}) == 2
assert [output["batch_idx"] for output in outputs] == [0, 1]

for b in outputs:
assert isinstance(b, dict)
assert self.count_num_graphs(b) == 0
assert {"random_things", "loss"} == set(b.keys())
assert {"random_things", "loss", "batch_idx"} == set(b.keys())

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
Expand Down Expand Up @@ -155,7 +157,7 @@ def training_step(self, batch, batch_idx):
acc = acc + batch_idx

self.training_step_called = True
self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]}
self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx}
return self.out

def training_step_end(self, tr_step_output):
Expand All @@ -169,11 +171,13 @@ def training_epoch_end(self, outputs):

# verify we saw the current num of batches
assert len(outputs) == 2
assert len({id(output) for output in outputs}) == 2
assert [output["batch_idx"] for output in outputs] == [0, 1]

for b in outputs:
assert isinstance(b, dict)
assert self.count_num_graphs(b) == 0
assert {"random_things", "loss"} == set(b.keys())
assert {"random_things", "loss", "batch_idx"} == set(b.keys())

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
Expand Down