Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch method to Dataset class #7064

Merged
merged 8 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,37 @@ The following example shows how you can use `torch.distributed.barrier` to synch
... torch.distributed.barrier()
```

## Batch

The [`~Dataset.batch`] method allows you to group samples from the dataset into batches. This is particularly useful when you want to create batches of data for training or evaluation, especially when working with deep learning models.

Here's an example of how to use the `batch()` method:

```python
>>> from datasets import load_dataset
>>> dataset = load_dataset("rotten_tomatoes", split="train")
>>> batched_dataset = dataset.batch(batch_size=4)
>>> next(iter(batched_dataset))
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
{'text': ['the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .',
'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .',
'effective but too-tepid biopic',
'if you sometimes like to go to the movies to have fun , wasabi is a good place to start .'],
'label': [1, 1, 1, 1]}
```

The `batch()` method accepts the following parameters:

- `batch_size` (`int`): The number of samples in each batch.
- `drop_last_batch` (`bool`, defaults to `False`): Whether to drop the last incomplete batch if the dataset size is not divisible by the batch size.
- `num_proc` (`int`, optional, defaults to `None`): The number of processes to use for multiprocessing. If None, no multiprocessing is used. This can significantly speed up batching for large datasets.

Note that `Dataset.batch()` returns a new [`Dataset`] where each item is a batch of multiple samples from the original dataset. This is different from the batched processing in [`~Dataset.map`], which applies a function to batches but returns a dataset with the same number of items as the original.
lhoestq marked this conversation as resolved.
Show resolved Hide resolved

<Tip>
The batch() method is also available for [IterableDataset] objects. While Dataset.batch() returns a new [Dataset] with batched examples, IterableDataset.batch() returns an [IterableDataset] that yields batches on-the-fly. This makes IterableDataset.batch() more memory-efficient for very large datasets. For more information on batching with [IterableDataset], refer to the [Batch section in Stream](https://huggingface.co/docs/datasets/main/en/stream#batch) documentation.
</Tip>

lhoestq marked this conversation as resolved.
Show resolved Hide resolved

## Concatenate

Separate datasets can be concatenated if they share the same column types. Concatenate datasets with [`concatenate_datasets`]:
Expand Down
51 changes: 51 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3611,6 +3611,57 @@ def init_buffer_and_writer():
else:
yield rank, True, shard

@transmit_format
@fingerprint_transform(inplace=False)
def batch(
self,
batch_size: int,
drop_last_batch: bool = False,
num_proc: Optional[int] = None,
new_fingerprint: Optional[str] = None,
) -> "Dataset":
"""
Group samples from the dataset into batches.

Args:
batch_size (`int`):
The number of samples in each batch.
drop_last_batch (`bool`, defaults to `False`):
Whether to drop the last incomplete batch.
num_proc (`int`, *optional*, defaults to `None`):
Max number of processes when generating cache. Already cached shards are loaded sequentially.
new_fingerprint (`str`, *optional*, defaults to `None`):
The new fingerprint of the dataset after transform.
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.

Returns:
[`Dataset`]: A new Dataset where each item is a batch of multiple samples from the original dataset.

Example:

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> batched_ds = ds.batch(batch_size=4)
>>> batched_ds[0]
{'text': ['compassionately explores the seemingly irreconcilable situation...', ...], # 4 items
'label': [1, 1, 1, 1]}
```
"""

def batch_fn(example):
return {k: [v] for k, v in example.items()}

return self.map(
batch_fn,
batched=True,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_proc=num_proc,
new_fingerprint=new_fingerprint,
desc="Batching examples",
)

@transmit_format
@fingerprint_transform(
inplace=False, ignore_kwargs=["load_from_cache_file", "cache_file_name", "desc"], version="2.0.1"
Expand Down
51 changes: 51 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4848,3 +4848,54 @@ def test_categorical_dataset(tmpdir):

# Categorical types get transparently converted to string
assert entry["animals"] == "Flamingo"


def test_dataset_batch():
# Create a simple Dataset
data = {"id": list(range(10)), "text": [f"Text {i}" for i in range(10)]}
ds = Dataset.from_dict(data)

# Test with batch_size=3, drop_last_batch=False
batched_ds = ds.batch(batch_size=3, drop_last_batch=False)
batches = list(batched_ds)

assert len(batches) == 4 # 3 full batches and 1 partial batch
for i, batch in enumerate(batches[:3]): # Check full batches
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
assert len(batches[3]["text"]) == 1
assert batches[3]["id"] == [9]
assert batches[3]["text"] == ["Text 9"]

# Test with batch_size=3, drop_last_batch=True
batched_ds = ds.batch(batch_size=3, drop_last_batch=True)
batches = list(batched_ds)

assert len(batches) == 3 # Only full batches
for i, batch in enumerate(batches):
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
batches = list(batched_ds)

assert len(batches) == 3 # 2 full batches and 1 partial batch
for i, batch in enumerate(batches[:2]): # Check full batches
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
assert len(batches[2]["text"]) == 2
assert batches[2]["id"] == [8, 9]
assert batches[2]["text"] == ["Text 8", "Text 9"]
Loading