diff --git a/tests/test_utils.py b/tests/test_utils.py index 4d26819058..87404070a8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -205,54 +205,74 @@ def setUp(self): ignore_index=self.ignore_index, ) - # See https://github.com/huggingface/trl/pull/2287#discussion_r1856594421 - @unittest.skip("This test must be updated.") def test_data_collator_for_chatml(self): # Process the data data = self.collator(self.examples) + # Verify basic shapes and types + self.assertIn("input_ids", data) + self.assertIn("attention_mask", data) + self.assertIn("labels", data) + self.assertIn("prompts", data) + self.assertIn("prompt_attention_mask", data) + # Decode input_ids and labels for verification input_ids = data["input_ids"][0].tolist() labels = data["labels"][0].tolist() prompt_only = data["prompts"][0].tolist() - # Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones - first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id) - self.assertEqual( - first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token." - ) - self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.") - - # Verify that the assistant's response token is present in input_ids and not in the prompt_only - last_assistant_response = self.examples[0][self.messages_key][-1]["content"] - last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False) - response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens) - self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.") + # Get the last assistant's response for comparison + last_message = self.examples[0][self.messages_key][-1] + self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant") + last_assistant_response = last_message["content"] - # Check if the last assistant's response tokens are not in prompt_only - response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens) - self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.") + # Verify that input_ids contain both prompt and response + decoded_input = self.tokenizer.decode(input_ids) + self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response") - # Verify that EOS token is at the end of input_ids - self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.") + # Verify that prompts only contain the conversation up to the last response + decoded_prompt = self.tokenizer.decode(prompt_only) + self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response") - # Verify that the labels preserved the target string (last_assistant_response) - last_assistant_response = self.examples[0][self.messages_key][-1]["content"] - last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False) + # Verify labels are -100 for non-assistant parts + prompt_length = len(prompt_only) + self.assertTrue( + all(label == self.ignore_index for label in labels[:prompt_length]), + "Labels should be ignore_index for prompt tokens", + ) - # Find the start and end of the last assistant's response in the labels - response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index) - response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index) + # Verify labels match assistant response after prompt + # Add a filter to remove any trailing tokens after the first <|im_end|> + last_assistant_response_with_end = last_assistant_response + self.tokenizer.eos_token + last_assistant_response_tokens = self.tokenizer.encode( + last_assistant_response_with_end, add_special_tokens=False + ) - actual_response = labels[response_start : response_end - 1] + response_labels = [] + for label in labels[prompt_length:]: + if label == self.ignore_index: + continue + response_labels.append(label) + if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"): + break self.assertEqual( - actual_response, + response_labels, last_assistant_response_tokens, - "The labels should preserve the last assistant's response tokens.", + "Labels should match assistant response tokens", ) - # Verify that EOS token is at the end of labels - self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.") + # Verify there isn't a generation prompt at the end + generation_prompt = "<|im_start|>assistant" + self.assertFalse( + decoded_input.strip().endswith(generation_prompt), + f"Input should not end with generation prompt '{generation_prompt}'", + ) + + self.assertEqual( + response_labels, + last_assistant_response_tokens, + "Labels should match assistant response tokens", + ) class TestBatchGeneration(unittest.TestCase): diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index d1cc3a0e9d..1122086ca9 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -274,7 +274,7 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: if "input_ids" not in example: message = example[self.messages_key] formatted_message = self.tokenizer.apply_chat_template( - message, tokenize=False, add_generation_prompt=True + message, tokenize=False, add_generation_prompt=False ) tokenized_message = self.tokenizer( formatted_message,