diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 55b730812ea94..536a7c96a1e9e 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -4,8 +4,8 @@ import pytest +from vllm.entrypoints.openai.chat_utils import load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( @@ -64,8 +64,7 @@ def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=chatml_jinja_path) + load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + load_chat_template(mock_serving_chat, chat_template=template) def test_no_load_chat_template_literallike(): @@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike(): tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + load_chat_template(mock_serving_chat, chat_template=template) template_content = tokenizer.chat_template assert template_content == template @@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + load_chat_template(mock_serving_chat, chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index f9dbf69c2eaab..35af0b02747e9 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -6,7 +6,6 @@ import jsonschema import openai # use the official client for correctness check import pytest -import requests # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, prompt="Give an example string that fits this regex", extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema)) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): - base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") - - for add_special in [False, True]: - prompt = "This is a test prompt." - tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - - response = requests.post(base_url + "/tokenize", - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt - }) - response.raise_for_status() - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): - base_url = str(client.base_url)[:-3] - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") - - prompt = "This is a test prompt." - tokens = tokenizer.encode(prompt, add_special_tokens=False) - - response = requests.post(base_url + "detokenize", - json={ - "model": model_name, - "tokens": tokens - }) - response.raise_for_status() - assert response.json() == {"prompt": prompt} diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py new file mode 100644 index 0000000000000..d33fd222ee150 --- /dev/null +++ b/tests/entrypoints/openai/test_tokenization.py @@ -0,0 +1,128 @@ +import openai # use the official client for correctness check +import pytest +import requests + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def server(): + with RemoteOpenAIServer([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + ]) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_tokenize_completions(client: openai.AsyncOpenAI, + model_name: str): + base_url = str(client.base_url)[:-3].strip("/") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + + for add_special in [False, True]: + prompt = "This is a test prompt." + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3].strip("/") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + + for add_generation in [False, True]: + for add_special in [False, True]: + conversation = [{ + "role": "user", + "content": "Hi there!" + }, { + "role": "assistant", + "content": "Nice to meet you!" + }, { + "role": "user", + "content": "Can I ask a question?" + }] + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + conversation=conversation, + tokenize=False) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_generation_prompt": + add_generation, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3].strip("/") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + + prompt = "This is a test prompt." + tokens = tokenizer.encode(prompt, add_special_tokens=False) + + response = requests.post(base_url + "/detokenize", + json={ + "model": model_name, + "tokens": tokens + }) + response.raise_for_status() + + assert response.json() == {"prompt": prompt} diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 45c634b4a2991..a35dcbbd6545e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -33,6 +33,8 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser @@ -46,6 +48,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +openai_serving_tokenization: OpenAIServingTokenization logger = init_logger('vllm.entrypoints.openai.api_server') @@ -86,7 +89,7 @@ async def health() -> Response: @router.post("/tokenize") async def tokenize(request: TokenizeRequest): - generator = await openai_serving_completion.create_tokenize(request) + generator = await openai_serving_tokenization.create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest): @router.post("/detokenize") async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_completion.create_detokenize(request) + generator = await openai_serving_tokenization.create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -241,6 +244,7 @@ def run_server(args, llm_engine=None): global openai_serving_chat global openai_serving_completion global openai_serving_embedding + global openai_serving_tokenization openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, @@ -252,6 +256,8 @@ def run_server(args, llm_engine=None): args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) + openai_serving_tokenization = OpenAIServingTokenization( + engine, model_config, served_model_names, args.chat_template) app.root_path = args.root_path logger.info("Available routes are:") diff --git a/vllm/entrypoints/openai/chat_utils.py b/vllm/entrypoints/openai/chat_utils.py new file mode 100644 index 0000000000000..27115391d5b27 --- /dev/null +++ b/vllm/entrypoints/openai/chat_utils.py @@ -0,0 +1,156 @@ +import codecs +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final + +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) + +from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam, + ChatCompletionMessageParam) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import async_get_and_parse_image + +logger = init_logger(__name__) + + +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict): + role: str + content: str + + +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + mm_futures: List[Awaitable[MultiModalDataDict]] = field( + default_factory=list) + + +def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]): + tokenizer = engine.tokenizer + + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode(chat_template, + "unicode_escape") + + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) + else: + logger.warning("No chat template provided. Chat API will not work.") + + +@lru_cache(maxsize=None) +def _image_token_str(engine: OpenAIServing) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + model_type = engine.model_config.hf_config.model_type + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return "<|image_1|>" + if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"): + # These models do not use image tokens in the prompt + return None + if model_type.startswith("llava"): + return engine.tokenizer.decode( + engine.model_config.hf_config.image_token_index) + + else: + raise TypeError("Unknown model type: {model_type}") + + +# TODO: Let user specify how to insert image tokens into prompt +# (similar to chat template) +def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str, + text_prompt: str) -> str: + """Combine image and text prompts for vision language model""" + + # NOTE: For now we assume all model architectures use the same + # image + text prompt format. This may change in the future. + return f"{image_token_str}\n{text_prompt}" + + +def _parse_chat_message_content_parts( + engine: OpenAIServing, + role: str, + parts: Iterable[ChatCompletionContentPartParam], +) -> ChatMessageParseResult: + texts: List[str] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] + + for part in parts: + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] + texts.append(text) + elif part_type == "image_url": + if len(mm_futures) > 0: + raise NotImplementedError( + "Multiple 'image_url' input is currently not supported.") + + image_url = cast(ChatCompletionContentPartImageParam, + part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + image_future = async_get_and_parse_image(image_url["url"]) + mm_futures.append(image_future) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + text_prompt = "\n".join(texts) + + if mm_futures: + image_token_str = _image_token_str(engine) + if image_token_str is not None: + if image_token_str in text_prompt: + logger.warning( + "Detected image token string in the text prompt. " + "Skipping prompt formatting.") + else: + text_prompt = _get_full_image_text_prompt( + engine, + image_token_str=image_token_str, + text_prompt=text_prompt, + ) + + messages = [ConversationMessage(role=role, content=text_prompt)] + + return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) + + +def parse_chat_message_content( + engine: OpenAIServing, + message: ChatCompletionMessageParam, +) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[], mm_futures=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages, mm_futures=[]) + + return _parse_chat_message_content_parts(engine, role, content) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b3f0aae6d002d..2faf061192307 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -738,15 +738,17 @@ class BatchRequestOutput(OpenAIBaseModel): class TokenizeRequest(OpenAIBaseModel): + add_generation_prompt: bool = Field(default=True) + add_special_tokens: bool = Field(default=False) + prompt: Optional[str] = Field(default=None) + messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None) model: str - prompt: str - add_special_tokens: bool = Field(default=True) class TokenizeResponse(OpenAIBaseModel): - tokens: List[int] count: int max_model_len: int + tokens: List[int] class DetokenizeRequest(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 010d6f2ebb909..dbd4521073da9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,22 +1,19 @@ -import codecs import time -from dataclasses import dataclass, field -from functools import cached_property -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, - List, Optional) +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, + Optional) from typing import Sequence as GenericSequence -from typing import TypedDict, Union, cast, final +from typing import Union from fastapi import Request -from openai.types.chat import (ChatCompletionContentPartImageParam, - ChatCompletionContentPartTextParam) from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.chat_utils import (ConversationMessage, + load_chat_template, + parse_chat_message_content) from vllm.entrypoints.openai.protocol import ( - ChatCompletionContentPartParam, ChatCompletionLogProb, - ChatCompletionLogProbs, ChatCompletionLogProbsContent, - ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, @@ -28,7 +25,6 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import async_get_and_parse_image from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -38,19 +34,6 @@ logger = init_logger(__name__) -@final # So that it should be compatible with Dict[str, str] -class ConversationMessage(TypedDict): - role: str - content: str - - -@dataclass(frozen=True) -class ChatMessageParseResult: - messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] = field( - default_factory=list) - - class OpenAIServingChat(OpenAIServing): def __init__(self, @@ -66,131 +49,7 @@ def __init__(self, lora_modules=lora_modules) self.response_role = response_role - self._load_chat_template(chat_template) - - def _load_chat_template(self, chat_template: Optional[str]): - tokenizer = self.tokenizer - - if chat_template is not None: - try: - with open(chat_template, "r") as f: - tokenizer.chat_template = f.read() - except OSError as e: - JINJA_CHARS = "{}\n" - if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") - raise ValueError(msg) from e - - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - tokenizer.chat_template = codecs.decode( - chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", - tokenizer.chat_template) - elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) - else: - logger.warning( - "No chat template provided. Chat API will not work.") - - @cached_property - def image_token_str(self) -> Optional[str]: - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - model_type = self.model_config.hf_config.model_type - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer - return "<|image_1|>" - if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", - "paligemma"): - # These models do not use image tokens in the prompt - return None - if model_type.startswith("llava"): - return self.tokenizer.decode( - self.model_config.hf_config.image_token_index) - - else: - raise TypeError("Unknown model type: {model_type}") - - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - def _get_full_image_text_prompt(self, image_token_str: str, - text_prompt: str) -> str: - """Combine image and text prompts for vision language model""" - - # NOTE: For now we assume all model architectures use the same - # image + text prompt format. This may change in the future. - return f"{image_token_str}\n{text_prompt}" - - def _parse_chat_message_content_parts( - self, - role: str, - parts: Iterable[ChatCompletionContentPartParam], - ) -> ChatMessageParseResult: - texts: List[str] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - - for part in parts: - part_type = part["type"] - if part_type == "text": - text = cast(ChatCompletionContentPartTextParam, part)["text"] - texts.append(text) - elif part_type == "image_url": - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple 'image_url' input is currently not supported." - ) - - image_url = cast(ChatCompletionContentPartImageParam, - part)["image_url"] - - if image_url.get("detail", "auto") != "auto": - logger.warning( - "'image_url.detail' is currently not supported and " - "will be ignored.") - - image_future = async_get_and_parse_image(image_url["url"]) - mm_futures.append(image_future) - else: - raise NotImplementedError(f"Unknown part type: {part_type}") - - text_prompt = "\n".join(texts) - - if mm_futures: - image_token_str = self.image_token_str - if image_token_str is not None: - if image_token_str in text_prompt: - logger.warning( - "Detected image token string in the text prompt. " - "Skipping prompt formatting.") - else: - text_prompt = self._get_full_image_text_prompt( - image_token_str=image_token_str, - text_prompt=text_prompt, - ) - - messages = [ConversationMessage(role=role, content=text_prompt)] - - return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) - - def _parse_chat_message_content( - self, - message: ChatCompletionMessageParam, - ) -> ChatMessageParseResult: - role = message["role"] - content = message.get("content") - - if content is None: - return ChatMessageParseResult(messages=[], mm_futures=[]) - if isinstance(content, str): - messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, mm_futures=[]) - - return self._parse_chat_message_content_parts(role, content) + load_chat_template(self, chat_template) async def create_chat_completion( self, @@ -216,7 +75,7 @@ async def create_chat_completion( mm_futures: List[Awaitable[MultiModalDataDict]] = [] for msg in request.messages: - chat_parsed_result = self._parse_chat_message_content(msg) + chat_parsed_result = parse_chat_message_content(self, msg) conversation.extend(chat_parsed_result.messages) mm_futures.extend(chat_parsed_result.mm_futures) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b53b058b52af3..647fc31410647 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -16,10 +16,7 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - DetokenizeRequest, - DetokenizeResponse, - TokenizeRequest, - TokenizeResponse, UsageInfo) + UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, @@ -457,29 +454,3 @@ def _create_completion_logprobs( tokens=out_tokens, top_logprobs=out_top_logprobs, ) - - async def create_tokenize(self, - request: TokenizeRequest) -> TokenizeResponse: - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, - prompt=request.prompt, - add_special_tokens=request.add_special_tokens) - - return TokenizeResponse(tokens=input_ids, - count=len(input_ids), - max_model_len=self.max_model_len) - - async def create_detokenize( - self, request: DetokenizeRequest) -> DetokenizeResponse: - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, prompt_ids=request.tokens) - - return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py new file mode 100644 index 0000000000000..f441e940c5e5f --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -0,0 +1,73 @@ +from typing import List, Optional + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.chat_utils import (ConversationMessage, + load_chat_template, + parse_chat_message_content) +from vllm.entrypoints.openai.protocol import (DetokenizeRequest, + DetokenizeResponse, + TokenizeRequest, + TokenizeResponse) +from vllm.entrypoints.openai.serving_engine import OpenAIServing + + +class OpenAIServingTokenization(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + chat_template: Optional[str] = None): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + + load_chat_template(self, chat_template) + + async def create_tokenize(self, + request: TokenizeRequest) -> TokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + if not (request.prompt or request.messages): + return self.create_error_response( + "Either `prompt` or `messages` should be provided.") + + if (request.prompt and request.messages): + return self.create_error_response( + "Only one of `prompt` or `messages` should be provided.") + + if request.messages: + conversation: List[ConversationMessage] = [] + + for message in request.messages: + conversation.extend( + parse_chat_message_content(self, message).messages) + + request.prompt = self.tokenizer.apply_chat_template( + add_generation_prompt=request.add_generation_prompt, + conversation=conversation, + tokenize=False) + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, + prompt=request.prompt, + add_special_tokens=request.add_special_tokens) + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, request: DetokenizeRequest) -> DetokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, prompt_ids=request.tokens) + + return DetokenizeResponse(prompt=input_text)