Skip to content

Commit

Permalink
Chat dataset + SlimOrca refactor + more templates (pytorch#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA authored and Thomas Capelle committed Apr 5, 2024
1 parent 54a5e2a commit de155dc
Show file tree
Hide file tree
Showing 24 changed files with 1,014 additions and 295 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ torchtune.datasets
alpaca_dataset
grammar_dataset
samsum_dataset
SlimOrcaDataset
slimorca_dataset
4 changes: 2 additions & 2 deletions docs/source/examples/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,5 @@ name directly. Any nested fields in the components can be overridden with dot no
.. code-block:: bash
# Change to SlimOrcaDataset and set train_on_input to False
tune full_finetune --config my_config.yaml dataset=torchtune.datasets.SlimOrcaDataset dataset.train_on_input=False
# Change to slimorca_dataset and set train_on_input to False
tune full_finetune --config my_config.yaml dataset=torchtune.datasets.slimorca_dataset dataset.train_on_input=False
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
}


class DummyTokenizer:
def encode(self, text, **kwargs):
words = text.split()
return [len(word) for word in words]

@property
def eos_id(self):
return -1


def get_assets_path():
return Path(__file__).parent / "assets"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import pytest
from torchtune.config._utils import (
_get_component_from_path,
_get_template,
_merge_yaml_and_cli_args,
InstantiationError,
)
from torchtune.data import AlpacaInstructTemplate
from torchtune.utils.argparse import TuneArgumentParser

_CONFIG = {
Expand Down Expand Up @@ -107,3 +109,33 @@ def test_merge_yaml_and_cli_args(self, mock_load):
ValueError, match="Command-line overrides must be in the form of key=value"
):
_ = _merge_yaml_and_cli_args(yaml_args, cli_args)

def test_get_template(self):
# Test valid template class
template = _get_template("AlpacaInstructTemplate")
assert isinstance(template, AlpacaInstructTemplate)

# Test invalid template class
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template("InvalidTemplate")

# Test valid template strings
valid_templates = [
"Instruction: {instruction}\nInput: {input}",
"Instruction: {instruction}",
"{a}",
]
for template in valid_templates:
assert _get_template(template) == template

# Test invalid template strings
invalid_templates = ["hello", "{}", "a}{b"]
for template in invalid_templates:
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template(template)
84 changes: 84 additions & 0 deletions tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from tests.test_utils import DummyTokenizer
from torchtune.data import tokenize_prompt_and_response, truncate
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX


def test_tokenize_prompt_and_response():
tokenizer = DummyTokenizer()
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
response = "I always know what I'm doing, do you?"
prompt_length = 11
expected_tokenized_prompt = [
12,
4,
2,
2,
12,
6,
4,
2,
2,
6,
9,
1,
6,
4,
4,
3,
6,
2,
4,
]
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
1,
6,
4,
4,
3,
6,
2,
4,
]

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_label

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response, train_on_input=True
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_prompt


def test_truncate():
prompt_tokens = [1, 2, 3, 4, -1]
label_tokens = [1, 2, 3, 4, -1]

# Test no truncation
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=5,
)
assert truncated_prompt_tokens == prompt_tokens
assert truncated_label_tokens == label_tokens

# Test truncated
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=4,
)
assert truncated_prompt_tokens == [1, 2, 3, -1]
assert truncated_label_tokens == [1, 2, 3, -1]
108 changes: 108 additions & 0 deletions tests/torchtune/data/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from torchtune.data import (
AlpacaInstructTemplate,
ChatMLTemplate,
GrammarErrorCorrectionTemplate,
Llama2ChatTemplate,
MistralChatTemplate,
SummarizeTemplate,
)

# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
CHAT_SAMPLE = {
"system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950
"user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950
"assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950
}


class TestAlpacaInstructTemplate:
samples = [
Expand Down Expand Up @@ -144,3 +156,99 @@ def test_format_with_column_map(self):
actual = self.template.format(modified_sample, column_map=column_map)

assert actual == expected_prompt


class TestLlama2ChatTemplate:
expected_prompt = (
"[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] "
)

template = Llama2ChatTemplate()

def test_format(self):
actual = self.template.format(CHAT_SAMPLE)
assert actual == self.expected_prompt

def test_format_with_column_map(self):
column_map = {"system": "not_system"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_system"] = modified_sample["system"]
del modified_sample["system"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt


class TestMistralChatTemplate:
expected_prompt = (
"[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] "
)

template = MistralChatTemplate()

def test_format(self):
no_system_sample = CHAT_SAMPLE.copy()
del no_system_sample["system"]
actual = self.template.format(no_system_sample)
assert actual == self.expected_prompt

def test_format_with_system_prompt_raises(self):
with pytest.raises(
ValueError, match="System prompts are not supported in MistralChatTemplate"
):
_ = self.template.format(CHAT_SAMPLE)

def test_format_with_column_map(self):
column_map = {"user": "not_user"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_user"] = modified_sample["user"]
del modified_sample["system"]
del modified_sample["user"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt


class TestChatMLTemplate:
expected_prompt = (
"<|im_start|>system\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.<|im_end|>\n<|im_start|>user\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary:<|im_end|>\n<|im_start|>assistant\n"
)

template = ChatMLTemplate()

def test_format(self):
actual = self.template.format(CHAT_SAMPLE)
assert actual == self.expected_prompt

def test_format_with_column_map(self):
column_map = {"system": "not_system"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_system"] = modified_sample["system"]
del modified_sample["system"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt
4 changes: 2 additions & 2 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import pytest

from tests.test_utils import get_assets_path
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._alpaca import alpaca_dataset
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import alpaca_dataset
from torchtune.modules.tokenizer import Tokenizer


Expand Down
Loading

0 comments on commit de155dc

Please sign in to comment.