diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 4a4503621e18e..57dc337ddf56a 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -290,6 +290,22 @@ Use this method to generate the test dataloader. This is also a good place to pl ]) return DataLoader(self.test_dataset, transform=transforms, batch_size=64) +transfer_batch_to_device +^^^^^^^^^^^^^^^^^^^^^^^^ +Override to define how you want to move an arbitrary batch to a device + +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def transfer_batch_to_device(self, batch, device): + x = batch['x'] + x = CustomDataWrapper(x) + batch['x'].to(device) + return batch + ------------------ Using a DataModule diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 657b35e9fe011..cb166663e7c3a 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -18,6 +18,7 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Optional, Tuple, Union +import torch from torch.utils.data import DataLoader from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn @@ -306,6 +307,56 @@ def test_dataloader(self): return loader """ + @abstractmethod + def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + """ + Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors + wrapped in a custom data structure. + + The data types listed below (and any arbitrary nesting of them) are supported out of the box: + + - :class:`torch.Tensor` or anything that implements `.to(...)` + - :class:`list` + - :class:`dict` + - :class:`tuple` + - :class:`torchtext.data.batch.Batch` + + For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + + Example:: + + def transfer_batch_to_device(self, batch, device) + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + else: + batch = super().transfer_batch_to_device(data, device) + return batch + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + + Returns: + A reference to the data on the new device. + + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument (unless you know what you are doing). + + Note: + This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support + for your custom batch objects, you need to define your custom + :class:`~torch.nn.parallel.DistributedDataParallel` or + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` + - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` + """ + @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: r"""Extends existing argparse by default `LightningDataModule` attributes. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 98d93ad6357fa..97da8757fc266 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1112,11 +1112,6 @@ def __attach_datamodule(self, model, datamodule, stage): # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: - # If datamodule.setup('test') has not been called yet, call it - # if stage == 'test': - # if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: - # datamodule.setup('test') - # Override loader hooks if self.is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader @@ -1125,6 +1120,10 @@ def __attach_datamodule(self, model, datamodule, stage): if self.is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader + # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule + if self.is_overridden('transfer_batch_to_device', datamodule): + model.transfer_batch_to_device = datamodule.transfer_batch_to_device + self.datamodule = datamodule def run_pretrain_routine(self, model: LightningModule): diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3de920cf8b35d..1640c9bad2bbf 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -1,10 +1,11 @@ import pickle from argparse import ArgumentParser +from unittest.mock import MagicMock import pytest import torch -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import LightningDataModule, Trainer, seed_everything from tests.base import EvalModelTemplate from tests.base.datamodules import TrialMNISTDataModule from tests.base.develop_utils import reset_seed @@ -317,3 +318,40 @@ def test_full_loop_ddp_spawn(tmpdir): result = trainer.test(datamodule=dm) result = result[0] assert result['test_acc'] > 0.8 + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine") +def test_dm_transfer_batch_to_device(tmpdir): + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + self.targets = data[1] + + class CurrentTestDM(LightningDataModule): + + hook_called = False + + def transfer_batch_to_device(self, data, device): + self.hook_called = True + if isinstance(data, CustomBatch): + data.samples = data.samples.to(device) + data.targets = data.targets.to(device) + else: + data = super().transfer_batch_to_device(data, device) + return data + + model = EvalModelTemplate() + dm = CurrentTestDM() + batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) + + trainer = Trainer() + # running .fit() would require us to implement custom data loaders, we mock the model reference instead + trainer.get_model = MagicMock(return_value=model) + if trainer.is_overridden('transfer_batch_to_device', dm): + model.transfer_batch_to_device = dm.transfer_batch_to_device + + batch_gpu = trainer.transfer_batch_to_gpu(batch, 0) + expected = torch.device('cuda', 0) + assert dm.hook_called + assert batch_gpu.samples.device == batch_gpu.targets.device == expected