Skip to content

Commit

Permalink
Add support for missing return obj from to function on custom batch o…
Browse files Browse the repository at this point in the history
…bjects (#8433)

* resolve bug

* update

* add changelog

* Update tests/utilities/test_apply_func.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
tchaton and awaelchli authored Jul 19, 2021
1 parent 7bb810f commit 257fabd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284))


- Fixed `move_data_to_device` to return the batch if the object `to` function didn't return `self` ([#8433](https://github.com/PyTorchLightning/pytorch-lightning/pull/8433))


- Fixed progress bar updates for Pod Training ([#8258](https://github.com/PyTorchLightning/pytorch-lightning/pull/8258))


Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,11 @@ def batch_to(data):
return device_data

kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
return data.to(device, **kwargs)
data_output = data.to(device, **kwargs)
if data_output is not None:
return data_output
# user wrongly implemented the ``TransferableDataType`` and forgot to return ``self``.
return data

dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to)
Expand Down
22 changes: 21 additions & 1 deletion tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device


def test_recursive_application_to_collection():
Expand Down Expand Up @@ -209,3 +209,23 @@ def fn(a, b):
assert reduced1 == reduced2 == [1, 4, 9]
reduced = apply_to_collections(None, None, int, lambda x: x * x)
assert reduced is None


@pytest.mark.parametrize('should_return', [False, True])
def test_wrongly_implemented_transferable_data_type(should_return):

class TensorObject:

def __init__(self, tensor: torch.Tensor, should_return: bool = True):
self.tensor = tensor
self.should_return = should_return

def to(self, device):
self.tensor.to(device)
# simulate a user forgets to return self
if self.should_return:
return self

tensor = torch.tensor(0.1)
obj = TensorObject(tensor, should_return)
assert obj == move_data_to_device(obj, torch.device("cpu"))

0 comments on commit 257fabd

Please sign in to comment.