Skip to content

Commit

Permalink
Add batch method to Dataset class (#7064)
Browse files Browse the repository at this point in the history
* feat: add `batch` method to `Dataset` class

* feat: add `num_proc` arg from `map` to `batch`

* test: add test for `Dataset.batch()

* style: formatting...

* docs: move `Dataset.batch()`documentation to `process.mdx`

* docs: add `numb_proc` to docs

* Apply suggestions from code review

---------

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
  • Loading branch information
lappemic and lhoestq authored Jul 25, 2024
1 parent 5142a8c commit 9c98e06
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 0 deletions.
26 changes: 26 additions & 0 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,32 @@ The following example shows how you can use `torch.distributed.barrier` to synch
... torch.distributed.barrier()
```

## Batch

The [`~Dataset.batch`] method allows you to group samples from the dataset into batches. This is particularly useful when you want to create batches of data for training or evaluation, especially when working with deep learning models.

Here's an example of how to use the `batch()` method:

```python
>>> from datasets import load_dataset
>>> dataset = load_dataset("rotten_tomatoes", split="train")
>>> batched_dataset = dataset.batch(batch_size=4)
>>> batched_dataset[0]
{'text': ['the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .',
'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .',
'effective but too-tepid biopic',
'if you sometimes like to go to the movies to have fun , wasabi is a good place to start .'],
'label': [1, 1, 1, 1]}
```

The `batch()` method accepts the following parameters:

- `batch_size` (`int`): The number of samples in each batch.
- `drop_last_batch` (`bool`, defaults to `False`): Whether to drop the last incomplete batch if the dataset size is not divisible by the batch size.
- `num_proc` (`int`, optional, defaults to `None`): The number of processes to use for multiprocessing. If None, no multiprocessing is used. This can significantly speed up batching for large datasets.

Note that `Dataset.batch()` returns a new [`Dataset`] where each item is a batch of multiple samples from the original dataset. If you want to process data in batches, you should use a batched [`~Dataset.map`] directly, which applies a function to batches but the output dataset is unbatched.

## Concatenate

Separate datasets can be concatenated if they share the same column types. Concatenate datasets with [`concatenate_datasets`]:
Expand Down
51 changes: 51 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3611,6 +3611,57 @@ def init_buffer_and_writer():
else:
yield rank, True, shard

@transmit_format
@fingerprint_transform(inplace=False)
def batch(
self,
batch_size: int,
drop_last_batch: bool = False,
num_proc: Optional[int] = None,
new_fingerprint: Optional[str] = None,
) -> "Dataset":
"""
Group samples from the dataset into batches.
Args:
batch_size (`int`):
The number of samples in each batch.
drop_last_batch (`bool`, defaults to `False`):
Whether to drop the last incomplete batch.
num_proc (`int`, *optional*, defaults to `None`):
Max number of processes when generating cache. Already cached shards are loaded sequentially.
new_fingerprint (`str`, *optional*, defaults to `None`):
The new fingerprint of the dataset after transform.
If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments.
Returns:
[`Dataset`]: A new Dataset where each item is a batch of multiple samples from the original dataset.
Example:
```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> batched_ds = ds.batch(batch_size=4)
>>> batched_ds[0]
{'text': ['compassionately explores the seemingly irreconcilable situation...', ...], # 4 items
'label': [1, 1, 1, 1]}
```
"""

def batch_fn(example):
return {k: [v] for k, v in example.items()}

return self.map(
batch_fn,
batched=True,
batch_size=batch_size,
drop_last_batch=drop_last_batch,
num_proc=num_proc,
new_fingerprint=new_fingerprint,
desc="Batching examples",
)

@transmit_format
@fingerprint_transform(
inplace=False, ignore_kwargs=["load_from_cache_file", "cache_file_name", "desc"], version="2.0.1"
Expand Down
51 changes: 51 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4848,3 +4848,54 @@ def test_categorical_dataset(tmpdir):

# Categorical types get transparently converted to string
assert entry["animals"] == "Flamingo"


def test_dataset_batch():
# Create a simple Dataset
data = {"id": list(range(10)), "text": [f"Text {i}" for i in range(10)]}
ds = Dataset.from_dict(data)

# Test with batch_size=3, drop_last_batch=False
batched_ds = ds.batch(batch_size=3, drop_last_batch=False)
batches = list(batched_ds)

assert len(batches) == 4 # 3 full batches and 1 partial batch
for i, batch in enumerate(batches[:3]): # Check full batches
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
assert len(batches[3]["text"]) == 1
assert batches[3]["id"] == [9]
assert batches[3]["text"] == ["Text 9"]

# Test with batch_size=3, drop_last_batch=True
batched_ds = ds.batch(batch_size=3, drop_last_batch=True)
batches = list(batched_ds)

assert len(batches) == 3 # Only full batches
for i, batch in enumerate(batches):
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
batches = list(batched_ds)

assert len(batches) == 3 # 2 full batches and 1 partial batch
for i, batch in enumerate(batches[:2]): # Check full batches
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
assert len(batches[2]["text"]) == 2
assert batches[2]["id"] == [8, 9]
assert batches[2]["text"] == ["Text 8", "Text 9"]

0 comments on commit 9c98e06

Please sign in to comment.