From a4df9256d1139bc3187b66ddb9ed11ee37b93bbd Mon Sep 17 00:00:00 2001
From: Xwdit <44023235+Xwdit@users.noreply.github.com>
Date: Sun, 5 May 2024 15:49:35 +0200
Subject: [PATCH] Tool Call Updated
Co-Authored-By: FlorianJoncour <148003496+florianjoncour@users.noreply.github.com>
---
examples/openai_tools_calls.py | 241 +++++++++++
vllm/entrypoints/openai/api_server.py | 125 ++++--
vllm/entrypoints/openai/cli_args.py | 19 +
vllm/entrypoints/openai/protocol.py | 116 ++++-
vllm/entrypoints/openai/serving_chat.py | 405 ++++++++++++++----
vllm/entrypoints/openai/serving_completion.py | 18 +-
vllm/entrypoints/openai/serving_engine.py | 15 +-
vllm/entrypoints/openai/tools.py | 291 +++++++++++++
.../guided_decoding/outlines_decoding.py | 17 +-
9 files changed, 1115 insertions(+), 132 deletions(-)
create mode 100644 examples/openai_tools_calls.py
create mode 100644 vllm/entrypoints/openai/tools.py
diff --git a/examples/openai_tools_calls.py b/examples/openai_tools_calls.py
new file mode 100644
index 0000000000000..8987cc35181d0
--- /dev/null
+++ b/examples/openai_tools_calls.py
@@ -0,0 +1,241 @@
+"""
+Inspired by the OpenAI example found here:
+ https://platform.openai.com/docs/guides/function-calling/parallel-function-calling
+"""
+
+from openai import OpenAI
+import datetime
+import json
+
+client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1")
+models = client.models.list()
+model = models.data[0].id
+temperature = 0.1
+stream = True
+
+# Can be used to reset the tokenizer and functions templates. Vllm have to be launch with --debug argument:
+# import httpx
+# httpx.post('http://localhost:8000/debug/reload-server')
+
+# This template can be set to None, and the server will use a generic template. It is only defined here to be an example.
+# The generic template is defined in vllm/entrypoints/openai/protocol.py:VllmToolsTemplate.
+# Most values can be empty (except for call_token_start) but cannot be None.
+# This template is used internally and will not be returned to the user, but it can influence the quality of the responses provided by the llm.
+TOOLS_TEMPLATE = {
+ # Keywords used by the model to call functions. Must be defined to catch function calls:
+ "call_token_start":
+ "",
+ "call_token_end":
+ "",
+
+ # Keywords used to define functions. Used to present the list of functions to the llm
+ "tool_token_start":
+ "",
+ "tool_token_end":
+ "",
+
+ # Response keywords. Used to present the values returned by the functions
+ "response_token_start":
+ "",
+ "response_token_end":
+ "",
+
+ # Instructions (guided generation if tool_choice is defined on a specific function)
+ "function_guided":
+ "You must call the following function at least one time to answer the question. You may call it multiple times if needed:",
+
+ # Instructions (auto mode, if tool_choice equals "auto" or None)
+ "function_list_start":
+ "The following is a list of external functions that may be called to complete certain tasks:",
+ "function_list_end":
+ """End of list
+
+* Whenever the user asks you something, you can either respond directly or invoke a function if it is present in the previous list.
+* The decision to invoke a function is yours, only invoke a function if it is necessary to answer the user's question
+* If you need to call at least one function, your message should contain only a list of function calls and nothing else; the function calls are the response.""",
+
+ # Instructions on how to call functions. Must follow call_token_start and call_token_end to get the parser work
+ "function_call_instruct":
+ """For each function call return a valid json object (using quotes) with function name and arguments within { } XML tags as follows::
+* With arguments:
+{ "name": "function_name", "arguments": {"argument_name": "value"} }
+* Without arguments:
+{ "name": "function_name", "arguments": null }
+
+End of functions instructions"""
+}
+
+EXTRA_BODY_OPENAI = {"stop_token_ids": [32000], "tool_params": TOOLS_TEMPLATE}
+
+
+# Example dummy function hard coded to return the same weather
+# In production, this could be your backend API or an external API
+def get_current_weather(location, unit="celsius"):
+ """Get the current weather in a given location"""
+ if unit is None:
+ unit = "celsius"
+ print("Calling get_current_weather client side : (\"%s\", %s)" %
+ (str(location), unit))
+ if isinstance(location, str):
+ if "tokyo" in location.lower():
+ temperature = "50" if unit.lower() == "fahrenheit" else "10"
+ return json.dumps({
+ "location": "Tokyo",
+ "temperature": temperature,
+ "unit": unit
+ })
+ elif "san francisco" in location.lower():
+ temperature = "75" if unit.lower() == "fahrenheit" else "24"
+ return json.dumps({
+ "location": "San Francisco",
+ "temperature": temperature,
+ "unit": unit
+ })
+ elif "paris" in location.lower():
+ temperature = "72" if unit.lower() == "fahrenheit" else "22"
+ return json.dumps({
+ "location": "Paris",
+ "temperature": temperature,
+ "unit": unit
+ })
+ return json.dumps({"location": str(location), "temperature": "unknown"})
+
+
+def get_current_date_utc():
+ print("Calling get_current_date_utc client side.")
+ return datetime.datetime.now(datetime.timezone.utc).strftime(
+ "The current UTC datetime is (day: %A, date (day/month/year): %d/%m/%Y, time: %H:%M)."
+ )
+
+
+def run_conversation(question: str, tool_choice_param):
+ # Step 1: send the conversation and available functions to the model
+ # messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
+ messages = [{"role": "user", "content": question}]
+ tools = [{
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather in a given location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type":
+ "string",
+ "description":
+ "The city and state, e.g. San Francisco, CA as a string",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"]
+ },
+ },
+ "required": ["location"],
+ },
+ },
+ }, {
+ "type": "function",
+ "function": {
+ "name": "get_current_date_utc",
+ "description": "Get the current UTC time",
+ },
+ }]
+ response = client.chat.completions.create(model=model,
+ messages=messages,
+ tools=tools,
+ stream=stream,
+ tool_choice=tool_choice_param,
+ temperature=temperature,
+ extra_body=EXTRA_BODY_OPENAI)
+ response_message = ""
+ tool_calls = []
+ if stream:
+ text_message = ""
+ for chunk in response:
+ if chunk.choices[0].finish_reason is not None:
+ if chunk.choices[0].finish_reason == "tool_calls":
+ tool_calls += chunk.choices[0].delta.tool_calls
+ # print("TEST : %s" % chunk.choices[0].delta.tool_calls)
+ break
+ if chunk.choices[0].delta.content is not None:
+ text_message += chunk.choices[0].delta.content
+ response_message = {
+ "role": "assistant",
+ "content": text_message,
+ "tool_calls": tool_calls
+ }
+ # print(str(response_message))
+ else:
+ if not len(response.choices):
+ return None
+ response_message = response.choices[0].message
+ if response_message.tool_calls is not None:
+ tool_calls = response_message.tool_calls
+ else:
+ print("The tool_calls response is null ?!")
+
+ # Step 2: check if the model wanted to call a function
+ if len(tool_calls):
+ # Step 3: call the function
+ # Note: the JSON response may not always be valid; be sure to handle errors
+ available_functions = {
+ "get_current_weather": get_current_weather,
+ "get_current_date_utc": get_current_date_utc,
+ }
+ messages.append(
+ response_message) # extend conversation with assistant's reply
+ # Step 4: send the info for each function call and function response to the model
+ for tool_call in tool_calls:
+ function_name = tool_call.function.name
+ if function_name in available_functions:
+ function_to_call = available_functions[function_name]
+ if function_name == "get_current_weather":
+ function_args = json.loads(tool_call.function.arguments)
+ function_response = function_to_call(
+ location=function_args.get("location"),
+ unit=function_args.get("unit"),
+ )
+ else:
+ function_response = function_to_call()
+ else:
+ print("The model halucinated a function : %s" % function_name)
+ continue
+
+ messages.append({
+ "tool_call_id": tool_call.id,
+ "role": "tool",
+ "name": function_name,
+ "content": function_response,
+ }) # extend conversation with function response
+ second_response = client.chat.completions.create(
+ model=model, messages=messages, extra_body=EXTRA_BODY_OPENAI
+ ) # get a new response from the model where it can see the function response
+
+ for it_msg, msg in enumerate(messages):
+ print("Message %i:\n %s\n" % (it_msg, str(msg)))
+
+ return second_response
+
+
+print("#############################################################")
+question = "What's the weather like in San Francisco, Tokyo, and Paris ? We also need to know the current date."
+# question = "What's the weather like in Paris ? We also need to know the current date."
+print("New request using templates: %s" % question)
+auto_result = run_conversation(question=question, tool_choice_param="auto")
+print("Final response (tool_choice=\"auto\"):\n%s" % auto_result)
+print("#############################################################\n")
+
+print("#############################################################")
+question = "What's the weather like in Paris ?"
+print("New request using guided generation: %s" % question)
+guided_result = run_conversation(question=question,
+ tool_choice_param={
+ "type": "function",
+ "function": {
+ "name": "get_current_weather"
+ }
+ })
+print("Final response (tool_choice=\"get_current_weather\"):\n%s" %
+ guided_result)
+print("#############################################################\n")
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 7cd51b959a0ea..495deda3afbe1 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -1,4 +1,5 @@
import asyncio
+import sys
import importlib
import inspect
import re
@@ -27,15 +28,17 @@
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
+from vllm.entrypoints.openai.tools import OpenAIToolsPrompter
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds
-openai_serving_chat: OpenAIServingChat
-openai_serving_completion: OpenAIServingCompletion
-openai_serving_embedding: OpenAIServingEmbedding
-
+vllm_engine: AsyncLLMEngine = None
+vllm_engine_args = None
+openai_serving_chat: OpenAIServingChat = None
+openai_serving_completion: OpenAIServingCompletion = None
+openai_serving_embedding: OpenAIServingEmbedding = None
logger = init_logger(__name__)
_running_tasks: Set[asyncio.Task] = set()
@@ -47,9 +50,9 @@ async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
- await engine.do_log_stats()
+ await vllm_engine.do_log_stats()
- if not engine_args.disable_log_stats:
+ if not vllm_engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
@@ -65,6 +68,56 @@ def parse_args():
return parser.parse_args()
+def _loadServingServices():
+ """ Load or reload the OpenAI service.
+ This function should only be called once on initialization, but may be called to reload the API internals.
+ Reloading must be used for development purpose only. """
+ global openai_serving_chat
+ global openai_serving_completion
+ global openai_serving_embedding
+ if openai_serving_chat is not None:
+ del openai_serving_chat
+ if openai_serving_completion is not None:
+ del openai_serving_completion
+ if openai_serving_embedding is not None:
+ del openai_serving_embedding
+
+ event_loop: Optional[asyncio.AbstractEventLoop]
+ try:
+ event_loop = asyncio.get_running_loop()
+ except RuntimeError:
+ event_loop = None
+
+ if event_loop is not None and event_loop.is_running():
+ # If the current is instanced by Ray Serve,
+ # there is already a running event loop
+ model_config = event_loop.run_until_complete(
+ vllm_engine.get_model_config())
+ else:
+ # When using single vLLM without engine_use_ray
+ model_config = asyncio.run(vllm_engine.get_model_config())
+
+ openai_tools_prompter = OpenAIToolsPrompter(
+ debug=args.debug) if args.enable_api_tools else None
+ openai_serving_chat = OpenAIServingChat(
+ engine=vllm_engine,
+ model_config=model_config,
+ served_model_names=served_model_names,
+ response_role=args.response_role,
+ tools_role=args.tools_role,
+ lora_modules=args.lora_modules,
+ chat_template=args.chat_template,
+ openai_tools_prompter=openai_tools_prompter,
+ debug=args.debug,
+ tools_response_merge=args.enable_tools_response_merge)
+
+ openai_serving_completion = OpenAIServingCompletion(
+ vllm_engine, model_config, served_model_names, args.lora_modules)
+ openai_serving_embedding = OpenAIServingEmbedding(vllm_engine,
+ model_config,
+ served_model_names)
+
+
# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
@@ -73,7 +126,11 @@ def parse_args():
@app.exception_handler(RequestValidationError)
-async def validation_exception_handler(_, exc):
+async def validation_exception_handler(req: Request, exc: Exception):
+ if "--debug" in sys.argv:
+ logger.warning("Request error (headers) : %s" % str(dict(req.headers)))
+ logger.warning("Request error (body) : %s" % str(
+ (await req.body()).decode("utf-8")))
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@@ -85,6 +142,16 @@ async def health() -> Response:
return Response(status_code=200)
+if "--debug" in sys.argv:
+
+ @app.post("/debug/reload-server")
+ async def debug_reload_server() -> Response:
+ """Reload the API internals. Danger !"""
+ logger.warning("Debugging server reload called.")
+ _loadServingServices()
+ return Response(status_code=200)
+
+
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
@@ -169,44 +236,32 @@ async def authentication(request: Request, call_next):
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
- raise ValueError(f"Invalid middleware {middleware}. "
- f"Must be a function or a class.")
+ raise ValueError(
+ f"Invalid middleware {middleware}. Must be a function or a class."
+ )
logger.info("vLLM API server version %s", vllm.__version__)
logger.info("args: %s", args)
+ if args.debug:
+ logger.warning(
+ "\n"
+ "##########################################################################\n"
+ "Debugging mode enabled. This should only be used for development purpose.\n"
+ "If It's not the case, you should disable this !\n"
+ "##########################################################################\n"
+ )
+
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
- engine_args = AsyncEngineArgs.from_cli_args(args)
- engine = AsyncLLMEngine.from_engine_args(
- engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
-
- event_loop: Optional[asyncio.AbstractEventLoop]
- try:
- event_loop = asyncio.get_running_loop()
- except RuntimeError:
- event_loop = None
-
- if event_loop is not None and event_loop.is_running():
- # If the current is instanced by Ray Serve,
- # there is already a running event loop
- model_config = event_loop.run_until_complete(engine.get_model_config())
- else:
- # When using single vLLM without engine_use_ray
- model_config = asyncio.run(engine.get_model_config())
+ vllm_engine_args = AsyncEngineArgs.from_cli_args(args)
+ vllm_engine = AsyncLLMEngine.from_engine_args(
+ vllm_engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
+ _loadServingServices()
- openai_serving_chat = OpenAIServingChat(engine, model_config,
- served_model_names,
- args.response_role,
- args.lora_modules,
- args.chat_template)
- openai_serving_completion = OpenAIServingCompletion(
- engine, model_config, served_model_names, args.lora_modules)
- openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
- served_model_names)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index 4c0cb1e4f3e49..50e6e9fd39e50 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -70,11 +70,23 @@ def make_arg_parser():
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
+ parser.add_argument("--enable-api-tools",
+ action="store_true",
+ help="Enable OpenAI-like tools API "
+ "(only function calls are currently supported)")
+ parser.add_argument("--enable-tools-response-merge",
+ action="store_true",
+ help="Enable merging of tools response "
+ "into one message, if multiple tools are used.")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
+ parser.add_argument("--tools-role",
+ type=nullable_str,
+ default="tool",
+ help="The role name of the response of tools.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
@@ -93,6 +105,13 @@ def make_arg_parser():
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
+ parser.add_argument(
+ "--debug",
+ action="store_true",
+ help=
+ "Enable API internals and templates reloading but do not deallocate the engine, "
+ "and also enable extra logs. This should only be used for development purpose."
+ )
parser.add_argument(
"--root-path",
type=nullable_str,
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 139c5716c7cea..7c60d2bf8ff88 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -1,10 +1,10 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
-from typing import Any, Dict, List, Literal, Optional, Union
+from typing import Any, Dict, List, Literal, Optional, Union, Iterable
import torch
-from openai.types.chat import ChatCompletionMessageParam
+from openai.types.chat import ChatCompletionContentPartParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
@@ -62,6 +62,99 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0
+class Function(OpenAIBaseModel):
+ name: str
+ arguments: str
+
+
+class ChatCompletionMessageToolCall(OpenAIBaseModel):
+ index: int
+ id: str
+ type: str
+ function: Function
+
+
+class FunctionDefinition(OpenAIBaseModel):
+ name: str
+ description: str
+ parameters: Optional[Dict[str, object]] = None
+ # See : https://json-schema.org/understanding-json-schema/reference/object
+
+
+class ChatCompletionToolParam(OpenAIBaseModel):
+ type: str = "function"
+ function: FunctionDefinition = None
+
+
+class ChatCompletionSystemMessage(OpenAIBaseModel):
+ role: str = "system"
+ content: str
+ name: Optional[str] = None
+
+
+class ChatCompletionUserMessage(OpenAIBaseModel):
+ role: str = "user"
+ content: Union[str, Iterable[ChatCompletionContentPartParam]]
+ name: Optional[str] = None
+
+
+class ChatCompletionAssistantMessage(OpenAIBaseModel):
+ role: str = "assistant"
+ content: Optional[str] = None
+ name: Optional[str] = None
+ tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
+
+
+class ChatCompletionToolMessage(OpenAIBaseModel):
+ role: str = "tool"
+ content: Optional[str] = None
+ name: Optional[str] = None
+ tool_call_id: str
+
+
+class ChatCompletionNamedFunction(OpenAIBaseModel):
+ name: str
+
+
+class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
+ function: ChatCompletionNamedFunction
+ type: Optional[Literal["function"]] = None
+
+
+class VllmToolsTemplate(OpenAIBaseModel):
+ # Extension to define the tools template. The strings may be empty but not None
+ call_token_start: str = ""
+ call_token_end: str = ""
+ tool_token_start: str = ""
+ tool_token_end: str = ""
+ response_token_start: str = ""
+ response_token_end: str = ""
+
+ function_guided: str = "You have the capability to call functions, you must call the following function (tool) at least one time to complete certain tasks or answer questions, and you may call it multiple times if needed:"
+
+ function_list_start: str = """You have the capability to call functions, the following is a list of external functions (tools) that you can called proactively to complete certain tasks or answer questions, if you want to call at least one function, your message/response should contain only function calls in XML format and nothing else:
+"""
+
+ function_list_end: str = """"""
+
+ function_call_instruct: str = '''For each function call you always need to return a valid json object (using quotes) with function name and arguments within { } XML tags as follows:
+* If you are calling a function with arguments:
+"""
+{ "name": "function_name", "arguments": {"argument_name": "value"} }
+"""
+
+* If you are calling a function without arguments:
+"""
+{ "name": "function_name", "arguments": null }
+"""
+
+* If you are calling multiple functions in parallel with/without arguments:
+"""
+{ "name": "function_name_1", "arguments": {"argument_name": "value"} }
+{ "name": "function_name_2", "arguments": null }
+"""'''
+
+
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
@@ -70,7 +163,9 @@ class ResponseFormat(OpenAIBaseModel):
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
- messages: List[ChatCompletionMessageParam]
+ messages: List[Union[ChatCompletionSystemMessage,
+ ChatCompletionAssistantMessage,
+ ChatCompletionUserMessage, ChatCompletionToolMessage]]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
@@ -88,8 +183,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
user: Optional[str] = None
+ tools: Optional[List[ChatCompletionToolParam]] = None
+ tool_choice: Optional[Union[Literal["auto", "none"],
+ ChatCompletionNamedToolChoiceParam]] = "auto"
# doc: begin-chat-completion-sampling-params
+ tool_params: Optional[VllmToolsTemplate] = None
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
@@ -452,7 +551,8 @@ class EmbeddingResponse(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
- content: str
+ content: Optional[str] = None
+ tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(OpenAIBaseModel):
@@ -472,9 +572,17 @@ class ChatCompletionResponse(OpenAIBaseModel):
usage: UsageInfo
+class ChoiceDeltaToolCall(OpenAIBaseModel):
+ index: int
+ id: str
+ type: str
+ function: Function
+
+
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
+ tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 1b469fc59b076..95c61677e7500 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -12,6 +12,8 @@
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
+ VllmToolsTemplate, ChatCompletionAssistantMessage,
+ ChatCompletionToolMessage, ChatCompletionNamedToolChoiceParam,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
@@ -21,6 +23,7 @@
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.utils import random_uuid
+from vllm.entrypoints.openai.tools import OpenAIToolsPrompter, ChatPromptCapture
logger = init_logger(__name__)
@@ -38,14 +41,22 @@ def __init__(self,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
+ tools_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
- chat_template: Optional[str] = None):
+ chat_template: Optional[str] = None,
+ openai_tools_prompter: OpenAIToolsPrompter = None,
+ debug: bool = False,
+ tools_response_merge: bool = False):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
- lora_modules=lora_modules)
-
+ lora_modules=lora_modules,
+ debug=debug)
self.response_role = response_role
+ self.tools_role = tools_role
+ self.default_tools_template = VllmToolsTemplate()
+ self.openai_tools_prompter = openai_tools_prompter
+ self.tools_response_merge = tools_response_merge
self._load_chat_template(chat_template)
def _load_chat_template(self, chat_template: Optional[str]):
@@ -105,23 +116,65 @@ async def create_chat_completion(
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
- See https://platform.openai.com/docs/api-reference/chat/create
- for the API specification. This API mimics the OpenAI
- ChatCompletion API.
-
- NOTE: Currently we do not support the following feature:
- - function_call (Users should implement this by themselves)
+ See https://platform.openai.com/docs/api-reference/chat/create
+ for the API specification. This API mimics the OpenAI ChatCompletion API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
+ if self.openai_tools_prompter is not None:
+ if isinstance(request.tool_params, VllmToolsTemplate):
+ if len(request.tool_params.call_token_start) == 0:
+ raise ValueError(
+ "Error, the tool_params.call_token_start can't be empty !"
+ )
+ else:
+ # No need to allocate it for each requests, if this param is not set, we use the default value.
+ request.tool_params = self.default_tools_template
+ self.openai_tools_prompter.inject_prompt(request)
+
+ # FIXME : As on dec 2023, the tokenizer only accept "role" and "content" attributes.
+ # FIXME : So we manually copy other attributes into "content" when needed.
+ merged_messages = []
+ last_tool_message = None
+
+ for m in request.messages:
+ if isinstance(m, ChatCompletionToolMessage
+ ) and m.tool_call_id is not None:
+ m.content = self.openai_tools_prompter.content_from_tool(
+ m, request.tool_params)
+ m.role = self.tools_role
+ if self.tools_response_merge:
+ if last_tool_message is None:
+ last_tool_message = m
+ else:
+ last_tool_message.content += "\n" + m.content
+ else:
+ merged_messages.append(m)
+ else:
+ if last_tool_message is not None:
+ merged_messages.append(last_tool_message)
+ last_tool_message = None
+
+ if isinstance(m, ChatCompletionAssistantMessage
+ ) and m.tool_calls is not None:
+ m.content = self.openai_tools_prompter.content_from_assistant(
+ m, request.tool_params)
+
+ merged_messages.append(m)
+
+ if last_tool_message is not None:
+ merged_messages.append(last_tool_message)
+
+ request.messages = merged_messages
+
try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
- m["role"], m["content"])
+ m.role, m.content)
conversation.extend(messages)
@@ -134,6 +187,15 @@ async def create_chat_completion(
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
+ if self.debug: # ease the templates development
+ logger.info("\n######## Development infos (debug) ########")
+ logger.info("API tools status: %s" %
+ str(self.openai_tools_prompter is not None))
+ logger.info("- Request:\n%s" % str(request.dict()))
+ logger.info("")
+ logger.info("- Prompt:\n%s" % str(prompt))
+ logger.info("##############################################")
+
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
@@ -176,7 +238,7 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
- return request.messages[-1]["role"]
+ return request.messages[-1].role
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
@@ -188,6 +250,28 @@ async def chat_completion_stream_generator(
chunk_object_type = "chat.completion.chunk"
first_iteration = True
+ if isinstance(
+ request.tool_choice,
+ ChatCompletionNamedToolChoiceParam): # Guided function call
+ tools_capture_texts = [
+ ChatPromptCapture(self.openai_tools_prompter,
+ request.tool_params)
+ for i in range(request.n)
+ ]
+ is_tools_guided_generation = True
+ else:
+ is_tools_guided_generation = False
+ if self.openai_tools_prompter is not None and (
+ isinstance(request.tool_choice, str)
+ and request.tool_choice == "auto"):
+ tools_capture_texts = [
+ ChatPromptCapture(self.openai_tools_prompter,
+ request.tool_params)
+ for i in range(request.n)
+ ]
+ else:
+ tools_capture_texts = None
+
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
@@ -199,8 +283,7 @@ async def chat_completion_stream_generator(
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
- # Send first response for each request.n (index) with
- # the role
+ # Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
@@ -217,8 +300,7 @@ async def chat_completion_stream_generator(
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
- # Send response to echo the input portion of the
- # last message
+ # Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
@@ -228,12 +310,11 @@ async def chat_completion_stream_generator(
if last_msg_content:
for i in range(request.n):
- choice_data = (
- ChatCompletionResponseStreamChoice(
- index=i,
- delta=DeltaMessage(
- content=last_msg_content),
- finish_reason=None))
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(
+ content=last_msg_content),
+ finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
@@ -252,6 +333,17 @@ async def chat_completion_stream_generator(
if finish_reason_sent[i]:
continue
+ current_capture = tools_capture_texts[
+ i] if tools_capture_texts is not None else None
+
+ if current_capture is not None and current_capture.after_new_function_call:
+ current_capture.after_new_function_call = False
+ # If the last token is a new line char right after a function call, we ignore it.
+ # Otherwise, each function call creates a line break in the content part of the response.
+ if output.text[len(previous_texts[i]):] == "\n":
+ previous_texts[i] = output.text
+ continue
+
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
@@ -266,51 +358,153 @@ async def chat_completion_stream_generator(
else:
logprobs = None
- delta_text = output.text[len(previous_texts[i]):]
- previous_texts[i] = output.text
- previous_num_tokens[i] = len(output.token_ids)
- if output.finish_reason is None:
- # Send token-by-token response for each request.n
- choice_data = ChatCompletionResponseStreamChoice(
- index=i,
- delta=DeltaMessage(content=delta_text),
- logprobs=logprobs,
- finish_reason=None)
- chunk = ChatCompletionStreamResponse(
- id=request_id,
- object=chunk_object_type,
- created=created_time,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- else:
- # Send the finish response for each request.n only once
- prompt_tokens = len(res.prompt_token_ids)
- final_usage = UsageInfo(
- prompt_tokens=prompt_tokens,
- completion_tokens=previous_num_tokens[i],
- total_tokens=prompt_tokens +
- previous_num_tokens[i],
- )
- choice_data = ChatCompletionResponseStreamChoice(
- index=i,
- delta=DeltaMessage(content=delta_text),
- logprobs=logprobs,
- finish_reason=output.finish_reason,
- stop_reason=output.stop_reason)
- chunk = ChatCompletionStreamResponse(
- id=request_id,
- object=chunk_object_type,
- created=created_time,
- choices=[choice_data],
- model=model_name)
- if final_usage is not None:
- chunk.usage = final_usage
- data = chunk.model_dump_json(exclude_unset=True,
- exclude_none=True)
- yield f"data: {data}\n\n"
- finish_reason_sent[i] = True
+ if is_tools_guided_generation: # Manage tools calling when request.tool_choice set a function
+ if len(current_capture.content) == 0:
+ current_capture.startNamedFunction(
+ request.tool_choice)
+ current_token: str = output.text[len(previous_texts[i]
+ ):]
+ if len(current_token):
+ current_capture.content += current_token
+ if current_capture.checkBracketsFunctionCall(
+ request.tool_params
+ ): # We have the complete call block
+ previous_texts[i] = output.text
+ current_capture.closeNamedFunction()
+ current_capture.make_calls_list()
+ current_capture.reset(False)
+ current_capture.after_new_function_call = True
+ else: # Manage tools calling when request.tool_choice is "auto"
+ if (self.openai_tools_prompter is not None) and \
+ (current_capture is not None) and \
+ (request.tools is not None) and \
+ (output.finish_reason is None):
+ if len(current_capture.content) == 0:
+ current_token: str = output.text[
+ len(previous_texts[i]):]
+ if current_capture.func_call_token_pre(
+ ) in current_token:
+ start_pos: int = current_token.index(
+ current_capture.func_call_token_pre())
+ current_capture.content = current_token[
+ start_pos:] # With some models the completion may start by a space.
+ current_capture.prefix_size = len(
+ output.text) - len(
+ current_capture.content)
+ current_capture.maybe_function_call = True
+ else: # Maybe a function call...
+ current_token: str = output.text[
+ len(current_capture.content) +
+ current_capture.prefix_size:]
+ current_capture.content += current_token
+ if len(
+ current_capture.content
+ ) < current_capture.func_call_token_size():
+ pass
+ elif not current_capture.is_function_call:
+ if current_capture.content.startswith(
+ current_capture.func_call_token(
+ )): # Function call !
+ current_capture.is_function_call = True
+ else: # This is not a function call...
+ current_capture.reset(False)
+ else: # Currently extracting the function call
+ if current_capture.checkBracketsFunctionCall(
+ request.tool_params
+ ): # We have the complete call block
+ previous_texts[i] = output.text
+ current_capture.make_calls_list()
+ current_capture.reset(False)
+ current_capture.after_new_function_call = True
+ else:
+ pass
+
+ if current_capture is None or (
+ isinstance(current_capture, ChatPromptCapture)
+ and not current_capture.maybe_function_call):
+ delta_text = output.text[len(previous_texts[i]):]
+ previous_texts[i] = output.text
+ previous_num_tokens[i] = len(output.token_ids)
+ if output.finish_reason is None:
+ if len(delta_text) > 0:
+ # Send token-by-token response for each request.n
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(content=delta_text),
+ logprobs=logprobs,
+ finish_reason=None)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(
+ exclude_unset=True)
+ yield f"data: {data}\n\n"
+ else:
+ if output.finish_reason == "stop" and (
+ isinstance(current_capture,
+ ChatPromptCapture) and
+ (current_capture.num_calls() > 0)):
+ tools_calls_list = current_capture.to_ChoiceDeltaToolCallList(
+ )
+
+ if self.debug:
+ for t in tools_calls_list:
+ logger.warning(
+ "Calling tools: %s" %
+ str(t.model_dump_json()))
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(
+ content=None,
+ tool_calls=tools_calls_list),
+ finish_reason="tool_calls",
+ stop_reason=output.stop_reason)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+ chunk.usage = UsageInfo(
+ prompt_tokens=len(res.prompt_token_ids),
+ completion_tokens=len(output.token_ids),
+ total_tokens=len(res.prompt_token_ids) +
+ len(output.token_ids),
+ )
+ data = chunk.model_dump_json(
+ exclude_unset=True, exclude_none=True)
+ yield f"data: {data}\n\n"
+ else:
+ # Send the finish response for each request.n only once
+ prompt_tokens = len(res.prompt_token_ids)
+ final_usage = UsageInfo(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=previous_num_tokens[i],
+ total_tokens=prompt_tokens +
+ previous_num_tokens[i],
+ )
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(content=delta_text),
+ logprobs=logprobs,
+ finish_reason=output.finish_reason,
+ stop_reason=output.stop_reason)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+ if final_usage is not None:
+ chunk.usage = final_usage
+ data = chunk.model_dump_json(
+ exclude_unset=True, exclude_none=True)
+ yield f"data: {data}\n\n"
+ finish_reason_sent[i] = True
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
@@ -337,11 +531,12 @@ async def chat_completion_full_generator(
assert final_res is not None
choices = []
-
role = self.get_chat_request_role(request)
+
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
+ tools_calls_validation = False
if request.logprobs:
logprobs = self._create_logprobs(
@@ -352,14 +547,74 @@ async def chat_completion_full_generator(
else:
logprobs = None
- choice_data = ChatCompletionResponseChoice(
- index=output.index,
- message=ChatMessage(role=role, content=output.text),
- logprobs=logprobs,
- finish_reason=output.finish_reason,
- stop_reason=output.stop_reason,
- )
- choices.append(choice_data)
+ # Manage tools calling
+ if self.openai_tools_prompter is not None and \
+ request.tools is not None:
+ current_capture = ChatPromptCapture(self.openai_tools_prompter,
+ request.tool_params)
+
+ if isinstance(request.tool_choice,
+ ChatCompletionNamedToolChoiceParam
+ ): # Guided function call
+ current_capture.startNamedFunction(request.tool_choice)
+ current_capture.content += output.text
+ current_capture.closeNamedFunction()
+ current_capture.make_calls_list()
+ current_capture.reset(False)
+ else:
+ start_pos = 0
+ while True:
+ pos = output.text.find(
+ current_capture.func_call_token(), start_pos, -1)
+ if pos < 0:
+ break
+ start_bloc = output.text.find("{", pos, -1)
+ if start_bloc < 0:
+ break
+ if (start_bloc -
+ (pos +
+ current_capture.func_call_token_size())) > 1:
+ break
+ count = 1
+ bloc_end = start_bloc + 1
+ for it_ch in range(start_bloc + 1, len(output.text),
+ 1):
+ ch = output.text[it_ch]
+ bloc_end += 1
+ if ch == "{":
+ count += 1
+ elif ch == "}":
+ count -= 1
+ if count == 0: # We have the complete call block
+ current_capture.content = output.text[
+ start_bloc:bloc_end]
+ current_capture.make_calls_list()
+ current_capture.reset(False)
+ break
+ start_pos = bloc_end + 1
+
+ if current_capture.num_calls() > 0:
+ tools_calls_validation = True
+ tools_calls_list = current_capture.to_ChatCompletionMessageToolCallList(
+ )
+ message = ChatMessage(role=role,
+ content=None,
+ tool_calls=tools_calls_list)
+ choice_data = ChatCompletionResponseChoice(
+ index=output.index,
+ message=message,
+ logprobs=logprobs,
+ finish_reason="tool_calls",
+ stop_reason=output.stop_reason)
+ choices.append(choice_data)
+ if not tools_calls_validation:
+ choice_data = ChatCompletionResponseChoice(
+ index=output.index,
+ message=ChatMessage(role=role, content=output.text),
+ logprobs=logprobs,
+ finish_reason=output.finish_reason,
+ stop_reason=output.stop_reason)
+ choices.append(choice_data)
if request.echo:
last_msg_content = ""
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 158d8ed7fbbf5..a4e71c0ef7e30 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -46,8 +46,9 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
- raise ValueError("prompt must be a string, array of strings, "
- "array of tokens, or array of token arrays")
+ raise ValueError(
+ "prompt must be a string, array of strings, array of tokens, or array of token arrays"
+ )
return prompt_is_tokens, prompts
@@ -59,7 +60,8 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
- lora_modules=lora_modules)
+ lora_modules=lora_modules,
+ debug=False)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
@@ -133,8 +135,7 @@ async def create_completion(self, request: CompletionRequest,
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
- # results. In addition, we do not stream the results when use
- # beam search.
+ # results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
@@ -202,8 +203,7 @@ async def completion_stream_generator(
for output in res.outputs:
i = output.index + prompt_idx * request.n
- # TODO(simon): optimize the performance by avoiding full
- # text O(n^2) sending.
+ # TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
@@ -212,8 +212,8 @@ async def completion_stream_generator(
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
- elif (request.echo and request.max_tokens > 0
- and not has_echoed[i]):
+ elif request.echo and request.max_tokens > 0 and not has_echoed[
+ i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 58a1c2f7e73fe..d645ea19ef091 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -29,11 +29,15 @@ class LoRAModulePath:
class OpenAIServing:
- def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
+ def __init__(self,
+ engine: AsyncLLMEngine,
+ model_config: ModelConfig,
served_model_names: List[str],
- lora_modules: Optional[List[LoRAModulePath]]):
+ lora_modules: Optional[List[LoRAModulePath]],
+ debug: bool = False):
super().__init__()
+ self.debug = debug
self.engine = engine
self.max_model_len = model_config.max_model_len
@@ -124,6 +128,8 @@ def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
+ if self.debug:
+ logger.warning("Error response : %s" % message)
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
@@ -214,9 +220,8 @@ def _validate_prompt_and_tokenize(
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
- f"This model's maximum context length is "
- f"{self.max_model_len} tokens. However, you requested "
- f"{request.max_tokens + token_num} tokens "
+ f"This model's maximum context length is {self.max_model_len} tokens. "
+ f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
diff --git a/vllm/entrypoints/openai/tools.py b/vllm/entrypoints/openai/tools.py
new file mode 100644
index 0000000000000..bbda8a6d4698e
--- /dev/null
+++ b/vllm/entrypoints/openai/tools.py
@@ -0,0 +1,291 @@
+import json
+from typing import List, Dict, Union
+from vllm.logger import init_logger
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest, ChatCompletionToolParam, VllmToolsTemplate,
+ ChoiceDeltaToolCall, ChatCompletionMessageToolCall, Function,
+ ChatCompletionAssistantMessage, ChatCompletionToolMessage,
+ ChatCompletionNamedToolChoiceParam)
+
+logger = init_logger(__name__)
+
+
+class ToolsCallsTemplate:
+ """ This template system is only used when the tool_choice is set to "auto" """
+
+ def __init__(self):
+ pass
+
+ def render_toolcalls(self, tool_calls: List[ChatCompletionMessageToolCall],
+ tool_params: VllmToolsTemplate) -> str:
+ parts = []
+ for call in tool_calls:
+ try:
+ if call.function.arguments:
+ _arguments = json.loads(call.function.arguments)
+ else:
+ _arguments = None
+
+ _call = {"name": call.function.name, "arguments": _arguments}
+ part = f'{tool_params.call_token_start}{json.dumps(_call)}{tool_params.call_token_end}\n'
+ parts.append(part)
+ except json.JSONDecodeError:
+ continue
+
+ return ''.join(parts).strip()
+
+ def render_toolresponse(self, message: ChatCompletionToolMessage,
+ tool_params: VllmToolsTemplate) -> str:
+ return tool_params.response_token_start + str(
+ message.content.strip()) + tool_params.response_token_end
+
+ def render_tool(self, tool: ChatCompletionToolParam,
+ tool_params: VllmToolsTemplate) -> str:
+ if tool.function.parameters is None or len(
+ tool.function.parameters) == 0:
+ return f"""{tool_params.tool_token_start}{{"name": "{tool.function.name}",""" \
+ f""""description": "{tool.function.description}", "parameters": null}}{tool_params.tool_token_end}"""
+ else:
+ json_params = json.dumps(tool.function.parameters)
+ return f"""{tool_params.tool_token_start}{{"name": "{tool.function.name}",""" \
+ f""""description": "{tool.function.description}", "parameters": {json_params}}}{tool_params.tool_token_end}"""
+
+ def render_toolslist(self, tool_choice: Union[
+ str, ChatCompletionNamedToolChoiceParam],
+ tools_list: List[ChatCompletionToolParam],
+ tool_params: VllmToolsTemplate) -> str:
+ if isinstance(tool_choice, str) and (tool_choice == "auto"
+ or tool_choice == "none"):
+ tool_choice = None
+ if tool_choice is not None: # Guided generation
+ for tool in tools_list:
+ # Search if the tool_choice is in the tools_list
+ if tool.type == "function" and tool.function.name == tool_choice:
+ instructions = tool_params.function_guided + "\n" + self.render_tool(
+ tool, tool_params=tool_params) + "\n"
+ instructions += tool_params.function_call_instruct
+ return instructions.strip()
+ return "" # Tool not found. What should we do ?
+ else:
+ instructions = tool_params.function_list_start + "\n"
+ for tool in tools_list:
+ instructions += self.render_tool(
+ tool, tool_params=tool_params) + "\n"
+ instructions = instructions.strip()
+ instructions += "\n" + tool_params.function_list_end + "\n"
+ instructions += tool_params.function_call_instruct
+ return instructions.strip()
+
+
+class OpenAIToolsPrompter:
+ """
+ https://platform.openai.com/docs/assistants/tools
+ """
+
+ def __init__(self, debug: bool):
+ self.debug = debug
+ self.template = ToolsCallsTemplate()
+
+ def content_from_assistant(self, message: ChatCompletionAssistantMessage,
+ tool_params: VllmToolsTemplate) -> str:
+ text = self.template.render_toolcalls(message.tool_calls,
+ tool_params=tool_params)
+ if message.content is None or len(message.content.strip()) == 0:
+ return text
+ else:
+ return message.content.strip() + "\n\n" + text
+
+ def content_from_tool(self, message: ChatCompletionToolMessage,
+ tool_params: VllmToolsTemplate) -> str:
+ return self.template.render_toolresponse(message,
+ tool_params=tool_params)
+
+ def inject_prompt(self, request: ChatCompletionRequest):
+ """ Generate and inject the prompt for tools calls. """
+ if request.tools is not None and len(request.tools):
+ if (isinstance(request.tool_choice,
+ ChatCompletionNamedToolChoiceParam)):
+ if request.tool_choice.type == "function":
+ select_tool_choice = request.tool_choice.function.name
+ else:
+ select_tool_choice = None
+ else:
+ select_tool_choice = None
+ text_inject = self.template.render_toolslist(
+ tool_choice=select_tool_choice,
+ tools_list=request.tools,
+ tool_params=request.tool_params)
+ if isinstance(request.messages, str):
+ request.messages = request.messages.strip(
+ ) + "\n\n" + text_inject
+ elif isinstance(request.messages,
+ List) and len(request.messages) >= 1:
+ request.messages[0].content = request.messages[
+ 0].content.strip() + "\n\n" + text_inject
+
+
+class ChatPromptCapture:
+
+ def __init__(self, prompter: OpenAIToolsPrompter,
+ tool_params: VllmToolsTemplate):
+ self.content: str = ""
+ self.prompter = prompter
+ self.maybe_function_call = False
+ self.is_function_call = False
+ self.prefix_size = 0
+ self.calls_list: List[dict] = []
+ self.after_new_function_call = False
+ self.named_function_call = False # True if the function call is forced using request.tool_choice
+ self.tool_params = tool_params
+
+ def __str__(self):
+ """ Show current state. For debugging purpose. """
+ return (
+ f"ChatPromptCapture {{\n"
+ f" maybe_function_call={self.maybe_function_call},\n"
+ f" is_function_call={self.is_function_call},\n"
+ f" prefix_size={self.prefix_size},\n"
+ f" after_new_function_call={self.after_new_function_call},\n"
+ f" content={self.content},\n"
+ f" calls_list={self.calls_list},\n"
+ f"}}")
+
+ def reset(self, reset_calls_list=False):
+ self.content = ""
+ self.maybe_function_call = False
+ self.is_function_call = False
+ self.named_function_call = False
+ self.prefix_size = 0
+ if reset_calls_list:
+ self.calls_list = []
+
+ def func_call_token_pre(self) -> str:
+ return self.tool_params.call_token_start[0]
+ # return self.call_token_pre
+
+ def func_call_token_size(self) -> int:
+ return len(self.tool_params.call_token_start)
+ # return len(self.call_token_str)
+
+ def func_call_token(self) -> str:
+ return self.tool_params.call_token_start
+ # return self.call_token_str
+
+ def num_calls(self):
+ return len(self.calls_list)
+
+ def startNamedFunction(self,
+ tool_choice: ChatCompletionNamedToolChoiceParam):
+ # Should not have to be templated since it's defined by the OpenAI reference:
+ self.content = "{ \"name\": \"" + tool_choice.function.name + "\", \"arguments\": "
+ self.named_function_call = True
+ self.prefix_size = 0
+ self.is_function_call = True
+
+ def closeNamedFunction(self):
+ self.content += "}"
+
+ def checkBracketsFunctionCall(self,
+ tool_params: VllmToolsTemplate) -> bool:
+ """ Count brackets in a string to check if the function call is complete. """
+ if self.named_function_call:
+ if self.content.rfind("}", -6) != -1:
+ c1 = self.content.count("{")
+ c2 = self.content.count("}")
+ return c1 == (c2 + 1)
+ else:
+ content_end = self.content[-(len(tool_params.call_token_end) +
+ 6):].rstrip()
+ if tool_params.call_token_end in content_end and content_end.find(
+ "}") != -1:
+ c1 = self.content.count("{")
+ c2 = self.content.count("}")
+ return c1 == c2 # We have the complete call block
+
+ def make_calls_list(self):
+ """ Convert the extracted text to json function calls. """
+ if self.named_function_call:
+ if self._add_calls_list(self.content) == 0:
+ return
+ else:
+ calls_list = self.content.split(self.tool_params.call_token_start)
+ for v_call in calls_list:
+ if len(self.tool_params.call_token_end):
+ content = v_call.split(self.tool_params.call_token_end)[0]
+ else:
+ content = v_call
+ self._add_calls_list(content)
+
+ def _add_calls_list(self, content: str) -> int:
+ """ Returns the number of added tools calls. """
+ count = len(self.calls_list)
+ if len(content) > 1:
+ try:
+ call_data = json.loads(content)
+ except json.decoder.JSONDecodeError as exc:
+ # Simply ignore invalid functions calls...
+ if self.named_function_call:
+ logger.warning(
+ "Error in parsing the function call. This should not happen since it's guided generation : %s"
+ % str(exc))
+ else:
+ logger.warning(
+ "Error in parsing the function call. The model have probably generated a wrong synthax : %s"
+ % str(exc))
+ return 0
+ if isinstance(call_data, List):
+ for call_elem in call_data:
+ if isinstance(call_elem, Dict):
+ if "name" in call_elem:
+ self.calls_list.append(call_elem)
+ elif isinstance(call_data, Dict):
+ if "name" in call_data:
+ self.calls_list.append(call_data)
+ if self.prompter.debug:
+ logger.info("Catched tool call : %s" % str(call_data))
+ return len(self.calls_list) - count
+
+ def to_ChatCompletionMessageToolCall(
+ self, call_id: int) -> Union[ChatCompletionMessageToolCall, None]:
+ if len(self.calls_list) and call_id < len(self.calls_list):
+ call = self.calls_list[call_id]
+ arguments = call["arguments"] if "arguments" in call else None
+ if arguments is None and "parameters" in call:
+ arguments = call["parameters"]
+ function_call = Function(name=call["name"],
+ arguments=json.dumps(arguments)
+ if arguments is not None else "")
+ return ChatCompletionMessageToolCall(index=call_id,
+ id="call_" + call["name"] +
+ "_" + str(call_id),
+ type="function",
+ function=function_call)
+ return None
+
+ def to_ChatCompletionMessageToolCallList(
+ self) -> List[ChatCompletionMessageToolCall]:
+ calls_count = self.num_calls()
+ tools_calls_list = []
+ for call_id in range(calls_count):
+ tools_calls_list.append(
+ self.to_ChatCompletionMessageToolCall(call_id=call_id))
+ return tools_calls_list
+
+ def to_ChoiceDeltaToolCall(
+ self, call_id: int) -> Union[ChoiceDeltaToolCall, None]:
+ mesg = self.to_ChatCompletionMessageToolCall(call_id)
+ if mesg is not None:
+ return ChoiceDeltaToolCall(index=call_id,
+ id=mesg.id,
+ type=mesg.type,
+ function=mesg.function)
+ return None
+
+ def to_ChoiceDeltaToolCallList(
+ self) -> List[Union[ChoiceDeltaToolCall, None]]:
+ calls_count = self.num_calls()
+ tools_calls_list = []
+ for call_id in range(calls_count):
+ tools_calls_list.append(
+ self.to_ChoiceDeltaToolCall(call_id=call_id))
+ return tools_calls_list
diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py
index 8403604286903..1539a83ade595 100644
--- a/vllm/model_executor/guided_decoding/outlines_decoding.py
+++ b/vllm/model_executor/guided_decoding/outlines_decoding.py
@@ -10,8 +10,9 @@
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
-from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
- CompletionRequest)
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest, CompletionRequest,
+ ChatCompletionNamedToolChoiceParam)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
@@ -86,8 +87,16 @@ async def get_outlines_guided_decoding_logits_processor(
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
-
- if request.guided_json:
+ if isinstance(request, ChatCompletionRequest) and isinstance(
+ request.tool_choice, ChatCompletionNamedToolChoiceParam):
+ # Guided generation for tools/functions parameters
+ if request.tool_choice.type == "function":
+ for tool in request.tools:
+ if tool.type == "function" and tool.function.name == request.tool_choice.function.name:
+ json = json_dumps(tool.function.parameters, sort_keys=True)
+ return json, GuidedDecodingMode.JSON
+ return None, None
+ elif request.guided_json:
json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string