Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ref: Pull duplicate data interface definition up into DataHooks class #3344

Merged
merged 3 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/lightning-module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1239,5 +1239,5 @@ teardown
transfer_batch_to_device
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.hooks.ModelHooks.transfer_batch_to_device
.. autofunction:: pytorch_lightning.core.hooks.DataHooks.transfer_batch_to_device
:noindex:
137 changes: 9 additions & 128 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import torch
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn
from pytorch_lightning.core.hooks import DataHooks
from pytorch_lightning.utilities import parsing, rank_zero_only


class _DataModuleWrapper(type):
Expand Down Expand Up @@ -87,7 +88,7 @@ def wrapped_fn(*args, **kwargs):
return wrapped_fn


class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover
class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Expand Down Expand Up @@ -215,147 +216,27 @@ def has_setup_test(self):

@abstractmethod
def prepare_data(self, *args, **kwargs):
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.

.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.

Pseudocode::

dm.prepare_data()
dm.setup()

Example::

def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""
pass

@abstractmethod
def setup(self, stage: Optional[str] = None):
"""
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
This hook is called on every process when using DDP.

Example::

def setup(self, stage):
data = load_data(...)
self.train_ds, self.val_ds, self.test_ds = split_data(data)
"""
pass

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.

Example::

def train_dataloader(self):
dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset)
return loader

"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
pass

@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders

Example::

def val_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""
pass

@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders

Example::

def test_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""
pass

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- :class:`torchtext.data.batch.Batch`

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).

Example::

def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch

Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.

Returns:
A reference to the data on the new device.

Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).

Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.

See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
pass

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
Expand Down
Loading