forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for unstructured text corpus datasets for CPT (pytorch#868)
- Loading branch information
Showing
9 changed files
with
369 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from tests.test_utils import get_assets_path | ||
|
||
from torchtune.datasets import cnn_dailymail_articles_dataset | ||
from torchtune.modules.tokenizers import SentencePieceTokenizer | ||
|
||
|
||
class TestCNNDailyMailArticlesDataset: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
# m.model is a pretrained Sentencepiece model using the following command: | ||
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000') | ||
return SentencePieceTokenizer(str(get_assets_path() / "m.model")) | ||
|
||
@patch("torchtune.datasets._text_completion.load_dataset") | ||
@pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) | ||
def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): | ||
# Sample data from CNN / DailyMail dataset | ||
load_dataset.return_value = [ | ||
{ | ||
"article": "(CNN) -- An American woman died aboard a cruise ship " | ||
"that docked at Rio de Janeiro on Tuesday, the same ship on which " | ||
"86 passengers previously fell ill, according to the state-run " | ||
"Brazilian news agency, Agencia Brasil. The American tourist died " | ||
"aboard the MS Veendam, owned by cruise operator Holland America. " | ||
"Federal Police told Agencia Brasil that forensic doctors were " | ||
"investigating her death. The ship's doctors told police that the " | ||
"woman was elderly and suffered from diabetes and hypertension, " | ||
"according the agency. The other passengers came down with diarrhea " | ||
"prior to her death during an earlier part of the trip, the ship's " | ||
"doctors said. The Veendam left New York 36 days ago for a South " | ||
"America tour.", | ||
} | ||
] | ||
ds = cnn_dailymail_articles_dataset( | ||
tokenizer=tokenizer, | ||
max_seq_len=max_seq_len, | ||
) | ||
input, label = ds[0] | ||
assert len(input) <= max_seq_len | ||
assert len(label) <= max_seq_len | ||
assert len(input) == len(label) | ||
assert input[0] == tokenizer.bos_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from unittest import mock | ||
|
||
from tests.test_utils import DummyTokenizer | ||
|
||
from torchtune.datasets import TextCompletionDataset | ||
|
||
|
||
class TestTextCompletionDataset: | ||
expected_tokenized_prompts = [ | ||
[0, 4, 2, 2, 7, 5, -1], | ||
[0, 4, 2, 7, 7, 5, -1], | ||
] | ||
|
||
def get_samples(self): | ||
return [ | ||
{ | ||
"text": "This is an example text.", | ||
}, | ||
{ | ||
"text": "This is another example text.", | ||
}, | ||
] | ||
|
||
@mock.patch("torchtune.datasets._text_completion.load_dataset") | ||
def test_get_item(self, mock_load_dataset): | ||
mock_load_dataset.return_value = self.get_samples() | ||
expected_labels = self.expected_tokenized_prompts | ||
|
||
dataset = TextCompletionDataset( | ||
tokenizer=DummyTokenizer(), | ||
source="iam/agoofy/goober", | ||
column="text", | ||
max_seq_len=100, | ||
) | ||
assert len(dataset) == 2 | ||
mock_load_dataset.assert_called_once() | ||
|
||
for i in range(len(dataset)): | ||
prompt, label = dataset[i] | ||
assert prompt == self.expected_tokenized_prompts[i] | ||
assert label == expected_labels[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from tests.test_utils import get_assets_path | ||
|
||
from torchtune.datasets import wikitext_dataset | ||
from torchtune.modules.tokenizers import SentencePieceTokenizer | ||
|
||
|
||
class TestWikiTextDataset: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
# m.model is a pretrained Sentencepiece model using the following command: | ||
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000') | ||
return SentencePieceTokenizer(str(get_assets_path() / "m.model")) | ||
|
||
@patch("torchtune.datasets._text_completion.load_dataset") | ||
@pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096]) | ||
def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len): | ||
# Sample data from wikitext dataset | ||
load_dataset.return_value = [ | ||
{ | ||
"text": "Bart , like the rest of his family , has yellow skin . " | ||
"Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . " | ||
"When the Simpson family goes to church in the episodes , or to school " | ||
"events or shows , Bart wears a blue suit with a white shirt , a purple " | ||
"tie , blue shorts and a blue jacket .", | ||
} | ||
] | ||
ds = wikitext_dataset( | ||
tokenizer=tokenizer, | ||
max_seq_len=max_seq_len, | ||
) | ||
input, label = ds[0] | ||
assert len(input) <= max_seq_len | ||
assert len(label) <= max_seq_len | ||
assert len(input) == len(label) | ||
assert input[0] == tokenizer.bos_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
from torchtune.datasets._text_completion import TextCompletionDataset | ||
from torchtune.modules.tokenizers import Tokenizer | ||
|
||
|
||
def cnn_dailymail_articles_dataset( | ||
tokenizer: Tokenizer, | ||
source: str = "ccdv/cnn_dailymail", | ||
max_seq_len: Optional[int] = None, | ||
**load_dataset_kwargs: Dict[str, Any], | ||
) -> TextCompletionDataset: | ||
""" | ||
Support for family of datasets similar to `CNN / DailyMail <https://huggingface.co/datasets/ccdv/cnn_dailymail>`_, | ||
a corpus of news articles. This builder only extracts the articles and not the highlights for | ||
general text completion tasks. | ||
Args: | ||
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method. | ||
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` | ||
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) | ||
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. | ||
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory | ||
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. | ||
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. | ||
Returns: | ||
TextCompletionDataset: the configured TextCompletionDataset | ||
""" | ||
|
||
return TextCompletionDataset( | ||
tokenizer=tokenizer, | ||
source=source, | ||
column="article", | ||
max_seq_len=max_seq_len, | ||
split="train", | ||
# This is used to specify the version of the dataset, a required argument | ||
# by the cnn_dailymail dataset builder: | ||
# https://huggingface.co/datasets/ccdv/cnn_dailymail/blob/main/cnn_dailymail.py#L80 | ||
name="3.0.0", | ||
**load_dataset_kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, List, Mapping, Optional, Tuple | ||
|
||
from datasets import load_dataset | ||
from torch.utils.data import Dataset | ||
from torchtune.data import truncate | ||
from torchtune.modules.tokenizers import Tokenizer | ||
|
||
|
||
class TextCompletionDataset(Dataset): | ||
""" | ||
Freeform dataset for any unstructured text corpus. Quickly load any dataset | ||
from Hugging Face or local disk and tokenize it for your model. | ||
Args: | ||
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method. | ||
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` | ||
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) | ||
column (str): name of column in the sample that contains the text data | ||
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. | ||
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory | ||
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. | ||
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer: Tokenizer, | ||
source: str, | ||
column: str, | ||
max_seq_len: Optional[int] = None, | ||
**load_dataset_kwargs: Dict[str, Any], | ||
) -> None: | ||
self._tokenizer = tokenizer | ||
self._data = load_dataset(source, **load_dataset_kwargs) | ||
self.max_seq_len = max_seq_len | ||
self._column = column | ||
|
||
def __len__(self): | ||
return len(self._data) | ||
|
||
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: | ||
sample = self._data[index] | ||
return self._prepare_sample(sample) | ||
|
||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[int]]: | ||
prompt = sample[self._column] | ||
tokens = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=True) | ||
|
||
# Truncate if needed, but don't coerce EOS id | ||
if self.max_seq_len is not None: | ||
tokens = truncate(tokens, self.max_seq_len - 1) | ||
|
||
# No need to offset labels by 1 - happens in the recipe | ||
labels = tokens.copy() | ||
|
||
return tokens, labels | ||
|
||
|
||
def text_completion_dataset( | ||
tokenizer: Tokenizer, | ||
source: str, | ||
column: str, | ||
max_seq_len: Optional[int] = None, | ||
**load_dataset_kwargs: Dict[str, Any], | ||
) -> TextCompletionDataset: | ||
""" | ||
Build a configurable freeform text dataset with instruction prompts. This method should be | ||
used to configure a custom text dataset from the yaml config instead of | ||
using `TextDataset` directly, as it is made to be config friendly. | ||
Args: | ||
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method. | ||
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset`` | ||
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) | ||
column (str): name of column in the sample that contains the text data | ||
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. | ||
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory | ||
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. | ||
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. | ||
Examples: | ||
>>> from torchtune.datasets import text_completion_dataset | ||
>>> dataset = text_completion_dataset( | ||
... tokenizer=tokenizer, | ||
... source="allenai/c4", | ||
... column="text", | ||
... max_seq_len=2096, | ||
... data_dir="realnewslike", | ||
... ) | ||
This can also be accomplished via the yaml config:: | ||
dataset: | ||
_component_: torchtune.datasets.text_completion_dataset | ||
source: allenai/c4 | ||
column: text | ||
max_seq_len: 2096 | ||
data_dir: realnewslike | ||
Returns: | ||
TextCompletionDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset` | ||
""" | ||
return TextCompletionDataset( | ||
tokenizer=tokenizer, | ||
source=source, | ||
column=column, | ||
max_seq_len=max_seq_len, | ||
**load_dataset_kwargs, | ||
) |
Oops, something went wrong.