Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat dataset + SlimOrca refactor + more templates #576

Merged
merged 6 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ torchtune.datasets
alpaca_dataset
grammar_dataset
samsum_dataset
SlimOrcaDataset
slimorca_dataset
4 changes: 2 additions & 2 deletions docs/source/examples/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,5 @@ name directly. Any nested fields in the components can be overridden with dot no

.. code-block:: bash

# Change to SlimOrcaDataset and set train_on_input to False
tune full_finetune --config my_config.yaml dataset=torchtune.datasets.SlimOrcaDataset dataset.train_on_input=False
# Change to slimorca_dataset and set train_on_input to False
tune full_finetune --config my_config.yaml dataset=torchtune.datasets.slimorca_dataset dataset.train_on_input=False
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
}


class DummyTokenizer:
def encode(self, text, **kwargs):
words = text.split()
return [len(word) for word in words]

@property
def eos_id(self):
return -1


def get_assets_path():
return Path(__file__).parent / "assets"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import pytest
from torchtune.config._utils import (
_get_component_from_path,
_get_template,
_merge_yaml_and_cli_args,
InstantiationError,
)
from torchtune.data import AlpacaInstructTemplate
from torchtune.utils.argparse import TuneArgumentParser

_CONFIG = {
Expand Down Expand Up @@ -107,3 +109,33 @@ def test_merge_yaml_and_cli_args(self, mock_load):
ValueError, match="Command-line overrides must be in the form of key=value"
):
_ = _merge_yaml_and_cli_args(yaml_args, cli_args)

def test_get_template(self):
# Test valid template class
template = _get_template("AlpacaInstructTemplate")
assert isinstance(template, AlpacaInstructTemplate)

# Test invalid template class
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template("InvalidTemplate")

# Test valid template strings
valid_templates = [
"Instruction: {instruction}\nInput: {input}",
"Instruction: {instruction}",
"{a}",
]
for template in valid_templates:
assert _get_template(template) == template

# Test invalid template strings
invalid_templates = ["hello", "{}", "a}{b"]
for template in invalid_templates:
with pytest.raises(
ValueError,
match="Must be a PromptTemplate class or a string with placeholders.",
):
_ = _get_template(template)
84 changes: 84 additions & 0 deletions tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from tests.test_utils import DummyTokenizer
from torchtune.data import tokenize_prompt_and_response, truncate
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX


def test_tokenize_prompt_and_response():
tokenizer = DummyTokenizer()
prompt = "Instruction:\nThis is an instruction.\n\nInput:\nThis is an input.\n\nResponse: "
response = "I always know what I'm doing, do you?"
prompt_length = 11
expected_tokenized_prompt = [
12,
4,
2,
2,
12,
6,
4,
2,
2,
6,
9,
1,
6,
4,
4,
3,
6,
2,
4,
]
expected_tokenized_label = [CROSS_ENTROPY_IGNORE_IDX] * prompt_length + [
1,
6,
4,
4,
3,
6,
2,
4,
]

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_label

tokenized_prompt, tokenized_label = tokenize_prompt_and_response(
tokenizer, prompt, response, train_on_input=True
)
assert tokenized_prompt == expected_tokenized_prompt
assert tokenized_label == expected_tokenized_prompt


def test_truncate():
prompt_tokens = [1, 2, 3, 4, -1]
label_tokens = [1, 2, 3, 4, -1]

# Test no truncation
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=5,
)
assert truncated_prompt_tokens == prompt_tokens
assert truncated_label_tokens == label_tokens

# Test truncated
truncated_prompt_tokens, truncated_label_tokens = truncate(
tokenizer=DummyTokenizer(),
prompt_tokens=prompt_tokens,
label_tokens=label_tokens,
max_seq_len=4,
)
assert truncated_prompt_tokens == [1, 2, 3, -1]
assert truncated_label_tokens == [1, 2, 3, -1]
108 changes: 108 additions & 0 deletions tests/torchtune/data/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from torchtune.data import (
AlpacaInstructTemplate,
ChatMLTemplate,
GrammarErrorCorrectionTemplate,
Llama2ChatTemplate,
MistralChatTemplate,
SummarizeTemplate,
)

# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
CHAT_SAMPLE = {
"system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950
"user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950
"assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950
}


class TestAlpacaInstructTemplate:
samples = [
Expand Down Expand Up @@ -144,3 +156,99 @@ def test_format_with_column_map(self):
actual = self.template.format(modified_sample, column_map=column_map)

assert actual == expected_prompt


class TestLlama2ChatTemplate:
expected_prompt = (
"[INST] <<SYS>>\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.\n<</SYS>>\n\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] "
)

template = Llama2ChatTemplate()

def test_format(self):
actual = self.template.format(CHAT_SAMPLE)
assert actual == self.expected_prompt

def test_format_with_column_map(self):
column_map = {"system": "not_system"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_system"] = modified_sample["system"]
del modified_sample["system"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt


class TestMistralChatTemplate:
expected_prompt = (
"[INST] Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary: [/INST] "
)

template = MistralChatTemplate()

def test_format(self):
no_system_sample = CHAT_SAMPLE.copy()
del no_system_sample["system"]
actual = self.template.format(no_system_sample)
assert actual == self.expected_prompt

def test_format_with_system_prompt_raises(self):
with pytest.raises(
ValueError, match="System prompts are not supported in MistralChatTemplate"
):
_ = self.template.format(CHAT_SAMPLE)
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

def test_format_with_column_map(self):
column_map = {"user": "not_user"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_user"] = modified_sample["user"]
del modified_sample["system"]
del modified_sample["user"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt


class TestChatMLTemplate:
expected_prompt = (
"<|im_start|>system\nYou are an AI assistant. User will you give you a task. "
"Your goal is to complete the task as faithfully as you can. While performing "
"the task think step-by-step and justify your steps.<|im_end|>\n<|im_start|>user\nPlease "
"briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old "
"Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? "
"How about on an icy road? Well one father in Russia did just that, and recorded "
"the entire thing. To her credit, the child seemed to be doing a great job. "
"(0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\n"
"Summary:<|im_end|>\n<|im_start|>assistant\n"
)

template = ChatMLTemplate()

def test_format(self):
actual = self.template.format(CHAT_SAMPLE)
assert actual == self.expected_prompt

def test_format_with_column_map(self):
column_map = {"system": "not_system"}
modified_sample = CHAT_SAMPLE.copy()
modified_sample["not_system"] = modified_sample["system"]
del modified_sample["system"]

actual = self.template.format(modified_sample, column_map=column_map)

assert actual == self.expected_prompt
4 changes: 2 additions & 2 deletions tests/torchtune/datasets/test_alpaca_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import pytest

from tests.test_utils import get_assets_path
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX

from torchtune.datasets._alpaca import alpaca_dataset
from torchtune.datasets._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import alpaca_dataset
from torchtune.modules.tokenizer import Tokenizer


Expand Down
Loading
Loading