Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update incorrect data processing in DataCollatorForChatML #2172

Merged
merged 31 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5c538fb
Update incorrect data processing in DataCollatorForChatML
ruijunfeng Oct 4, 2024
20b50e9
Merge pull request #1 from ruijunfeng/ruijunfeng-patch-for-DataCollat…
ruijunfeng Oct 4, 2024
aa50956
Merge branch 'main' into main
kashif Oct 7, 2024
d9db42a
Update trl/trainer/utils.py
qgallouedec Oct 7, 2024
996face
Merge branch 'main' into main
qgallouedec Oct 7, 2024
55e0155
style
qgallouedec Oct 7, 2024
78a1850
move comment
qgallouedec Oct 7, 2024
ef2453f
add test for DataCollatorForChatML
kashif Oct 8, 2024
d10acb1
update comment with more details
ruijunfeng Oct 8, 2024
afa84d3
update assert reports and comments, and adds verification that the la…
ruijunfeng Oct 8, 2024
2e20011
Merge branch 'main' into main
qgallouedec Oct 8, 2024
017220b
new line at the end of file for code quality
ruijunfeng Oct 8, 2024
57ebf5f
Update tests/test_utils.py
kashif Oct 8, 2024
2d2471b
Update tests/test_utils.py
kashif Oct 8, 2024
21fda01
Update tests/test_utils.py
kashif Oct 8, 2024
86119b1
update tests
qgallouedec Oct 8, 2024
5dd175f
fix test
kashif Oct 8, 2024
c09880e
Merge branch 'main' into main
kashif Oct 8, 2024
4c56e5a
Update tests/test_utils.py
kashif Oct 8, 2024
43aae62
Update tests/test_utils.py
kashif Oct 8, 2024
d6686e3
formatting
kashif Oct 8, 2024
ed88690
fix typo
kashif Oct 8, 2024
7e4006c
simplify
kashif Oct 8, 2024
49bc69d
Merge branch 'main' into main
kashif Oct 8, 2024
cb7f9be
Revert "simplify"
kashif Oct 8, 2024
4166ca0
Merge branch 'main' into main
kashif Oct 9, 2024
9c98fae
tokenize full messages
kashif Oct 9, 2024
7924cc2
dont add eos
kashif Oct 9, 2024
385bdb0
eos is in the last token
kashif Oct 9, 2024
de4ea96
simplify DataCollatorForChatML
kashif Oct 9, 2024
b4a2e97
Update tests/test_utils.py
kashif Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from transformers.utils import is_peft_available

from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad
from trl.trainer.utils import (
DataCollatorForChatML,
decode_and_strip_padding,
generate_model_card,
get_peft_config,
pad,
)


if is_peft_available():
Expand Down Expand Up @@ -169,3 +175,87 @@ def test_val_none(self):
assert "my_model" in card_text
assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
assert "My Trainer" in card_text


class TestDataCollatorForChatML(unittest.TestCase):
def setUp(self):
# Initialize the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer.pad_token = (
self.tokenizer.bos_token if self.tokenizer.pad_token is None else self.tokenizer.pad_token
)

# Define token IDs
self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1
self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2
self.assistant_output_token_id = 1565 # Token ID for "true", which is the last assistant's response in the example
kashif marked this conversation as resolved.
Show resolved Hide resolved
self.ignore_index = -100
self.max_length = 1024
self.messages_key = "messages"

# Example input
self.examples = [
{
self.messages_key: [
{
"role": "user",
"content": (
"Does the following code contain any security vulnerabilities? Return true or false.\n"
"char buffer[10];\nchar input[50];\nstrcpy(buffer, input);\n"
),
},
{"role": "assistant", "content": "true"},
]
}
]

# Initialize the data collator
self.collator = DataCollatorForChatML(
tokenizer=self.tokenizer,
max_length=self.max_length,
ignore_index=self.ignore_index,
messages_key=self.messages_key,
)

def test_data_collator_for_chatml(self):
# Process the data
data = self.collator(self.examples)

# Decode input_ids and labels for verification
input_ids = data["input_ids"][0].tolist()
kashif marked this conversation as resolved.
Show resolved Hide resolved
labels = data["labels"][0].tolist()

# Expected tokens
expected_bos = self.bos_token_id
expected_eos = self.eos_token_id
expected_assistant_token = self.assistant_output_token_id

# Verify that input_ids start with a BOS token and there are no extra ones
self.assertEqual(input_ids[0], expected_bos, "The first token of input_ids should be BOS token.")
self.assertNotEqual(
input_ids[1], expected_bos, "The second token of input_ids should not be BOS token (extra BOS)."
)

# Verify that the assistant's response token is present in input_ids
self.assertIn(expected_assistant_token, input_ids, "Assistant's response token should be in input_ids.")

# Verify that EOS token is at the end of input_ids
self.assertEqual(input_ids[-1], expected_eos, "The last token of input_ids should be EOS token.")

# Verify that the labels preserved the target string (last_assistant_response)
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)

# Find the start and end of the last assistant's response in the labels
response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index)
response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index)

actual_response = labels[response_start : response_end - 1]
self.assertEqual(
actual_response,
last_assistant_response_tokens,
"The labels should preserve the last assistant's response tokens.",
)

# Verify that EOS token is at the end of labels
self.assertEqual(labels[-1], expected_eos, "The last token of labels should be EOS token.")
18 changes: 15 additions & 3 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,29 @@ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
assistant_messages = [msg for msg in messages if msg["role"] == "assistant"]
last_assistant_message = assistant_messages[-1]["content"]
prompt = formatted_chat.rsplit(last_assistant_message, 1)[0]
completion = last_assistant_message
completion = last_assistant_message + formatted_chat.rsplit(last_assistant_message, 1)[1]

prompts.append(prompt)
completions.append(completion)

# Tokenize prompts and completions
tokenized_prompts = self.tokenizer(
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
prompts,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
# We assume the inputs are already wrapped with BOS&EOS tokens in tokenizer.apply_chat_template, so extra BOS/EOS tokens should not be added
add_special_tokens=False,
)
tokenized_completions = self.tokenizer(
completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
completions,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
# We assume the inputs are already wrapped with BOS&EOS tokens in tokenizer.apply_chat_template, so extra BOS/EOS tokens should not be added
add_special_tokens=False,
)

# Combine prompts and completions
Expand Down
Loading