diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 232fb100e9c13c..975607e66a00a4 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -98,6 +98,8 @@
title: Use model-specific APIs
- local: custom_models
title: Share a custom model
+ - local: chat_templating
+ title: Templates for chat models
- local: sagemaker
title: Run training on Amazon SageMaker
- local: serialization
diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md
new file mode 100644
index 00000000000000..8a9c1e915d4eb6
--- /dev/null
+++ b/docs/source/en/chat_templating.md
@@ -0,0 +1,255 @@
+
+
+# Templates for Chat Models
+
+## Introduction
+
+An increasingly common use case for LLMs is **chat**. In a chat context, rather than continuing a single string
+of text (as is the case with a standard language model), the model instead continues a conversation that consists
+of one or more **messages**, each of which includes a **role** as well as message text.
+
+Most commonly, these roles are "user" for messages sent by the user, and "assistant" for messages sent by the model.
+Some models also support a "system" role. System messages are usually sent at the beginning of the conversation
+and include directives about how the model should behave in the subsequent chat.
+
+All language models, including models fine-tuned for chat, operate on linear sequences of tokens and do not intrinsically
+have special handling for roles. This means that role information is usually injected by adding control tokens
+between messages, to indicate both the message boundary and the relevant roles.
+
+Unfortunately, there isn't (yet!) a standard for which tokens to use, and so different models have been trained
+with wildly different formatting and control tokens for chat. This can be a real problem for users - if you use the
+wrong format, then the model will be confused by your input, and your performance will be a lot worse than it should be.
+This is the problem that **chat templates** aim to resolve.
+
+Chat conversations are typically represented as a list of dictionaries, where each dictionary contains `role`
+and `content` keys, and represents a single chat message. Chat templates are strings containing a Jinja template that
+specifies how to format a conversation for a given model into a single tokenizable sequence. By storing this information
+with the tokenizer, we can ensure that models get input data in the format they expect.
+
+Let's make this concrete with a quick example using the `BlenderBot` model. BlenderBot has an extremely simple default
+template, which mostly just adds whitespace between rounds of dialogue:
+
+```python
+>>> from transformers import AutoTokenizer
+>>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
+
+>>> chat = [
+... {"role": "user", "content": "Hello, how are you?"},
+... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
+... {"role": "user", "content": "I'd like to show off how chat templating works!"},
+... ]
+
+>>> tokenizer.apply_chat_template(chat, tokenize=False)
+" Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!"
+```
+
+Notice how the entire chat is condensed into a single string. If we use `tokenize=True`, which is the default setting,
+that string will also be tokenized for us. To see a more complex template in action, though, let's use the
+`meta-llama/Llama-2-7b-chat-hf` model. Note that this model has gated access, so you will have to
+[request access on the repo](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) if you want to run this code yourself:
+
+```python
+>> from transformers import AutoTokenizer
+>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+>> chat = [
+... {"role": "user", "content": "Hello, how are you?"},
+... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
+... {"role": "user", "content": "I'd like to show off how chat templating works!"},
+... ]
+
+>> tokenizer.use_default_system_prompt = False
+>> tokenizer.apply_chat_template(chat, tokenize=False)
+"[INST] Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]"
+```
+
+Note that this time, the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of
+user messages (but not assistant messages!)
+
+## How do chat templates work?
+
+The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
+default template for that model class is used instead. Let's take a look at the template for `BlenderBot`:
+
+```python
+
+>>> from transformers import AutoTokenizer
+>>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
+
+>>> tokenizer.default_chat_template
+"{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
+```
+
+That's kind of intimidating. Let's add some newlines and indentation to make it more readable. Note that
+we remove the first newline after each block as well as any preceding whitespace before a block by default, using the
+Jinja `trim_blocks` and `lstrip_blocks` flags. This means that you can write your templates with indentations and
+newlines and still have them function correctly!
+
+```
+{% for message in messages %}
+ {% if message['role'] == 'user' %}
+ {{ ' ' }}
+ {% endif %}
+ {{ message['content'] }}
+ {% if not loop.last %}
+ {{ ' ' }}
+ {% endif %}
+{% endfor %}
+{{ eos_token }}
+```
+
+If you've never seen one of these before, this is a [Jinja template](https://jinja.palletsprojects.com/en/3.1.x/templates/).
+Jinja is a templating language that allows you to write simple code that generates text. In many ways, the code and
+syntax resembles Python. In pure Python, this template would look something like this:
+
+```python
+for idx, message in enumerate(messages):
+ if message['role'] == 'user':
+ print(' ')
+ print(message['content'])
+ if not idx == len(messages) - 1: # Check for the last message in the conversation
+ print(' ')
+print(eos_token)
+```
+
+Effectively, the template does three things:
+1. For each message, if the message is a user message, add a blank space before it, otherwise print nothing.
+2. Add the message content
+3. If the message is not the last message, add two spaces after it. After the final message, print the EOS token.
+
+This is a pretty simple template - it doesn't add any control tokens, and it doesn't support "system" messages, which
+are a common way to give the model directives about how it should behave in the subsequent conversation.
+But Jinja gives you a lot of flexibility to do those things! Let's see a Jinja template that can format inputs
+similarly to the way LLaMA formats them (note that the real LLaMA template includes handling for default system
+messages and slightly different system message handling in general - don't use this one in your actual code!)
+
+```
+{% for message in messages %}
+ {% if message['role'] == 'user' %}
+ {{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}
+ {% elif message['role'] == 'system' %}
+ {{ '<>\\n' + message['content'] + '\\n<>\\n\\n' }}
+ {% elif message['role'] == 'assistant' %}
+ {{ ' ' + message['content'] + ' ' + eos_token }}
+ {% endif %}
+{% endfor %}
+```
+
+Hopefully if you stare at this for a little bit you can see what this template is doing - it adds specific tokens based
+on the "role" of each message, which represents who sent it. User, assistant and system messages are clearly
+distinguishable to the model because of the tokens they're wrapped in.
+
+## How do I create a chat template?
+
+Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an
+existing template from another model and simply edit it for your needs! For example, we could take the LLaMA template
+above and add "[ASST]" and "[/ASST]" to assistant messages:
+
+```
+{% for message in messages %}
+ {% if message['role'] == 'user' %}
+ {{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}
+ {% elif message['role'] == 'system' %}
+ {{ '<>\\n' + message['content'].strip() + '\\n<>\\n\\n' }}
+ {% elif message['role'] == 'assistant' %}
+ {{ '[ASST] ' + message['content'] + ' [/ASST]' + eos_token }}
+ {% endif %}
+{% endfor %}
+```
+
+Now, simply set the `tokenizer.chat_template` attribute. Next time you use [`~PreTrainedTokenizer.apply_chat_template`], it will
+use your new template! This attribute will be saved in the `tokenizer_config.json` file, so you can use
+[`~utils.PushToHubMixin.push_to_hub`] to upload your new template to the Hub and make sure everyone's using the right
+template for your model!
+
+```python
+template = tokenizer.chat_template
+template = template.replace("SYS", "SYSTEM") # Change the system token
+tokenizer.chat_template = template # Set the new template
+tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
+```
+
+The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so
+once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`].
+
+## What are "default" templates?
+
+Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards
+compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a
+model does not have a chat template set, but there is a default template for its model class, the `ConversationPipeline`
+class and methods like `apply_chat_template` will use the class template instead. You can find out what the default
+template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute.
+
+This is something we do purely for backward compatibility reasons, to avoid breaking any existing workflows. Even when
+the class template is appropriate for your model, we strongly recommend overriding the default template by
+setting the `chat_template` attribute explicitly to make it clear to users that your model has been correctly configured
+for chat, and to future-proof in case the default templates are ever altered or deprecated.
+
+## What template should I use?
+
+When setting the template for a model that's already been trained for chat, you should ensure that the template
+exactly matches the message formatting that the model saw during training, or else you will probably experience
+performance degradation. This is true even if you're training the model further - you will probably get the best
+performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the
+best performance for inference or fine-tuning when you precisely match the tokenization used during training.
+
+If you're training a model from scratch, or fine-tuning a base language model for chat, on the other hand,
+you have a lot of freedom to choose an appropriate template! LLMs are smart enough to learn to handle lots of different
+input formats. Our default template for models that don't have a class-specific template follows the
+[ChatML format](https://github.com/openai/openai-python/blob/main/chatml.md), and this is a good, flexible choice for many use-cases. It looks like this:
+
+```
+{% for message in messages %}
+ {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
+{% endfor %}
+```
+
+If you like this one, here it is in one-liner form, ready to copy into your code:
+
+```
+tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
+```
+
+This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which
+allows for flexibility in the roles you train with. The output looks like this:
+
+```
+<|im_start|>system
+You are a helpful chatbot that will do its best not to say anything so stupid that people tweet about it.<|im_end|>
+<|im_start|>user
+How are you?<|im_end|>
+<|im_start|>assistant
+I'm doing great!<|im_end|>
+```
+
+The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense,
+particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
+to these roles - templating is extremely flexible, and any string can be a role.
+
+## I want to use chat templates! How should I get started?
+
+If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using
+[`~PreTrainedTokenizer.apply_chat_template`]. This applies even if you're not the model owner - if you're using a model
+with an empty chat template, or one that's still using the default class template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to
+the model repository so that this attribute can be set properly!
+
+Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that
+model, which means it is also automatically supported in places like `ConversationPipeline`!
+
+By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
+open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
+it's time to put an end to them!
\ No newline at end of file
diff --git a/docs/source/en/main_classes/tokenizer.md b/docs/source/en/main_classes/tokenizer.md
index 251cbb43ea7203..71f96c55cb51c9 100644
--- a/docs/source/en/main_classes/tokenizer.md
+++ b/docs/source/en/main_classes/tokenizer.md
@@ -58,6 +58,7 @@ to a given token).
- batch_decode
- decode
- encode
+ - apply_chat_template
- push_to_hub
- all
@@ -71,6 +72,7 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers
- batch_decode
- decode
- encode
+ - apply_chat_template
- push_to_hub
- all
diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot.py b/src/transformers/models/blenderbot/tokenization_blenderbot.py
index cb4a33a3c28bda..d6a70beb30a136 100644
--- a/src/transformers/models/blenderbot/tokenization_blenderbot.py
+++ b/src/transformers/models/blenderbot/tokenization_blenderbot.py
@@ -17,7 +17,7 @@
import json
import os
from functools import lru_cache
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
import regex as re
@@ -25,9 +25,6 @@
from ...utils import logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
logger = logging.get_logger(__name__)
@@ -413,19 +410,16 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1:
"""
return token_ids_0 + [self.eos_token_id]
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- inputs = []
- for is_user, text in conversation.iter_texts():
- if is_user:
- # We need to space prefix as it's being done within blenderbot
- inputs.append(" " + text)
- else:
- # Generated responses should contain them already.
- inputs.append(text)
-
- full_string = " ".join(inputs)
- input_ids = self.encode(full_string)
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
- return input_ids
+ @property
+ def default_chat_template(self):
+ """
+ A very simple chat template that just adds whitespace between messages.
+ """
+ return (
+ "{% for message in messages %}"
+ "{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
+ "{{ message['content'] }}"
+ "{% if not loop.last %}{{ ' ' }}{% endif %}"
+ "{% endfor %}"
+ "{{ eos_token }}"
+ )
diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py
index 4737e92617c70d..ebe39ed09f9a35 100644
--- a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py
+++ b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py
@@ -14,7 +14,7 @@
# limitations under the License.
"""Fast Tokenization class for Blenderbot."""
import json
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors
@@ -24,9 +24,6 @@
from .tokenization_blenderbot import BlenderbotTokenizer
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
logger = logging.get_logger(__name__)
@@ -297,19 +294,17 @@ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1:
"""
return token_ids_0 + [self.eos_token_id]
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- inputs = []
- for is_user, text in conversation.iter_texts():
- if is_user:
- # We need to space prefix as it's being done within blenderbot
- inputs.append(" " + text)
- else:
- # Generated responses should contain them already.
- inputs.append(text)
-
- full_string = " ".join(inputs)
- input_ids = self.encode(full_string)
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
- return input_ids
+ @property
+ # Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ A very simple chat template that just adds whitespace between messages.
+ """
+ return (
+ "{% for message in messages %}"
+ "{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
+ "{{ message['content'] }}"
+ "{% if not loop.last %}{{ ' ' }}{% endif %}"
+ "{% endfor %}"
+ "{{ eos_token }}"
+ )
diff --git a/src/transformers/models/bloom/tokenization_bloom_fast.py b/src/transformers/models/bloom/tokenization_bloom_fast.py
index 8339ece5433bd3..47b78ac723f757 100644
--- a/src/transformers/models/bloom/tokenization_bloom_fast.py
+++ b/src/transformers/models/bloom/tokenization_bloom_fast.py
@@ -16,17 +16,13 @@
import pickle
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import Optional, Tuple
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
@@ -166,12 +162,10 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """This corresponds to DialoGPT variants of models."""
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
-
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
+ """
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
diff --git a/src/transformers/models/code_llama/tokenization_code_llama.py b/src/transformers/models/code_llama/tokenization_code_llama.py
index 0cf48b12077227..53a2d3577a1740 100644
--- a/src/transformers/models/code_llama/tokenization_code_llama.py
+++ b/src/transformers/models/code_llama/tokenization_code_llama.py
@@ -17,7 +17,7 @@
"""Tokenization classes for Code LLaMA."""
import os
from shutil import copyfile
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
@@ -26,9 +26,6 @@
from ...utils import logging, requires_backends
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
@@ -441,70 +438,57 @@ def create_token_type_ids_from_sequences(
return output
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- r"""Builds the input ids for a conversation.
- This is the format used in the provided examples. System prompts should be manually added at the beginning of
- the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
- ```
- [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer
- [INST] Prompt [/INST] Answer
- [INST] Prompt [/INST]
- ```
+ @property
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages.
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
+ to fine-tune a model with more flexible role ordering!
- If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
- ```python
- >>> from transformers import Conversation
+ The output should look something like:
- >>> Conversation(
- ... "<>\n Complete the functions without any documentation\n<>\n\n `def remove_non_ascii(s: str) -> str:`"
- ... ) # doctest: +IGNORE_RESULT
- ```
- Args:
- conversation (`Conversation`):
- Conversation to build input ids for.
- Returns:
- `List[int]`:
- Input ids for the conversation.
+ [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer
+ [INST] Prompt [/INST]
"""
- if self.use_default_system_prompt:
- if len(conversation.past_user_inputs) > 0:
- if (
- not conversation.past_user_inputs[0].startswith(B_SYS)
- or E_SYS not in conversation.past_user_inputs[0]
- ):
- conversation.past_user_inputs[0] = (
- B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
- )
- elif conversation.new_user_input:
- if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
- conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
- else:
- raise ValueError("Last message must be from user")
-
- dialogue = list(conversation.iter_texts())
- if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
- [not is_user for is_user, msg in dialogue[1::2]]
- ):
- raise ValueError(
- "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
- )
- dialog_tokens: List[int] = []
- dialog_tokens += sum(
- [
- [self.bos_token_id]
- + self.encode(
- f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
- )
- + [self.eos_token_id]
- for prompt, answer in zip(dialogue[::2], dialogue[1::2])
- ],
- [],
+ template = (
+ "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
+ "{% set system_message = messages[0]['content'] %}"
+ "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}"
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% set system_message = false %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}" # Loop over all non-system messages
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
+ "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
+ "{% else %}"
+ "{% set content = message['content'] %}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
+ "{% elif message['role'] == 'system' %}"
+ "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
+ "{% endif %}"
+ "{% endfor %}"
)
- dialog_tokens += [self.bos_token_id] + self.encode(
- f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
- )
- return dialog_tokens
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
+
+ return template
def __getstate__(self):
state = self.__dict__.copy()
diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py
index 030d473bfeb261..91a1896c3c10db 100644
--- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py
+++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -14,7 +14,7 @@
# limitations under the License.
import os
from shutil import copyfile
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
from tokenizers import normalizers, processors
@@ -23,9 +23,6 @@
from ...utils.versions import require_version
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
require_version("tokenizers>=0.13.3")
if is_sentencepiece_available():
@@ -344,6 +341,58 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
return (out_vocab_file,)
+ @property
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages.
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
+ to fine-tune a model with more flexible role ordering!
+
+ The output should look something like:
+
+ [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer
+ [INST] Prompt [/INST]
+ """
+
+ template = (
+ "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
+ "{% set system_message = messages[0]['content'] %}"
+ "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}"
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% set system_message = false %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}" # Loop over all non-system messages
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
+ "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
+ "{% else %}"
+ "{% set content = message['content'] %}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
+ "{% elif message['role'] == 'system' %}"
+ "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
+
+ return template
+
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
@@ -371,69 +420,3 @@ def build_inputs_with_special_tokens(
if token_ids_1 is None:
return self.bos_token_id + token_ids_0 + self.eos_token_id
return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
-
- # Copied from transformers.models.code_llama.tokenization_code_llama.CodeLlamaTokenizer._build_conversation_input_ids
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- r"""Builds the input ids for a conversation.
- This is the format used in the provided examples. System prompts should be manually added at the beginning of
- the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
- ```
- [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer
- [INST] Prompt [/INST] Answer
- [INST] Prompt [/INST]
- ```
-
- If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
- ```python
- >>> from transformers import Conversation
-
- >>> Conversation(
- ... "<>\n Complete the functions without any documentation\n<>\n\n `def remove_non_ascii(s: str) -> str:`"
- ... ) # doctest: +IGNORE_RESULT
- ```
- Args:
- conversation (`Conversation`):
- Conversation to build input ids for.
- Returns:
- `List[int]`:
- Input ids for the conversation.
- """
- if self.use_default_system_prompt:
- if len(conversation.past_user_inputs) > 0:
- if (
- not conversation.past_user_inputs[0].startswith(B_SYS)
- or E_SYS not in conversation.past_user_inputs[0]
- ):
- conversation.past_user_inputs[0] = (
- B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
- )
- elif conversation.new_user_input:
- if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
- conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
- else:
- raise ValueError("Last message must be from user")
-
- dialogue = list(conversation.iter_texts())
- if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
- [not is_user for is_user, msg in dialogue[1::2]]
- ):
- raise ValueError(
- "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
- )
-
- dialog_tokens: List[int] = []
- dialog_tokens += sum(
- [
- [self.bos_token_id]
- + self.encode(
- f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
- )
- + [self.eos_token_id]
- for prompt, answer in zip(dialogue[::2], dialogue[1::2])
- ],
- [],
- )
- dialog_tokens += [self.bos_token_id] + self.encode(
- f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
- )
- return dialog_tokens
diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py
index 8a778a947cfbca..880ed17d95ef28 100644
--- a/src/transformers/models/deberta/tokenization_deberta.py
+++ b/src/transformers/models/deberta/tokenization_deberta.py
@@ -16,7 +16,7 @@
import json
import os
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
import regex as re
@@ -24,9 +24,6 @@
from ...utils import logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
@@ -433,12 +430,3 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
-
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
diff --git a/src/transformers/models/deberta/tokenization_deberta_fast.py b/src/transformers/models/deberta/tokenization_deberta_fast.py
index c05cf257611ebf..d77f0b39b98486 100644
--- a/src/transformers/models/deberta/tokenization_deberta_fast.py
+++ b/src/transformers/models/deberta/tokenization_deberta_fast.py
@@ -15,7 +15,7 @@
""" Fast Tokenization class for model DeBERTa."""
import json
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers
@@ -25,10 +25,6 @@
from .tokenization_deberta import DebertaTokenizer
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -288,14 +284,3 @@ def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
-
- # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._build_conversation_input_ids
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """This corresponds to DialoGPT variants of models."""
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
-
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
diff --git a/src/transformers/models/gpt2/tokenization_gpt2.py b/src/transformers/models/gpt2/tokenization_gpt2.py
index 9a8ce3a4fabd5a..278ff69032585c 100644
--- a/src/transformers/models/gpt2/tokenization_gpt2.py
+++ b/src/transformers/models/gpt2/tokenization_gpt2.py
@@ -18,7 +18,7 @@
import json
import os
from functools import lru_cache
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
import regex as re
@@ -26,9 +26,6 @@
from ...utils import logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
@@ -354,10 +351,9 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
text = " " + text
return (text, kwargs)
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ def default_chat_template(self):
+ """
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
+ """
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
diff --git a/src/transformers/models/gpt2/tokenization_gpt2_fast.py b/src/transformers/models/gpt2/tokenization_gpt2_fast.py
index cf2b8b2cb22c6a..189a3550840885 100644
--- a/src/transformers/models/gpt2/tokenization_gpt2_fast.py
+++ b/src/transformers/models/gpt2/tokenization_gpt2_fast.py
@@ -16,7 +16,7 @@
import json
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import Optional, Tuple
from tokenizers import pre_tokenizers
@@ -26,10 +26,6 @@
from .tokenization_gpt2 import GPT2Tokenizer
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -181,12 +177,10 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """This corresponds to DialoGPT variants of models."""
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
-
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
+ """
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
index 570b2abaa49fde..f666b97efd2bd0 100644
--- a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
+++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py
@@ -14,7 +14,7 @@
# limitations under the License.
"""Tokenization classes for GPTNeoX."""
import json
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import Optional, Tuple
from tokenizers import pre_tokenizers
@@ -22,10 +22,6 @@
from ...utils import logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -133,12 +129,10 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """This corresponds to DialoGPT variants of models."""
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
-
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
+ """
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
diff --git a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py
index c9f4f677cb483e..6ac2f214a16568 100644
--- a/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py
+++ b/src/transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py
@@ -17,7 +17,7 @@
import json
import os
import re
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import Optional, Tuple
import numpy as np
@@ -25,10 +25,6 @@
from ...utils import logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
@@ -179,15 +175,14 @@ def convert_tokens_to_string(self, tokens):
out_string = "".join(tokens).strip()
return out_string
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """This corresponds to DialoGPT variants of models."""
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
-
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ def default_chat_template(self):
+ """
+ A simple chat template that just adds BOS/EOS tokens around messages while discarding role information.
+ """
+ return (
+ "{% for message in messages %}" "{{ bos_token + eos_token + message.content + eos_token }}" "{% endfor %}"
+ )
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
diff --git a/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py b/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
index f592a2b63eabf5..4874ba732245f0 100644
--- a/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
+++ b/src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
@@ -4,7 +4,7 @@
import re
import unicodedata
from shutil import copyfile
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
@@ -16,10 +16,6 @@
import torch
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
@@ -319,31 +315,18 @@ def decode_fast(self, token_ids: Union[int, List[int]]) -> str:
return self.sp_model.decode(token_ids)
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """Builds the input ids for a conversation.
-
- This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2].
- The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows:
-
- ```
- User: Jag tycker träd är finaBot: Kul att du tycker det!...
- ```
-
- Args:
- conversation (`Conversation`):
- Conversation to build input ids for.
-
- Returns:
- `List[int]`:
- Input ids for the conversation.
-
- References:
- - [1] https://doi.org/10.48550/arXiv.2305.12987
- - [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct
- - [3] https://github.com/openai/openai-python/blob/main/chatml.md
+ @property
+ def default_chat_template(self):
+ """
+ This chat template formats messages like an instant messenger chat log, with "User:" and "Bot:" strings
+ preceding messages. BOS tokens are added between all messages.
"""
- all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()]
- prompt = (
- f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:"
+ return (
+ "{{ eos_token }}{{ bos_token }}"
+ "{% for message in messages %}"
+ "{% if message['role'] == 'user' %}{{ 'User: ' + message['content']}}"
+ "{% else %}{{ 'Bot: ' + message['content']}}{% endif %}"
+ "{{ message['text'] }}{{ bos_token }}"
+ "{% endfor %}"
+ "Bot:"
)
- return self.encode(text=prompt)
diff --git a/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py
index a16e55ec7180d6..c567b6b6003fff 100644
--- a/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py
+++ b/src/transformers/models/gptsan_japanese/tokenization_gptsan_japanese.py
@@ -17,7 +17,7 @@
import json
import os
import re
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
@@ -33,10 +33,6 @@
from ...utils import PaddingStrategy, logging
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
@@ -258,16 +254,18 @@ def convert_tokens_to_string(self, tokens):
text = "".join(words)
return text
- # Copied from tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizer._build_conversation_input_ids
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- """This corresponds to DialoGPT variants of models."""
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
-
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ def default_chat_template(self):
+ """
+ A simple chat template that adds standard BOS, SEP and EOS tokens between messages while discarding role
+ information.
+ """
+ return (
+ "{% for message in messages %}"
+ "{% if not loop.first %}{{ bos_token}}{% endif %}"
+ "{{ sep_token }}{{ message.content }} {{ eos_token }}"
+ "{% endfor %}"
+ )
# Copied from tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py
index f33771995d291c..8db2f9970e199a 100644
--- a/src/transformers/models/llama/tokenization_llama.py
+++ b/src/transformers/models/llama/tokenization_llama.py
@@ -31,7 +31,6 @@
if TYPE_CHECKING:
- from ...pipelines.conversational import Conversation
from ...tokenization_utils_base import TextInput
logger = logging.get_logger(__name__)
@@ -374,67 +373,53 @@ def create_token_type_ids_from_sequences(
return output
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- r"""Builds the input ids for a conversation.
- This is the format used in the provided examples. System prompts should be manually added at the beginning of
- the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
- ```
- [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer
- [INST] Prompt [/INST] Answer
- [INST] Prompt [/INST]
- ```
+ @property
+ def default_chat_template(self):
+ """
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages.
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
+ to fine-tune a model with more flexible role ordering!
- If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
- ```python
- >>> from transformers import Conversation
+ The output should look something like:
- >>> Conversation(
- ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?"
- ... ) # doctest: +IGNORE_RESULT
- ```
- Args:
- conversation (`Conversation`):
- Conversation to build input ids for.
- Returns:
- `List[int]`:
- Input ids for the conversation.
+ [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer
+ [INST] Prompt [/INST]
"""
- if self.use_default_system_prompt:
- if len(conversation.past_user_inputs) > 0:
- if (
- not conversation.past_user_inputs[0].startswith(B_SYS)
- or E_SYS not in conversation.past_user_inputs[0]
- ):
- conversation.past_user_inputs[0] = (
- B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
- )
- elif conversation.new_user_input:
- if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
- conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
- else:
- raise ValueError("Last message must be from user")
-
- dialogue = list(conversation.iter_texts())
- if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
- [not is_user for is_user, msg in dialogue[1::2]]
- ):
- raise ValueError(
- "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
- )
- dialog_tokens: List[int] = []
- dialog_tokens += sum(
- [
- [self.bos_token_id]
- + self.encode(
- f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
- )
- + [self.eos_token_id]
- for prompt, answer in zip(dialogue[::2], dialogue[1::2])
- ],
- [],
- )
- dialog_tokens += [self.bos_token_id] + self.encode(
- f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
+ template = (
+ "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
+ "{% set system_message = messages[0]['content'] %}"
+ "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}"
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% set system_message = false %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}" # Loop over all non-system messages
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
+ "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
+ "{% else %}"
+ "{% set content = message['content'] %}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
+ "{% elif message['role'] == 'system' %}"
+ "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
+ "{% endif %}"
+ "{% endfor %}"
)
- return dialog_tokens
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
+
+ return template
diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py
index 09c81417672e80..282a0f06740eaa 100644
--- a/src/transformers/models/llama/tokenization_llama_fast.py
+++ b/src/transformers/models/llama/tokenization_llama_fast.py
@@ -14,7 +14,7 @@
# limitations under the License.
import os
from shutil import copyfile
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import Optional, Tuple
from tokenizers import processors
@@ -23,9 +23,6 @@
from ...utils.versions import require_version
-if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
-
require_version("tokenizers>=0.13.3")
if is_sentencepiece_available():
@@ -192,67 +189,54 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
return (out_vocab_file,)
- def _build_conversation_input_ids(self, conversation: "Conversation"):
- """Builds the input ids for a conversation.
- This is the format used in the provided examples. System prompts should be manually added at the beginning of
- the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
- ```
- [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer
- [INST] Prompt [/INST] Answer
+ @property
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages.
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
+ to fine-tune a model with more flexible role ordering!
+
+ The output should look something like:
+
+ [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer
[INST] Prompt [/INST]
- ```
-
- If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
- ```python
- >>> from transformers import Conversation
-
- >>> Conversation(
- ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?"
- ... )
- ```
- Args:
- conversation (`Conversation`):
- Conversation to build input ids for.
- Returns:
- `List[int]`:
- Input ids for the conversation.
"""
- if self.use_default_system_prompt:
- if len(conversation.past_user_inputs) > 0:
- if (
- not conversation.past_user_inputs[0].startswith(B_SYS)
- or E_SYS not in conversation.past_user_inputs[0]
- ):
- conversation.past_user_inputs[0] = (
- B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
- )
- elif conversation.new_user_input:
- if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
- conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
- else:
- raise ValueError("Last message must be from user")
-
- dialogue = list(conversation.iter_texts())
- if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
- [not is_user for is_user, msg in dialogue[1::2]]
- ):
- raise ValueError(
- "The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
- )
- dialog_tokens = []
- dialog_tokens += sum(
- [
- [self.bos_token_id]
- + self.encode(
- f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
- )
- + [self.eos_token_id]
- for prompt, answer in zip(dialogue[::2], dialogue[1::2])
- ],
- [],
+ template = (
+ "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
+ "{% set system_message = messages[0]['content'] %}"
+ "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}"
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% set system_message = false %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}" # Loop over all non-system messages
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
+ "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}"
+ "{% else %}"
+ "{% set content = message['content'] %}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
+ "{% elif message['role'] == 'system' %}"
+ "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
+ "{% endif %}"
+ "{% endfor %}"
)
- dialog_tokens += [self.bos_token_id] + self.encode(
- f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
- )
- return dialog_tokens
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
+
+ return template
diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py
index b3eccddf410355..a22521b4e00dfb 100644
--- a/src/transformers/models/whisper/tokenization_whisper.py
+++ b/src/transformers/models/whisper/tokenization_whisper.py
@@ -16,7 +16,7 @@
import json
import os
from functools import lru_cache
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
import regex as re
@@ -26,10 +26,6 @@
from .english_normalizer import EnglishTextNormalizer
-if TYPE_CHECKING:
- from ...pipelines.conversational import Conversation
-
-
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
@@ -751,14 +747,13 @@ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
text = " " + text
return (text, kwargs)
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
+ """
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py
index 4ad500bbf1c0ef..cb321f669c7ca6 100644
--- a/src/transformers/models/whisper/tokenization_whisper_fast.py
+++ b/src/transformers/models/whisper/tokenization_whisper_fast.py
@@ -16,7 +16,7 @@
import json
import os
from functools import lru_cache
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import List, Optional, Tuple
import numpy as np
from tokenizers import pre_tokenizers, processors
@@ -28,10 +28,6 @@
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
-if TYPE_CHECKING:
- from ...pipelines.conversational import Conversation
-
-
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
@@ -520,14 +516,13 @@ def get_special_tokens_mask(
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
- # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids
- def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
- input_ids = []
- for is_user, text in conversation.iter_texts():
- input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
- if len(input_ids) > self.model_max_length:
- input_ids = input_ids[-self.model_max_length :]
- return input_ids
+ @property
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
+ def default_chat_template(self):
+ """
+ A simple chat template that ignores role information and just concatenates messages with EOS tokens.
+ """
+ return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_decoder_prompt_ids
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py
index 93d056c88d44ef..c455c75574435b 100644
--- a/src/transformers/pipelines/conversational.py
+++ b/src/transformers/pipelines/conversational.py
@@ -1,5 +1,5 @@
import uuid
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Union
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
@@ -19,137 +19,153 @@ class Conversation:
"""
Utility class containing a conversation and its history. This class is meant to be used as an input to the
[`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user
- inputs and generated model responses. A conversation needs to contain an unprocessed user input before being passed
- to the [`ConversationalPipeline`]. This user input is either created when the class is instantiated, or by calling
- `conversational_pipeline.append_response("input")` after a conversation turn.
+ inputs and generated model responses.
Arguments:
- text (`str`, *optional*):
- The initial user input to start the conversation. If not provided, a user input needs to be provided
- manually using the [`~Conversation.add_user_input`] method before the conversation can begin.
+ messages (Union[str, List[Dict[str, str]]], *optional*):
+ The initial messages to start the conversation, either a string, or a list of dicts containing "role" and
+ "content" keys. If a string is passed, it is interpreted as a single message with the "user" role.
conversation_id (`uuid.UUID`, *optional*):
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
conversation.
- past_user_inputs (`List[str]`, *optional*):
- Eventual past history of the conversation of the user. You don't need to pass it manually if you use the
- pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and
- `generated_responses` with equal length lists of strings
- generated_responses (`List[str]`, *optional*):
- Eventual past history of the conversation of the model. You don't need to pass it manually if you use the
- pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and
- `generated_responses` with equal length lists of strings
Usage:
```python
conversation = Conversation("Going to the movies tonight - any suggestions?")
-
- # Steps usually performed by the model when generating a response:
- # 1. Mark the user input as processed (moved to the history)
- conversation.mark_processed()
- # 2. Append a mode response
- conversation.append_response("The Big lebowski.")
-
- conversation.add_user_input("Is it good?")
+ conversation.add_message({"role": "assistant", "content": "The Big lebowski."})
+ conversation.add_message({"role": "user", "content": "Is it good?"})
```"""
def __init__(
- self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None
+ self, messages: Union[str, List[Dict[str, str]]] = None, conversation_id: uuid.UUID = None, **deprecated_kwargs
):
if not conversation_id:
conversation_id = uuid.uuid4()
- if past_user_inputs is None:
- past_user_inputs = []
- if generated_responses is None:
- generated_responses = []
- self.uuid: uuid.UUID = conversation_id
- self.past_user_inputs: List[str] = past_user_inputs
- self.generated_responses: List[str] = generated_responses
- self.new_user_input: Optional[str] = text
+ if messages is None:
+ text = deprecated_kwargs.pop("text", None)
+ if text is not None:
+ messages = [{"role": "user", "content": text}]
+ else:
+ messages = []
+ elif isinstance(messages, str):
+ messages = [{"role": "user", "content": messages}]
+
+ # This block deals with the legacy args - new code should just totally
+ # avoid past_user_inputs and generated_responses
+ generated_responses = deprecated_kwargs.pop("generated_responses", None)
+ past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
+ if generated_responses is not None and past_user_inputs is None:
+ raise ValueError("generated_responses cannot be passed without past_user_inputs!")
+ if past_user_inputs is not None:
+ legacy_messages = []
+ if generated_responses is None:
+ generated_responses = []
+ # We structure it this way instead of using zip() because the lengths may differ by 1
+ for i in range(max([len(past_user_inputs), len(generated_responses)])):
+ if i < len(past_user_inputs):
+ legacy_messages.append({"role": "user", "content": past_user_inputs[i]})
+ if i < len(generated_responses):
+ legacy_messages.append({"role": "assistant", "content": generated_responses[i]})
+ messages = legacy_messages + messages
+
+ self.uuid = conversation_id
+ self.messages = messages
def __eq__(self, other):
if not isinstance(other, Conversation):
return False
- if self.uuid == other.uuid:
- return True
- return (
- self.new_user_input == other.new_user_input
- and self.past_user_inputs == other.past_user_inputs
- and self.generated_responses == other.generated_responses
- )
+ return self.uuid == other.uuid or self.messages == other.messages
+
+ def add_message(self, message: Dict[str, str]):
+ if not set(message.keys()) == {"role", "content"}:
+ raise ValueError("Message should contain only 'role' and 'content' keys!")
+ if message["role"] not in ("user", "assistant", "system"):
+ raise ValueError("Only 'user', 'assistant' and 'system' roles are supported for now!")
+ self.messages.append(message)
def add_user_input(self, text: str, overwrite: bool = False):
"""
- Add a user input to the conversation for the next round. This populates the internal `new_user_input` field.
-
- Args:
- text (`str`): The user input for the next conversation round.
- overwrite (`bool`, *optional*, defaults to `False`):
- Whether or not existing and unprocessed user input should be overwritten when this function is called.
+ Add a user input to the conversation for the next round. This is a legacy method that assumes that inputs must
+ alternate user/assistant/user/assistant, and so will not add multiple user messages in succession. We recommend
+ just using `add_message` with role "user" instead.
"""
- if self.new_user_input:
+ if len(self) > 0 and self[-1]["role"] == "user":
if overwrite:
logger.warning(
- f'User input added while unprocessed input was existing: "{self.new_user_input}" was overwritten '
+ f'User input added while unprocessed input was existing: "{self[-1]["content"]}" was overwritten '
f'with: "{text}".'
)
- self.new_user_input = text
+ self[-1]["content"] = text
else:
logger.warning(
- f'User input added while unprocessed input was existing: "{self.new_user_input}" new input '
+ f'User input added while unprocessed input was existing: "{self[-1]["content"]}" new input '
f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input'
)
else:
- self.new_user_input = text
+ self.messages.append({"role": "user", "content": text})
- def mark_processed(self):
+ def append_response(self, response: str):
"""
- Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties
- the `new_user_input` field.
+ This is a legacy method. We recommend just using `add_message` with an appropriate role instead.
"""
- if self.new_user_input:
- self.past_user_inputs.append(self.new_user_input)
- self.new_user_input = None
+ self.messages.append({"role": "assistant", "content": response})
- def append_response(self, response: str):
+ def mark_processed(self):
"""
- Append a response to the list of generated responses.
-
- Args:
- response (`str`): The model generated response.
+ This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
+ processed and unprocessed user input.
"""
- self.generated_responses.append(response)
+ pass
- def iter_texts(self):
- """
- Iterates over all blobs of the conversation.
+ def __iter__(self):
+ for message in self.messages:
+ yield message
- Returns: Iterator of (is_user, text_chunk) in chronological order of the conversation. `is_user` is a `bool`,
- `text_chunks` is a `str`.
- """
- for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
- yield True, user_input
- yield False, generated_response
- if self.new_user_input:
- yield True, self.new_user_input
+ def __getitem__(self, item):
+ return self.messages[item]
+
+ def __setitem__(self, key, value):
+ self.messages[key] = value
+
+ def __len__(self):
+ return len(self.messages)
def __repr__(self):
"""
Generates a string representation of the conversation.
- Return:
+ Returns:
`str`:
- Example: Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user >> Going to the movies tonight - any
- suggestions? bot >> The Big Lebowski
+ Example:
+ Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user: Going to the movies tonight - any suggestions?
+ bot: The Big Lebowski
"""
- output = f"Conversation id: {self.uuid} \n"
- for is_user, text in self.iter_texts():
- name = "user" if is_user else "bot"
- output += f"{name} >> {text} \n"
+ output = f"Conversation id: {self.uuid}\n"
+ for message in self.messages:
+ output += f"{message['role']}: {message['content']}\n"
return output
+ def iter_texts(self):
+ # This is a legacy method for backwards compatibility. It is recommended to just directly access
+ # conversation.messages instead.
+ for message in self.messages:
+ yield message["role"] == "user", message["content"]
+
+ @property
+ def past_user_inputs(self):
+ # This is a legacy property for backwards compatibility. It is recommended to just directly access
+ # conversation.messages instead.
+ return [message["content"] for message in self.messages if message["role"] == "user"]
+
+ @property
+ def generated_responses(self):
+ # This is a legacy property for backwards compatibility. It is recommended to just directly access
+ # conversation.messages instead.
+ return [message["content"] for message in self.messages if message["role"] == "assistant"]
+
@add_end_docstrings(
PIPELINE_INIT_ARGS,
@@ -246,18 +262,7 @@ def __call__(self, conversations: Union[Conversation, List[Conversation]], num_w
return outputs
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
- if not isinstance(conversation, Conversation):
- raise ValueError("ConversationalPipeline, expects Conversation as inputs")
- if conversation.new_user_input is None:
- raise ValueError(
- f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. "
- "Add user inputs with the conversation's `add_user_input` method"
- )
- if hasattr(self.tokenizer, "_build_conversation_input_ids"):
- input_ids = self.tokenizer._build_conversation_input_ids(conversation)
- else:
- # If the tokenizer cannot handle conversations, we default to only the old version
- input_ids = self._legacy_parse_and_tokenize(conversation)
+ input_ids = self.tokenizer.apply_chat_template(conversation)
if self.framework == "pt":
input_ids = torch.LongTensor([input_ids])
@@ -292,19 +297,5 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=True):
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
conversation = model_outputs["conversation"]
- conversation.mark_processed()
- conversation.append_response(answer)
+ conversation.add_message({"role": "assistant", "content": answer})
return conversation
-
- def _legacy_parse_and_tokenize(self, conversation: Conversation) -> Dict:
- eos_token_id = self.tokenizer.eos_token_id
- input_ids = []
- for is_user, text in conversation.iter_texts():
- if eos_token_id is not None:
- input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id])
- else:
- input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False))
-
- if len(input_ids) > self.tokenizer.model_max_length:
- input_ids = input_ids[-self.tokenizer.model_max_length :]
- return input_ids
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index a7e36322a6b099..ccdc7eb6dd7e21 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -64,6 +64,7 @@
is_ftfy_available,
is_ipex_available,
is_jieba_available,
+ is_jinja_available,
is_jumanpp_available,
is_keras_nlp_available,
is_librosa_available,
@@ -336,6 +337,13 @@ def require_jieba(test_case):
return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)
+def require_jinja(test_case):
+ """
+ Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
+ """
+ return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
+
+
def require_tf2onnx(test_case):
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 2390ed478f30c4..a65f799a724b13 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -27,6 +27,7 @@
from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass, field
+from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
@@ -69,6 +70,7 @@
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp # noqa: F401
+ from .pipelines.conversational import Conversation
if is_tokenizers_available():
@@ -1426,6 +1428,7 @@ def all_special_ids(self) -> List[int]:
- **length** -- The length of the inputs (when `return_length=True`)
"""
+
INIT_TOKENIZER_DOCSTRING = r"""
Class attributes (overridden by derived classes)
@@ -1461,6 +1464,9 @@ def all_special_ids(self) -> List[int]:
truncation_side (`str`, *optional*):
The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name.
+ chat_template (`str`, *optional*):
+ A Jinja template string that will be used to format lists of chat messages. See
+ https://huggingface.co/docs/transformers/chat_templating for a full description.
model_input_names (`List[string]`, *optional*):
The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
`"attention_mask"`). Default value is picked from the class attribute of the same name.
@@ -1558,6 +1564,10 @@ def __init__(self, **kwargs):
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
self._in_target_context_manager = False
+
+ # Stores a Jinja template that formats chat histories into tokenizable strings
+ self.chat_template = kwargs.pop("chat_template", None)
+
super().__init__(**kwargs)
@property
@@ -1627,6 +1637,109 @@ def get_vocab(self) -> Dict[str, int]:
"""
raise NotImplementedError()
+ def apply_chat_template(
+ self,
+ conversation: Union[List[Dict[str, str]], "Conversation"],
+ chat_template: Optional[str] = None,
+ tokenize: bool = True,
+ padding: bool = False,
+ truncation: bool = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **tokenizer_kwargs,
+ ) -> Union[str, List[int]]:
+ """
+ Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
+ ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
+ determine the format and control tokens to use when converting. When chat_template is None, it will fall back
+ to the default_chat_template specified at the class level.
+
+ Args:
+ conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
+ with "role" and "content" keys, representing the chat history so far.
+ chat_template (str, *optional*): A Jinja template to use for this conversion. If
+ this is not passed, the model's default chat template will be used instead.
+ tokenize (`bool`, defaults to `True`):
+ Whether to tokenize the output. If `False`, the output will be a string.
+ padding (`bool`, defaults to `False`):
+ Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
+ truncation (`bool`, defaults to `False`):
+ Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
+ max_length (`int`, *optional*):
+ Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
+ not specified, the tokenizer's `max_length` attribute will be used as a default.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
+ values are:
+ - `'tf'`: Return TensorFlow `tf.Tensor` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+ **tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
+
+ Returns:
+ `List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
+ output is ready to pass to the model, either directly or via methods like `generate()`.
+ """
+
+ if hasattr(conversation, "messages"):
+ # Indicates it's a Conversation object
+ conversation = conversation.messages
+
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template`
+ if chat_template is None:
+ if self.chat_template is not None:
+ chat_template = self.chat_template
+ else:
+ chat_template = self.default_chat_template
+
+ # Compilation function uses a cache to avoid recompiling the same template
+ compiled_template = self._compile_jinja_template(chat_template)
+
+ rendered = compiled_template.render(messages=conversation, **self.special_tokens_map)
+
+ if padding is True:
+ padding = "max_length" # There's only one sequence here, so "longest" makes no sense
+ if tokenize:
+ return self.encode(
+ rendered,
+ add_special_tokens=False,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ **tokenizer_kwargs,
+ )
+ else:
+ return rendered
+
+ @lru_cache
+ def _compile_jinja_template(self, chat_template):
+ try:
+ from jinja2.exceptions import TemplateError
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
+ except ImportError:
+ raise ImportError("apply_chat_template requires jinja2 to be installed.")
+
+ def raise_exception(message):
+ raise TemplateError(message)
+
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
+ jinja_env.globals["raise_exception"] = raise_exception
+ return jinja_env.from_string(chat_template)
+
+ @property
+ def default_chat_template(self):
+ """
+ This template formats inputs in the standard ChatML format. See
+ https://github.com/openai/openai-python/blob/main/chatml.md
+ """
+ return (
+ "{% for message in messages %}"
+ "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
+ "{% endfor %}"
+ )
+
@classmethod
def from_pretrained(
cls,
@@ -2187,6 +2300,9 @@ def save_pretrained(
if hasattr(self, k):
tokenizer_config[k] = getattr(self, k)
+ if self.chat_template is not None:
+ tokenizer_config["chat_template"] = self.chat_template
+
if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 68c39c732e3c35..ac9beb7856187b 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -119,6 +119,7 @@
is_in_notebook,
is_ipex_available,
is_jieba_available,
+ is_jinja_available,
is_jumanpp_available,
is_kenlm_available,
is_keras_nlp_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 6829ca9ad67e1a..aeb351db96e22b 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -91,6 +91,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_ftfy_available = _is_package_available("ftfy")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
+_jinja_available = _is_package_available("jinja2")
_kenlm_available = _is_package_available("kenlm")
_keras_nlp_available = _is_package_available("keras_nlp")
_librosa_available = _is_package_available("librosa")
@@ -793,6 +794,10 @@ def is_jieba_available():
return _jieba_available
+def is_jinja_available():
+ return _jinja_available
+
+
# docstyle-ignore
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
@@ -1081,6 +1086,11 @@ def is_jieba_available():
peft`. Please note that you may need to restart your runtime after installation.
"""
+JINJA_IMPORT_ERROR = """
+{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
+jinja2`. Please note that you may need to restart your runtime after installation.
+"""
+
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -1118,6 +1128,7 @@ def is_jieba_available():
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
+ ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
]
)
diff --git a/tests/models/blenderbot/test_tokenization_blenderbot.py b/tests/models/blenderbot/test_tokenization_blenderbot.py
index 3a9c95027b3221..7fbf2b7603f95e 100644
--- a/tests/models/blenderbot/test_tokenization_blenderbot.py
+++ b/tests/models/blenderbot/test_tokenization_blenderbot.py
@@ -17,6 +17,7 @@
import unittest
from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast
+from transformers.testing_utils import require_jinja
from transformers.utils import cached_property
@@ -50,3 +51,24 @@ def test_3B_tokenization_same_as_parlai(self):
def test_3B_tokenization_same_as_parlai_rust_tokenizer(self):
assert self.rust_tokenizer_3b.add_prefix_space
assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
+
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ tok = self.tokenizer_3b
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
+ ]
+ tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats]
+ expected_tokens = [
+ [553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2],
+ [553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2],
+ [3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2],
+ ]
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
diff --git a/tests/models/bloom/test_tokenization_bloom.py b/tests/models/bloom/test_tokenization_bloom.py
index 576a191c70b5f1..7383eeb668face 100644
--- a/tests/models/bloom/test_tokenization_bloom.py
+++ b/tests/models/bloom/test_tokenization_bloom.py
@@ -18,7 +18,7 @@
from datasets import load_dataset
from transformers import BloomTokenizerFast
-from transformers.testing_utils import require_tokenizers
+from transformers.testing_utils import require_jinja, require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin
@@ -134,6 +134,27 @@ def test_pretrained_model_lists(self):
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ tokenizer = self.get_rust_tokenizer()
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
+ ]
+ tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
+ expected_tokens = [
+ [5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2],
+ [5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2],
+ [229126, 427, 11890, 1152, 17, 2, 59414, 4, 2],
+ ]
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
+
def test_add_prefix_space_fast(self):
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True)
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False)
diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py
index 0dd33e776d497e..cceb3b9238b20f 100644
--- a/tests/models/gpt2/test_tokenization_gpt2.py
+++ b/tests/models/gpt2/test_tokenization_gpt2.py
@@ -20,7 +20,7 @@
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
-from transformers.testing_utils import require_tokenizers
+from transformers.testing_utils import require_jinja, require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin
@@ -275,6 +275,27 @@ def test_special_tokens_mask_input_pairs_and_bos_token(self):
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname)
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
+ ]
+ tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
+ # fmt: off
+ expected_tokens = [[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20],
+ [20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20, 20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20],
+ [20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20, 20, 3, 0, 0, 1, 20, 20]]
+ # fmt: on
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
+
@require_tokenizers
class OPTTokenizationTest(unittest.TestCase):
diff --git a/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py b/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py
index b030996e89dcf0..d639c33ef6440b 100644
--- a/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py
+++ b/tests/models/gpt_sw3/test_tokenization_gpt_sw3.py
@@ -16,7 +16,7 @@
import unittest
from transformers import GPTSw3Tokenizer
-from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
+from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -128,3 +128,27 @@ def test_tokenizer_integration(self):
model_name="AI-Sweden/gpt-sw3-126m",
sequences=sequences,
)
+
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB)
+ # This is in English, but it's just here to make sure the chat control tokens are being added properly
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
+ ]
+ tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
+ # fmt: off
+ expected_tokens = [
+ [268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ],
+ [268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 938, 541, 419, ],
+ [268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ]
+ ]
+ # fmt: on
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
diff --git a/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py b/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py
index 4352f6425f0d32..489e4f942664e5 100644
--- a/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py
+++ b/tests/models/gptsan_japanese/test_tokenization_gptsan_japanese.py
@@ -22,7 +22,7 @@
VOCAB_FILES_NAMES,
GPTSanJapaneseTokenizer,
)
-from transformers.testing_utils import require_tokenizers, slow
+from transformers.testing_utils import require_jinja, require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -193,3 +193,27 @@ def test_conversion_reversible(self):
def test_padding_different_model_input_name(self):
# tokenizer has no padding token
pass
+
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
+ # This is in English, but it's just here to make sure the chat control tokens are being added properly
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
+ ]
+ tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
+ # fmt: off
+ expected_tokens = [
+ [35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
+ [35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999, 35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999],
+ [35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
+ ]
+ # fmt: on
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
index 942cceaf7cd0d4..ca90eb9641ea80 100644
--- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
+++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
@@ -2486,3 +2486,7 @@ def test_layoutlmv2_integration_test(self):
@unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self):
pass
+
+ @unittest.skip("Chat is not supported")
+ def test_chat_template(self):
+ pass
diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
index 58092834e5a160..59efc4b1cf3ba1 100644
--- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
+++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
@@ -2439,3 +2439,7 @@ def test_tf_encode_plus_sent_to_model(self):
# This should not fail
model(encoded_sequence)
model(batch_encoded_sequence)
+
+ @unittest.skip("Chat is not supported")
+ def test_chat_template(self):
+ pass
diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py
index f7f8329706dff2..0b502748d13123 100644
--- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py
+++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py
@@ -1958,3 +1958,7 @@ def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
@unittest.skip("Doesn't use SentencePiece")
def test_sentencepiece_tokenize_and_decode(self):
pass
+
+ @unittest.skip("Chat is not supported")
+ def test_chat_template(self):
+ pass
diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py
index 9223e626f997fe..231474203032b1 100644
--- a/tests/models/llama/test_tokenization_llama.py
+++ b/tests/models/llama/test_tokenization_llama.py
@@ -32,6 +32,7 @@
from transformers.testing_utils import (
get_tests_dir,
nested_simplify,
+ require_jinja,
require_sentencepiece,
require_tokenizers,
require_torch,
@@ -574,6 +575,32 @@ def test_some_edge_cases(self):
# a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
+
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "user", "content": "Hello!"}],
+ ]
+ # Matt: The third test case tests the default system message, but if this is ever changed in the
+ # class/repo code then that test will fail, and the case will need to be updated.
+ tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
+ # fmt: off
+ expected_tokens = [
+ [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962],
+ [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2],
+ [1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962]
+ ]
+ # fmt: on
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
+
@require_sentencepiece
@require_tokenizers
diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py
index 73979b255e08db..331f63a94a5818 100644
--- a/tests/models/markuplm/test_tokenization_markuplm.py
+++ b/tests/models/markuplm/test_tokenization_markuplm.py
@@ -2311,3 +2311,7 @@ def test_padding_warning_message_fast_tokenizer(self):
"Dummy warning",
cm.records[0].message,
)
+
+ @unittest.skip("Chat is not supported")
+ def test_chat_template(self):
+ pass
diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py
index b4cca18162d806..9d82c468aa3091 100644
--- a/tests/models/tapas/test_tokenization_tapas.py
+++ b/tests/models/tapas/test_tokenization_tapas.py
@@ -1274,3 +1274,7 @@ def test_pretrained_model_lists(self):
@unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self):
pass
+
+ @unittest.skip("Chat is not supported")
+ def test_chat_template(self):
+ pass
diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py
index 9ab29d29d1de12..ef58768e22776b 100644
--- a/tests/models/whisper/test_tokenization_whisper.py
+++ b/tests/models/whisper/test_tokenization_whisper.py
@@ -16,7 +16,7 @@
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
-from transformers.testing_utils import slow
+from transformers.testing_utils import require_jinja, slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -473,3 +473,25 @@ def test_offset_decoding(self):
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])
+
+ @require_jinja
+ def test_tokenization_for_chat(self):
+ multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
+ # This is in English, but it's just here to make sure the chat control tokens are being added properly
+ test_chats = [
+ [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
+ [
+ {"role": "system", "content": "You are a helpful chatbot."},
+ {"role": "user", "content": "Hello!"},
+ {"role": "assistant", "content": "Nice to meet you."},
+ ],
+ [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
+ ]
+ tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
+ expected_tokens = [
+ [3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
+ [3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
+ [37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
+ ]
+ for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
+ self.assertListEqual(tokenized_chat, expected_tokens)
diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py
index efb2215f491005..dfc42ea4815cce 100644
--- a/tests/pipelines/test_pipelines_conversational.py
+++ b/tests/pipelines/test_pipelines_conversational.py
@@ -78,17 +78,23 @@ def get_test_pipeline(self, model, tokenizer, processor):
def run_pipeline_test(self, conversation_agent, _):
# Simple
outputs = conversation_agent(Conversation("Hi there!"))
- self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
+ self.assertEqual(
+ outputs,
+ Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
+ )
# Single list
outputs = conversation_agent([Conversation("Hi there!")])
- self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
+ self.assertEqual(
+ outputs,
+ Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
+ )
# Batch
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
- self.assertEqual(len(conversation_1.past_user_inputs), 0)
- self.assertEqual(len(conversation_2.past_user_inputs), 0)
+ self.assertEqual(len(conversation_1), 1)
+ self.assertEqual(len(conversation_2), 1)
outputs = conversation_agent([conversation_1, conversation_2])
self.assertEqual(outputs, [conversation_1, conversation_2])
@@ -96,32 +102,35 @@ def run_pipeline_test(self, conversation_agent, _):
outputs,
[
Conversation(
- past_user_inputs=["Going to the movies tonight - any suggestions?"],
- generated_responses=[ANY(str)],
+ [
+ {"role": "user", "content": "Going to the movies tonight - any suggestions?"},
+ {"role": "assistant", "content": ANY(str)},
+ ],
+ ),
+ Conversation(
+ [
+ {"role": "user", "content": "What's the last book you have read?"},
+ {"role": "assistant", "content": ANY(str)},
+ ]
),
- Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]),
],
)
# One conversation with history
- conversation_2.add_user_input("Why do you recommend it?")
+ conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2)
self.assertEqual(outputs, conversation_2)
self.assertEqual(
outputs,
Conversation(
- past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
- generated_responses=[ANY(str), ANY(str)],
+ [
+ {"role": "user", "content": "What's the last book you have read?"},
+ {"role": "assistant", "content": ANY(str)},
+ {"role": "user", "content": "Why do you recommend it?"},
+ {"role": "assistant", "content": ANY(str)},
+ ]
),
)
- with self.assertRaises(ValueError):
- conversation_agent("Hi there!")
- with self.assertRaises(ValueError):
- conversation_agent(Conversation())
- # Conversation have been consumed and are not valid anymore
- # Inactive conversations passed to the pipeline raise a ValueError
- with self.assertRaises(ValueError):
- conversation_agent(conversation_2)
@require_torch
@slow
diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py
index aec5e493c57c00..fa3bf96d431a8a 100644
--- a/tests/test_tokenization_common.py
+++ b/tests/test_tokenization_common.py
@@ -50,6 +50,7 @@
check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
+ require_jinja,
require_tf,
require_tokenizers,
require_torch,
@@ -1052,6 +1053,40 @@ def test_sequence_ids(self):
if tokenizer.num_special_tokens_to_add(pair=True):
self.assertIn(None, output.sequence_ids())
+ @require_jinja
+ def test_chat_template(self):
+ dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
+ dummy_conversation = [
+ {"role": "system", "content": "system message"},
+ {"role": "user", "content": "user message"},
+ {"role": "assistant", "content": "assistant message"},
+ ]
+ expected_output = "systemsystem messageuseruser messageassistantassistant message"
+ tokenizers = self.get_tokenizers()
+ for tokenizer in tokenizers:
+ with self.subTest(f"{tokenizer.__class__.__name__}"):
+ output = tokenizer.apply_chat_template(
+ dummy_conversation, chat_template=dummy_template, tokenize=False
+ )
+ self.assertEqual(output, expected_output) # Test we can pass chat_template arg
+ # Check that no error raised when tokenize=True
+ tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
+
+ tokenizer.chat_template = dummy_template
+ self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
+ output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
+ self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
+ tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ tokenizer.save_pretrained(tmp_dir_name)
+ tokenizer = tokenizer.from_pretrained(tmp_dir_name)
+
+ self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
+ output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
+ self.assertEqual(output, expected_output) # Test output is the same after reloading
+ tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
+
def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: