Skip to content

Commit

Permalink
[Frontend] Add tokenize/detokenize endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored May 30, 2024
1 parent 87d41c8 commit 23a6b41
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 4 deletions.
35 changes: 35 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -1154,5 +1155,39 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 17


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize(server, client: openai.AsyncOpenAI, model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post("http://localhost:8000/tokenize",
json={
"add_special_tokens": add_special,
"prompt": prompt
})
assert response.json() == {"tokens": tokens}


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(server, client: openai.AsyncOpenAI, model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)

response = requests.post("http://localhost:8000/detokenize",
json={"tokens": tokens})
assert response.json() == {"prompt": prompt}


if __name__ == "__main__":
pytest.main([__file__])
20 changes: 19 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingRequest, ErrorResponse)
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
Expand Down Expand Up @@ -85,6 +89,20 @@ async def health() -> Response:
return Response(status_code=200)


@app.post("/tokenize")
async def tokenize(request: TokenizeRequest):
response = openai_serving_completion.create_tokenize(request)
assert isinstance(response, TokenizeResponse)
return JSONResponse(content=response.model_dump())


@app.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
response = openai_serving_completion.create_detokenize(request)
assert isinstance(response, DetokenizeResponse)
return JSONResponse(content=response.model_dump())


@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
Expand Down
17 changes: 17 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,20 @@ class BatchRequestOutput(OpenAIBaseModel):
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]


class TokenizeRequest(OpenAIBaseModel):
prompt: str
add_special_tokens: bool = Field(default=True)


class TokenizeResponse(OpenAIBaseModel):
tokens: List[int]


class DetokenizeRequest(OpenAIBaseModel):
tokens: List[int]


class DetokenizeResponse(OpenAIBaseModel):
prompt: str
18 changes: 17 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
Expand Down Expand Up @@ -413,3 +416,16 @@ def _create_completion_logprobs(
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)

def create_tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
return TokenizeResponse(tokens=input_ids)

def create_detokenize(self,
request: DetokenizeRequest) -> DetokenizeResponse:
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text)
11 changes: 9 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission)
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
Expand Down Expand Up @@ -125,7 +126,8 @@ def _maybe_get_lora(
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
EmbeddingRequest],
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Expand Down Expand Up @@ -171,6 +173,11 @@ def _validate_prompt_and_tokenize(
f"generation. Please reduce the length of the input.", )
return input_ids, input_text

# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text

if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
Expand Down

0 comments on commit 23a6b41

Please sign in to comment.