Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add system prompt #1366

Merged
merged 7 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/torchtune/data/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,19 @@ def test_call_train_on_input(self, sample):
]
assert_dialogue_equal(actual["messages"], expected)

def test_system_prompt(self, sample):
transform = InputOutputToMessages(
column_map={"input": "maybe_input", "output": "maybe_output"},
new_system_prompt="you are a robot",
)
actual = transform(sample)
expected = [
Message(role="system", content="you are a robot", masked=True, eot=True),
Message(role="user", content="hello world", masked=True, eot=True),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["messages"], expected)

def test_raise_value_error_when_input_not_in_column_map(self):
with pytest.raises(ValueError, match="Expected a key of 'input'"):
InputOutputToMessages(
Expand Down Expand Up @@ -173,6 +186,29 @@ def test_call_train_on_input(self, sample):
]
assert_dialogue_equal(actual["rejected"], expected_rejected)

def test_system_prompt(self, sample):
transform = ChosenRejectedToMessages(
column_map={
"chosen": "maybe_chosen",
"rejected": "maybe_rejected",
},
new_system_prompt="you are a robot",
)
actual = transform(sample)
expected_chosen = [
Message(role="system", content="you are a robot", masked=True, eot=True),
Message(role="user", content="hello world", masked=True, eot=True),
Message(role="assistant", content="hello world", masked=False, eot=True),
]
assert_dialogue_equal(actual["chosen"], expected_chosen)

expected_rejected = [
Message(role="system", content="you are a robot", masked=True, eot=True),
Message(role="user", content="hello world", masked=True, eot=True),
Message(role="assistant", content="bye world", masked=False, eot=True),
]
assert_dialogue_equal(actual["rejected"], expected_rejected)

def test_raise_value_error_when_chosen_not_in_column_map(self):
with pytest.raises(ValueError, match="Expected a key of 'chosen'"):
ChosenRejectedToMessages(
Expand Down Expand Up @@ -216,6 +252,19 @@ def test_call_train_on_input(self):
converted_messages["messages"], MESSAGE_SAMPLE_TRAIN_ON_INPUT
)

def test_system_prompt(self):
transform = ShareGPTToMessages(new_system_prompt="you are a robot")
converted_messages = transform(self.samples)
assert_dialogue_equal(
converted_messages["messages"],
[
Message(
role="system", content="you are a robot", masked=True, eot=True
),
]
+ MESSAGE_SAMPLE[1:],
)

def test_raise_value_error_when_conversations_not_in_column_map(self):
with pytest.raises(ValueError, match="Expected a key of 'conversations'"):
ShareGPTToMessages(
Expand Down Expand Up @@ -253,6 +302,19 @@ def test_call_train_on_input(self):
converted_messages["messages"], MESSAGE_SAMPLE_TRAIN_ON_INPUT
)

def test_system_prompt(self):
transform = JSONToMessages(new_system_prompt="you are a robot")
converted_messages = transform(self.samples)
assert_dialogue_equal(
converted_messages["messages"],
[
Message(
role="system", content="you are a robot", masked=True, eot=True
),
]
+ MESSAGE_SAMPLE[1:],
)

def test_raise_value_error_when_messages_not_in_column_map(self):
with pytest.raises(ValueError, match="Expected a key of 'messages'"):
JSONToMessages(
Expand Down
92 changes: 83 additions & 9 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _validate_message(self) -> None:

class InputOutputToMessages(Transform):
"""
Message transform class that converts a sample with "input" and "output" fields,
Message transform class that converts a single sample with "input" and "output" fields,
(or equivalent fields specified in column_map) to user and assistant messages,
respectively. This is useful for datasets that have two columns, one containing
the user prompt and the other containing the model response.
Expand All @@ -129,16 +129,22 @@ class InputOutputToMessages(Transform):
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input"
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default "input" and "output" column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.

Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``.
"""

def __init__(
self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "input" not in column_map:
raise ValueError(
Expand Down Expand Up @@ -167,13 +173,19 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
eot=True,
),
]
if self.new_system_prompt is not None:
messages = [
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
] + messages
return {"messages": messages}


class ChosenRejectedToMessages(Transform):
"""
Transform for converting datasets with "chosen" and "rejected" columns containing
conversations to a list of chosen and rejected messages. For example::
Transform for converting a single sample from datasets with "chosen" and "rejected" columns
containing conversations to a list of chosen and rejected messages. For example::

| chosen | rejected |
|----------------------------------------|----------------------------------------|
Expand All @@ -193,22 +205,32 @@ class ChosenRejectedToMessages(Transform):
Message(role="assistant", content="A2"),
]

A single sample typically consists of a single optional system prompt and one or multiple
turns of user and assistant messages.

Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is False.
column_map (Optional[Dict[str, str]]): a mapping to change the expected
"chosen" and "rejected" column names to the actual column names in the dataset.
Default is None, keeping the default column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.

Raises:
ValueError: If ``column_map`` is provided and ``chosen`` not in ``column_map``, or
``rejected`` not in ``column_map``.
"""

def __init__(
self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "chosen" not in column_map:
raise ValueError(
Expand All @@ -225,26 +247,45 @@ def __init__(
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
chosen_messages = []
for message in sample[self._column_map["chosen"]]:
if message["role"] == "system" and self.new_system_prompt is not None:
continue
message["masked"] = (message["role"] != "assistant") and (
not self.train_on_input
)
chosen_messages.append(Message.from_dict(message))

rejected_messages = []
for message in sample[self._column_map["rejected"]]:
if message["role"] == "system" and self.new_system_prompt is not None:
continue
message["masked"] = (message["role"] != "assistant") and (
not self.train_on_input
)
rejected_messages.append(Message.from_dict(message))

if self.new_system_prompt is not None:
chosen_messages = [
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
] + chosen_messages
rejected_messages = [
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
] + rejected_messages

return {"chosen": chosen_messages, "rejected": rejected_messages}


class ShareGPTToMessages(Transform):
"""
Convert a chat sample adhering to the ShareGPT json structure to torchtune's :class:`~torchtune.data.Message`
Convert a single chat sample adhering to the ShareGPT json structure to torchtune's :class:`~torchtune.data.Message`
structure.

A single sample typically consists of a single optional system prompt and one or multiple
turns of user and assistant messages.

ShareGPT follows::

{
Expand Down Expand Up @@ -272,15 +313,22 @@ class ShareGPTToMessages(Transform):
column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("conversations")
to the new column names in the dataset. If None, assume these are identical.
Default is None.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.

Raises:
ValueError: If ``column_map`` is provided and ``conversations`` not in ``column_map``.
"""

def __init__(
self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "conversations" not in column_map:
raise ValueError(
Expand All @@ -303,8 +351,16 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
"""
role_map = {"system": "system", "human": "user", "gpt": "assistant"}
messages = []
if self.new_system_prompt is not None:
messages.append(
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
)
for message in sample[self._column_map["conversations"]]:
role = role_map[message["from"]]
if role == "system" and self.new_system_prompt is not None:
continue
content = message["value"]
masked = (role != "assistant") and (not self.train_on_input)
messages.append(Message(role=role, content=content, masked=masked))
Expand All @@ -314,9 +370,12 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:

class JSONToMessages(Transform):
"""
Convert a chat sample with identical json structure to torchtune's :class:`~torchtune.data.Message`
Convert a single chat sample with identical json structure to torchtune's :class:`~torchtune.data.Message`
structure. This transform simply creates Message dataclasses from the provided jsons.

A single sample typically consists of a single optional system prompt and one or multiple
turns of user and assistant messages.

For example::

{
Expand Down Expand Up @@ -344,15 +403,22 @@ class JSONToMessages(Transform):
column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("messages")
to the new column names in the dataset. If None, assume these are identical.
Default is None.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.

Raises:
ValueError: If ``column_map`` is provided and ``messages`` not in ``column_map``.
"""

def __init__(
self, train_on_input: bool = False, column_map: Optional[Dict[str, str]] = None
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
):
self.train_on_input = train_on_input
self.new_system_prompt = new_system_prompt
if column_map:
if "messages" not in column_map:
raise ValueError(
Expand All @@ -374,7 +440,15 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
List[Message]: A list of messages with "role" and "content" fields.
"""
updated_messages = []
if self.new_system_prompt is not None:
updated_messages.append(
Message(
role="system", content=self.new_system_prompt, masked=True, eot=True
)
)
for message in sample[self._column_map["messages"]]:
if message["role"] == "system" and self.new_system_prompt is not None:
continue
message["masked"] = (message["role"] != "assistant") and (
not self.train_on_input
)
Expand Down
8 changes: 7 additions & 1 deletion torchtune/datasets/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def grammar_dataset(
source: str = "liweili/c4_200m",
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
split: str = "train",
) -> Union[SFTDataset, PackedDataset]:
Expand Down Expand Up @@ -47,6 +48,9 @@ def grammar_dataset(
:class:`~torchtune.data.InputOutputToMessages` to the new column names in the dataset. If None, use
the default column names ``"input"`` and ``"output"``in ``liweili/c4_200m``.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
Expand All @@ -63,7 +67,9 @@ def grammar_dataset(
"""

message_transform = InputOutputToMessages(
train_on_input=train_on_input, column_map=column_map
train_on_input=train_on_input,
column_map=column_map,
new_system_prompt=new_system_prompt,
)
ds = SFTDataset(
source=source,
Expand Down
8 changes: 7 additions & 1 deletion torchtune/datasets/_hh_rlhf_helpful.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def hh_rlhf_helpful_dataset(
source: str = "RLHFlow/HH-RLHF-Helpful-standard",
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
split: str = "train",
) -> PreferenceDataset:
"""
Expand All @@ -35,6 +36,9 @@ def hh_rlhf_helpful_dataset(
column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template
to the new column names in the dataset. If None, assume these are identical.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
new_system_prompt (Optional[str]): if specified, prepend a system message to every sample for both chosen
and rejected. This can serve as instructions to guide the model response. Setting this will OVERRIDE
any system messages already present in the dataset. Default is None.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".

Expand All @@ -43,7 +47,9 @@ def hh_rlhf_helpful_dataset(
"""

message_transform = ChosenRejectedToMessages(
train_on_input=train_on_input, column_map=column_map
train_on_input=train_on_input,
column_map=column_map,
new_system_prompt=new_system_prompt,
)

return PreferenceDataset(
Expand Down
Loading
Loading