Skip to content

Commit

Permalink
Ref: Pull duplicate data interface definition up into DataHooks class (
Browse files Browse the repository at this point in the history
…#3344)

* pull data hooks up into a common interface

* fix multiple inheritance ordering

* docs reference datahooks
  • Loading branch information
awaelchli committed Sep 4, 2020
1 parent 24809b0 commit 7bd2f94
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 355 deletions.
2 changes: 1 addition & 1 deletion docs/source/lightning-module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1239,5 +1239,5 @@ teardown
transfer_batch_to_device
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.transfer_batch_to_device
.. autofunction:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device
:noindex:
137 changes: 9 additions & 128 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import torch
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn
from pytorch_lightning.core.hooks import DataHooks
from pytorch_lightning.utilities import parsing, rank_zero_only


class _DataModuleWrapper(type):
Expand Down Expand Up @@ -87,7 +88,7 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn


class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover
class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Expand Down Expand Up @@ -215,147 +216,27 @@ def has_setup_test(self):

@abstractmethod
def prepare_data(self, *args, **kwargs):
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.
.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.
Pseudocode::
dm.prepare_data()
dm.setup()
Example::
def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""
pass

@abstractmethod
def setup(self, stage: Optional[str] = None):
"""
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
This hook is called on every process when using DDP.
Example::
def setup(self, stage):
data = load_data(...)
self.train_ds, self.val_ds, self.test_ds = split_data(data)
"""
pass

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Example::
def train_dataloader(self):
dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset)
return loader
"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
pass

@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders
Example::
def val_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""
pass

@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders
Example::
def test_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""
pass

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- :class:`torchtext.data.batch.Batch`
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
pass

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
Expand Down
Loading

0 comments on commit 7bd2f94

Please sign in to comment.