Skip to content

Commit

Permalink
fix collecting training_step outputs (#8613)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Aug 3, 2021
1 parent 005cd82 commit 82799ae
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from collections import OrderedDict
from contextlib import contextmanager
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple

Expand Down Expand Up @@ -146,12 +147,12 @@ def advance(self, batch, batch_idx, dataloader_idx):

result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(result.training_step_output)
self.batch_outputs[opt_idx].append(copy(result.training_step_output))
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
if result:
self.batch_outputs[0].append(result.training_step_output)
self.batch_outputs[0].append(copy(result.training_step_output))

def teardown(self) -> None:
# release memory
Expand Down
12 changes: 8 additions & 4 deletions tests/trainer/loops/test_training_loop_flow_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,21 @@ def training_step(self, batch, batch_idx):
acc = acc + batch_idx

self.training_step_called = True
out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]}
out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx}
return out

def training_epoch_end(self, outputs):
self.training_epoch_end_called = True

# verify we saw the current num of batches
assert len(outputs) == 2
assert len({id(output) for output in outputs}) == 2
assert [output["batch_idx"] for output in outputs] == [0, 1]

for b in outputs:
assert isinstance(b, dict)
assert self.count_num_graphs(b) == 0
assert {"random_things", "loss"} == set(b.keys())
assert {"random_things", "loss", "batch_idx"} == set(b.keys())

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
Expand Down Expand Up @@ -155,7 +157,7 @@ def training_step(self, batch, batch_idx):
acc = acc + batch_idx

self.training_step_called = True
self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)]}
self.out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx}
return self.out

def training_step_end(self, tr_step_output):
Expand All @@ -169,11 +171,13 @@ def training_epoch_end(self, outputs):

# verify we saw the current num of batches
assert len(outputs) == 2
assert len({id(output) for output in outputs}) == 2
assert [output["batch_idx"] for output in outputs] == [0, 1]

for b in outputs:
assert isinstance(b, dict)
assert self.count_num_graphs(b) == 0
assert {"random_things", "loss"} == set(b.keys())
assert {"random_things", "loss", "batch_idx"} == set(b.keys())

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
Expand Down
16 changes: 13 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import math
import os
Expand Down Expand Up @@ -1872,13 +1873,22 @@ def on_epoch_start(self, trainer, *_):
assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu")
assert trainer.callback_metrics["train_loss"].device == torch.device("cpu")

# before measuring the memory force release any leftover allocations, including CUDA tensors
gc.collect()
memory_1 = torch.cuda.memory_allocated(0)
assert memory_1 == initial

deepcopy(trainer)

# before measuring the memory force release any leftover allocations, including CUDA tensors
gc.collect()
memory_2 = torch.cuda.memory_allocated(0)
assert memory_1 == memory_2 == initial
assert memory_2 == initial

trainer_2 = Trainer(**trainer_kwargs)
trainer_2.fit(model)
memory_3 = torch.cuda.memory_allocated(0)

assert initial == memory_1 == memory_3
# before measuring the memory force release any leftover allocations, including CUDA tensors
gc.collect()
memory_3 = torch.cuda.memory_allocated(0)
assert memory_3 == initial

0 comments on commit 82799ae

Please sign in to comment.