Skip to content

Commit

Permalink
Add default val to transfer_batch_to_device hook
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Feb 18, 2021
1 parent 1f324a4 commit 5c6a68f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
5 changes: 4 additions & 1 deletion docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,16 @@ Override to define how you want to move an arbitrary batch to a device.
.. testcode::

class MNISTDataModule(LightningDataModule):
def transfer_batch_to_device(self, batch, device):
def transfer_batch_to_device(self, batch, device, dataloader_idx):
x = batch['x']
x = CustomDataWrapper(x)
batch['x'] = x.to(device)
return batch


.. warning::
Currently dataloader_idx always returns 0 and will be updated to support the true idx in the future.

.. note:: This hook only runs on single GPU training and DDP (no data-parallel).


Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,13 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
will have an argument ``dataloader_idx`` which matches the order here.
"""

def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx=0) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor` or anything that implements `.to(...)`
Expand All @@ -594,19 +596,20 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
dataloader_idx: DataLoader idx for batch (Default: 0)
Returns:
A reference to the data on the new device.
Example::
def transfer_batch_to_device(self, batch, device):
def transfer_batch_to_device(self, batch, device, dataloader_idx):
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
batch = super().transfer_batch_to_device(data, device, dataloader_idx)
return batch
See Also:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def logger(self):

def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0):
batch = self.on_before_batch_transfer(batch, dataloader_idx)
batch = self.transfer_batch_to_device(batch, device)
batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
batch = self.on_after_batch_transfer(batch, dataloader_idx)
return batch

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
batch.targets *= 2
return batch

def transfer_batch_to_device(self, batch, device):
def transfer_batch_to_device(self, batch, device, dataloader_idx):
self.transfer_batch_to_device_hook_rank = self.rank
self.rank += 1
batch.samples = batch.samples.to(device)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
batch.targets *= 2
return batch

def transfer_batch_to_device(self, batch, device):
def transfer_batch_to_device(self, batch, device, dataloader_idx):
self.transfer_batch_to_device_hook_rank = self.rank
self.rank += 1
batch.samples = batch.samples.to(device)
Expand Down

0 comments on commit 5c6a68f

Please sign in to comment.