diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 232fb100e9c13c..975607e66a00a4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -98,6 +98,8 @@ title: Use model-specific APIs - local: custom_models title: Share a custom model + - local: chat_templating + title: Templates for chat models - local: sagemaker title: Run training on Amazon SageMaker - local: serialization diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md new file mode 100644 index 00000000000000..8a9c1e915d4eb6 --- /dev/null +++ b/docs/source/en/chat_templating.md @@ -0,0 +1,255 @@ + + +# Templates for Chat Models + +## Introduction + +An increasingly common use case for LLMs is **chat**. In a chat context, rather than continuing a single string +of text (as is the case with a standard language model), the model instead continues a conversation that consists +of one or more **messages**, each of which includes a **role** as well as message text. + +Most commonly, these roles are "user" for messages sent by the user, and "assistant" for messages sent by the model. +Some models also support a "system" role. System messages are usually sent at the beginning of the conversation +and include directives about how the model should behave in the subsequent chat. + +All language models, including models fine-tuned for chat, operate on linear sequences of tokens and do not intrinsically +have special handling for roles. This means that role information is usually injected by adding control tokens +between messages, to indicate both the message boundary and the relevant roles. + +Unfortunately, there isn't (yet!) a standard for which tokens to use, and so different models have been trained +with wildly different formatting and control tokens for chat. This can be a real problem for users - if you use the +wrong format, then the model will be confused by your input, and your performance will be a lot worse than it should be. +This is the problem that **chat templates** aim to resolve. + +Chat conversations are typically represented as a list of dictionaries, where each dictionary contains `role` +and `content` keys, and represents a single chat message. Chat templates are strings containing a Jinja template that +specifies how to format a conversation for a given model into a single tokenizable sequence. By storing this information +with the tokenizer, we can ensure that models get input data in the format they expect. + +Let's make this concrete with a quick example using the `BlenderBot` model. BlenderBot has an extremely simple default +template, which mostly just adds whitespace between rounds of dialogue: + +```python +>>> from transformers import AutoTokenizer +>>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + +>>> chat = [ +... {"role": "user", "content": "Hello, how are you?"}, +... {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, +... {"role": "user", "content": "I'd like to show off how chat templating works!"}, +... ] + +>>> tokenizer.apply_chat_template(chat, tokenize=False) +" Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!" +``` + +Notice how the entire chat is condensed into a single string. If we use `tokenize=True`, which is the default setting, +that string will also be tokenized for us. To see a more complex template in action, though, let's use the +`meta-llama/Llama-2-7b-chat-hf` model. Note that this model has gated access, so you will have to +[request access on the repo](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) if you want to run this code yourself: + +```python +>> from transformers import AutoTokenizer +>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + +>> chat = [ +... {"role": "user", "content": "Hello, how are you?"}, +... {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, +... {"role": "user", "content": "I'd like to show off how chat templating works!"}, +... ] + +>> tokenizer.use_default_system_prompt = False +>> tokenizer.apply_chat_template(chat, tokenize=False) +"[INST] Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" +``` + +Note that this time, the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of +user messages (but not assistant messages!) + +## How do chat templates work? + +The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the +default template for that model class is used instead. Let's take a look at the template for `BlenderBot`: + +```python + +>>> from transformers import AutoTokenizer +>>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + +>>> tokenizer.default_chat_template +"{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}" +``` + +That's kind of intimidating. Let's add some newlines and indentation to make it more readable. Note that +we remove the first newline after each block as well as any preceding whitespace before a block by default, using the +Jinja `trim_blocks` and `lstrip_blocks` flags. This means that you can write your templates with indentations and +newlines and still have them function correctly! + +``` +{% for message in messages %} + {% if message['role'] == 'user' %} + {{ ' ' }} + {% endif %} + {{ message['content'] }} + {% if not loop.last %} + {{ ' ' }} + {% endif %} +{% endfor %} +{{ eos_token }} +``` + +If you've never seen one of these before, this is a [Jinja template](https://jinja.palletsprojects.com/en/3.1.x/templates/). +Jinja is a templating language that allows you to write simple code that generates text. In many ways, the code and +syntax resembles Python. In pure Python, this template would look something like this: + +```python +for idx, message in enumerate(messages): + if message['role'] == 'user': + print(' ') + print(message['content']) + if not idx == len(messages) - 1: # Check for the last message in the conversation + print(' ') +print(eos_token) +``` + +Effectively, the template does three things: +1. For each message, if the message is a user message, add a blank space before it, otherwise print nothing. +2. Add the message content +3. If the message is not the last message, add two spaces after it. After the final message, print the EOS token. + +This is a pretty simple template - it doesn't add any control tokens, and it doesn't support "system" messages, which +are a common way to give the model directives about how it should behave in the subsequent conversation. +But Jinja gives you a lot of flexibility to do those things! Let's see a Jinja template that can format inputs +similarly to the way LLaMA formats them (note that the real LLaMA template includes handling for default system +messages and slightly different system message handling in general - don't use this one in your actual code!) + +``` +{% for message in messages %} + {% if message['role'] == 'user' %} + {{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }} + {% elif message['role'] == 'system' %} + {{ '<>\\n' + message['content'] + '\\n<>\\n\\n' }} + {% elif message['role'] == 'assistant' %} + {{ ' ' + message['content'] + ' ' + eos_token }} + {% endif %} +{% endfor %} +``` + +Hopefully if you stare at this for a little bit you can see what this template is doing - it adds specific tokens based +on the "role" of each message, which represents who sent it. User, assistant and system messages are clearly +distinguishable to the model because of the tokens they're wrapped in. + +## How do I create a chat template? + +Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an +existing template from another model and simply edit it for your needs! For example, we could take the LLaMA template +above and add "[ASST]" and "[/ASST]" to assistant messages: + +``` +{% for message in messages %} + {% if message['role'] == 'user' %} + {{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }} + {% elif message['role'] == 'system' %} + {{ '<>\\n' + message['content'].strip() + '\\n<>\\n\\n' }} + {% elif message['role'] == 'assistant' %} + {{ '[ASST] ' + message['content'] + ' [/ASST]' + eos_token }} + {% endif %} +{% endfor %} +``` + +Now, simply set the `tokenizer.chat_template` attribute. Next time you use [`~PreTrainedTokenizer.apply_chat_template`], it will +use your new template! This attribute will be saved in the `tokenizer_config.json` file, so you can use +[`~utils.PushToHubMixin.push_to_hub`] to upload your new template to the Hub and make sure everyone's using the right +template for your model! + +```python +template = tokenizer.chat_template +template = template.replace("SYS", "SYSTEM") # Change the system token +tokenizer.chat_template = template # Set the new template +tokenizer.push_to_hub("model_name") # Upload your new template to the Hub! +``` + +The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so +once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`]. + +## What are "default" templates? + +Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards +compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a +model does not have a chat template set, but there is a default template for its model class, the `ConversationPipeline` +class and methods like `apply_chat_template` will use the class template instead. You can find out what the default +template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute. + +This is something we do purely for backward compatibility reasons, to avoid breaking any existing workflows. Even when +the class template is appropriate for your model, we strongly recommend overriding the default template by +setting the `chat_template` attribute explicitly to make it clear to users that your model has been correctly configured +for chat, and to future-proof in case the default templates are ever altered or deprecated. + +## What template should I use? + +When setting the template for a model that's already been trained for chat, you should ensure that the template +exactly matches the message formatting that the model saw during training, or else you will probably experience +performance degradation. This is true even if you're training the model further - you will probably get the best +performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the +best performance for inference or fine-tuning when you precisely match the tokenization used during training. + +If you're training a model from scratch, or fine-tuning a base language model for chat, on the other hand, +you have a lot of freedom to choose an appropriate template! LLMs are smart enough to learn to handle lots of different +input formats. Our default template for models that don't have a class-specific template follows the +[ChatML format](https://github.com/openai/openai-python/blob/main/chatml.md), and this is a good, flexible choice for many use-cases. It looks like this: + +``` +{% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} +{% endfor %} +``` + +If you like this one, here it is in one-liner form, ready to copy into your code: + +``` +tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" +``` + +This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which +allows for flexibility in the roles you train with. The output looks like this: + +``` +<|im_start|>system +You are a helpful chatbot that will do its best not to say anything so stupid that people tweet about it.<|im_end|> +<|im_start|>user +How are you?<|im_end|> +<|im_start|>assistant +I'm doing great!<|im_end|> +``` + +The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense, +particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited +to these roles - templating is extremely flexible, and any string can be a role. + +## I want to use chat templates! How should I get started? + +If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using +[`~PreTrainedTokenizer.apply_chat_template`]. This applies even if you're not the model owner - if you're using a model +with an empty chat template, or one that's still using the default class template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to +the model repository so that this attribute can be set properly! + +Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that +model, which means it is also automatically supported in places like `ConversationPipeline`! + +By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of +open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long - +it's time to put an end to them! \ No newline at end of file diff --git a/docs/source/en/main_classes/tokenizer.md b/docs/source/en/main_classes/tokenizer.md index 251cbb43ea7203..71f96c55cb51c9 100644 --- a/docs/source/en/main_classes/tokenizer.md +++ b/docs/source/en/main_classes/tokenizer.md @@ -58,6 +58,7 @@ to a given token). - batch_decode - decode - encode + - apply_chat_template - push_to_hub - all @@ -71,6 +72,7 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers - batch_decode - decode - encode + - apply_chat_template - push_to_hub - all diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot.py b/src/transformers/models/blenderbot/tokenization_blenderbot.py index cb4a33a3c28bda..d6a70beb30a136 100644 --- a/src/transformers/models/blenderbot/tokenization_blenderbot.py +++ b/src/transformers/models/blenderbot/tokenization_blenderbot.py @@ -17,7 +17,7 @@ import json import os from functools import lru_cache -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple import regex as re @@ -25,9 +25,6 @@ from ...utils import logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - logger = logging.get_logger(__name__) @@ -413,19 +410,16 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: """ return token_ids_0 + [self.eos_token_id] - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - inputs = [] - for is_user, text in conversation.iter_texts(): - if is_user: - # We need to space prefix as it's being done within blenderbot - inputs.append(" " + text) - else: - # Generated responses should contain them already. - inputs.append(text) - - full_string = " ".join(inputs) - input_ids = self.encode(full_string) - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.") - return input_ids + @property + def default_chat_template(self): + """ + A very simple chat template that just adds whitespace between messages. + """ + return ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}" + "{{ message['content'] }}" + "{% if not loop.last %}{{ ' ' }}{% endif %}" + "{% endfor %}" + "{{ eos_token }}" + ) diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py index 4737e92617c70d..ebe39ed09f9a35 100644 --- a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py +++ b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py @@ -14,7 +14,7 @@ # limitations under the License. """Fast Tokenization class for Blenderbot.""" import json -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple from tokenizers import pre_tokenizers, processors @@ -24,9 +24,6 @@ from .tokenization_blenderbot import BlenderbotTokenizer -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - logger = logging.get_logger(__name__) @@ -297,19 +294,17 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: """ return token_ids_0 + [self.eos_token_id] - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - inputs = [] - for is_user, text in conversation.iter_texts(): - if is_user: - # We need to space prefix as it's being done within blenderbot - inputs.append(" " + text) - else: - # Generated responses should contain them already. - inputs.append(text) - - full_string = " ".join(inputs) - input_ids = self.encode(full_string) - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.") - return input_ids + @property + # Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template + def default_chat_template(self): + """ + A very simple chat template that just adds whitespace between messages. + """ + return ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}" + "{{ message['content'] }}" + "{% if not loop.last %}{{ ' ' }}{% endif %}" + "{% endfor %}" + "{{ eos_token }}" + ) diff --git a/src/transformers/models/bloom/tokenization_bloom_fast.py b/src/transformers/models/bloom/tokenization_bloom_fast.py index 8339ece5433bd3..47b78ac723f757 100644 --- a/src/transformers/models/bloom/tokenization_bloom_fast.py +++ b/src/transformers/models/bloom/tokenization_bloom_fast.py @@ -16,17 +16,13 @@ import pickle -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import Optional, Tuple from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"} @@ -166,12 +162,10 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = files = self._tokenizer.model.save(save_directory, name=filename_prefix) return tuple(files) - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """This corresponds to DialoGPT variants of models.""" - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template + def default_chat_template(self): + """ + A simple chat template that ignores role information and just concatenates messages with EOS tokens. + """ + return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}" diff --git a/src/transformers/models/code_llama/tokenization_code_llama.py b/src/transformers/models/code_llama/tokenization_code_llama.py index 0cf48b12077227..53a2d3577a1740 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama.py +++ b/src/transformers/models/code_llama/tokenization_code_llama.py @@ -17,7 +17,7 @@ """Tokenization classes for Code LLaMA.""" import os from shutil import copyfile -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import sentencepiece as spm @@ -26,9 +26,6 @@ from ...utils import logging, requires_backends -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} @@ -441,70 +438,57 @@ def create_token_type_ids_from_sequences( return output - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - r"""Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - ``` + @property + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation + The output should look something like: - >>> Conversation( - ... "<>\n Complete the functions without any documentation\n<>\n\n `def remove_non_ascii(s: str) -> str:`" - ... ) # doctest: +IGNORE_RESULT - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] """ - if self.use_default_system_prompt: - if len(conversation.past_user_inputs) > 0: - if ( - not conversation.past_user_inputs[0].startswith(B_SYS) - or E_SYS not in conversation.past_user_inputs[0] - ): - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif conversation.new_user_input: - if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: - conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input - else: - raise ValueError("Last message must be from user") - - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]] - ): - raise ValueError( - "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" - ) - dialog_tokens: List[int] = [] - dialog_tokens += sum( - [ - [self.bos_token_id] - + self.encode( - f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False - ) - + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2]) - ], - [], + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" ) - dialog_tokens += [self.bos_token_id] + self.encode( - f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False - ) - return dialog_tokens + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template def __getstate__(self): state = self.__dict__.copy() diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py index 030d473bfeb261..91a1896c3c10db 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py +++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py @@ -14,7 +14,7 @@ # limitations under the License. import os from shutil import copyfile -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple from tokenizers import normalizers, processors @@ -23,9 +23,6 @@ from ...utils.versions import require_version -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - require_version("tokenizers>=0.13.3") if is_sentencepiece_available(): @@ -344,6 +341,58 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = return (out_vocab_file,) + @property + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + """ + + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template + def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: @@ -371,69 +420,3 @@ def build_inputs_with_special_tokens( if token_ids_1 is None: return self.bos_token_id + token_ids_0 + self.eos_token_id return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id - - # Copied from transformers.models.code_llama.tokenization_code_llama.CodeLlamaTokenizer._build_conversation_input_ids - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - r"""Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - ``` - - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation - - >>> Conversation( - ... "<>\n Complete the functions without any documentation\n<>\n\n `def remove_non_ascii(s: str) -> str:`" - ... ) # doctest: +IGNORE_RESULT - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. - """ - if self.use_default_system_prompt: - if len(conversation.past_user_inputs) > 0: - if ( - not conversation.past_user_inputs[0].startswith(B_SYS) - or E_SYS not in conversation.past_user_inputs[0] - ): - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif conversation.new_user_input: - if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: - conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input - else: - raise ValueError("Last message must be from user") - - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]] - ): - raise ValueError( - "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" - ) - - dialog_tokens: List[int] = [] - dialog_tokens += sum( - [ - [self.bos_token_id] - + self.encode( - f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False - ) - + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2]) - ], - [], - ) - dialog_tokens += [self.bos_token_id] + self.encode( - f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False - ) - return dialog_tokens diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py index 8a778a947cfbca..880ed17d95ef28 100644 --- a/src/transformers/models/deberta/tokenization_deberta.py +++ b/src/transformers/models/deberta/tokenization_deberta.py @@ -16,7 +16,7 @@ import json import os -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple import regex as re @@ -24,9 +24,6 @@ from ...utils import logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} @@ -433,12 +430,3 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): text = " " + text return (text, kwargs) - - # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids diff --git a/src/transformers/models/deberta/tokenization_deberta_fast.py b/src/transformers/models/deberta/tokenization_deberta_fast.py index c05cf257611ebf..d77f0b39b98486 100644 --- a/src/transformers/models/deberta/tokenization_deberta_fast.py +++ b/src/transformers/models/deberta/tokenization_deberta_fast.py @@ -15,7 +15,7 @@ """ Fast Tokenization class for model DeBERTa.""" import json -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple from tokenizers import pre_tokenizers @@ -25,10 +25,6 @@ from .tokenization_deberta import DebertaTokenizer -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} @@ -288,14 +284,3 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: files = self._tokenizer.model.save(save_directory, name=filename_prefix) return tuple(files) - - # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._build_conversation_input_ids - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """This corresponds to DialoGPT variants of models.""" - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids diff --git a/src/transformers/models/gpt2/tokenization_gpt2.py b/src/transformers/models/gpt2/tokenization_gpt2.py index 9a8ce3a4fabd5a..278ff69032585c 100644 --- a/src/transformers/models/gpt2/tokenization_gpt2.py +++ b/src/transformers/models/gpt2/tokenization_gpt2.py @@ -18,7 +18,7 @@ import json import os from functools import lru_cache -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple import regex as re @@ -26,9 +26,6 @@ from ...utils import logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { @@ -354,10 +351,9 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): text = " " + text return (text, kwargs) - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + def default_chat_template(self): + """ + A simple chat template that ignores role information and just concatenates messages with EOS tokens. + """ + return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}" diff --git a/src/transformers/models/gpt2/tokenization_gpt2_fast.py b/src/transformers/models/gpt2/tokenization_gpt2_fast.py index cf2b8b2cb22c6a..189a3550840885 100644 --- a/src/transformers/models/gpt2/tokenization_gpt2_fast.py +++ b/src/transformers/models/gpt2/tokenization_gpt2_fast.py @@ -16,7 +16,7 @@ import json -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import Optional, Tuple from tokenizers import pre_tokenizers @@ -26,10 +26,6 @@ from .tokenization_gpt2 import GPT2Tokenizer -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} @@ -181,12 +177,10 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = files = self._tokenizer.model.save(save_directory, name=filename_prefix) return tuple(files) - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """This corresponds to DialoGPT variants of models.""" - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template + def default_chat_template(self): + """ + A simple chat template that ignores role information and just concatenates messages with EOS tokens. + """ + return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}" diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py index 570b2abaa49fde..f666b97efd2bd0 100644 --- a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py +++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py @@ -14,7 +14,7 @@ # limitations under the License. """Tokenization classes for GPTNeoX.""" import json -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import Optional, Tuple from tokenizers import pre_tokenizers @@ -22,10 +22,6 @@ from ...utils import logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} @@ -133,12 +129,10 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = files = self._tokenizer.model.save(save_directory, name=filename_prefix) return tuple(files) - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """This corresponds to DialoGPT variants of models.""" - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template + def default_chat_template(self): + """ + A simple chat template that ignores role information and just concatenates messages with EOS tokens. + """ + return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}" diff --git a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py index c9f4f677cb483e..6ac2f214a16568 100644 --- a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py @@ -17,7 +17,7 @@ import json import os import re -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import Optional, Tuple import numpy as np @@ -25,10 +25,6 @@ from ...utils import logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"} @@ -179,15 +175,14 @@ def convert_tokens_to_string(self, tokens): out_string = "".join(tokens).strip() return out_string - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """This corresponds to DialoGPT variants of models.""" - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + def default_chat_template(self): + """ + A simple chat template that just adds BOS/EOS tokens around messages while discarding role information. + """ + return ( + "{% for message in messages %}" "{{ bos_token + eos_token + message.content + eos_token }}" "{% endfor %}" + ) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: index = 0 diff --git a/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py b/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py index f592a2b63eabf5..4874ba732245f0 100644 --- a/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py +++ b/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py @@ -4,7 +4,7 @@ import re import unicodedata from shutil import copyfile -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import sentencepiece as spm @@ -16,10 +16,6 @@ import torch -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} @@ -319,31 +315,18 @@ def decode_fast(self, token_ids: Union[int, List[int]]) -> str: return self.sp_model.decode(token_ids) - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """Builds the input ids for a conversation. - - This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2]. - The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows: - - ``` - User: Jag tycker träd är finaBot: Kul att du tycker det!... - ``` - - Args: - conversation (`Conversation`): - Conversation to build input ids for. - - Returns: - `List[int]`: - Input ids for the conversation. - - References: - - [1] https://doi.org/10.48550/arXiv.2305.12987 - - [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct - - [3] https://github.com/openai/openai-python/blob/main/chatml.md + @property + def default_chat_template(self): + """ + This chat template formats messages like an instant messenger chat log, with "User:" and "Bot:" strings + preceding messages. BOS tokens are added between all messages. """ - all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()] - prompt = ( - f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:" + return ( + "{{ eos_token }}{{ bos_token }}" + "{% for message in messages %}" + "{% if message['role'] == 'user' %}{{ 'User: ' + message['content']}}" + "{% else %}{{ 'Bot: ' + message['content']}}{% endif %}" + "{{ message['text'] }}{{ bos_token }}" + "{% endfor %}" + "Bot:" ) - return self.encode(text=prompt) diff --git a/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py index a16e55ec7180d6..c567b6b6003fff 100644 --- a/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py @@ -17,7 +17,7 @@ import json import os import re -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np @@ -33,10 +33,6 @@ from ...utils import PaddingStrategy, logging -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"} @@ -258,16 +254,18 @@ def convert_tokens_to_string(self, tokens): text = "".join(words) return text - # Copied from tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizer._build_conversation_input_ids - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - """This corresponds to DialoGPT variants of models.""" - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + def default_chat_template(self): + """ + A simple chat template that adds standard BOS, SEP and EOS tokens between messages while discarding role + information. + """ + return ( + "{% for message in messages %}" + "{% if not loop.first %}{{ bos_token}}{% endif %}" + "{{ sep_token }}{{ message.content }} {{ eos_token }}" + "{% endfor %}" + ) # Copied from tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index f33771995d291c..8db2f9970e199a 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -31,7 +31,6 @@ if TYPE_CHECKING: - from ...pipelines.conversational import Conversation from ...tokenization_utils_base import TextInput logger = logging.get_logger(__name__) @@ -374,67 +373,53 @@ def create_token_type_ids_from_sequences( return output - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - r"""Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - ``` + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation + The output should look something like: - >>> Conversation( - ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?" - ... ) # doctest: +IGNORE_RESULT - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] """ - if self.use_default_system_prompt: - if len(conversation.past_user_inputs) > 0: - if ( - not conversation.past_user_inputs[0].startswith(B_SYS) - or E_SYS not in conversation.past_user_inputs[0] - ): - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif conversation.new_user_input: - if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: - conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input - else: - raise ValueError("Last message must be from user") - - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]] - ): - raise ValueError( - "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" - ) - dialog_tokens: List[int] = [] - dialog_tokens += sum( - [ - [self.bos_token_id] - + self.encode( - f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False - ) - + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2]) - ], - [], - ) - dialog_tokens += [self.bos_token_id] + self.encode( - f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" ) - return dialog_tokens + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index 09c81417672e80..282a0f06740eaa 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -14,7 +14,7 @@ # limitations under the License. import os from shutil import copyfile -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple from tokenizers import processors @@ -23,9 +23,6 @@ from ...utils.versions import require_version -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - require_version("tokenizers>=0.13.3") if is_sentencepiece_available(): @@ -192,67 +189,54 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = return (out_vocab_file,) - def _build_conversation_input_ids(self, conversation: "Conversation"): - """Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer + @property + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer [INST] Prompt [/INST] - ``` - - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation - - >>> Conversation( - ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?" - ... ) - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. """ - if self.use_default_system_prompt: - if len(conversation.past_user_inputs) > 0: - if ( - not conversation.past_user_inputs[0].startswith(B_SYS) - or E_SYS not in conversation.past_user_inputs[0] - ): - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif conversation.new_user_input: - if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: - conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input - else: - raise ValueError("Last message must be from user") - - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]] - ): - raise ValueError( - "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)" - ) - dialog_tokens = [] - dialog_tokens += sum( - [ - [self.bos_token_id] - + self.encode( - f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False - ) - + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2]) - ], - [], + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" ) - dialog_tokens += [self.bos_token_id] + self.encode( - f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False - ) - return dialog_tokens + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index b3eccddf410355..a22521b4e00dfb 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -16,7 +16,7 @@ import json import os from functools import lru_cache -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import regex as re @@ -26,10 +26,6 @@ from .english_normalizer import EnglishTextNormalizer -if TYPE_CHECKING: - from ...pipelines.conversational import Conversation - - VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", "tokenizer_file": "tokenizer.json", @@ -751,14 +747,13 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): text = " " + text return (text, kwargs) - # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template + def default_chat_template(self): + """ + A simple chat template that ignores role information and just concatenates messages with EOS tokens. + """ + return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}" def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 4ad500bbf1c0ef..cb321f669c7ca6 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -16,7 +16,7 @@ import json import os from functools import lru_cache -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np from tokenizers import pre_tokenizers, processors @@ -28,10 +28,6 @@ from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr -if TYPE_CHECKING: - from ...pipelines.conversational import Conversation - - logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = { @@ -520,14 +516,13 @@ def get_special_tokens_mask( return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones - # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids - def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: - input_ids = [] - for is_user, text in conversation.iter_texts(): - input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) - if len(input_ids) > self.model_max_length: - input_ids = input_ids[-self.model_max_length :] - return input_ids + @property + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template + def default_chat_template(self): + """ + A simple chat template that ignores role information and just concatenates messages with EOS tokens. + """ + return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}" # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_decoder_prompt_ids def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 93d056c88d44ef..c455c75574435b 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -1,5 +1,5 @@ import uuid -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -19,137 +19,153 @@ class Conversation: """ Utility class containing a conversation and its history. This class is meant to be used as an input to the [`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user - inputs and generated model responses. A conversation needs to contain an unprocessed user input before being passed - to the [`ConversationalPipeline`]. This user input is either created when the class is instantiated, or by calling - `conversational_pipeline.append_response("input")` after a conversation turn. + inputs and generated model responses. Arguments: - text (`str`, *optional*): - The initial user input to start the conversation. If not provided, a user input needs to be provided - manually using the [`~Conversation.add_user_input`] method before the conversation can begin. + messages (Union[str, List[Dict[str, str]]], *optional*): + The initial messages to start the conversation, either a string, or a list of dicts containing "role" and + "content" keys. If a string is passed, it is interpreted as a single message with the "user" role. conversation_id (`uuid.UUID`, *optional*): Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the conversation. - past_user_inputs (`List[str]`, *optional*): - Eventual past history of the conversation of the user. You don't need to pass it manually if you use the - pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and - `generated_responses` with equal length lists of strings - generated_responses (`List[str]`, *optional*): - Eventual past history of the conversation of the model. You don't need to pass it manually if you use the - pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and - `generated_responses` with equal length lists of strings Usage: ```python conversation = Conversation("Going to the movies tonight - any suggestions?") - - # Steps usually performed by the model when generating a response: - # 1. Mark the user input as processed (moved to the history) - conversation.mark_processed() - # 2. Append a mode response - conversation.append_response("The Big lebowski.") - - conversation.add_user_input("Is it good?") + conversation.add_message({"role": "assistant", "content": "The Big lebowski."}) + conversation.add_message({"role": "user", "content": "Is it good?"}) ```""" def __init__( - self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None + self, messages: Union[str, List[Dict[str, str]]] = None, conversation_id: uuid.UUID = None, **deprecated_kwargs ): if not conversation_id: conversation_id = uuid.uuid4() - if past_user_inputs is None: - past_user_inputs = [] - if generated_responses is None: - generated_responses = [] - self.uuid: uuid.UUID = conversation_id - self.past_user_inputs: List[str] = past_user_inputs - self.generated_responses: List[str] = generated_responses - self.new_user_input: Optional[str] = text + if messages is None: + text = deprecated_kwargs.pop("text", None) + if text is not None: + messages = [{"role": "user", "content": text}] + else: + messages = [] + elif isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + # This block deals with the legacy args - new code should just totally + # avoid past_user_inputs and generated_responses + generated_responses = deprecated_kwargs.pop("generated_responses", None) + past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None) + if generated_responses is not None and past_user_inputs is None: + raise ValueError("generated_responses cannot be passed without past_user_inputs!") + if past_user_inputs is not None: + legacy_messages = [] + if generated_responses is None: + generated_responses = [] + # We structure it this way instead of using zip() because the lengths may differ by 1 + for i in range(max([len(past_user_inputs), len(generated_responses)])): + if i < len(past_user_inputs): + legacy_messages.append({"role": "user", "content": past_user_inputs[i]}) + if i < len(generated_responses): + legacy_messages.append({"role": "assistant", "content": generated_responses[i]}) + messages = legacy_messages + messages + + self.uuid = conversation_id + self.messages = messages def __eq__(self, other): if not isinstance(other, Conversation): return False - if self.uuid == other.uuid: - return True - return ( - self.new_user_input == other.new_user_input - and self.past_user_inputs == other.past_user_inputs - and self.generated_responses == other.generated_responses - ) + return self.uuid == other.uuid or self.messages == other.messages + + def add_message(self, message: Dict[str, str]): + if not set(message.keys()) == {"role", "content"}: + raise ValueError("Message should contain only 'role' and 'content' keys!") + if message["role"] not in ("user", "assistant", "system"): + raise ValueError("Only 'user', 'assistant' and 'system' roles are supported for now!") + self.messages.append(message) def add_user_input(self, text: str, overwrite: bool = False): """ - Add a user input to the conversation for the next round. This populates the internal `new_user_input` field. - - Args: - text (`str`): The user input for the next conversation round. - overwrite (`bool`, *optional*, defaults to `False`): - Whether or not existing and unprocessed user input should be overwritten when this function is called. + Add a user input to the conversation for the next round. This is a legacy method that assumes that inputs must + alternate user/assistant/user/assistant, and so will not add multiple user messages in succession. We recommend + just using `add_message` with role "user" instead. """ - if self.new_user_input: + if len(self) > 0 and self[-1]["role"] == "user": if overwrite: logger.warning( - f'User input added while unprocessed input was existing: "{self.new_user_input}" was overwritten ' + f'User input added while unprocessed input was existing: "{self[-1]["content"]}" was overwritten ' f'with: "{text}".' ) - self.new_user_input = text + self[-1]["content"] = text else: logger.warning( - f'User input added while unprocessed input was existing: "{self.new_user_input}" new input ' + f'User input added while unprocessed input was existing: "{self[-1]["content"]}" new input ' f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input' ) else: - self.new_user_input = text + self.messages.append({"role": "user", "content": text}) - def mark_processed(self): + def append_response(self, response: str): """ - Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties - the `new_user_input` field. + This is a legacy method. We recommend just using `add_message` with an appropriate role instead. """ - if self.new_user_input: - self.past_user_inputs.append(self.new_user_input) - self.new_user_input = None + self.messages.append({"role": "assistant", "content": response}) - def append_response(self, response: str): + def mark_processed(self): """ - Append a response to the list of generated responses. - - Args: - response (`str`): The model generated response. + This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between + processed and unprocessed user input. """ - self.generated_responses.append(response) + pass - def iter_texts(self): - """ - Iterates over all blobs of the conversation. + def __iter__(self): + for message in self.messages: + yield message - Returns: Iterator of (is_user, text_chunk) in chronological order of the conversation. `is_user` is a `bool`, - `text_chunks` is a `str`. - """ - for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): - yield True, user_input - yield False, generated_response - if self.new_user_input: - yield True, self.new_user_input + def __getitem__(self, item): + return self.messages[item] + + def __setitem__(self, key, value): + self.messages[key] = value + + def __len__(self): + return len(self.messages) def __repr__(self): """ Generates a string representation of the conversation. - Return: + Returns: `str`: - Example: Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user >> Going to the movies tonight - any - suggestions? bot >> The Big Lebowski + Example: + Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user: Going to the movies tonight - any suggestions? + bot: The Big Lebowski """ - output = f"Conversation id: {self.uuid} \n" - for is_user, text in self.iter_texts(): - name = "user" if is_user else "bot" - output += f"{name} >> {text} \n" + output = f"Conversation id: {self.uuid}\n" + for message in self.messages: + output += f"{message['role']}: {message['content']}\n" return output + def iter_texts(self): + # This is a legacy method for backwards compatibility. It is recommended to just directly access + # conversation.messages instead. + for message in self.messages: + yield message["role"] == "user", message["content"] + + @property + def past_user_inputs(self): + # This is a legacy property for backwards compatibility. It is recommended to just directly access + # conversation.messages instead. + return [message["content"] for message in self.messages if message["role"] == "user"] + + @property + def generated_responses(self): + # This is a legacy property for backwards compatibility. It is recommended to just directly access + # conversation.messages instead. + return [message["content"] for message in self.messages if message["role"] == "assistant"] + @add_end_docstrings( PIPELINE_INIT_ARGS, @@ -246,18 +262,7 @@ def __call__(self, conversations: Union[Conversation, List[Conversation]], num_w return outputs def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]: - if not isinstance(conversation, Conversation): - raise ValueError("ConversationalPipeline, expects Conversation as inputs") - if conversation.new_user_input is None: - raise ValueError( - f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. " - "Add user inputs with the conversation's `add_user_input` method" - ) - if hasattr(self.tokenizer, "_build_conversation_input_ids"): - input_ids = self.tokenizer._build_conversation_input_ids(conversation) - else: - # If the tokenizer cannot handle conversations, we default to only the old version - input_ids = self._legacy_parse_and_tokenize(conversation) + input_ids = self.tokenizer.apply_chat_template(conversation) if self.framework == "pt": input_ids = torch.LongTensor([input_ids]) @@ -292,19 +297,5 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=True): clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) conversation = model_outputs["conversation"] - conversation.mark_processed() - conversation.append_response(answer) + conversation.add_message({"role": "assistant", "content": answer}) return conversation - - def _legacy_parse_and_tokenize(self, conversation: Conversation) -> Dict: - eos_token_id = self.tokenizer.eos_token_id - input_ids = [] - for is_user, text in conversation.iter_texts(): - if eos_token_id is not None: - input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id]) - else: - input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False)) - - if len(input_ids) > self.tokenizer.model_max_length: - input_ids = input_ids[-self.tokenizer.model_max_length :] - return input_ids diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index a7e36322a6b099..ccdc7eb6dd7e21 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -64,6 +64,7 @@ is_ftfy_available, is_ipex_available, is_jieba_available, + is_jinja_available, is_jumanpp_available, is_keras_nlp_available, is_librosa_available, @@ -336,6 +337,13 @@ def require_jieba(test_case): return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case) +def require_jinja(test_case): + """ + Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed. + """ + return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case) + + def require_tf2onnx(test_case): return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2390ed478f30c4..a65f799a724b13 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -27,6 +27,7 @@ from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass, field +from functools import lru_cache from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -69,6 +70,7 @@ import tensorflow as tf if is_flax_available(): import jax.numpy as jnp # noqa: F401 + from .pipelines.conversational import Conversation if is_tokenizers_available(): @@ -1426,6 +1428,7 @@ def all_special_ids(self) -> List[int]: - **length** -- The length of the inputs (when `return_length=True`) """ + INIT_TOKENIZER_DOCSTRING = r""" Class attributes (overridden by derived classes) @@ -1461,6 +1464,9 @@ def all_special_ids(self) -> List[int]: truncation_side (`str`, *optional*): The side on which the model should have truncation applied. Should be selected between ['right', 'left']. Default value is picked from the class attribute of the same name. + chat_template (`str`, *optional*): + A Jinja template string that will be used to format lists of chat messages. See + https://huggingface.co/docs/transformers/chat_templating for a full description. model_input_names (`List[string]`, *optional*): The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or `"attention_mask"`). Default value is picked from the class attribute of the same name. @@ -1558,6 +1564,10 @@ def __init__(self, **kwargs): {} ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). self._in_target_context_manager = False + + # Stores a Jinja template that formats chat histories into tokenizable strings + self.chat_template = kwargs.pop("chat_template", None) + super().__init__(**kwargs) @property @@ -1627,6 +1637,109 @@ def get_vocab(self) -> Dict[str, int]: """ raise NotImplementedError() + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]], "Conversation"], + chat_template: Optional[str] = None, + tokenize: bool = True, + padding: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **tokenizer_kwargs, + ) -> Union[str, List[int]]: + """ + Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token + ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to + determine the format and control tokens to use when converting. When chat_template is None, it will fall back + to the default_chat_template specified at the class level. + + Args: + conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts + with "role" and "content" keys, representing the chat history so far. + chat_template (str, *optional*): A Jinja template to use for this conversion. If + this is not passed, the model's default chat template will be used instead. + tokenize (`bool`, defaults to `True`): + Whether to tokenize the output. If `False`, the output will be a string. + padding (`bool`, defaults to `False`): + Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`. + truncation (`bool`, defaults to `False`): + Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`. + max_length (`int`, *optional*): + Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If + not specified, the tokenizer's `max_length` attribute will be used as a default. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable + values are: + - `'tf'`: Return TensorFlow `tf.Tensor` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. + + Returns: + `List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This + output is ready to pass to the model, either directly or via methods like `generate()`. + """ + + if hasattr(conversation, "messages"): + # Indicates it's a Conversation object + conversation = conversation.messages + + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template` + if chat_template is None: + if self.chat_template is not None: + chat_template = self.chat_template + else: + chat_template = self.default_chat_template + + # Compilation function uses a cache to avoid recompiling the same template + compiled_template = self._compile_jinja_template(chat_template) + + rendered = compiled_template.render(messages=conversation, **self.special_tokens_map) + + if padding is True: + padding = "max_length" # There's only one sequence here, so "longest" makes no sense + if tokenize: + return self.encode( + rendered, + add_special_tokens=False, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + else: + return rendered + + @lru_cache + def _compile_jinja_template(self, chat_template): + try: + from jinja2.exceptions import TemplateError + from jinja2.sandbox import ImmutableSandboxedEnvironment + except ImportError: + raise ImportError("apply_chat_template requires jinja2 to be installed.") + + def raise_exception(message): + raise TemplateError(message) + + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) + jinja_env.globals["raise_exception"] = raise_exception + return jinja_env.from_string(chat_template) + + @property + def default_chat_template(self): + """ + This template formats inputs in the standard ChatML format. See + https://github.com/openai/openai-python/blob/main/chatml.md + """ + return ( + "{% for message in messages %}" + "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" + "{% endfor %}" + ) + @classmethod def from_pretrained( cls, @@ -2187,6 +2300,9 @@ def save_pretrained( if hasattr(self, k): tokenizer_config[k] = getattr(self, k) + if self.chat_template is not None: + tokenizer_config["chat_template"] = self.chat_template + if len(self.init_inputs) > 0: tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) for file_id in self.vocab_files_names.keys(): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 68c39c732e3c35..ac9beb7856187b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -119,6 +119,7 @@ is_in_notebook, is_ipex_available, is_jieba_available, + is_jinja_available, is_jumanpp_available, is_kenlm_available, is_keras_nlp_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 6829ca9ad67e1a..aeb351db96e22b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -91,6 +91,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _ftfy_available = _is_package_available("ftfy") _ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) _jieba_available = _is_package_available("jieba") +_jinja_available = _is_package_available("jinja2") _kenlm_available = _is_package_available("kenlm") _keras_nlp_available = _is_package_available("keras_nlp") _librosa_available = _is_package_available("librosa") @@ -793,6 +794,10 @@ def is_jieba_available(): return _jieba_available +def is_jinja_available(): + return _jinja_available + + # docstyle-ignore DATASETS_IMPORT_ERROR = """ {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with: @@ -1081,6 +1086,11 @@ def is_jieba_available(): peft`. Please note that you may need to restart your runtime after installation. """ +JINJA_IMPORT_ERROR = """ +{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install +jinja2`. Please note that you may need to restart your runtime after installation. +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -1118,6 +1128,7 @@ def is_jieba_available(): ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)), ] ) diff --git a/tests/models/blenderbot/test_tokenization_blenderbot.py b/tests/models/blenderbot/test_tokenization_blenderbot.py index 3a9c95027b3221..7fbf2b7603f95e 100644 --- a/tests/models/blenderbot/test_tokenization_blenderbot.py +++ b/tests/models/blenderbot/test_tokenization_blenderbot.py @@ -17,6 +17,7 @@ import unittest from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast +from transformers.testing_utils import require_jinja from transformers.utils import cached_property @@ -50,3 +51,24 @@ def test_3B_tokenization_same_as_parlai(self): def test_3B_tokenization_same_as_parlai_rust_tokenizer(self): assert self.rust_tokenizer_3b.add_prefix_space assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]] + + @require_jinja + def test_tokenization_for_chat(self): + tok = self.tokenizer_3b + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], + ] + tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats] + expected_tokens = [ + [553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2], + [553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2], + [3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2], + ] + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) diff --git a/tests/models/bloom/test_tokenization_bloom.py b/tests/models/bloom/test_tokenization_bloom.py index 576a191c70b5f1..7383eeb668face 100644 --- a/tests/models/bloom/test_tokenization_bloom.py +++ b/tests/models/bloom/test_tokenization_bloom.py @@ -18,7 +18,7 @@ from datasets import load_dataset from transformers import BloomTokenizerFast -from transformers.testing_utils import require_tokenizers +from transformers.testing_utils import require_jinja, require_tokenizers from ...test_tokenization_common import TokenizerTesterMixin @@ -134,6 +134,27 @@ def test_pretrained_model_lists(self): self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1) self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1) + @require_jinja + def test_tokenization_for_chat(self): + tokenizer = self.get_rust_tokenizer() + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], + ] + tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + expected_tokens = [ + [5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2], + [5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2], + [229126, 427, 11890, 1152, 17, 2, 59414, 4, 2], + ] + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) + def test_add_prefix_space_fast(self): tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True) tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False) diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index 0dd33e776d497e..cceb3b9238b20f 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -20,7 +20,7 @@ from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES -from transformers.testing_utils import require_tokenizers +from transformers.testing_utils import require_jinja, require_tokenizers from ...test_tokenization_common import TokenizerTesterMixin @@ -275,6 +275,27 @@ def test_special_tokens_mask_input_pairs_and_bos_token(self): filtered_sequence = [x for x in filtered_sequence if x is not None] self.assertEqual(encoded_sequence, filtered_sequence) + @require_jinja + def test_tokenization_for_chat(self): + tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname) + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], + ] + tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + # fmt: off + expected_tokens = [[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20], + [20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20, 20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20], + [20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20, 20, 3, 0, 0, 1, 20, 20]] + # fmt: on + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) + @require_tokenizers class OPTTokenizationTest(unittest.TestCase): diff --git a/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py b/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py index b030996e89dcf0..d639c33ef6440b 100644 --- a/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py +++ b/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py @@ -16,7 +16,7 @@ import unittest from transformers import GPTSw3Tokenizer -from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow +from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow from ...test_tokenization_common import TokenizerTesterMixin @@ -128,3 +128,27 @@ def test_tokenizer_integration(self): model_name="AI-Sweden/gpt-sw3-126m", sequences=sequences, ) + + @require_jinja + def test_tokenization_for_chat(self): + tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB) + # This is in English, but it's just here to make sure the chat control tokens are being added properly + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], + ] + tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + # fmt: off + expected_tokens = [ + [268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ], + [268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 938, 541, 419, ], + [268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ] + ] + # fmt: on + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) diff --git a/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py b/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py index 4352f6425f0d32..489e4f942664e5 100644 --- a/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py +++ b/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py @@ -22,7 +22,7 @@ VOCAB_FILES_NAMES, GPTSanJapaneseTokenizer, ) -from transformers.testing_utils import require_tokenizers, slow +from transformers.testing_utils import require_jinja, require_tokenizers, slow from ...test_tokenization_common import TokenizerTesterMixin @@ -193,3 +193,27 @@ def test_conversion_reversible(self): def test_padding_different_model_input_name(self): # tokenizer has no padding token pass + + @require_jinja + def test_tokenization_for_chat(self): + tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese") + # This is in English, but it's just here to make sure the chat control tokens are being added properly + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], + ] + tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + # fmt: off + expected_tokens = [ + [35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999], + [35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999, 35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999], + [35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999], + ] + # fmt: on + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 942cceaf7cd0d4..ca90eb9641ea80 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -2486,3 +2486,7 @@ def test_layoutlmv2_integration_test(self): @unittest.skip("Doesn't support another framework than PyTorch") def test_np_encode_plus_sent_to_model(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template(self): + pass diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index 58092834e5a160..59efc4b1cf3ba1 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -2439,3 +2439,7 @@ def test_tf_encode_plus_sent_to_model(self): # This should not fail model(encoded_sequence) model(batch_encoded_sequence) + + @unittest.skip("Chat is not supported") + def test_chat_template(self): + pass diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index f7f8329706dff2..0b502748d13123 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -1958,3 +1958,7 @@ def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): @unittest.skip("Doesn't use SentencePiece") def test_sentencepiece_tokenize_and_decode(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template(self): + pass diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index 9223e626f997fe..231474203032b1 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -32,6 +32,7 @@ from transformers.testing_utils import ( get_tests_dir, nested_simplify, + require_jinja, require_sentencepiece, require_tokenizers, require_torch, @@ -574,6 +575,32 @@ def test_some_edge_cases(self): # a dummy prefix space is not added by the sp_model as it was de-activated self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str)) + @require_jinja + def test_tokenization_for_chat(self): + tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False) + + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "user", "content": "Hello!"}], + ] + # Matt: The third test case tests the default system message, but if this is ever changed in the + # class/repo code then that test will fail, and the case will need to be updated. + tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + # fmt: off + expected_tokens = [ + [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962], + [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2], + [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962] + ] + # fmt: on + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) + @require_sentencepiece @require_tokenizers diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index 73979b255e08db..331f63a94a5818 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -2311,3 +2311,7 @@ def test_padding_warning_message_fast_tokenizer(self): "Dummy warning", cm.records[0].message, ) + + @unittest.skip("Chat is not supported") + def test_chat_template(self): + pass diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index b4cca18162d806..9d82c468aa3091 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -1274,3 +1274,7 @@ def test_pretrained_model_lists(self): @unittest.skip("Doesn't support another framework than PyTorch") def test_np_encode_plus_sent_to_model(self): pass + + @unittest.skip("Chat is not supported") + def test_chat_template(self): + pass diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 9ab29d29d1de12..ef58768e22776b 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -16,7 +16,7 @@ from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence -from transformers.testing_utils import slow +from transformers.testing_utils import require_jinja, slow from ...test_tokenization_common import TokenizerTesterMixin @@ -473,3 +473,25 @@ def test_offset_decoding(self): output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"] self.assertEqual(output, []) + + @require_jinja + def test_tokenization_for_chat(self): + multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny") + # This is in English, but it's just here to make sure the chat control tokens are being added properly + test_chats = [ + [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], + [ + {"role": "system", "content": "You are a helpful chatbot."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Nice to meet you."}, + ], + [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], + ] + tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] + expected_tokens = [ + [3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257], + [3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257], + [37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257], + ] + for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): + self.assertListEqual(tokenized_chat, expected_tokens) diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py index efb2215f491005..dfc42ea4815cce 100644 --- a/tests/pipelines/test_pipelines_conversational.py +++ b/tests/pipelines/test_pipelines_conversational.py @@ -78,17 +78,23 @@ def get_test_pipeline(self, model, tokenizer, processor): def run_pipeline_test(self, conversation_agent, _): # Simple outputs = conversation_agent(Conversation("Hi there!")) - self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)])) + self.assertEqual( + outputs, + Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), + ) # Single list outputs = conversation_agent([Conversation("Hi there!")]) - self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)])) + self.assertEqual( + outputs, + Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), + ) # Batch conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_2 = Conversation("What's the last book you have read?") - self.assertEqual(len(conversation_1.past_user_inputs), 0) - self.assertEqual(len(conversation_2.past_user_inputs), 0) + self.assertEqual(len(conversation_1), 1) + self.assertEqual(len(conversation_2), 1) outputs = conversation_agent([conversation_1, conversation_2]) self.assertEqual(outputs, [conversation_1, conversation_2]) @@ -96,32 +102,35 @@ def run_pipeline_test(self, conversation_agent, _): outputs, [ Conversation( - past_user_inputs=["Going to the movies tonight - any suggestions?"], - generated_responses=[ANY(str)], + [ + {"role": "user", "content": "Going to the movies tonight - any suggestions?"}, + {"role": "assistant", "content": ANY(str)}, + ], + ), + Conversation( + [ + {"role": "user", "content": "What's the last book you have read?"}, + {"role": "assistant", "content": ANY(str)}, + ] ), - Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]), ], ) # One conversation with history - conversation_2.add_user_input("Why do you recommend it?") + conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"}) outputs = conversation_agent(conversation_2) self.assertEqual(outputs, conversation_2) self.assertEqual( outputs, Conversation( - past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"], - generated_responses=[ANY(str), ANY(str)], + [ + {"role": "user", "content": "What's the last book you have read?"}, + {"role": "assistant", "content": ANY(str)}, + {"role": "user", "content": "Why do you recommend it?"}, + {"role": "assistant", "content": ANY(str)}, + ] ), ) - with self.assertRaises(ValueError): - conversation_agent("Hi there!") - with self.assertRaises(ValueError): - conversation_agent(Conversation()) - # Conversation have been consumed and are not valid anymore - # Inactive conversations passed to the pipeline raise a ValueError - with self.assertRaises(ValueError): - conversation_agent(conversation_2) @require_torch @slow diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index aec5e493c57c00..fa3bf96d431a8a 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -50,6 +50,7 @@ check_json_file_has_correct_format, get_tests_dir, is_pt_tf_cross_test, + require_jinja, require_tf, require_tokenizers, require_torch, @@ -1052,6 +1053,40 @@ def test_sequence_ids(self): if tokenizer.num_special_tokens_to_add(pair=True): self.assertIn(None, output.sequence_ids()) + @require_jinja + def test_chat_template(self): + dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}" + dummy_conversation = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ] + expected_output = "systemsystem messageuseruser messageassistantassistant message" + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template, tokenize=False + ) + self.assertEqual(output, expected_output) # Test we can pass chat_template arg + # Check that no error raised when tokenize=True + tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True) + + tokenizer.chat_template = dummy_template + self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter + output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False) + self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed + tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised + + with tempfile.TemporaryDirectory() as tmp_dir_name: + tokenizer.save_pretrained(tmp_dir_name) + tokenizer = tokenizer.from_pretrained(tmp_dir_name) + + self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted + output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False) + self.assertEqual(output, expected_output) # Test output is the same after reloading + tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised + def test_number_of_added_tokens(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: