diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index e292c32999d63..95a9be7806633 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -11,4 +11,5 @@ pydantic >= 2.8 torch py-cpuinfo transformers +mistral_common >= 1.3.4 openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/requirements-common.txt b/requirements-common.txt index 534d63feec2b8..61daf99819756 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -26,3 +26,4 @@ librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 importlib_metadata +mistral_common >= 1.3.4 diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6acc057fe588c..4965354c0016b 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -30,9 +30,11 @@ def test_models( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, + tokenizer_mode="mistral") as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/vllm/config.py b/vllm/config.py index 74b18341e5ac9..4e014e43d849a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -61,7 +61,8 @@ class ModelConfig: output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if - available, and "slow" will always use the slow tokenizer. + available, "slow" will always use the slow tokenizer, and + "mistral" will always use the tokenizer from `mistral_common`. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. dtype: Data type for model weights and activations. The "auto" option @@ -246,10 +247,10 @@ def _init_multimodal_config( def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() - if tokenizer_mode not in ["auto", "slow"]: + if tokenizer_mode not in ["auto", "slow", "mistral"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto' or 'slow'.") + "either 'auto', 'slow' or 'mistral'.") self.tokenizer_mode = tokenizer_mode def _verify_embedding_mode(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index efcc646d0e8e2..6e66198e203fc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -198,10 +198,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--tokenizer-mode', type=str, default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], + choices=['auto', 'slow', 'mistral'], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer.') + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 19d1095084293..c5368ac3bf026 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -267,7 +267,7 @@ def apply_chat_template( *, tokenize: bool = False, # Different from HF's default **kwargs: Any, -) -> str: +) -> Union[str, List[int]]: if chat_template is None and tokenizer.chat_template is None: raise ValueError( "As of transformers v4.44, default chat template is no longer " @@ -280,6 +280,4 @@ def apply_chat_template( tokenize=tokenize, **kwargs, ) - assert isinstance(prompt, str) - return prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ecc3c4004bbfb..0edd4bfaecd6a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -390,15 +390,21 @@ def chat( conversations, _ = parse_chat_messages(messages, model_config, tokenizer) - prompts = apply_chat_template( + prompt = apply_chat_template( tokenizer, conversations, chat_template=chat_template, add_generation_prompt=add_generation_prompt) + inputs: PromptInputs + if isinstance(prompt, list) and isinstance(prompt[0], int): + inputs = TokensPrompt(prompt_token_ids=prompt) + else: + inputs = TextPrompt(prompt=prompt) + return self.generate( - prompts, - sampling_params, + inputs, + sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4d8e240a88ee6..d31ac4995fe2f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,8 @@ FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, - PromptAdapterPath) + PromptAdapterPath, + TextTokensPrompt) from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict @@ -130,13 +131,22 @@ async def create_chat_completion( guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) - prompt_inputs = self._tokenize_prompt_input( - request, - tokenizer, - prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if isinstance(prompt, str): + prompt_inputs = self._tokenize_prompt_input( + request, + tokenizer, + prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + assert isinstance(prompt, list) and isinstance( + prompt[0], int + ), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(prompt), prompt_token_ids=prompt) + + assert prompt_inputs is not None sampling_params = request.to_sampling_params( tokenizer, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index b7624c471cdb2..d27d7ba9e67bb 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens( prefix_offset = max( read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) # This is required to guard against out-of-vocab prompt token ids - _replace_none_with_empty(new_tokens) + _replace_none_with_empty(new_tokens) # type: ignore[arg-type] return new_tokens, prefix_offset, read_offset diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 0271aa809320e..2866975850db3 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,5 @@ import os +import warnings from pathlib import Path from typing import Optional, Union @@ -9,12 +10,14 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import BaichuanTokenizer +from vllm.transformers_utils.tokenizers import (BaichuanTokenizer, + MistralTokenizer) from vllm.utils import make_async logger = init_logger(__name__) -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, + MistralTokenizer] def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: @@ -99,45 +102,64 @@ def get_tokenizer( kwargs["gguf_file"] = Path(tokenizer_name).name tokenizer_name = Path(tokenizer_name).parent - try: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs) - except ValueError as e: - # If the error pertains to the tokenizer class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - if (not trust_remote_code and - ("does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e))): - err_msg = ( - "Failed to load the tokenizer. If the tokenizer is a custom " - "tokenizer not yet available in the HuggingFace transformers " - "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - except AttributeError as e: - if "BaichuanTokenizer" in str(e): - # This is for the error "'BaichuanTokenizer' object has no - # attribute 'sp_model'". - tokenizer = BaichuanTokenizer.from_pretrained( + # if tokenizer is from official mistral org + is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" + if is_from_mistral_org and tokenizer_mode != "mistral": + warnings.warn( + 'It is strongly recommended to run mistral models with ' + '`--tokenizer_mode "mistral"` to ensure correct ' + 'encoding and decoding.', + FutureWarning, + stacklevel=2) + + if tokenizer_mode == "mistral": + tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), + revision=revision) + else: + try: + tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, revision=revision, - **kwargs) - else: - raise e + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, + # suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e)): + err_msg = ("Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + except AttributeError as e: + if "BaichuanTokenizer" in str(e): + # This is for the error "'BaichuanTokenizer' object has no + # attribute 'sp_model'". + tokenizer = BaichuanTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + tokenizer = get_cached_tokenizer(tokenizer) - if not isinstance(tokenizer, PreTrainedTokenizerFast): - logger.warning( - "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead.") - return get_cached_tokenizer(tokenizer) + return tokenizer def get_lora_tokenizer(lora_request: LoRARequest, *args, diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index e6b59722c2591..9433f2d48f6f3 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,5 +1,4 @@ from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -__all__ = [ - "BaichuanTokenizer", -] +__all__ = ["BaichuanTokenizer", "MistralTokenizer"] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py new file mode 100644 index 0000000000000..23ecfc0af6be4 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -0,0 +1,174 @@ +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from huggingface_hub import HfApi, hf_hub_download +# yapf: disable +from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) +# yapf: enable +from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer) +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy, + Tekkenizer) + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ConversationMessage + + +@dataclass +class Encoding: + input_ids: List[int] + + +def find_tokenizer_file(files: List[str]): + file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") + + matched_files = [file for file in files if file_pattern.match(file)] + if len(matched_files) > 1: + raise OSError(f"Found {len(matched_files)} files matching the " + "pattern: {matched_files}. Make sure only one Mistral " + "tokenizer is present in {tokenizer_name}.") + elif len(matched_files) == 0: + raise OSError(f"Found {len(matched_files)} files matching the " + "pattern: {matched_files}. Make sure that a Mistral " + "tokenizer is present in {tokenizer_name}.") + + return matched_files[0] + + +class MistralTokenizer: + + def __init__(self, tokenizer: PublicMistralTokenizer) -> None: + self.mistral = tokenizer + self.instruct = tokenizer.instruct_tokenizer + self.tokenizer = tokenizer.instruct_tokenizer.tokenizer + + self.vocab_size = len(self.tokenizer.vocab()) + + assert isinstance(self.tokenizer, + (Tekkenizer, SentencePieceTokenizer)), type( + self.tokenizer) + self._is_tekken = isinstance(self.tokenizer, Tekkenizer) + + if self._is_tekken: + # Make sure special tokens will not raise + self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE + + # the following attributes are set to fit VLLM's design + self.is_fast = True + self.chat_template = True + self.all_special_ids: List[Any] = [] + self.all_special_tokens: List[Any] = [] + self.all_special_tokens_extended: List[Any] = [] + + @classmethod + def from_pretrained(cls, + path_or_repo_id: str, + *, + revision: Optional[str] = None) -> "MistralTokenizer": + if not Path(path_or_repo_id).exists(): + assert len(path_or_repo_id.split("/")) == 2, ( + "You have either provided a non-existent path: " + "{path_or_repo_id} or an invalid HF Hub repo id.") + tokenizer_file = cls._download_mistral_tokenizer_from_hf( + path_or_repo_id, revision) + elif Path(path_or_repo_id).is_dir(): + tokenizer_file_name = find_tokenizer_file( + os.listdir(path_or_repo_id)) + tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) + else: + assert Path( + path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" + + mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) + return cls(mistral_tokenizer) + + @staticmethod + def _download_mistral_tokenizer_from_hf(tokenizer_name: str, + revision: Optional[str]) -> str: + api = HfApi() + repo_info = api.model_info(tokenizer_name) + files = [s.rfilename for s in repo_info.siblings] + + filename = find_tokenizer_file(files) + + tokenizer_file = hf_hub_download(tokenizer_name, + filename=filename, + revision=revision) + return tokenizer_file + + def __call__( + self, + prompt: str, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + # Mistral Tokenizers should not add special tokens + input_ids = self.encode(prompt) + + if truncation: + input_ids = input_ids[:max_length] + + return Encoding(input_ids=input_ids) + + def get_added_vocab(self) -> List[str]: + # Mistral tokenizers have no added vocabulary + return [] + + def encode(self, prompt: str) -> List[int]: + # `encode ` should only be used for prompt completion + # it should never be used for chat_completion. + # For chat completion use `apply_chat_template` + return self.tokenizer.encode(prompt, bos=True, eos=False) + + def apply_chat_template(self, + conversation: List["ConversationMessage"], + tools: Optional[Dict[str, Any]] = None, + **kwargs) -> List[int]: + assert tools is None, "`tools` are not yet supported." + + request = ChatCompletionRequest( + messages=conversation) # type: ignore[type-var] + encoded = self.mistral.encode_chat_completion(request) + + # encode-decode to get clean prompt + return encoded.tokens + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + if self._is_tekken: + return "".join(tokens) + else: + return self.tokenizer.decode(tokens) # type: ignore[arg-type] + + def decode(self, ids: Union[List[int], int]) -> str: + if isinstance(ids, int): + ids = [ids] + return self.tokenizer.decode(ids) + + @property + def eos_token_id(self): + return self.tokenizer.eos_id + + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: Optional[bool] = True) -> List[str]: + # TODO(Patrick) - potentially allow special tokens to not be skipped + assert ( + skip_special_tokens + ), "Skipping special tokens is not supported for Mistral tokenizers." + + assert isinstance(self.tokenizer, + (Tekkenizer, SentencePieceTokenizer)), type( + self.tokenizer) + + tokens = [self.tokenizer.id_to_piece(id) for id in ids] + return tokens + + def __len__(self): + return self.vocab_size