Skip to content

Commit

Permalink
Support for unstructured text corpus datasets for CPT (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored May 21, 2024
1 parent 29ae975 commit bbf0010
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ torchtune supports several widely used datasets to help quickly bootstrap your f
grammar_dataset
samsum_dataset
slimorca_dataset
cnn_dailymail_articles_dataset
wikitext_dataset

Generic dataset builders
------------------------
Expand All @@ -34,6 +36,7 @@ These are especially useful for specifying from a YAML config.

instruct_dataset
chat_dataset
text_completion_dataset

Generic dataset classes
-----------------------
Expand All @@ -46,5 +49,6 @@ Class representations for the above dataset builders.

InstructDataset
ChatDataset
TextCompletionDataset
ConcatDataset
PackedDataset
51 changes: 51 additions & 0 deletions tests/torchtune/datasets/test_cnn_dailymail_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from unittest.mock import patch

import pytest

from tests.test_utils import get_assets_path

from torchtune.datasets import cnn_dailymail_articles_dataset
from torchtune.modules.tokenizers import SentencePieceTokenizer


class TestCNNDailyMailArticlesDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return SentencePieceTokenizer(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets._text_completion.load_dataset")
@pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096])
def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
# Sample data from CNN / DailyMail dataset
load_dataset.return_value = [
{
"article": "(CNN) -- An American woman died aboard a cruise ship "
"that docked at Rio de Janeiro on Tuesday, the same ship on which "
"86 passengers previously fell ill, according to the state-run "
"Brazilian news agency, Agencia Brasil. The American tourist died "
"aboard the MS Veendam, owned by cruise operator Holland America. "
"Federal Police told Agencia Brasil that forensic doctors were "
"investigating her death. The ship's doctors told police that the "
"woman was elderly and suffered from diabetes and hypertension, "
"according the agency. The other passengers came down with diarrhea "
"prior to her death during an earlier part of the trip, the ship's "
"doctors said. The Veendam left New York 36 days ago for a South "
"America tour.",
}
]
ds = cnn_dailymail_articles_dataset(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
input, label = ds[0]
assert len(input) <= max_seq_len
assert len(label) <= max_seq_len
assert len(input) == len(label)
assert input[0] == tokenizer.bos_id
47 changes: 47 additions & 0 deletions tests/torchtune/datasets/test_text_completion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from unittest import mock

from tests.test_utils import DummyTokenizer

from torchtune.datasets import TextCompletionDataset


class TestTextCompletionDataset:
expected_tokenized_prompts = [
[0, 4, 2, 2, 7, 5, -1],
[0, 4, 2, 7, 7, 5, -1],
]

def get_samples(self):
return [
{
"text": "This is an example text.",
},
{
"text": "This is another example text.",
},
]

@mock.patch("torchtune.datasets._text_completion.load_dataset")
def test_get_item(self, mock_load_dataset):
mock_load_dataset.return_value = self.get_samples()
expected_labels = self.expected_tokenized_prompts

dataset = TextCompletionDataset(
tokenizer=DummyTokenizer(),
source="iam/agoofy/goober",
column="text",
max_seq_len=100,
)
assert len(dataset) == 2
mock_load_dataset.assert_called_once()

for i in range(len(dataset)):
prompt, label = dataset[i]
assert prompt == self.expected_tokenized_prompts[i]
assert label == expected_labels[i]
44 changes: 44 additions & 0 deletions tests/torchtune/datasets/test_wikitext_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from unittest.mock import patch

import pytest

from tests.test_utils import get_assets_path

from torchtune.datasets import wikitext_dataset
from torchtune.modules.tokenizers import SentencePieceTokenizer


class TestWikiTextDataset:
@pytest.fixture
def tokenizer(self):
# m.model is a pretrained Sentencepiece model using the following command:
# spm.SentencePieceTrainer.train('--input=<TRAIN_FILE> --model_prefix=m --vocab_size=2000')
return SentencePieceTokenizer(str(get_assets_path() / "m.model"))

@patch("torchtune.datasets._text_completion.load_dataset")
@pytest.mark.parametrize("max_seq_len", [128, 512, 1024, 4096])
def test_dataset_get_item(self, load_dataset, tokenizer, max_seq_len):
# Sample data from wikitext dataset
load_dataset.return_value = [
{
"text": "Bart , like the rest of his family , has yellow skin . "
"Bart usually wears a red T @-@ shirt , blue shorts and blue trainers . "
"When the Simpson family goes to church in the episodes , or to school "
"events or shows , Bart wears a blue suit with a white shirt , a purple "
"tie , blue shorts and a blue jacket .",
}
]
ds = wikitext_dataset(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
input, label = ds[0]
assert len(input) <= max_seq_len
assert len(label) <= max_seq_len
assert len(input) == len(label)
assert input[0] == tokenizer.bos_id
6 changes: 3 additions & 3 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List
from typing import Any, List, Optional

from torchtune.data._types import Message


def truncate(
tokens: List[Any],
max_seq_len: int,
eos_id: Any,
eos_id: Optional[Any] = None,
) -> List[Any]:
tokens_truncated = tokens[:max_seq_len]
if tokens_truncated[-1] != eos_id:
if eos_id is not None and tokens_truncated[-1] != eos_id:
tokens_truncated[-1] = eos_id
return tokens_truncated

Expand Down
10 changes: 10 additions & 0 deletions torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@

from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
from torchtune.datasets._chat import chat_dataset, ChatDataset
from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset
from torchtune.datasets._concat import ConcatDataset
from torchtune.datasets._grammar import grammar_dataset
from torchtune.datasets._instruct import instruct_dataset, InstructDataset
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._samsum import samsum_dataset
from torchtune.datasets._slimorca import slimorca_dataset
from torchtune.datasets._stack_exchanged_paired import stack_exchanged_paired_dataset
from torchtune.datasets._text_completion import (
text_completion_dataset,
TextCompletionDataset,
)
from torchtune.datasets._wikitext import wikitext_dataset

__all__ = [
"alpaca_dataset",
Expand All @@ -25,6 +31,10 @@
"ChatDataset",
"instruct_dataset",
"chat_dataset",
"text_completion_dataset",
"TextCompletionDataset",
"cnn_dailymail_articles_dataset",
"PackedDataset",
"ConcatDataset",
"wikitext_dataset",
]
48 changes: 48 additions & 0 deletions torchtune/datasets/_cnn_dailymail.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

from torchtune.datasets._text_completion import TextCompletionDataset
from torchtune.modules.tokenizers import Tokenizer


def cnn_dailymail_articles_dataset(
tokenizer: Tokenizer,
source: str = "ccdv/cnn_dailymail",
max_seq_len: Optional[int] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
"""
Support for family of datasets similar to `CNN / DailyMail <https://huggingface.co/datasets/ccdv/cnn_dailymail>`_,
a corpus of news articles. This builder only extracts the articles and not the highlights for
general text completion tasks.
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method.
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Returns:
TextCompletionDataset: the configured TextCompletionDataset
"""

return TextCompletionDataset(
tokenizer=tokenizer,
source=source,
column="article",
max_seq_len=max_seq_len,
split="train",
# This is used to specify the version of the dataset, a required argument
# by the cnn_dailymail dataset builder:
# https://huggingface.co/datasets/ccdv/cnn_dailymail/blob/main/cnn_dailymail.py#L80
name="3.0.0",
**load_dataset_kwargs,
)
115 changes: 115 additions & 0 deletions torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping, Optional, Tuple

from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data import truncate
from torchtune.modules.tokenizers import Tokenizer


class TextCompletionDataset(Dataset):
"""
Freeform dataset for any unstructured text corpus. Quickly load any dataset
from Hugging Face or local disk and tokenize it for your model.
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method.
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
column (str): name of column in the sample that contains the text data
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
"""

def __init__(
self,
tokenizer: Tokenizer,
source: str,
column: str,
max_seq_len: Optional[int] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
self._tokenizer = tokenizer
self._data = load_dataset(source, **load_dataset_kwargs)
self.max_seq_len = max_seq_len
self._column = column

def __len__(self):
return len(self._data)

def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
sample = self._data[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Tuple[List[int], List[int]]:
prompt = sample[self._column]
tokens = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=True)

# Truncate if needed, but don't coerce EOS id
if self.max_seq_len is not None:
tokens = truncate(tokens, self.max_seq_len - 1)

# No need to offset labels by 1 - happens in the recipe
labels = tokens.copy()

return tokens, labels


def text_completion_dataset(
tokenizer: Tokenizer,
source: str,
column: str,
max_seq_len: Optional[int] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> TextCompletionDataset:
"""
Build a configurable freeform text dataset with instruction prompts. This method should be
used to configure a custom text dataset from the yaml config instead of
using `TextDataset` directly, as it is made to be config friendly.
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an ``encode`` and ``decode`` method.
source (str): path string of dataset, anything supported by Hugging Face's ``load_dataset``
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
column (str): name of column in the sample that contains the text data
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Examples:
>>> from torchtune.datasets import text_completion_dataset
>>> dataset = text_completion_dataset(
... tokenizer=tokenizer,
... source="allenai/c4",
... column="text",
... max_seq_len=2096,
... data_dir="realnewslike",
... )
This can also be accomplished via the yaml config::
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: allenai/c4
column: text
max_seq_len: 2096
data_dir: realnewslike
Returns:
TextCompletionDataset: the configured :class:`~torchtune.datasets.TextCompletionDataset`
"""
return TextCompletionDataset(
tokenizer=tokenizer,
source=source,
column=column,
max_seq_len=max_seq_len,
**load_dataset_kwargs,
)
Loading

0 comments on commit bbf0010

Please sign in to comment.