From d624ed33ceaacc121a81e85b05d78ad0893c5570 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Dec 2024 09:05:30 -0700 Subject: [PATCH] fix: Remove trailing \n from llama3 <|eot_id|> There's inconsistency in the documentation on whether or not there should be a \n after <|eot_id|>, but this maintains consistency with previous formatting Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- tests/test_chat_formatters.py | 33 +++++++++++---------------------- torchchat/generate.py | 2 +- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/tests/test_chat_formatters.py b/tests/test_chat_formatters.py index feae1c138..2f7f7a955 100644 --- a/tests/test_chat_formatters.py +++ b/tests/test_chat_formatters.py @@ -139,44 +139,33 @@ def test_llama2_chat_formatter(messages, expected): # single user message (no system prompt) (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -"""), +{USER1}<|eot_id|>"""), # sys, usr (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> -{SYSTEM_PROMPT}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -"""), +{USER1}<|eot_id|>"""), # sys, usr, asst (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> -{SYSTEM_PROMPT}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -<|start_header_id|>assistant<|end_header_id|> +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -{ASSISTANT1}<|eot_id|> -"""), +{ASSISTANT1}<|eot_id|>"""), # sys, usr, asst, usr, asst (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> -{SYSTEM_PROMPT}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -<|start_header_id|>assistant<|end_header_id|> +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -{ASSISTANT1}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER2}<|eot_id|> -<|start_header_id|>assistant<|end_header_id|> +{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -{ASSISTANT2}<|eot_id|> -"""), +{ASSISTANT2}<|eot_id|>"""), ] ) @pytest.mark.parametrize("add_generation_prompt", [True, False]) diff --git a/torchchat/generate.py b/torchchat/generate.py index 274a0cec8..4d2439d2f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -121,7 +121,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]: self.tokenizer.encode(content["text"], bos=False, eos=False) ) - tokens.append(self.tokenizer.special_tokens["<|eot_id|>\n"]) + tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) return tokens def encode_dialog_prompt(