Skip to content

Commit

Permalink
Fix for incorrect usage of detach(), cpu(), to() (#6216)
Browse files Browse the repository at this point in the history
* Fix for incorrect detach/cpu calls (#6214)

* Fix incorrect use of detach(), to(), and cpu(), #6214

* Fix incorrect use of detach() and cpu(), #6214

* update pr

* add typing

* chlog

* more...

* revert on module

* update on comments

* revert changes on model

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored and lexierule committed Mar 5, 2021
1 parent 09b287a commit ad61624
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))


- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,20 +416,22 @@ def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_i

return result

def detach(self):
def detach(self) -> 'Result':
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.detach())
return self

def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> 'Result':
"""Move all self attributes to the given device."""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.to(*args, **kwargs))
return self

def cpu(self):
def cpu(self) -> 'Result':
"""Move all self attributes to CPU."""
self.to(torch.device("cpu"))
return self.to(torch.device("cpu"))

def __repr__(self):
self_copy = self.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def cache_result(self) -> None:
# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)

hook_result.detach()
hook_result = hook_result.detach()
if self.trainer.move_metrics_to_cpu:
hook_result.cpu()
hook_result = hook_result.cpu()
elif self.trainer._distrib_type == DistributedType.DP:
hook_result.to(torch.device("cuda", self.trainer.root_gpu))
hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu))

self._internals[fx_name].append(hook_result, info)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,9 @@ def run_evaluation(self, max_batches=None, on_epoch=False):
def track_output_for_epoch_end(self, outputs, output):
if output is not None:
if isinstance(output, Result):
output.detach()
output = output.detach()
if self.move_metrics_to_cpu:
output.cpu()
output = output.cpu()
elif isinstance(output, dict):
output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu)
elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
is_result_obj = isinstance(training_step_output, Result)

if is_result_obj:
training_step_output.detach()
training_step_output = training_step_output.detach()
else:
training_step_output.batch_loss = training_step_output.batch_loss.detach()

Expand Down Expand Up @@ -397,9 +397,9 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):

# track metrics without grads for epoch reduction
training_step_output_for_epoch_end = copy(result)
training_step_output_for_epoch_end.detach()
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
if self.trainer.move_metrics_to_cpu:
training_step_output_for_epoch_end.cpu()
training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu()

# what flows back into the system
training_step_output = result
Expand Down
6 changes: 3 additions & 3 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def training_step(self, batch, batch_idx):
output.update({"python scalar": 12.3})
return output

model = TestModel()
model.to(device)
model.running_stage = RunningStage.TRAINING
model = TestModel().to(device)
model.trainer = MagicMock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
batch_idx = 0

Expand Down

0 comments on commit ad61624

Please sign in to comment.