Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for unstructured text corpus datasets for CPT #868

Merged
merged 16 commits into from
May 21, 2024
4 changes: 4 additions & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ torchtune supports several widely used datasets to help quickly bootstrap your f
grammar_dataset
samsum_dataset
slimorca_dataset
cnn_dailymail_articles_dataset
wikitext_dataset

Generic dataset builders
------------------------
Expand All @@ -34,6 +36,7 @@ These are especially useful for specifying from a YAML config.

instruct_dataset
chat_dataset
text_completion_dataset

Generic dataset classes
-----------------------
Expand All @@ -46,4 +49,5 @@ Class representations for the above dataset builders.

InstructDataset
ChatDataset
TextCompletionDataset
ConcatDataset
12 changes: 7 additions & 5 deletions recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ tokenizer:

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
_component_: torchtune.datasets.wikitext_dataset
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
max_seq_len: 8192
seed: null
shuffle: True

Expand All @@ -46,8 +46,8 @@ checkpointer:
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3
batch_size: 4
epochs: 1

optimizer:
_component_: torch.optim.AdamW
Expand All @@ -72,8 +72,10 @@ dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
_component_: torchtune.utils.metric_logging.WandBLogger
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
log_dir: ${output_dir}
project: torchtune
name: wikitext
output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: False
51 changes: 51 additions & 0 deletions tests/torchtune/datasets/test_cnn_dailymail_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from unittest.mock import patch

import pytest

from tests.test_utils import get_assets_path

from torchtune.datasets import cnn_dailymail_articles_dataset
from torchtune.modules.tokenizers import SentencePieceTokenizer


class TestCNNDailyMailArticlesDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return SentencePieceTokenizer(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets._text_completion.load_dataset")
@pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096])
def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
# Sample data from CNN / DailyMail dataset
load_dataset.return_value = [
{
"article": "(CNN) -- An American woman died aboard a cruise ship "
"that docked at Rio de Janeiro on Tuesday, the same ship on which "
"86 passengers previously fell ill, according to the state-run "
"Brazilian news agency, Agencia Brasil. The American tourist died "
"aboard the MS Veendam, owned by cruise operator Holland America. "
"Federal Police told Agencia Brasil that forensic doctors were "
"investigating her death. The ship's doctors told police that the "
"woman was elderly and suffered from diabetes and hypertension, "
"according the agency. The other passengers came down with diarrhea "
"prior to her death during an earlier part of the trip, the ship's "
"doctors said. The Veendam left New York 36 days ago for a South "
"America tour.",
}
]
ds = cnn_dailymail_articles_dataset(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
input, label = ds[0]
assert len(input) <= max_seq_len
assert len(label) <= max_seq_len
assert len(input) == len(label)
assert input[0] == tokenizer.bos_id
47 changes: 47 additions & 0 deletions tests/torchtune/datasets/test_text_completion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock

from tests.test_utils import DummyTokenizer

from torchtune.datasets import TextCompletionDataset


class TestTextCompletionDataset:
expected_tokenized_prompts = [
[0, 4, 2, 2, 7, 5, -1],
[0, 4, 2, 7, 7, 5, -1],
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
]

def get_samples(self):
return [
{
"text": "This is an example text.",
},
{
"text": "This is another example text.",
},
]

@mock.patch("torchtune.datasets._text_completion.load_dataset")
def test_get_item(self, mock_load_dataset):
mock_load_dataset.return_value = self.get_samples()
expected_labels = self.expected_tokenized_prompts

dataset = TextCompletionDataset(
tokenizer=DummyTokenizer(),
source="iam/agoofy/goober",
column="text",
max_seq_len=100,
)
assert len(dataset) == 2
mock_load_dataset.assert_called_once()

for i in range(len(dataset)):
prompt, label = dataset[i]
assert prompt == self.expected_tokenized_prompts[i]
assert label == expected_labels[i]
44 changes: 44 additions & 0 deletions tests/torchtune/datasets/test_wikitext_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from unittest.mock import patch

import pytest

from tests.test_utils import get_assets_path

from torchtune.datasets import wikitext_dataset
from torchtune.modules.tokenizers import SentencePieceTokenizer


class TestWikiTextDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return SentencePieceTokenizer(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets._text_completion.load_dataset")
@pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096])
def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
# Sample data from wikitext dataset
load_dataset.return_value = [
{
"text": "Bart , like the rest of his family , has yellow skin . "
"Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . "
"When the Simpson family goes to church in the episodes , or to school "
"events or shows , Bart wears a blue suit with a white shirt , a purple "
"tie , blue shorts and a blue jacket .",
}
]
ds = wikitext_dataset(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
input, label = ds[0]
assert len(input) <= max_seq_len
assert len(label) <= max_seq_len
assert len(input) == len(label)
assert input[0] == tokenizer.bos_id
6 changes: 3 additions & 3 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List
from typing import Any, List, Optional

from torchtune.data._types import Message


def truncate(
tokens: List[Any],
max_seq_len: int,
eos_id: Any,
eos_id: Optional[Any] = None,
) -> List[Any]:
tokens_truncated = tokens[:max_seq_len]
if tokens_truncated[-1] != eos_id:
if eos_id is not None and tokens_truncated[-1] != eos_id:
tokens_truncated[-1] = eos_id
return tokens_truncated

Expand Down
10 changes: 10 additions & 0 deletions torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@

from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
from torchtune.datasets._chat import chat_dataset, ChatDataset
from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset
from torchtune.datasets._concat import ConcatDataset
from torchtune.datasets._grammar import grammar_dataset
from torchtune.datasets._instruct import instruct_dataset, InstructDataset
from torchtune.datasets._samsum import samsum_dataset
from torchtune.datasets._slimorca import slimorca_dataset
from torchtune.datasets._stack_exchanged_paired import stack_exchanged_paired_dataset
from torchtune.datasets._text_completion import (
text_completion_dataset,
TextCompletionDataset,
)
from torchtune.datasets._wikitext import wikitext_dataset

__all__ = [
"alpaca_dataset",
Expand All @@ -24,5 +30,9 @@
"ChatDataset",
"instruct_dataset",
"chat_dataset",
"text_completion_dataset",
"TextCompletionDataset",
"cnn_dailymail_articles_dataset",
"ConcatDataset",
"wikitext_dataset",
]
45 changes: 45 additions & 0 deletions torchtune/datasets/_cnn_dailymail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

from torchtune.datasets._text_completion import TextCompletionDataset
from torchtune.modules.tokenizers import Tokenizer


def cnn_dailymail_articles_dataset(
tokenizer: Tokenizer,
source: str = "ccdv/cnn_dailymail",
max_seq_len: Optional[int] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
"""
Support for family of datasets similar to `CNN / DailyMail <https://huggingface.co/datasets/ccdv/cnn_dailymail>`_,
a corpus of news articles. This builder only extracts the articles and not the highlights for
general text completion tasks.

Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.

Returns:
TextCompletionDataset: the configured TextCompletionDataset
"""

return TextCompletionDataset(
tokenizer=tokenizer,
source=source,
column="article",
max_seq_len=max_seq_len,
split="train",
name="3.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "name"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me add a comment, but it's basically specifying the subset of data

**load_dataset_kwargs,
)
96 changes: 96 additions & 0 deletions torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping, Optional, Tuple

from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data import truncate
from torchtune.modules.tokenizers import Tokenizer


class TextCompletionDataset(Dataset):
"""
Freeform dataset for any unstructured text corpus. Quickly load any dataset
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
from Hugging Face or local disk and tokenize it for your model.

Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
column (str): name of column in the sample that contains the text data
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.
"""

def __init__(
self,
tokenizer: Tokenizer,
source: str,
column: str,
max_seq_len: Optional[int] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
self._tokenizer = tokenizer
self._data = load_dataset(source, **load_dataset_kwargs)
self.max_seq_len = max_seq_len
self._column = column

def __len__(self):
return len(self._data)

def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
sample = self._data[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[int]]:
prompt = sample[self._column]
tokens = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=True)

# Truncate if needed, but don't coerce EOS id
if self.max_seq_len is not None:
tokens = truncate(tokens, self.max_seq_len - 1)

# No need to offset labels by 1 - happens in the recipe
labels = tokens.copy()

return tokens, labels


def text_completion_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry one last nit: Can you add an examples section like this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll make the docs and usage much clearer

tokenizer: Tokenizer,
source: str,
column: str,
max_seq_len: Optional[int] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
"""
Build a configurable freeform text dataset with instruction prompts. This method should be
used to configure a custom text dataset from the yaml config instead of
using `TextDataset` directly, as it is made to be config friendly.

Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
column (str): name of column in the sample that contains the text data
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to `load_dataset`.

Returns:
TextCompletionDataset: the configured TextCompletionDataset
"""
return TextCompletionDataset(
tokenizer=tokenizer,
source=source,
column=column,
max_seq_len=max_seq_len,
**load_dataset_kwargs,
)
Loading
Loading