Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conversation_style of openai format (OpenAI API style) #890

Merged
merged 13 commits into from
Apr 28, 2024
139 changes: 139 additions & 0 deletions tests/torchtune/data/test_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.data import openai_to_llama2_messages, sharegpt_to_llama2_messages
from torchtune.data._types import Message

# Taken from Open-Orca/SlimOrca-Dedup on Hugging Face:
# https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup
CHAT_SAMPLE = {
"system": "You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.", # noqa: B950
"user": "Please briefly summarize this news article:\n\nAOL.com Video - Father Lets 8-Year-Old Drive On Icy Road\n\nDescription:Would you let your 8-year-old drive your car? How about on an icy road? Well one father in Russia did just that, and recorded the entire thing. To her credit, the child seemed to be doing a great job. (0:44)\n\nTags: 8-year-old driver , caught on camera , child driver , pix11\n\nSummary:", # noqa: B950
"assistant": "A father in Russia allowed his 8-year-old child to drive his car on an icy road and recorded the event. The child appeared to be handling the situation well, showcasing their driving skills despite the challenging conditions.", # noqa: B950
}

EXPECTED_MESSAGE_TRAIN_ON_INPUT = [
Message(
role="system",
content=CHAT_SAMPLE["system"],
),
Message(
role="user",
content=CHAT_SAMPLE["user"],
),
Message(
role="assistant",
content=CHAT_SAMPLE["assistant"],
),
]

EXPECTED_MESSAGE = [
Message(role="system", content=CHAT_SAMPLE["system"], masked=True),
Message(role="user", content=CHAT_SAMPLE["user"], masked=True),
Message(
role="assistant",
content=CHAT_SAMPLE["assistant"],
),
]


class TestShareGPTToLlama2Messages:
samples = {
"conversations": [
{
"from": "system",
"value": CHAT_SAMPLE["system"],
},
{
"from": "human",
"value": CHAT_SAMPLE["user"],
},
{
"from": "gpt",
"value": CHAT_SAMPLE["assistant"],
},
]
}

def test_conversion(self):
converted_messages = sharegpt_to_llama2_messages(self.samples)
for converted, expected in zip(converted_messages, EXPECTED_MESSAGE):
assert converted == expected
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved

def test_conversion_train_on_input(self):
converted_messages = sharegpt_to_llama2_messages(
self.samples, train_on_input=True
)
for converted, expected in zip(
converted_messages, EXPECTED_MESSAGE_TRAIN_ON_INPUT
):
assert converted == expected
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved


class TestOpenAIToLlama2Messages:
samples_1 = {
"id": "DUMMY",
"conversations": [
{
"role": "system",
"content": CHAT_SAMPLE["system"],
},
{
"role": "user",
"content": CHAT_SAMPLE["user"],
},
{
"role": "assistant",
"content": CHAT_SAMPLE["assistant"],
},
],
}

samples_2 = {
"id": "DUMMY",
"messages": [
{
"role": "system",
"content": CHAT_SAMPLE["system"],
},
{
"role": "user",
"content": CHAT_SAMPLE["user"],
},
{
"role": "assistant",
"content": CHAT_SAMPLE["assistant"],
},
],
}

def test_conversion_conversations_key(self):
converted_messages_1 = openai_to_llama2_messages(self.samples_1)
for converted, expected in zip(converted_messages_1, EXPECTED_MESSAGE):
assert converted == expected

def test_conversion_messages_key(self):
converted_messages_2 = openai_to_llama2_messages(self.samples_2)
for converted, expected in zip(converted_messages_2, EXPECTED_MESSAGE):
assert converted == expected

def test_conversion_conversations_key_train_on_input(self):
converted_messages_1 = openai_to_llama2_messages(
self.samples_1, train_on_input=True
)
for converted, expected in zip(
converted_messages_1, EXPECTED_MESSAGE_TRAIN_ON_INPUT
):
assert converted == expected

def test_conversion_messages_key_train_on_input(self):
converted_messages_2 = openai_to_llama2_messages(
self.samples_2, train_on_input=True
)
for converted, expected in zip(
converted_messages_2, EXPECTED_MESSAGE_TRAIN_ON_INPUT
):
assert converted == expected
6 changes: 5 additions & 1 deletion torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
MistralChatFormat,
)
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._converters import sharegpt_to_llama2_messages
from torchtune.data._converters import (
openai_to_llama2_messages,
sharegpt_to_llama2_messages,
)
from torchtune.data._instruct_templates import (
AlpacaInstructTemplate,
GrammarErrorCorrectionTemplate,
Expand All @@ -32,6 +35,7 @@
"Llama2ChatFormat",
"MistralChatFormat",
"ChatMLFormat",
"openai_to_llama2_messages",
"sharegpt_to_llama2_messages",
"truncate",
"Message",
Expand Down
57 changes: 57 additions & 0 deletions torchtune/data/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,60 @@ def sharegpt_to_llama2_messages(
masked = (role != "assistant") and (not train_on_input)
messages.append(Message(role=role, content=content, masked=masked))
return messages


def openai_to_llama2_messages(
sample: Mapping[str, Any],
train_on_input: bool = False,
) -> List[Message]:
"""
Convert a chat sample adhering to the OpenAI API standard chat format to the Llama2 chat format.
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved

OpenAI API `standard chat format <https://platform.openai.com/docs/guides/text-generation/chat-completions-api>`_ follows::
{
# key could be "messages" OR "conversations"
"messages": [
{
"role": <system|user|assistant>,
"content": <message>,
},
...
]
}

Llama2 follows::

[
{
"role": <system|user|assistant>,
"content": <message>,
},
...
]

Args:
sample (Mapping[str, Any]): a single data sample with "conversations" field pointing
to a list of dict messages.
train_on_input (bool): whether the prompt should remain unmasked. Default: False

Raises:
ValueError: If the sample does not contain "messages" or "conversations" key.

Returns:
xingyaoww marked this conversation as resolved.
Show resolved Hide resolved
List[Message]: A list of messages with "role" and "content" fields.
"""
if "messages" in sample:
messages_key = "messages"
elif "conversations" in sample:
messages_key = "conversations"
else:
raise ValueError(
f"Sample does not contain 'messages' or 'conversations' key. Existing keys: {sample.keys()}"
)
conversations = sample[messages_key]

messages = []
for message in conversations:
message["masked"] = (message["role"] != "assistant") and (not train_on_input)
messages.append(Message.from_dict(message))
return messages
3 changes: 3 additions & 0 deletions torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ChatFormat,
CROSS_ENTROPY_IGNORE_IDX,
Message,
openai_to_llama2_messages,
sharegpt_to_llama2_messages,
validate_messages,
)
Expand Down Expand Up @@ -159,6 +160,8 @@ def chat_dataset(
"""
if conversation_style == "sharegpt":
convert_to_messages = sharegpt_to_llama2_messages
elif conversation_style == "openai":
convert_to_messages = openai_to_llama2_messages
else:
raise ValueError(f"Unsupported conversation style: {conversation_style}")

Expand Down
Loading