Skip to content

Commit

Permalink
Fix reference issues during epoch end result collection (#8621)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli authored Jul 30, 2021
1 parent 93784da commit 5789e9f
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed references for `ResultCollection.extra` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


-
- Fixed reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621))


- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def training_step(self, batch, batch_idx):
out = self(x)
# softmax uses only a portion of the batch in the denomintaor
# softmax uses only a portion of the batch in the denominator
loss = self.softmax(out)
loss = nce_loss(loss)
return loss
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,16 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op

loss = None
hiddens = None
results.extra = {}

# handle dict return
if isinstance(training_step_output, dict):
loss = training_step_output.pop("loss", None)
hiddens = training_step_output.pop("hiddens", None)
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
loss = training_step_output.get("loss")
hiddens = training_step_output.get("hiddens")
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
results.extra = training_step_output
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}

# handle scalar return
elif isinstance(training_step_output, Tensor):
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterator, List, Optional, Union

import torch
Expand Down Expand Up @@ -276,11 +275,7 @@ def _track_epoch_end_reduce_metrics(
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if (
isinstance(opt_outputs, list)
and len(opt_outputs) == 1
and not isinstance(opt_outputs[0], ResultCollection)
):
if isinstance(opt_outputs, list) and len(opt_outputs) == 1:
opt_outputs = opt_outputs[0]

epoch_output[opt_idx].append(opt_outputs)
Expand Down Expand Up @@ -320,9 +315,10 @@ def _prepare_outputs(
batch_outputs = [batch_outputs]

for tbptt_output in batch_outputs:
out = tbptt_output.extra
out = {}
if tbptt_output.minimize is not None:
out["loss"] = tbptt_output.minimize.detach()
out.update(tbptt_output.extra)
processed_tbptt_outputs.append(out)

# if there was only one tbptt step then we can collapse that dimension
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,30 @@ def training_step_end(self, outputs):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

trainer.fit(model)


def test_prepare_outputs(tmpdir):
"""
Test that the `extra` field of the saved `ResultCollection` objects for
`training_epoch_end` doesn't get accidentally modified by reference.
"""

class TestModel(BoringModel):
on_train_batch_end_called = 0

def on_train_batch_end(self, outputs, *args, **kwargs):
epoch_outputs = self.trainer.fit_loop.epoch_loop._epoch_output
epoch_outputs = epoch_outputs[0] # 1 optimizer
assert len(epoch_outputs) == self.on_train_batch_end_called
# `extra` should be empty for all `ResultCollection` objects
assert all(not out.extra for out in epoch_outputs)
self.on_train_batch_end_called += 1

def training_epoch_end(self, outputs) -> None:
# override so epoch outputs get stored
pass

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
trainer.fit(model)
assert model.on_train_batch_end_called == 2

0 comments on commit 5789e9f

Please sign in to comment.