Skip to content

Commit

Permalink
[Frontend] Support for chat completions input in the tokenize endpoint (
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored Jul 16, 2024
1 parent d970115 commit 7a3d2a5
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 244 deletions.
14 changes: 5 additions & 9 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pytest

from vllm.entrypoints.openai.chat_utils import load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer

chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
Expand Down Expand Up @@ -64,8 +64,7 @@ def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)

template_content = tokenizer.chat_template

Expand All @@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike():
mock_serving_chat = MockServingChat(tokenizer)

with pytest.raises(ValueError, match="looks like a file path"):
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
load_chat_template(mock_serving_chat, chat_template=template)


def test_no_load_chat_template_literallike():
Expand All @@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike():
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
load_chat_template(mock_serving_chat, chat_template=template)
template_content = tokenizer.chat_template

assert template_content == template
Expand All @@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
load_chat_template(mock_serving_chat, chat_template=template)

# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
Expand Down
49 changes: 0 additions & 49 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jsonschema
import openai # use the official client for correctness check
import pytest
import requests
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
Expand Down Expand Up @@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=sample_regex,
guided_json=sample_json_schema))


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)

response = requests.post(base_url + "detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}
128 changes: 128 additions & 0 deletions tests/entrypoints/openai/test_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import openai # use the official client for correctness check
import pytest
import requests

from vllm.transformers_utils.tokenizer import get_tokenizer

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer([
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
]) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()

assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

for add_generation in [False, True]:
for add_special in [False, True]:
conversation = [{
"role": "user",
"content": "Hi there!"
}, {
"role": "assistant",
"content": "Nice to meet you!"
}, {
"role": "user",
"content": "Can I ask a question?"
}]

prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()

assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)

response = requests.post(base_url + "/detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()

assert response.json() == {"prompt": prompt}
10 changes: 8 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
Expand All @@ -46,6 +48,7 @@
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization

logger = init_logger('vllm.entrypoints.openai.api_server')

Expand Down Expand Up @@ -86,7 +89,7 @@ async def health() -> Response:

@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
generator = await openai_serving_tokenization.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand All @@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):

@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
generator = await openai_serving_tokenization.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand Down Expand Up @@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global openai_serving_tokenization

openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
Expand All @@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
app.root_path = args.root_path

logger.info("Available routes are:")
Expand Down
Loading

0 comments on commit 7a3d2a5

Please sign in to comment.