Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up batched PyTorch DataLoader #5512

Merged
merged 7 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ Reloading the dataset inside a worker doesn't fill up your RAM, since it simply
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)
```

#### Use a BatchSampler
#### Use a BatchSampler (torch<=1.12.1)

By default, the PyTorch `DataLoader` load batches of data from a dataset one by one like this:
For old versions of PyTorch, using a `BatchSampler` can speed up data loading.
Indeed if you are using `torch<=1.12.1`, the PyTorch `DataLoader` load batches of data from a dataset one by one like this:

```py
batch = [dataset[idx] for idx in range(start, end)]
Expand All @@ -198,6 +199,8 @@ For the PyTorch `DataLoader` to query batches using a list, you can use a `Batch
Moreover, this is particularly useful if you used [`set_transform`] to apply a transform on-the-fly when examples are accessed.
You must use a `BatchSampler` if you want the transform to be given full batches instead of receiving `batch_size` times one single element.

Recent versions of PyTorch use a list of indices, so a `BatchSampler` is not needed to get the best speed even if you used [`set_transform`].

### Stream data

Stream a dataset by loading it as an [`IterableDataset`]. This allows you to progressively iterate over a remote dataset without downloading it on disk and or over local data files.
Expand Down
10 changes: 7 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,9 +2594,13 @@ def __getitem__(self, key: str) -> List: # noqa: F811

def __getitem__(self, key): # noqa: F811
"""Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
return self._getitem(
key,
)
return self._getitem(key)

def __getitems__(self, keys: List) -> List:
"""Can be used to get a batch using a list of integers indices."""
batch = self.__getitem__(keys)
n_examples = len(batch[next(iter(batch))])
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

def cleanup_cache_files(self) -> int:
"""Clean up all cache files in the dataset cache directory, excepted the currently used cache file if there is
Expand Down
18 changes: 18 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pyarrow as pa
import pytest
from absl.testing import parameterized
from packaging import version

import datasets.arrow_dataset
from datasets import concatenate_datasets, interleave_datasets, load_from_disk
Expand Down Expand Up @@ -4203,6 +4204,23 @@ def test_dataset_to_iterable_dataset(dataset):
dataset.to_iterable_dataset(num_shards=len(dataset) + 1)


@pytest.mark.parametrize("batch_size", [1, 4])
@require_torch
def test_dataset_with_torch_dataloader(dataset, batch_size):
from torch.utils.data import DataLoader

from datasets import config

dataloader = DataLoader(dataset, batch_size=batch_size)
with patch.object(dataset, "_getitem", wraps=dataset._getitem) as mock_getitem:
out = list(dataloader)
getitem_call_count = mock_getitem.call_count
assert len(out) == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)
# calling dataset[list_of_indices] is much more efficient than [dataset[idx] for idx in list of indices]
if config.TORCH_VERSION >= version.parse("1.13.0"):
assert getitem_call_count == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)


@pytest.mark.parametrize("return_lazy_dict", [True, False, "mix"])
def test_map_cases(return_lazy_dict):
def f(x):
Expand Down