Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interleaved image support in tokenizers #1138

Merged
merged 31 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
75dae87
complete tokenizer refactor
RdoubleA Jun 12, 2024
0c20ba9
move tokenizers under data/
RdoubleA Jun 12, 2024
730a2c9
fix all tests
RdoubleA Jun 12, 2024
acf7e81
Merge branch 'main' into tokenizer
RdoubleA Jun 21, 2024
2ae157c
start to address comments
RdoubleA Jun 22, 2024
6a50cd5
load in special tokens, move tokenizer directory back, address comments
RdoubleA Jun 24, 2024
61534d0
fix encode whitespace
RdoubleA Jun 24, 2024
1d6e5e3
updates after manual comparisons
RdoubleA Jun 25, 2024
5712de4
default special tokens
RdoubleA Jun 26, 2024
d84bbda
fix docs
RdoubleA Jun 26, 2024
5a8b82b
fix doc strings
RdoubleA Jun 26, 2024
52643cb
Merge branch 'main' into tokenizer
RdoubleA Jun 26, 2024
a00c1dc
fix tests
RdoubleA Jun 26, 2024
29273ca
fix SP test
RdoubleA Jun 26, 2024
adca77e
Merge branch 'main' into tokenizer
RdoubleA Jul 2, 2024
b204563
update api ref
RdoubleA Jul 2, 2024
e236916
Merge branch 'main' into tokenizer
RdoubleA Jul 2, 2024
93028cf
fix llama3 toeknizer test:
RdoubleA Jul 2, 2024
fb12cbb
add image support
RdoubleA Jun 26, 2024
b5bf410
tool support
RdoubleA Jun 26, 2024
00f266f
update tests
RdoubleA Jun 26, 2024
c815069
update tests
RdoubleA Jun 26, 2024
21b3ea8
use images as attachments instead
RdoubleA Jun 27, 2024
adbfb20
update all tests
RdoubleA Jun 27, 2024
1e40a9d
use list of dicts for MM messages
RdoubleA Jun 27, 2024
0d3665c
fix chat formats
RdoubleA Jul 1, 2024
95edf70
run linter
RdoubleA Jul 2, 2024
a3067aa
Merge branch 'main' into tokenizer_updates
RdoubleA Jul 2, 2024
d49febf
merge main
RdoubleA Jul 2, 2024
7da4189
fix chat formats
RdoubleA Jul 3, 2024
fca9031
address comments, fix docs
RdoubleA Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Types
:nosignatures:

Message
Role

Converters
----------
Expand Down
55 changes: 48 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from torch import nn
from torchtune.data import truncate
from torchtune.data import ChatFormat, Message, truncate
from torchtune.modules.tokenizers import ModelTokenizer

skip_if_cuda_not_available = unittest.skipIf(
Expand Down Expand Up @@ -50,7 +50,7 @@ def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]:
return tokens

def tokenize_messages(
self, messages: List[str], max_seq_len: Optional[int] = None
self, messages: List[Message], max_seq_len: Optional[int] = None
) -> Tuple[List[int], List[bool]]:
"""
A simplified version of Llama2Tokenizer's ``tokenize_messages`` for testing purposes.
Expand All @@ -69,11 +69,15 @@ def tokenize_messages(
mask.append(message.masked)

# Tokenize current message, append with masks
tokens = self.encode(
message.content,
add_bos=False,
add_eos=False,
)
for item in message.content:
if item["type"] == "text":
tokens = self.encode(
item["content"],
add_bos=False,
add_eos=False,
)
elif item["type"] == "image":
tokens = [self.image_id]

tokenized_messages.extend(tokens)
mask.extend([message.masked] * len(tokens))
Expand Down Expand Up @@ -106,6 +110,36 @@ def eos_id(self):
def bos_id(self):
return 0

@property
def image_id(self):
return -2


class DummyChatFormat(ChatFormat):

B_SYS, E_SYS = "System:\n", "\n"
B_INST, E_INST = "User:\n", "\nAssistant:\n"
B_ASST, E_ASST = "", ""
system = f"{B_SYS}{{content}}{E_SYS}"
user = f"{B_INST}{{content}}{E_INST}"
assistant = f"{B_ASST}{{content}}{E_ASST}"

@classmethod
def format(
cls,
messages,
):
formats = {"system": cls.system, "user": cls.user, "assistant": cls.assistant}
formatted_dialogue = []
for message in messages:
content = formats.get(message.role).format(
content=message.content[0]["content"]
)
formatted_dialogue.append(
Message(role=message.role, content=content, masked=message.masked),
)
return formatted_dialogue


def get_assets_path():
return Path(__file__).parent / "assets"
Expand Down Expand Up @@ -258,3 +292,10 @@ def gen_log_file_name(tmpdir, suffix: Optional[str] = None) -> str:
filename += suffix
filename += ".txt"
return filename


def assert_dialogue_equal(actual, expected):
assert len(actual) == len(expected)
for i in range(len(actual)):
assert actual[i].role == expected[i].role
assert actual[i].text_content == expected[i].text_content
15 changes: 5 additions & 10 deletions tests/torchtune/data/test_chat_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import pytest
from tests.test_utils import assert_dialogue_equal
from torchtune.data import ChatMLFormat, Llama2ChatFormat, Message, MistralChatFormat

# Taken from Open-Orca/SlimOrca-Dedup on HuggingFace:
Expand Down Expand Up @@ -34,13 +36,6 @@
]


def _assert_dialogue_equal(actual, expected):
assert len(actual) == len(expected)
for i in range(len(actual)):
assert actual[i].role == expected[i].role
assert actual[i].content == expected[i].content


class TestLlama2ChatFormat:
expected_dialogue = [
Message(
Expand All @@ -65,7 +60,7 @@ class TestLlama2ChatFormat:

def test_format(self):
actual = Llama2ChatFormat.format(CHAT_SAMPLE)
_assert_dialogue_equal(actual, self.expected_dialogue)
assert_dialogue_equal(actual, self.expected_dialogue)


class TestMistralChatFormat:
Expand All @@ -90,7 +85,7 @@ class TestMistralChatFormat:
def test_format(self):
no_system_sample = CHAT_SAMPLE[1:]
actual = MistralChatFormat.format(no_system_sample)
_assert_dialogue_equal(actual, self.expected_dialogue)
assert_dialogue_equal(actual, self.expected_dialogue)

def test_format_with_system_prompt_raises(self):
with pytest.raises(
Expand Down Expand Up @@ -127,4 +122,4 @@ class TestChatMLFormat:

def test_format(self):
actual = ChatMLFormat.format(CHAT_SAMPLE)
_assert_dialogue_equal(actual, self.expected_dialogue)
assert_dialogue_equal(actual, self.expected_dialogue)
25 changes: 7 additions & 18 deletions tests/torchtune/data/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from tests.test_utils import assert_dialogue_equal
from torchtune.data import get_openai_messages, get_sharegpt_messages
from torchtune.data._types import Message

Expand Down Expand Up @@ -60,15 +61,11 @@ class TestShareGPTToLlama2Messages:

def test_conversion(self):
converted_messages = get_sharegpt_messages(self.samples)
for converted, expected in zip(converted_messages, EXPECTED_MESSAGE):
assert converted == expected
assert_dialogue_equal(converted_messages, EXPECTED_MESSAGE)

def test_conversion_train_on_input(self):
converted_messages = get_sharegpt_messages(self.samples, train_on_input=True)
for converted, expected in zip(
converted_messages, EXPECTED_MESSAGE_TRAIN_ON_INPUT
):
assert converted == expected
assert_dialogue_equal(converted_messages, EXPECTED_MESSAGE_TRAIN_ON_INPUT)


class TestOpenAIToLlama2Messages:
Expand Down Expand Up @@ -110,24 +107,16 @@ class TestOpenAIToLlama2Messages:

def test_conversion_conversations_key(self):
converted_messages_1 = get_openai_messages(self.samples_1)
for converted, expected in zip(converted_messages_1, EXPECTED_MESSAGE):
assert converted == expected
assert_dialogue_equal(converted_messages_1, EXPECTED_MESSAGE)

def test_conversion_messages_key(self):
converted_messages_2 = get_openai_messages(self.samples_2)
for converted, expected in zip(converted_messages_2, EXPECTED_MESSAGE):
assert converted == expected
assert_dialogue_equal(converted_messages_2, EXPECTED_MESSAGE)

def test_conversion_conversations_key_train_on_input(self):
converted_messages_1 = get_openai_messages(self.samples_1, train_on_input=True)
for converted, expected in zip(
converted_messages_1, EXPECTED_MESSAGE_TRAIN_ON_INPUT
):
assert converted == expected
assert_dialogue_equal(converted_messages_1, EXPECTED_MESSAGE_TRAIN_ON_INPUT)

def test_conversion_messages_key_train_on_input(self):
converted_messages_2 = get_openai_messages(self.samples_2, train_on_input=True)
for converted, expected in zip(
converted_messages_2, EXPECTED_MESSAGE_TRAIN_ON_INPUT
):
assert converted == expected
assert_dialogue_equal(converted_messages_2, EXPECTED_MESSAGE_TRAIN_ON_INPUT)
90 changes: 34 additions & 56 deletions tests/torchtune/datasets/test_chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,12 @@
from unittest import mock

import pytest
from datasets import Dataset
from tests.test_utils import DummyTokenizer
from torchtune.data import ChatFormat, Message
from tests.test_utils import DummyChatFormat, DummyTokenizer
from torchtune.data import Message
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ChatDataset


class DummyChatFormat(ChatFormat):

B_SYS, E_SYS = "System:\n", "\n"
B_INST, E_INST = "User:\n", "\nAssistant:\n"
B_ASST, E_ASST = "", ""
system = f"{B_SYS}{{content}}{E_SYS}"
user = f"{B_INST}{{content}}{E_INST}"
assistant = f"{B_ASST}{{content}}{E_ASST}"

@classmethod
def format(
cls,
messages,
):
formats = {"system": cls.system, "user": cls.user, "assistant": cls.assistant}
formatted_dialogue = []
for message in messages:
content = formats.get(message["role"]).format(content=message["content"])
formatted_dialogue.append(
Message(
role=message["role"], content=content, masked=message["masked"]
),
)
return formatted_dialogue


def _are_messages_equal(messages_a, messages_b):
for ma, mb in zip(messages_a, messages_b):
if ma.role != mb.role:
return False
if ma.content != mb.content:
return False
return True


class TestChatDataset:
@pytest.fixture
def chat_format(self):
Expand All @@ -59,30 +23,44 @@ def dialogue(self):
return [
{
"dialogue": [
{
"role": "system",
"content": "You are an AI assistant.",
"masked": True,
},
{
"role": "user",
"content": "What is the meaning of life?",
"masked": True,
},
{
"role": "assistant",
"content": "The meaning of life is 42.",
"masked": False,
},
{"role": "user", "content": "That's ridiculous.", "masked": True},
{"role": "assistant", "content": "I agree.", "masked": False},
Message.from_dict(
{
"role": "system",
"content": "You are an AI assistant.",
"masked": True,
}
),
Message.from_dict(
{
"role": "user",
"content": "What is the meaning of life?",
"masked": True,
}
),
Message.from_dict(
{
"role": "assistant",
"content": "The meaning of life is 42.",
"masked": False,
}
),
Message.from_dict(
{
"role": "user",
"content": "That's ridiculous.",
"masked": True,
}
),
Message.from_dict(
{"role": "assistant", "content": "I agree.", "masked": False}
),
],
},
]

@mock.patch("torchtune.datasets._chat.load_dataset")
def test_get_item(self, mock_load_dataset, chat_format, dialogue):
mock_load_dataset.return_value = Dataset.from_list(dialogue)
mock_load_dataset.return_value = dialogue
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
expected_tokenized_prompts = [
[
0,
Expand Down
Loading
Loading